diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml
index cc97e043c1..939cce9ffb 100644
--- a/.github/workflows/benchmark.yml
+++ b/.github/workflows/benchmark.yml
@@ -38,9 +38,8 @@ jobs:
run: |
apt update
apt install -y libpq-dev postgresql-client
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install -r benchmarks/requirements.txt
+ uv pip install -e ".[quality]"
+ uv pip install -r benchmarks/requirements.txt
- name: Environment
run: |
python utils/print_env.py
diff --git a/.github/workflows/build_docker_images.yml b/.github/workflows/build_docker_images.yml
index 583853c6d6..1d7be0d6bc 100644
--- a/.github/workflows/build_docker_images.yml
+++ b/.github/workflows/build_docker_images.yml
@@ -72,7 +72,6 @@ jobs:
image-name:
- diffusers-pytorch-cpu
- diffusers-pytorch-cuda
- - diffusers-pytorch-cuda
- diffusers-pytorch-xformers-cuda
- diffusers-pytorch-minimum-cuda
- diffusers-doc-builder
diff --git a/.github/workflows/mirror_community_pipeline.yml b/.github/workflows/mirror_community_pipeline.yml
index 9cf573312b..ab4ded9730 100644
--- a/.github/workflows/mirror_community_pipeline.yml
+++ b/.github/workflows/mirror_community_pipeline.yml
@@ -74,7 +74,7 @@ jobs:
python-version: "3.10"
- name: Install dependencies
run: |
- python -m pip install --upgrade pip
+ pip install --upgrade pip
pip install --upgrade huggingface_hub
# Check secret is set
diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml
index 9216564093..c8fa3a7ad9 100644
--- a/.github/workflows/nightly_tests.yml
+++ b/.github/workflows/nightly_tests.yml
@@ -71,10 +71,9 @@ jobs:
run: nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- python -m uv pip install pytest-reportlog
+ uv pip install -e ".[quality]"
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ uv pip install pytest-reportlog
- name: Environment
run: |
python utils/print_env.py
@@ -84,7 +83,7 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
--report-log=tests_pipeline_${{ matrix.module }}_cuda.log \
@@ -124,11 +123,10 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install peft@git+https://github.com/huggingface/peft.git
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- python -m uv pip install pytest-reportlog
+ uv pip install -e ".[quality]"
+ uv pip install peft@git+https://github.com/huggingface/peft.git
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ uv pip install pytest-reportlog
- name: Environment
run: python utils/print_env.py
@@ -139,7 +137,7 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_torch_${{ matrix.module }}_cuda \
--report-log=tests_torch_${{ matrix.module }}_cuda.log \
@@ -152,7 +150,7 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v --make-reports=examples_torch_cuda \
--report-log=examples_torch_cuda.log \
examples/
@@ -191,8 +189,7 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test,training]
+ uv pip install -e ".[quality,training]"
- name: Environment
run: |
python utils/print_env.py
@@ -201,7 +198,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -232,11 +229,10 @@ jobs:
run: nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install peft@git+https://github.com/huggingface/peft.git
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- python -m uv pip install pytest-reportlog
+ uv pip install -e ".[quality]"
+ uv pip install peft@git+https://github.com/huggingface/peft.git
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ uv pip install pytest-reportlog
- name: Environment
run: |
python utils/print_env.py
@@ -247,7 +243,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
BIG_GPU_MEMORY: 40
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-m "big_accelerator" \
--make-reports=tests_big_gpu_torch_cuda \
--report-log=tests_big_gpu_torch_cuda.log \
@@ -282,10 +278,9 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install peft@git+https://github.com/huggingface/peft.git
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ uv pip install -e ".[quality]"
+ uv pip install peft@git+https://github.com/huggingface/peft.git
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
@@ -297,7 +292,7 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_torch_minimum_version_cuda \
tests/models/test_modeling_common.py \
@@ -340,6 +335,9 @@ jobs:
- backend: "optimum_quanto"
test_location: "quanto"
additional_deps: []
+ - backend: "nvidia_modelopt"
+ test_location: "modelopt"
+ additional_deps: []
runs-on:
group: aws-g6e-xlarge-plus
container:
@@ -354,13 +352,12 @@ jobs:
run: nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install -U ${{ matrix.config.backend }}
+ uv pip install -e ".[quality]"
+ uv pip install -U ${{ matrix.config.backend }}
if [ "${{ join(matrix.config.additional_deps, ' ') }}" != "" ]; then
- python -m uv pip install ${{ join(matrix.config.additional_deps, ' ') }}
+ uv pip install ${{ join(matrix.config.additional_deps, ' ') }}
fi
- python -m uv pip install pytest-reportlog
+ uv pip install pytest-reportlog
- name: Environment
run: |
python utils/print_env.py
@@ -371,7 +368,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
BIG_GPU_MEMORY: 40
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.backend }}_torch_cuda \
--report-log=tests_${{ matrix.config.backend }}_torch_cuda.log \
tests/quantization/${{ matrix.config.test_location }}
@@ -406,10 +403,9 @@ jobs:
run: nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install -U bitsandbytes optimum_quanto
- python -m uv pip install pytest-reportlog
+ uv pip install -e ".[quality]"
+ uv pip install -U bitsandbytes optimum_quanto
+ uv pip install pytest-reportlog
- name: Environment
run: |
python utils/print_env.py
@@ -420,7 +416,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
BIG_GPU_MEMORY: 40
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_pipeline_level_quant_torch_cuda \
--report-log=tests_pipeline_level_quant_torch_cuda.log \
tests/quantization/test_pipeline_level_quantization.py
@@ -520,11 +516,11 @@ jobs:
# - name: Install dependencies
# shell: arch -arch arm64 bash {0}
# run: |
-# ${CONDA_RUN} python -m pip install --upgrade pip uv
-# ${CONDA_RUN} python -m uv pip install -e [quality,test]
-# ${CONDA_RUN} python -m uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
-# ${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
-# ${CONDA_RUN} python -m uv pip install pytest-reportlog
+# ${CONDA_RUN} pip install --upgrade pip uv
+# ${CONDA_RUN} uv pip install -e ".[quality]"
+# ${CONDA_RUN} uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
+# ${CONDA_RUN} uv pip install accelerate@git+https://github.com/huggingface/accelerate
+# ${CONDA_RUN} uv pip install pytest-reportlog
# - name: Environment
# shell: arch -arch arm64 bash {0}
# run: |
@@ -535,7 +531,7 @@ jobs:
# HF_HOME: /System/Volumes/Data/mnt/cache
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# run: |
-# ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
+# ${CONDA_RUN} pytest -n 1 -s -v --make-reports=tests_torch_mps \
# --report-log=tests_torch_mps.log \
# tests/
# - name: Failure short reports
@@ -576,11 +572,11 @@ jobs:
# - name: Install dependencies
# shell: arch -arch arm64 bash {0}
# run: |
-# ${CONDA_RUN} python -m pip install --upgrade pip uv
-# ${CONDA_RUN} python -m uv pip install -e [quality,test]
-# ${CONDA_RUN} python -m uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
-# ${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
-# ${CONDA_RUN} python -m uv pip install pytest-reportlog
+# ${CONDA_RUN} pip install --upgrade pip uv
+# ${CONDA_RUN} uv pip install -e ".[quality]"
+# ${CONDA_RUN} uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
+# ${CONDA_RUN} uv pip install accelerate@git+https://github.com/huggingface/accelerate
+# ${CONDA_RUN} uv pip install pytest-reportlog
# - name: Environment
# shell: arch -arch arm64 bash {0}
# run: |
@@ -591,7 +587,7 @@ jobs:
# HF_HOME: /System/Volumes/Data/mnt/cache
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# run: |
-# ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
+# ${CONDA_RUN} pytest -n 1 -s -v --make-reports=tests_torch_mps \
# --report-log=tests_torch_mps.log \
# tests/
# - name: Failure short reports
diff --git a/.github/workflows/pr_dependency_test.yml b/.github/workflows/pr_dependency_test.yml
index d9350c09ac..b914d10761 100644
--- a/.github/workflows/pr_dependency_test.yml
+++ b/.github/workflows/pr_dependency_test.yml
@@ -25,11 +25,8 @@ jobs:
python-version: "3.8"
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pip install --upgrade pip uv
- python -m uv pip install -e .
- python -m uv pip install pytest
+ pip install -e .
+ pip install pytest
- name: Check for soft dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- pytest tests/others/test_dependencies.py
+ pytest tests/others/test_dependencies.py
diff --git a/.github/workflows/pr_flax_dependency_test.yml b/.github/workflows/pr_flax_dependency_test.yml
deleted file mode 100644
index e091b5f2d7..0000000000
--- a/.github/workflows/pr_flax_dependency_test.yml
+++ /dev/null
@@ -1,38 +0,0 @@
-name: Run Flax dependency tests
-
-on:
- pull_request:
- branches:
- - main
- paths:
- - "src/diffusers/**.py"
- push:
- branches:
- - main
-
-concurrency:
- group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
- cancel-in-progress: true
-
-jobs:
- check_flax_dependencies:
- runs-on: ubuntu-22.04
- steps:
- - uses: actions/checkout@v3
- - name: Set up Python
- uses: actions/setup-python@v4
- with:
- python-version: "3.8"
- - name: Install dependencies
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pip install --upgrade pip uv
- python -m uv pip install -e .
- python -m uv pip install "jax[cpu]>=0.2.16,!=0.3.2"
- python -m uv pip install "flax>=0.4.1"
- python -m uv pip install "jaxlib>=0.1.65"
- python -m uv pip install pytest
- - name: Check for soft dependencies
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- pytest tests/others/test_dependencies.py
diff --git a/.github/workflows/pr_modular_tests.yml b/.github/workflows/pr_modular_tests.yml
new file mode 100644
index 0000000000..83c84fbab2
--- /dev/null
+++ b/.github/workflows/pr_modular_tests.yml
@@ -0,0 +1,138 @@
+name: Fast PR tests for Modular
+
+on:
+ pull_request:
+ branches: [main]
+ paths:
+ - "src/diffusers/modular_pipelines/**.py"
+ - "src/diffusers/models/modeling_utils.py"
+ - "src/diffusers/models/model_loading_utils.py"
+ - "src/diffusers/pipelines/pipeline_utils.py"
+ - "src/diffusers/pipeline_loading_utils.py"
+ - "src/diffusers/loaders/lora_base.py"
+ - "src/diffusers/loaders/lora_pipeline.py"
+ - "src/diffusers/loaders/peft.py"
+ - "tests/modular_pipelines/**.py"
+ - ".github/**.yml"
+ - "utils/**.py"
+ - "setup.py"
+ push:
+ branches:
+ - ci-*
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+env:
+ DIFFUSERS_IS_CI: yes
+ HF_HUB_ENABLE_HF_TRANSFER: 1
+ OMP_NUM_THREADS: 4
+ MKL_NUM_THREADS: 4
+ PYTEST_TIMEOUT: 60
+
+jobs:
+ check_code_quality:
+ runs-on: ubuntu-22.04
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.10"
+ - name: Install dependencies
+ run: |
+ pip install --upgrade pip
+ pip install .[quality]
+ - name: Check quality
+ run: make quality
+ - name: Check if failure
+ if: ${{ failure() }}
+ run: |
+ echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
+
+ check_repository_consistency:
+ needs: check_code_quality
+ runs-on: ubuntu-22.04
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.10"
+ - name: Install dependencies
+ run: |
+ pip install --upgrade pip
+ pip install .[quality]
+ - name: Check repo consistency
+ run: |
+ python utils/check_copies.py
+ python utils/check_dummies.py
+ python utils/check_support_list.py
+ make deps_table_check_updated
+ - name: Check if failure
+ if: ${{ failure() }}
+ run: |
+ echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
+
+ run_fast_tests:
+ needs: [check_code_quality, check_repository_consistency]
+ strategy:
+ fail-fast: false
+ matrix:
+ config:
+ - name: Fast PyTorch Modular Pipeline CPU tests
+ framework: pytorch_pipelines
+ runner: aws-highmemory-32-plus
+ image: diffusers/diffusers-pytorch-cpu
+ report: torch_cpu_modular_pipelines
+
+ name: ${{ matrix.config.name }}
+
+ runs-on:
+ group: ${{ matrix.config.runner }}
+
+ container:
+ image: ${{ matrix.config.image }}
+ options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
+
+ defaults:
+ run:
+ shell: bash
+
+ steps:
+ - name: Checkout diffusers
+ uses: actions/checkout@v3
+ with:
+ fetch-depth: 2
+
+ - name: Install dependencies
+ run: |
+ uv pip install -e ".[quality]"
+ uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
+
+ - name: Environment
+ run: |
+ python utils/print_env.py
+
+ - name: Run fast PyTorch Pipeline CPU tests
+ if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
+ run: |
+ pytest -n 8 --max-worker-restart=0 --dist=loadfile \
+ -s -v -k "not Flax and not Onnx" \
+ --make-reports=tests_${{ matrix.config.report }} \
+ tests/modular_pipelines
+
+ - name: Failure short reports
+ if: ${{ failure() }}
+ run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
+
+ - name: Test suite reports artifacts
+ if: ${{ always() }}
+ uses: actions/upload-artifact@v4
+ with:
+ name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
+ path: reports
+
+
diff --git a/.github/workflows/pr_test_fetcher.yml b/.github/workflows/pr_test_fetcher.yml
index b032bb8427..83b2ab4edb 100644
--- a/.github/workflows/pr_test_fetcher.yml
+++ b/.github/workflows/pr_test_fetcher.yml
@@ -33,8 +33,7 @@ jobs:
fetch-depth: 0
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
+ uv pip install -e ".[quality]"
- name: Environment
run: |
python utils/print_env.py
@@ -90,19 +89,16 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pip install -e [quality,test]
- python -m pip install accelerate
+ uv pip install -e ".[quality]"
+ uv pip install accelerate
- name: Environment
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run all selected tests on CPU
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 2 --dist=loadfile -v --make-reports=${{ matrix.modules }}_tests_cpu ${{ fromJson(needs.setup_pr_tests.outputs.test_map)[matrix.modules] }}
+ pytest -n 2 --dist=loadfile -v --make-reports=${{ matrix.modules }}_tests_cpu ${{ fromJson(needs.setup_pr_tests.outputs.test_map)[matrix.modules] }}
- name: Failure short reports
if: ${{ failure() }}
@@ -148,19 +144,16 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pip install -e [quality,test]
+ pip install -e [quality]
- name: Environment
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run Hub tests for models, schedulers, and pipelines on a staging env
if: ${{ matrix.config.framework == 'hub_tests_pytorch' }}
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- HUGGINGFACE_CO_STAGING=true python -m pytest \
+ HUGGINGFACE_CO_STAGING=true pytest \
-m "is_staging_test" \
--make-reports=tests_${{ matrix.config.report }} \
tests
diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml
index 34a344528e..03205fcec6 100644
--- a/.github/workflows/pr_tests.yml
+++ b/.github/workflows/pr_tests.yml
@@ -38,7 +38,7 @@ jobs:
python-version: "3.8"
- name: Install dependencies
run: |
- python -m pip install --upgrade pip
+ pip install --upgrade pip
pip install .[quality]
- name: Check quality
run: make quality
@@ -58,7 +58,7 @@ jobs:
python-version: "3.8"
- name: Install dependencies
run: |
- python -m pip install --upgrade pip
+ pip install --upgrade pip
pip install .[quality]
- name: Check repo consistency
run: |
@@ -114,21 +114,18 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
+ uv pip install -e ".[quality]"
+ uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run fast PyTorch Pipeline CPU tests
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile \
+ pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/pipelines
@@ -136,8 +133,7 @@ jobs:
- name: Run fast PyTorch Model Scheduler CPU tests
if: ${{ matrix.config.framework == 'pytorch_models' }}
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and not Dependency" \
--make-reports=tests_${{ matrix.config.report }} \
tests/models tests/schedulers tests/others
@@ -145,9 +141,8 @@ jobs:
- name: Run example PyTorch CPU tests
if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install peft timm
- python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ uv pip install ".[training]"
+ pytest -n 4 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \
examples
@@ -195,19 +190,16 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
+ uv pip install -e ".[quality]"
- name: Environment
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run Hub tests for models, schedulers, and pipelines on a staging env
if: ${{ matrix.config.framework == 'hub_tests_pytorch' }}
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- HUGGINGFACE_CO_STAGING=true python -m pytest \
+ HUGGINGFACE_CO_STAGING=true pytest \
-m "is_staging_test" \
--make-reports=tests_${{ matrix.config.report }} \
tests
@@ -249,27 +241,24 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
+ uv pip install -e ".[quality]"
# TODO (sayakpaul, DN6): revisit `--no-deps`
- python -m pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
- python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
- python -m uv pip install -U tokenizers
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
+ uv pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
+ uv pip install -U tokenizers
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
+ uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
- name: Environment
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run fast PyTorch LoRA tests with PEFT
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v \
--make-reports=tests_peft_main \
tests/lora/
- python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v \
--make-reports=tests_models_lora_peft_main \
tests/models/ -k "lora"
diff --git a/.github/workflows/pr_tests_gpu.yml b/.github/workflows/pr_tests_gpu.yml
index 45294c89fe..900a53da94 100644
--- a/.github/workflows/pr_tests_gpu.yml
+++ b/.github/workflows/pr_tests_gpu.yml
@@ -39,7 +39,7 @@ jobs:
python-version: "3.8"
- name: Install dependencies
run: |
- python -m pip install --upgrade pip
+ pip install --upgrade pip
pip install .[quality]
- name: Check quality
run: make quality
@@ -59,7 +59,7 @@ jobs:
python-version: "3.8"
- name: Install dependencies
run: |
- python -m pip install --upgrade pip
+ pip install --upgrade pip
pip install .[quality]
- name: Check repo consistency
run: |
@@ -88,8 +88,7 @@ jobs:
fetch-depth: 2
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
+ uv pip install -e ".[quality]"
- name: Environment
run: |
python utils/print_env.py
@@ -130,10 +129,9 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
+ uv pip install -e ".[quality]"
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
- name: Environment
run: |
@@ -152,13 +150,13 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
if [ "${{ matrix.module }}" = "ip_adapters" ]; then
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
else
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and $pattern" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
@@ -200,11 +198,10 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install peft@git+https://github.com/huggingface/peft.git
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
+ uv pip install -e ".[quality]"
+ uv pip install peft@git+https://github.com/huggingface/peft.git
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
- name: Environment
run: |
@@ -225,10 +222,10 @@ jobs:
run: |
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
if [ -z "$pattern" ]; then
- python -m pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
+ pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
--make-reports=tests_torch_cuda_${{ matrix.module }}
else
- python -m pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
+ pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
--make-reports=tests_torch_cuda_${{ matrix.module }}
fi
@@ -265,22 +262,19 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
- python -m uv pip install -e [quality,test,training]
+ uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
+ uv pip install -e ".[quality,training]"
- name: Environment
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install timm
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
+ uv pip install ".[training]"
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
diff --git a/.github/workflows/pr_torch_dependency_test.yml b/.github/workflows/pr_torch_dependency_test.yml
index c39d5eca2d..4b6160ff71 100644
--- a/.github/workflows/pr_torch_dependency_test.yml
+++ b/.github/workflows/pr_torch_dependency_test.yml
@@ -25,12 +25,8 @@ jobs:
python-version: "3.8"
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pip install --upgrade pip uv
- python -m uv pip install -e .
- python -m uv pip install torch torchvision torchaudio
- python -m uv pip install pytest
+ pip install -e .
+ pip install torch torchvision torchaudio pytest
- name: Check for soft dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- pytest tests/others/test_dependencies.py
+ pytest tests/others/test_dependencies.py
diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml
index 6896e0145c..a1f5e12153 100644
--- a/.github/workflows/push_tests.yml
+++ b/.github/workflows/push_tests.yml
@@ -34,8 +34,7 @@ jobs:
fetch-depth: 2
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
+ uv pip install -e ".[quality]"
- name: Environment
run: |
python utils/print_env.py
@@ -75,9 +74,8 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ uv pip install -e ".[quality]"
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
python utils/print_env.py
@@ -87,7 +85,7 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
@@ -126,10 +124,9 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install peft@git+https://github.com/huggingface/peft.git
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ uv pip install -e ".[quality]"
+ uv pip install peft@git+https://github.com/huggingface/peft.git
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
@@ -141,7 +138,7 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_torch_cuda_${{ matrix.module }} \
tests/${{ matrix.module }}
@@ -180,8 +177,7 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test,training]
+ uv pip install -e ".[quality,training]"
- name: Environment
run: |
python utils/print_env.py
@@ -190,7 +186,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -223,8 +219,7 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test,training]
+ uv pip install -e ".[quality,training]"
- name: Environment
run: |
python utils/print_env.py
@@ -232,7 +227,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
@@ -264,21 +259,18 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test,training]
+ uv pip install -e ".[quality,training]"
- name: Environment
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install timm
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
+ uv pip install ".[training]"
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml
index e274cb0218..6b1dd0b2d0 100644
--- a/.github/workflows/push_tests_fast.yml
+++ b/.github/workflows/push_tests_fast.yml
@@ -60,19 +60,16 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
+ uv pip install -e ".[quality]"
- name: Environment
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run fast PyTorch CPU tests
if: ${{ matrix.config.framework == 'pytorch' }}
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
@@ -80,9 +77,8 @@ jobs:
- name: Run example PyTorch CPU tests
if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install peft timm
- python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
+ uv pip install ".[training]"
+ pytest -n 4 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \
examples
diff --git a/.github/workflows/release_tests_fast.yml b/.github/workflows/release_tests_fast.yml
index 81a34f7a46..808818bead 100644
--- a/.github/workflows/release_tests_fast.yml
+++ b/.github/workflows/release_tests_fast.yml
@@ -32,8 +32,7 @@ jobs:
fetch-depth: 2
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
+ uv pip install -e ".[quality]"
- name: Environment
run: |
python utils/print_env.py
@@ -73,9 +72,8 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ uv pip install -e ".[quality]"
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
python utils/print_env.py
@@ -85,7 +83,7 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
@@ -124,10 +122,9 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install peft@git+https://github.com/huggingface/peft.git
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ uv pip install -e ".[quality]"
+ uv pip install peft@git+https://github.com/huggingface/peft.git
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
@@ -139,7 +136,7 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_torch_${{ matrix.module }}_cuda \
tests/${{ matrix.module }}
@@ -175,10 +172,9 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install peft@git+https://github.com/huggingface/peft.git
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ uv pip install -e ".[quality]"
+ uv pip install peft@git+https://github.com/huggingface/peft.git
+ uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
@@ -190,7 +186,7 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_torch_minimum_cuda \
tests/models/test_modeling_common.py \
@@ -235,8 +231,7 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test,training]
+ uv pip install -e ".[quality,training]"
- name: Environment
run: |
python utils/print_env.py
@@ -245,7 +240,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -278,8 +273,7 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test,training]
+ uv pip install -e ".[quality,training]"
- name: Environment
run: |
python utils/print_env.py
@@ -287,7 +281,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
@@ -321,21 +315,18 @@ jobs:
- name: Install dependencies
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test,training]
+ uv pip install -e ".[quality,training]"
- name: Environment
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install timm
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
+ uv pip install ".[training]"
+ pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
diff --git a/.github/workflows/run_tests_from_a_pr.yml b/.github/workflows/run_tests_from_a_pr.yml
index c8eee8dbbc..fa8c579dd7 100644
--- a/.github/workflows/run_tests_from_a_pr.yml
+++ b/.github/workflows/run_tests_from_a_pr.yml
@@ -63,9 +63,8 @@ jobs:
- name: Install pytest
run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- python -m uv pip install peft
+ uv pip install -e ".[quality]"
+ uv pip install peft
- name: Run tests
env:
diff --git a/.gitignore b/.gitignore
index 15617d5fdc..a55026febd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -125,6 +125,9 @@ dmypy.json
.vs
.vscode
+# Cursor
+.cursor
+
# Pycharm
.idea
diff --git a/README.md b/README.md
index dac3b3598a..68202ba095 100644
--- a/README.md
+++ b/README.md
@@ -37,7 +37,7 @@ limitations under the License.
## Installation
-We recommend installing 🤗 Diffusers in a virtual environment from PyPI or Conda. For more details about installing [PyTorch](https://pytorch.org/get-started/locally/) and [Flax](https://flax.readthedocs.io/en/latest/#installation), please refer to their official documentation.
+We recommend installing 🤗 Diffusers in a virtual environment from PyPI or Conda. For more details about installing [PyTorch](https://pytorch.org/get-started/locally/), please refer to their official documentation.
### PyTorch
@@ -53,14 +53,6 @@ With `conda` (maintained by the community):
conda install -c conda-forge diffusers
```
-### Flax
-
-With `pip` (official package):
-
-```bash
-pip install --upgrade diffusers[flax]
-```
-
### Apple Silicon (M1/M2) support
Please refer to the [How to use Stable Diffusion in Apple Silicon](https://huggingface.co/docs/diffusers/optimization/mps) guide.
diff --git a/docker/diffusers-doc-builder/Dockerfile b/docker/diffusers-doc-builder/Dockerfile
index 3a76b3331c..8453ef4e6c 100644
--- a/docker/diffusers-doc-builder/Dockerfile
+++ b/docker/diffusers-doc-builder/Dockerfile
@@ -1,56 +1,45 @@
-FROM ubuntu:20.04
+FROM python:3.10-slim
+ENV PYTHONDONTWRITEBYTECODE=1
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ENV DEBIAN_FRONTEND=noninteractive
-RUN apt-get -y update \
- && apt-get install -y software-properties-common \
- && add-apt-repository ppa:deadsnakes/ppa
+RUN apt-get -y update && apt-get install -y bash \
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libglib2.0-0 \
+ libsndfile1-dev \
+ libgl1 \
+ zip \
+ wget
-RUN apt install -y bash \
- build-essential \
- git \
- git-lfs \
- curl \
- ca-certificates \
- libsndfile1-dev \
- python3.10 \
- python3-pip \
- libgl1 \
- zip \
- wget \
- python3.10-venv && \
- rm -rf /var/lib/apt/lists
-
-# make sure to use venv
-RUN python3.10 -m venv /opt/venv
-ENV PATH="/opt/venv/bin:$PATH"
+ENV UV_PYTHON=/usr/local/bin/python
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
-RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
- python3.10 -m uv pip install --no-cache-dir \
- torch \
- torchvision \
- torchaudio \
- invisible_watermark \
- --extra-index-url https://download.pytorch.org/whl/cpu && \
- python3.10 -m uv pip install --no-cache-dir \
- accelerate \
- datasets \
- hf-doc-builder \
- huggingface-hub \
- Jinja2 \
- librosa \
- numpy==1.26.4 \
- scipy \
- tensorboard \
- transformers \
- matplotlib \
- setuptools==69.5.1 \
- bitsandbytes \
- torchao \
- gguf \
- optimum-quanto
+RUN pip install uv
+RUN uv pip install --no-cache-dir \
+ torch \
+ torchvision \
+ torchaudio \
+ --extra-index-url https://download.pytorch.org/whl/cpu
+
+RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"
+
+# Extra dependencies
+RUN uv pip install --no-cache-dir \
+ accelerate \
+ numpy==1.26.4 \
+ hf_transfer \
+ setuptools==69.5.1 \
+ bitsandbytes \
+ torchao \
+ gguf \
+ optimum-quanto
+
+RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean
CMD ["/bin/bash"]
diff --git a/docker/diffusers-flax-cpu/Dockerfile b/docker/diffusers-flax-cpu/Dockerfile
deleted file mode 100644
index 051008aa9a..0000000000
--- a/docker/diffusers-flax-cpu/Dockerfile
+++ /dev/null
@@ -1,49 +0,0 @@
-FROM ubuntu:20.04
-LABEL maintainer="Hugging Face"
-LABEL repository="diffusers"
-
-ENV DEBIAN_FRONTEND=noninteractive
-
-RUN apt-get -y update \
- && apt-get install -y software-properties-common \
- && add-apt-repository ppa:deadsnakes/ppa
-
-RUN apt install -y bash \
- build-essential \
- git \
- git-lfs \
- curl \
- ca-certificates \
- libsndfile1-dev \
- libgl1 \
- python3.10 \
- python3-pip \
- python3.10-venv && \
- rm -rf /var/lib/apt/lists
-
-# make sure to use venv
-RUN python3.10 -m venv /opt/venv
-ENV PATH="/opt/venv/bin:$PATH"
-
-# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
-# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
-RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
- python3 -m uv pip install --upgrade --no-cache-dir \
- clu \
- "jax[cpu]>=0.2.16,!=0.3.2" \
- "flax>=0.4.1" \
- "jaxlib>=0.1.65" && \
- python3 -m uv pip install --no-cache-dir \
- accelerate \
- datasets \
- hf-doc-builder \
- huggingface-hub \
- Jinja2 \
- librosa \
- numpy==1.26.4 \
- scipy \
- tensorboard \
- transformers \
- hf_transfer
-
-CMD ["/bin/bash"]
\ No newline at end of file
diff --git a/docker/diffusers-flax-tpu/Dockerfile b/docker/diffusers-flax-tpu/Dockerfile
deleted file mode 100644
index 405f068923..0000000000
--- a/docker/diffusers-flax-tpu/Dockerfile
+++ /dev/null
@@ -1,51 +0,0 @@
-FROM ubuntu:20.04
-LABEL maintainer="Hugging Face"
-LABEL repository="diffusers"
-
-ENV DEBIAN_FRONTEND=noninteractive
-
-RUN apt-get -y update \
- && apt-get install -y software-properties-common \
- && add-apt-repository ppa:deadsnakes/ppa
-
-RUN apt install -y bash \
- build-essential \
- git \
- git-lfs \
- curl \
- ca-certificates \
- libsndfile1-dev \
- libgl1 \
- python3.10 \
- python3-pip \
- python3.10-venv && \
- rm -rf /var/lib/apt/lists
-
-# make sure to use venv
-RUN python3.10 -m venv /opt/venv
-ENV PATH="/opt/venv/bin:$PATH"
-
-# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
-# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
-RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
- python3 -m pip install --no-cache-dir \
- "jax[tpu]>=0.2.16,!=0.3.2" \
- -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \
- python3 -m uv pip install --upgrade --no-cache-dir \
- clu \
- "flax>=0.4.1" \
- "jaxlib>=0.1.65" && \
- python3 -m uv pip install --no-cache-dir \
- accelerate \
- datasets \
- hf-doc-builder \
- huggingface-hub \
- Jinja2 \
- librosa \
- numpy==1.26.4 \
- scipy \
- tensorboard \
- transformers \
- hf_transfer
-
-CMD ["/bin/bash"]
\ No newline at end of file
diff --git a/docker/diffusers-pytorch-cpu/Dockerfile b/docker/diffusers-pytorch-cpu/Dockerfile
index 8d98c52598..dc5d0fc71c 100644
--- a/docker/diffusers-pytorch-cpu/Dockerfile
+++ b/docker/diffusers-pytorch-cpu/Dockerfile
@@ -1,50 +1,38 @@
-FROM ubuntu:20.04
+FROM python:3.10-slim
+ENV PYTHONDONTWRITEBYTECODE=1
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ENV DEBIAN_FRONTEND=noninteractive
-RUN apt-get -y update \
- && apt-get install -y software-properties-common \
- && add-apt-repository ppa:deadsnakes/ppa
+RUN apt-get -y update && apt-get install -y bash \
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libglib2.0-0 \
+ libsndfile1-dev \
+ libgl1
-RUN apt install -y bash \
- build-essential \
- git \
- git-lfs \
- curl \
- ca-certificates \
- libsndfile1-dev \
- python3.10 \
- python3.10-dev \
- python3-pip \
- libgl1 \
- python3.10-venv && \
- rm -rf /var/lib/apt/lists
-
-# make sure to use venv
-RUN python3.10 -m venv /opt/venv
-ENV PATH="/opt/venv/bin:$PATH"
+ENV UV_PYTHON=/usr/local/bin/python
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
-RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
- python3.10 -m uv pip install --no-cache-dir \
- torch \
- torchvision \
- torchaudio \
- invisible_watermark \
- --extra-index-url https://download.pytorch.org/whl/cpu && \
- python3.10 -m uv pip install --no-cache-dir \
- accelerate \
- datasets \
- hf-doc-builder \
- huggingface-hub \
- Jinja2 \
- librosa \
- numpy==1.26.4 \
- scipy \
- tensorboard \
- transformers matplotlib \
- hf_transfer
+RUN pip install uv
+RUN uv pip install --no-cache-dir \
+ torch \
+ torchvision \
+ torchaudio \
+ --extra-index-url https://download.pytorch.org/whl/cpu
+
+RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"
+
+# Extra dependencies
+RUN uv pip install --no-cache-dir \
+ accelerate \
+ numpy==1.26.4 \
+ hf_transfer
+
+RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean
CMD ["/bin/bash"]
diff --git a/docker/diffusers-pytorch-cuda/Dockerfile b/docker/diffusers-pytorch-cuda/Dockerfile
index 695f5ed08d..2bdfd409b4 100644
--- a/docker/diffusers-pytorch-cuda/Dockerfile
+++ b/docker/diffusers-pytorch-cuda/Dockerfile
@@ -2,11 +2,13 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
+ARG PYTHON_VERSION=3.12
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get -y update \
&& apt-get install -y software-properties-common \
- && add-apt-repository ppa:deadsnakes/ppa
+ && add-apt-repository ppa:deadsnakes/ppa && \
+ apt-get update
RUN apt install -y bash \
build-essential \
@@ -14,38 +16,34 @@ RUN apt install -y bash \
git-lfs \
curl \
ca-certificates \
+ libglib2.0-0 \
libsndfile1-dev \
libgl1 \
- python3.10 \
- python3.10-dev \
+ python3 \
python3-pip \
- python3.10-venv && \
- rm -rf /var/lib/apt/lists
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*
-# make sure to use venv
-RUN python3.10 -m venv /opt/venv
-ENV PATH="/opt/venv/bin:$PATH"
+RUN curl -LsSf https://astral.sh/uv/install.sh | sh
+ENV PATH="/root/.local/bin:$PATH"
+ENV VIRTUAL_ENV="/opt/venv"
+ENV UV_PYTHON_INSTALL_DIR=/opt/uv/python
+RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
+ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
-RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
- python3.10 -m uv pip install --no-cache-dir \
+RUN uv pip install --no-cache-dir \
torch \
torchvision \
- torchaudio \
- invisible_watermark && \
- python3.10 -m pip install --no-cache-dir \
+ torchaudio
+
+RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"
+
+# Extra dependencies
+RUN uv pip install --no-cache-dir \
accelerate \
- datasets \
- hf-doc-builder \
- huggingface-hub \
- hf_transfer \
- Jinja2 \
- librosa \
numpy==1.26.4 \
- scipy \
- tensorboard \
- transformers \
- pytorch-lightning \
+ pytorch-lightning \
hf_transfer
CMD ["/bin/bash"]
diff --git a/docker/diffusers-pytorch-minimum-cuda/Dockerfile b/docker/diffusers-pytorch-minimum-cuda/Dockerfile
index 57ca7657ac..a2ce193f68 100644
--- a/docker/diffusers-pytorch-minimum-cuda/Dockerfile
+++ b/docker/diffusers-pytorch-minimum-cuda/Dockerfile
@@ -2,6 +2,7 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
+ARG PYTHON_VERSION=3.10
ENV DEBIAN_FRONTEND=noninteractive
ENV MINIMUM_SUPPORTED_TORCH_VERSION="2.1.0"
ENV MINIMUM_SUPPORTED_TORCHVISION_VERSION="0.16.0"
@@ -9,7 +10,8 @@ ENV MINIMUM_SUPPORTED_TORCHAUDIO_VERSION="2.1.0"
RUN apt-get -y update \
&& apt-get install -y software-properties-common \
- && add-apt-repository ppa:deadsnakes/ppa
+ && add-apt-repository ppa:deadsnakes/ppa && \
+ apt-get update
RUN apt install -y bash \
build-essential \
@@ -17,37 +19,34 @@ RUN apt install -y bash \
git-lfs \
curl \
ca-certificates \
+ libglib2.0-0 \
libsndfile1-dev \
libgl1 \
- python3.10 \
- python3.10-dev \
+ python3 \
python3-pip \
- python3.10-venv && \
- rm -rf /var/lib/apt/lists
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*
-# make sure to use venv
-RUN python3.10 -m venv /opt/venv
-ENV PATH="/opt/venv/bin:$PATH"
+RUN curl -LsSf https://astral.sh/uv/install.sh | sh
+ENV PATH="/root/.local/bin:$PATH"
+ENV VIRTUAL_ENV="/opt/venv"
+ENV UV_PYTHON_INSTALL_DIR=/opt/uv/python
+RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
+ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
-RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
- python3.10 -m uv pip install --no-cache-dir \
+RUN uv pip install --no-cache-dir \
torch==$MINIMUM_SUPPORTED_TORCH_VERSION \
torchvision==$MINIMUM_SUPPORTED_TORCHVISION_VERSION \
- torchaudio==$MINIMUM_SUPPORTED_TORCHAUDIO_VERSION \
- invisible_watermark && \
- python3.10 -m pip install --no-cache-dir \
+ torchaudio==$MINIMUM_SUPPORTED_TORCHAUDIO_VERSION
+
+RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"
+
+# Extra dependencies
+RUN uv pip install --no-cache-dir \
accelerate \
- datasets \
- hf-doc-builder \
- huggingface-hub \
- hf_transfer \
- Jinja2 \
- librosa \
numpy==1.26.4 \
- scipy \
- tensorboard \
- transformers \
+ pytorch-lightning \
hf_transfer
CMD ["/bin/bash"]
diff --git a/docker/diffusers-pytorch-xformers-cuda/Dockerfile b/docker/diffusers-pytorch-xformers-cuda/Dockerfile
index 1693eb2930..1ea258bdb7 100644
--- a/docker/diffusers-pytorch-xformers-cuda/Dockerfile
+++ b/docker/diffusers-pytorch-xformers-cuda/Dockerfile
@@ -2,50 +2,49 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
+ARG PYTHON_VERSION=3.12
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get -y update \
&& apt-get install -y software-properties-common \
- && add-apt-repository ppa:deadsnakes/ppa
+ && add-apt-repository ppa:deadsnakes/ppa && \
+ apt-get update
RUN apt install -y bash \
- build-essential \
- git \
- git-lfs \
- curl \
- ca-certificates \
- libsndfile1-dev \
- libgl1 \
- python3.10 \
- python3.10-dev \
- python3-pip \
- python3.10-venv && \
- rm -rf /var/lib/apt/lists
+ build-essential \
+ git \
+ git-lfs \
+ curl \
+ ca-certificates \
+ libglib2.0-0 \
+ libsndfile1-dev \
+ libgl1 \
+ python3 \
+ python3-pip \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*
-# make sure to use venv
-RUN python3.10 -m venv /opt/venv
-ENV PATH="/opt/venv/bin:$PATH"
+RUN curl -LsSf https://astral.sh/uv/install.sh | sh
+ENV PATH="/root/.local/bin:$PATH"
+ENV VIRTUAL_ENV="/opt/venv"
+ENV UV_PYTHON_INSTALL_DIR=/opt/uv/python
+RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
+ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
-RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
- python3.10 -m pip install --no-cache-dir \
- torch \
- torchvision \
- torchaudio \
- invisible_watermark && \
- python3.10 -m uv pip install --no-cache-dir \
- accelerate \
- datasets \
- hf-doc-builder \
- huggingface-hub \
- hf_transfer \
- Jinja2 \
- librosa \
- numpy==1.26.4 \
- scipy \
- tensorboard \
- transformers \
- xformers \
- hf_transfer
+RUN uv pip install --no-cache-dir \
+ torch \
+ torchvision \
+ torchaudio
+
+RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"
+
+# Extra dependencies
+RUN uv pip install --no-cache-dir \
+ accelerate \
+ numpy==1.26.4 \
+ pytorch-lightning \
+ hf_transfer \
+ xformers
CMD ["/bin/bash"]
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index eb51b4d0da..848e38079e 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -5,31 +5,29 @@
- local: installation
title: Installation
- local: quicktour
- title: Quicktour
+ title: Quickstart
- local: stable_diffusion
- title: Effective and efficient diffusion
+ title: Basic performance
-- title: DiffusionPipeline
+- title: Pipelines
isExpanded: false
sections:
- local: using-diffusers/loading
- title: Load pipelines
+ title: DiffusionPipeline
- local: tutorials/autopipeline
title: AutoPipeline
- local: using-diffusers/custom_pipeline_overview
- title: Load community pipelines and components
+ title: Community pipelines and components
- local: using-diffusers/callback
title: Pipeline callbacks
- local: using-diffusers/reusing_seeds
- title: Reproducible pipelines
+ title: Reproducibility
- local: using-diffusers/schedulers
- title: Load schedulers and models
- - local: using-diffusers/scheduler_features
- title: Scheduler features
+ title: Schedulers
- local: using-diffusers/other-formats
- title: Model files and layouts
+ title: Model formats
- local: using-diffusers/push_to_hub
- title: Push files to the Hub
+ title: Sharing pipelines and models
- title: Adapters
isExpanded: false
@@ -51,21 +49,13 @@
isExpanded: false
sections:
- local: using-diffusers/weighted_prompts
- title: Prompt techniques
+ title: Prompting
- local: using-diffusers/create_a_server
title: Create a server
- local: using-diffusers/batched_inference
title: Batch inference
- local: training/distributed_inference
title: Distributed inference
- - local: using-diffusers/scheduler_features
- title: Scheduler features
- - local: using-diffusers/callback
- title: Pipeline callbacks
- - local: using-diffusers/reusing_seeds
- title: Reproducible pipelines
- - local: using-diffusers/image_quality
- title: Controlling image quality
- title: Inference optimization
isExpanded: false
@@ -74,10 +64,12 @@
title: Accelerate inference
- local: optimization/cache
title: Caching
+ - local: optimization/attention_backends
+ title: Attention backends
- local: optimization/memory
title: Reduce memory usage
- local: optimization/speed-memory-optims
- title: Compile and offloading quantized models
+ title: Compiling and offloading quantized models
- title: Community optimizations
sections:
- local: optimization/pruna
@@ -88,12 +80,16 @@
title: Token merging
- local: optimization/deepcache
title: DeepCache
+ - local: optimization/cache_dit
+ title: CacheDiT
- local: optimization/tgate
title: TGATE
- local: optimization/xdit
title: xDiT
- local: optimization/para_attn
title: ParaAttention
+ - local: using-diffusers/image_quality
+ title: FreeU
- title: Hybrid Inference
isExpanded: false
@@ -112,22 +108,24 @@
sections:
- local: modular_diffusers/overview
title: Overview
- - local: modular_diffusers/modular_pipeline
- title: Modular Pipeline
- - local: modular_diffusers/components_manager
- title: Components Manager
+ - local: modular_diffusers/quickstart
+ title: Quickstart
- local: modular_diffusers/modular_diffusers_states
- title: Modular Diffusers States
+ title: States
- local: modular_diffusers/pipeline_block
- title: Pipeline Block
+ title: ModularPipelineBlocks
- local: modular_diffusers/sequential_pipeline_blocks
- title: Sequential Pipeline Blocks
+ title: SequentialPipelineBlocks
- local: modular_diffusers/loop_sequential_pipeline_blocks
- title: Loop Sequential Pipeline Blocks
+ title: LoopSequentialPipelineBlocks
- local: modular_diffusers/auto_pipeline_blocks
- title: Auto Pipeline Blocks
- - local: modular_diffusers/end_to_end_guide
- title: End-to-End Example
+ title: AutoPipelineBlocks
+ - local: modular_diffusers/modular_pipeline
+ title: ModularPipeline
+ - local: modular_diffusers/components_manager
+ title: ComponentsManager
+ - local: modular_diffusers/guiders
+ title: Guiders
- title: Training
isExpanded: false
@@ -188,12 +186,12 @@
title: torchao
- local: quantization/quanto
title: quanto
+ - local: quantization/modelopt
+ title: NVIDIA ModelOpt
- title: Model accelerators and hardware
isExpanded: false
sections:
- - local: using-diffusers/stable_diffusion_jax_how_to
- title: JAX/Flax
- local: optimization/onnx
title: ONNX
- local: optimization/open_vino
@@ -282,6 +280,20 @@
title: Outputs
- local: api/quantization
title: Quantization
+ - local: api/parallel
+ title: Parallel inference
+ - title: Modular
+ sections:
+ - local: api/modular_diffusers/pipeline
+ title: Pipeline
+ - local: api/modular_diffusers/pipeline_blocks
+ title: Blocks
+ - local: api/modular_diffusers/pipeline_states
+ title: States
+ - local: api/modular_diffusers/pipeline_components
+ title: Components and configs
+ - local: api/modular_diffusers/guiders
+ title: Guiders
- title: Loaders
sections:
- local: api/loaders/ip_adapter
@@ -326,6 +338,8 @@
title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel
+ - local: api/models/bria_transformer
+ title: BriaTransformer2DModel
- local: api/models/chroma_transformer
title: ChromaTransformer2DModel
- local: api/models/cogvideox_transformer3d
@@ -454,6 +468,8 @@
title: AutoPipeline
- local: api/pipelines/blip_diffusion
title: BLIP-Diffusion
+ - local: api/pipelines/bria_3_2
+ title: Bria 3.2
- local: api/pipelines/chroma
title: Chroma
- local: api/pipelines/cogvideox
diff --git a/docs/source/en/api/configuration.md b/docs/source/en/api/configuration.md
index bc58e190b8..328e109e1e 100644
--- a/docs/source/en/api/configuration.md
+++ b/docs/source/en/api/configuration.md
@@ -14,11 +14,8 @@ specific language governing permissions and limitations under the License.
Schedulers from [`~schedulers.scheduling_utils.SchedulerMixin`] and models from [`ModelMixin`] inherit from [`ConfigMixin`] which stores all the parameters that are passed to their respective `__init__` methods in a JSON-configuration file.
-
-
-To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf auth login`.
-
-
+> [!TIP]
+> To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf auth login`.
## ConfigMixin
diff --git a/docs/source/en/api/image_processor.md b/docs/source/en/api/image_processor.md
index 3e75af026d..82d1837b0b 100644
--- a/docs/source/en/api/image_processor.md
+++ b/docs/source/en/api/image_processor.md
@@ -20,6 +20,12 @@ All pipelines with [`VaeImageProcessor`] accept PIL Image, PyTorch tensor, or Nu
[[autodoc]] image_processor.VaeImageProcessor
+## InpaintProcessor
+
+The [`InpaintProcessor`] accepts `mask` and `image` inputs and process them together. Optionally, it can accept padding_mask_crop and apply mask overlay.
+
+[[autodoc]] image_processor.InpaintProcessor
+
## VaeImageProcessorLDM3D
The [`VaeImageProcessorLDM3D`] accepts RGB and depth inputs and returns RGB and depth outputs.
diff --git a/docs/source/en/api/loaders/ip_adapter.md b/docs/source/en/api/loaders/ip_adapter.md
index 0c94bcb220..508e6d4ee6 100644
--- a/docs/source/en/api/loaders/ip_adapter.md
+++ b/docs/source/en/api/loaders/ip_adapter.md
@@ -14,11 +14,8 @@ specific language governing permissions and limitations under the License.
[IP-Adapter](https://hf.co/papers/2308.06721) is a lightweight adapter that enables prompting a diffusion model with an image. This method decouples the cross-attention layers of the image and text features. The image features are generated from an image encoder.
-
-
-Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading](../../using-diffusers/loading_adapters#ip-adapter) guide, and you can see how to use it in the [usage](../../using-diffusers/ip_adapter) guide.
-
-
+> [!TIP]
+> Learn how to load and use an IP-Adapter checkpoint and image in the [IP-Adapter](../../using-diffusers/ip_adapter) guide,.
## IPAdapterMixin
diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md
index da5c3842c6..b1d1ffb634 100644
--- a/docs/source/en/api/loaders/lora.md
+++ b/docs/source/en/api/loaders/lora.md
@@ -33,11 +33,8 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen)
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
-
-
-To learn more about how to load LoRA weights, see the [LoRA](../../using-diffusers/loading_adapters#lora) loading guide.
-
-
+> [!TIP]
+> To learn more about how to load LoRA weights, see the [LoRA](../../tutorials/using_peft_for_inference) loading guide.
## LoraBaseMixin
diff --git a/docs/source/en/api/loaders/peft.md b/docs/source/en/api/loaders/peft.md
index a371ab9c8e..c514766dd8 100644
--- a/docs/source/en/api/loaders/peft.md
+++ b/docs/source/en/api/loaders/peft.md
@@ -12,13 +12,10 @@ specific language governing permissions and limitations under the License.
# PEFT
-Diffusers supports loading adapters such as [LoRA](../../using-diffusers/loading_adapters) with the [PEFT](https://huggingface.co/docs/peft/index) library with the [`~loaders.peft.PeftAdapterMixin`] class. This allows modeling classes in Diffusers like [`UNet2DConditionModel`], [`SD3Transformer2DModel`] to operate with an adapter.
+Diffusers supports loading adapters such as [LoRA](../../tutorials/using_peft_for_inference) with the [PEFT](https://huggingface.co/docs/peft/index) library with the [`~loaders.peft.PeftAdapterMixin`] class. This allows modeling classes in Diffusers like [`UNet2DConditionModel`], [`SD3Transformer2DModel`] to operate with an adapter.
-
-
-Refer to the [Inference with PEFT](../../tutorials/using_peft_for_inference.md) tutorial for an overview of how to use PEFT in Diffusers for inference.
-
-
+> [!TIP]
+> Refer to the [Inference with PEFT](../../tutorials/using_peft_for_inference.md) tutorial for an overview of how to use PEFT in Diffusers for inference.
## PeftAdapterMixin
diff --git a/docs/source/en/api/loaders/textual_inversion.md b/docs/source/en/api/loaders/textual_inversion.md
index 30d8f5b8d5..5e8bfac255 100644
--- a/docs/source/en/api/loaders/textual_inversion.md
+++ b/docs/source/en/api/loaders/textual_inversion.md
@@ -16,11 +16,8 @@ Textual Inversion is a training method for personalizing models by learning new
[`TextualInversionLoaderMixin`] provides a function for loading Textual Inversion embeddings from Diffusers and Automatic1111 into the text encoder and loading a special token to activate the embeddings.
-
-
-To learn more about how to load Textual Inversion embeddings, see the [Textual Inversion](../../using-diffusers/loading_adapters#textual-inversion) loading guide.
-
-
+> [!TIP]
+> To learn more about how to load Textual Inversion embeddings, see the [Textual Inversion](../../using-diffusers/textual_inversion_inference) loading guide.
## TextualInversionLoaderMixin
diff --git a/docs/source/en/api/loaders/transformer_sd3.md b/docs/source/en/api/loaders/transformer_sd3.md
index 0e7664cdd1..2c8b81b59c 100644
--- a/docs/source/en/api/loaders/transformer_sd3.md
+++ b/docs/source/en/api/loaders/transformer_sd3.md
@@ -16,11 +16,8 @@ This class is useful when *only* loading weights into a [`SD3Transformer2DModel`
The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs.
-
-
-To learn more about how to load LoRA weights, see the [LoRA](../../using-diffusers/loading_adapters#lora) loading guide.
-
-
+> [!TIP]
+> To learn more about how to load LoRA weights, see the [LoRA](../../tutorials/using_peft_for_inference) loading guide.
## SD3Transformer2DLoadersMixin
diff --git a/docs/source/en/api/loaders/unet.md b/docs/source/en/api/loaders/unet.md
index 51b4c4ef48..50d210bbf5 100644
--- a/docs/source/en/api/loaders/unet.md
+++ b/docs/source/en/api/loaders/unet.md
@@ -16,11 +16,8 @@ Some training methods - like LoRA and Custom Diffusion - typically target the UN
The [`UNet2DConditionLoadersMixin`] class provides functions for loading and saving weights, fusing and unfusing LoRAs, disabling and enabling LoRAs, and setting and deleting adapters.
-
-
-To learn more about how to load LoRA weights, see the [LoRA](../../using-diffusers/loading_adapters#lora) loading guide.
-
-
+> [!TIP]
+> To learn more about how to load LoRA weights, see the [LoRA](../../tutorials/using_peft_for_inference) guide.
## UNet2DConditionLoadersMixin
diff --git a/docs/source/en/api/models/autoencoderkl.md b/docs/source/en/api/models/autoencoderkl.md
index baeab4017b..3d949e9bb0 100644
--- a/docs/source/en/api/models/autoencoderkl.md
+++ b/docs/source/en/api/models/autoencoderkl.md
@@ -44,15 +44,3 @@ model = AutoencoderKL.from_single_file(url)
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
-
-## FlaxAutoencoderKL
-
-[[autodoc]] FlaxAutoencoderKL
-
-## FlaxAutoencoderKLOutput
-
-[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput
-
-## FlaxDecoderOutput
-
-[[autodoc]] models.vae_flax.FlaxDecoderOutput
diff --git a/docs/source/en/api/models/bria_transformer.md b/docs/source/en/api/models/bria_transformer.md
new file mode 100644
index 0000000000..9df7eeb6ff
--- /dev/null
+++ b/docs/source/en/api/models/bria_transformer.md
@@ -0,0 +1,19 @@
+
+
+# BriaTransformer2DModel
+
+A modified flux Transformer model from [Bria](https://huggingface.co/briaai/BRIA-3.2)
+
+## BriaTransformer2DModel
+
+[[autodoc]] BriaTransformer2DModel
diff --git a/docs/source/en/api/models/consistency_decoder_vae.md b/docs/source/en/api/models/consistency_decoder_vae.md
index cf4955a074..fe039df7f9 100644
--- a/docs/source/en/api/models/consistency_decoder_vae.md
+++ b/docs/source/en/api/models/consistency_decoder_vae.md
@@ -16,11 +16,8 @@ Consistency decoder can be used to decode the latents from the denoising UNet in
The original codebase can be found at [openai/consistencydecoder](https://github.com/openai/consistencydecoder).
-
-
-Inference is only supported for 2 iterations as of now.
-
-
+> [!WARNING]
+> Inference is only supported for 2 iterations as of now.
The pipeline could not have been contributed without the help of [madebyollin](https://github.com/madebyollin) and [mrsteyk](https://github.com/mrsteyk) from [this issue](https://github.com/openai/consistencydecoder/issues/1).
diff --git a/docs/source/en/api/models/controlnet.md b/docs/source/en/api/models/controlnet.md
index 7ce14f17d5..f56b7383a0 100644
--- a/docs/source/en/api/models/controlnet.md
+++ b/docs/source/en/api/models/controlnet.md
@@ -40,11 +40,3 @@ pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=contro
## ControlNetOutput
[[autodoc]] models.controlnets.controlnet.ControlNetOutput
-
-## FlaxControlNetModel
-
-[[autodoc]] FlaxControlNetModel
-
-## FlaxControlNetOutput
-
-[[autodoc]] models.controlnets.controlnet_flax.FlaxControlNetOutput
diff --git a/docs/source/en/api/models/overview.md b/docs/source/en/api/models/overview.md
index 1c6a2092e6..eb9722739f 100644
--- a/docs/source/en/api/models/overview.md
+++ b/docs/source/en/api/models/overview.md
@@ -19,10 +19,6 @@ All models are built from the base [`ModelMixin`] class which is a [`torch.nn.Mo
## ModelMixin
[[autodoc]] ModelMixin
-## FlaxModelMixin
-
-[[autodoc]] FlaxModelMixin
-
## PushToHubMixin
[[autodoc]] utils.PushToHubMixin
diff --git a/docs/source/en/api/models/transformer2d.md b/docs/source/en/api/models/transformer2d.md
index 16ae6ace97..d8e0a858b0 100644
--- a/docs/source/en/api/models/transformer2d.md
+++ b/docs/source/en/api/models/transformer2d.md
@@ -22,11 +22,8 @@ When the input is **continuous**:
When the input is **discrete**:
-
-
-It is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised image don't contain a prediction for the masked pixel because the unnoised image cannot be masked.
-
-
+> [!TIP]
+> It is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised image don't contain a prediction for the masked pixel because the unnoised image cannot be masked.
1. Convert input (classes of latent pixels) to embeddings and apply positional embeddings.
2. Apply the Transformer blocks in the standard way.
diff --git a/docs/source/en/api/models/unet2d-cond.md b/docs/source/en/api/models/unet2d-cond.md
index 175fb11220..99a7c41ab2 100644
--- a/docs/source/en/api/models/unet2d-cond.md
+++ b/docs/source/en/api/models/unet2d-cond.md
@@ -23,9 +23,3 @@ The abstract from the paper is:
## UNet2DConditionOutput
[[autodoc]] models.unets.unet_2d_condition.UNet2DConditionOutput
-
-## FlaxUNet2DConditionModel
-[[autodoc]] models.unets.unet_2d_condition_flax.FlaxUNet2DConditionModel
-
-## FlaxUNet2DConditionOutput
-[[autodoc]] models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput
diff --git a/docs/source/en/api/modular_diffusers/guiders.md b/docs/source/en/api/modular_diffusers/guiders.md
new file mode 100644
index 0000000000..a24eb72207
--- /dev/null
+++ b/docs/source/en/api/modular_diffusers/guiders.md
@@ -0,0 +1,39 @@
+# Guiders
+
+Guiders are components in Modular Diffusers that control how the diffusion process is guided during generation. They implement various guidance techniques to improve generation quality and control.
+
+## BaseGuidance
+
+[[autodoc]] diffusers.guiders.guider_utils.BaseGuidance
+
+## ClassifierFreeGuidance
+
+[[autodoc]] diffusers.guiders.classifier_free_guidance.ClassifierFreeGuidance
+
+## ClassifierFreeZeroStarGuidance
+
+[[autodoc]] diffusers.guiders.classifier_free_zero_star_guidance.ClassifierFreeZeroStarGuidance
+
+## SkipLayerGuidance
+
+[[autodoc]] diffusers.guiders.skip_layer_guidance.SkipLayerGuidance
+
+## SmoothedEnergyGuidance
+
+[[autodoc]] diffusers.guiders.smoothed_energy_guidance.SmoothedEnergyGuidance
+
+## PerturbedAttentionGuidance
+
+[[autodoc]] diffusers.guiders.perturbed_attention_guidance.PerturbedAttentionGuidance
+
+## AdaptiveProjectedGuidance
+
+[[autodoc]] diffusers.guiders.adaptive_projected_guidance.AdaptiveProjectedGuidance
+
+## AutoGuidance
+
+[[autodoc]] diffusers.guiders.auto_guidance.AutoGuidance
+
+## TangentialClassifierFreeGuidance
+
+[[autodoc]] diffusers.guiders.tangential_classifier_free_guidance.TangentialClassifierFreeGuidance
diff --git a/docs/source/en/api/modular_diffusers/pipeline.md b/docs/source/en/api/modular_diffusers/pipeline.md
new file mode 100644
index 0000000000..f60261ea66
--- /dev/null
+++ b/docs/source/en/api/modular_diffusers/pipeline.md
@@ -0,0 +1,5 @@
+# Pipeline
+
+## ModularPipeline
+
+[[autodoc]] diffusers.modular_pipelines.modular_pipeline.ModularPipeline
diff --git a/docs/source/en/api/modular_diffusers/pipeline_blocks.md b/docs/source/en/api/modular_diffusers/pipeline_blocks.md
new file mode 100644
index 0000000000..8ad581e679
--- /dev/null
+++ b/docs/source/en/api/modular_diffusers/pipeline_blocks.md
@@ -0,0 +1,17 @@
+# Pipeline blocks
+
+## ModularPipelineBlocks
+
+[[autodoc]] diffusers.modular_pipelines.modular_pipeline.ModularPipelineBlocks
+
+## SequentialPipelineBlocks
+
+[[autodoc]] diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks
+
+## LoopSequentialPipelineBlocks
+
+[[autodoc]] diffusers.modular_pipelines.modular_pipeline.LoopSequentialPipelineBlocks
+
+## AutoPipelineBlocks
+
+[[autodoc]] diffusers.modular_pipelines.modular_pipeline.AutoPipelineBlocks
\ No newline at end of file
diff --git a/docs/source/en/api/modular_diffusers/pipeline_components.md b/docs/source/en/api/modular_diffusers/pipeline_components.md
new file mode 100644
index 0000000000..2d8e10aef6
--- /dev/null
+++ b/docs/source/en/api/modular_diffusers/pipeline_components.md
@@ -0,0 +1,17 @@
+# Components and configs
+
+## ComponentSpec
+
+[[autodoc]] diffusers.modular_pipelines.modular_pipeline.ComponentSpec
+
+## ConfigSpec
+
+[[autodoc]] diffusers.modular_pipelines.modular_pipeline.ConfigSpec
+
+## ComponentsManager
+
+[[autodoc]] diffusers.modular_pipelines.components_manager.ComponentsManager
+
+## InsertableDict
+
+[[autodoc]] diffusers.modular_pipelines.modular_pipeline_utils.InsertableDict
\ No newline at end of file
diff --git a/docs/source/en/api/modular_diffusers/pipeline_states.md b/docs/source/en/api/modular_diffusers/pipeline_states.md
new file mode 100644
index 0000000000..341d18ecb4
--- /dev/null
+++ b/docs/source/en/api/modular_diffusers/pipeline_states.md
@@ -0,0 +1,9 @@
+# Pipeline states
+
+## PipelineState
+
+[[autodoc]] diffusers.modular_pipelines.modular_pipeline.PipelineState
+
+## BlockState
+
+[[autodoc]] diffusers.modular_pipelines.modular_pipeline.BlockState
\ No newline at end of file
diff --git a/docs/source/en/api/outputs.md b/docs/source/en/api/outputs.md
index bed92f10f9..0fba1ab2fa 100644
--- a/docs/source/en/api/outputs.md
+++ b/docs/source/en/api/outputs.md
@@ -39,11 +39,8 @@ For instance, retrieving an image by indexing into it returns the tuple `(output
outputs[:1]
```
-
-
-To check a specific pipeline or model output, refer to its corresponding API documentation.
-
-
+> [!TIP]
+> To check a specific pipeline or model output, refer to its corresponding API documentation.
## BaseOutput
@@ -54,10 +51,6 @@ To check a specific pipeline or model output, refer to its corresponding API doc
[[autodoc]] pipelines.ImagePipelineOutput
-## FlaxImagePipelineOutput
-
-[[autodoc]] pipelines.pipeline_flax_utils.FlaxImagePipelineOutput
-
## AudioPipelineOutput
[[autodoc]] pipelines.AudioPipelineOutput
diff --git a/docs/source/en/api/parallel.md b/docs/source/en/api/parallel.md
new file mode 100644
index 0000000000..f2a6bee391
--- /dev/null
+++ b/docs/source/en/api/parallel.md
@@ -0,0 +1,24 @@
+
+
+# Parallelism
+
+Parallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times. Refer to the [Distributed inferece](../training/distributed_inference) guide to learn more.
+
+## ParallelConfig
+
+[[autodoc]] ParallelConfig
+
+## ContextParallelConfig
+
+[[autodoc]] ContextParallelConfig
+
+[[autodoc]] hooks.apply_context_parallel
diff --git a/docs/source/en/api/pipelines/allegro.md b/docs/source/en/api/pipelines/allegro.md
index 09313c2db0..a981fb1f94 100644
--- a/docs/source/en/api/pipelines/allegro.md
+++ b/docs/source/en/api/pipelines/allegro.md
@@ -17,11 +17,8 @@ The abstract from the paper is:
*Significant advancements have been made in the field of video generation, with the open-source community contributing a wealth of research papers and tools for training high-quality models. However, despite these efforts, the available information and resources remain insufficient for achieving commercial-level performance. In this report, we open the black box and introduce Allegro, an advanced video generation model that excels in both quality and temporal consistency. We also highlight the current limitations in the field and present a comprehensive methodology for training high-performance, commercial-level video generation models, addressing key aspects such as data, model architecture, training pipeline, and evaluation. Our user study shows that Allegro surpasses existing open-source models and most commercial models, ranking just behind Hailuo and Kling. Code: https://github.com/rhymes-ai/Allegro , Model: https://huggingface.co/rhymes-ai/Allegro , Gallery: https://rhymes.ai/allegro_gallery .*
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## Quantization
diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md
index aeec3254ca..f0188f3c36 100644
--- a/docs/source/en/api/pipelines/animatediff.md
+++ b/docs/source/en/api/pipelines/animatediff.md
@@ -102,11 +102,8 @@ Here are some sample outputs:
-
-
-AnimateDiff tends to work better with finetuned Stable Diffusion models. If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples. Additionally, the AnimateDiff checkpoints can be sensitive to the beta schedule of the scheduler. We recommend setting this to `linear`.
-
-
+> [!TIP]
+> AnimateDiff tends to work better with finetuned Stable Diffusion models. If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples. Additionally, the AnimateDiff checkpoints can be sensitive to the beta schedule of the scheduler. We recommend setting this to `linear`.
### AnimateDiffControlNetPipeline
@@ -799,17 +796,11 @@ frames = output.frames[0]
export_to_gif(frames, "animation.gif")
```
-
+> [!WARNING]
+> FreeInit is not really free - the improved quality comes at the cost of extra computation. It requires sampling a few extra times depending on the `num_iters` parameter that is set when enabling it. Setting the `use_fast_sampling` parameter to `True` can improve the overall performance (at the cost of lower quality compared to when `use_fast_sampling=False` but still better results than vanilla video generation models).
-FreeInit is not really free - the improved quality comes at the cost of extra computation. It requires sampling a few extra times depending on the `num_iters` parameter that is set when enabling it. Setting the `use_fast_sampling` parameter to `True` can improve the overall performance (at the cost of lower quality compared to when `use_fast_sampling=False` but still better results than vanilla video generation models).
-
-
-
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
diff --git a/docs/source/en/api/pipelines/attend_and_excite.md b/docs/source/en/api/pipelines/attend_and_excite.md
index b5ce3bb767..e7d1e1d2b8 100644
--- a/docs/source/en/api/pipelines/attend_and_excite.md
+++ b/docs/source/en/api/pipelines/attend_and_excite.md
@@ -23,11 +23,8 @@ The abstract from the paper is:
You can find additional information about Attend-and-Excite on the [project page](https://attendandexcite.github.io/Attend-and-Excite/), the [original codebase](https://github.com/AttendAndExcite/Attend-and-Excite), or try it out in a [demo](https://huggingface.co/spaces/AttendAndExcite/Attend-and-Excite).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## StableDiffusionAttendAndExcitePipeline
diff --git a/docs/source/en/api/pipelines/audioldm.md b/docs/source/en/api/pipelines/audioldm.md
index 6b143d2990..c8073a14ef 100644
--- a/docs/source/en/api/pipelines/audioldm.md
+++ b/docs/source/en/api/pipelines/audioldm.md
@@ -38,11 +38,8 @@ During inference:
* The _quality_ of the predicted audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference.
* The _length_ of the predicted audio sample can be controlled by varying the `audio_length_in_s` argument.
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## AudioLDMPipeline
[[autodoc]] AudioLDMPipeline
diff --git a/docs/source/en/api/pipelines/audioldm2.md b/docs/source/en/api/pipelines/audioldm2.md
index 1a196099d7..45a9002ea0 100644
--- a/docs/source/en/api/pipelines/audioldm2.md
+++ b/docs/source/en/api/pipelines/audioldm2.md
@@ -58,11 +58,8 @@ See table below for details on the three checkpoints:
The following example demonstrates how to construct good music and speech generation using the aforementioned tips: [example](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm2#diffusers.AudioLDM2Pipeline.__call__.example).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## AudioLDM2Pipeline
[[autodoc]] AudioLDM2Pipeline
diff --git a/docs/source/en/api/pipelines/aura_flow.md b/docs/source/en/api/pipelines/aura_flow.md
index 1d6002335c..67951859b9 100644
--- a/docs/source/en/api/pipelines/aura_flow.md
+++ b/docs/source/en/api/pipelines/aura_flow.md
@@ -16,11 +16,8 @@ AuraFlow is inspired by [Stable Diffusion 3](../pipelines/stable_diffusion/stabl
It was developed by the Fal team and more details about it can be found in [this blog post](https://blog.fal.ai/auraflow/).
-
-
-AuraFlow can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details.
-
-
+> [!TIP]
+> AuraFlow can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details.
## Quantization
diff --git a/docs/source/en/api/pipelines/blip_diffusion.md b/docs/source/en/api/pipelines/blip_diffusion.md
index d94281a4a9..b9c6ed7b5f 100644
--- a/docs/source/en/api/pipelines/blip_diffusion.md
+++ b/docs/source/en/api/pipelines/blip_diffusion.md
@@ -26,11 +26,8 @@ The original codebase can be found at [salesforce/LAVIS](https://github.com/sale
`BlipDiffusionPipeline` and `BlipDiffusionControlNetPipeline` were contributed by [`ayushtues`](https://github.com/ayushtues/).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## BlipDiffusionPipeline
diff --git a/docs/source/en/api/pipelines/bria_3_2.md b/docs/source/en/api/pipelines/bria_3_2.md
new file mode 100644
index 0000000000..059fa01f9f
--- /dev/null
+++ b/docs/source/en/api/pipelines/bria_3_2.md
@@ -0,0 +1,44 @@
+
+
+# Bria 3.2
+
+Bria 3.2 is the next-generation commercial-ready text-to-image model. With just 4 billion parameters, it provides exceptional aesthetics and text rendering, evaluated to provide on par results to leading open-source models, and outperforming other licensed models.
+In addition to being built entirely on licensed data, 3.2 provides several advantages for enterprise and commercial use:
+
+- Efficient Compute - the model is X3 smaller than the equivalent models in the market (4B parameters vs 12B parameters other open source models)
+- Architecture Consistency: Same architecture as 3.1—ideal for users looking to upgrade without disruption.
+- Fine-tuning Speedup: 2x faster fine-tuning on L40S and A100.
+
+Original model checkpoints for Bria 3.2 can be found [here](https://huggingface.co/briaai/BRIA-3.2).
+Github repo for Bria 3.2 can be found [here](https://github.com/Bria-AI/BRIA-3.2).
+
+If you want to learn more about the Bria platform, and get free traril access, please visit [bria.ai](https://bria.ai).
+
+
+## Usage
+
+_As the model is gated, before using it with diffusers you first need to go to the [Bria 3.2 Hugging Face page](https://huggingface.co/briaai/BRIA-3.2), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
+
+Use the command below to log in:
+
+```bash
+hf auth login
+```
+
+
+## BriaPipeline
+
+[[autodoc]] BriaPipeline
+ - all
+ - __call__
+
diff --git a/docs/source/en/api/pipelines/chroma.md b/docs/source/en/api/pipelines/chroma.md
index 40e290e4bd..df03fbb325 100644
--- a/docs/source/en/api/pipelines/chroma.md
+++ b/docs/source/en/api/pipelines/chroma.md
@@ -21,11 +21,8 @@ Chroma is a text to image generation model based on Flux.
Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma).
-
-
-Chroma can use all the same optimizations as Flux.
-
-
+> [!TIP]
+> Chroma can use all the same optimizations as Flux.
## Inference
diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md
index 157e987efd..ec673e0763 100644
--- a/docs/source/en/api/pipelines/cogvideox.md
+++ b/docs/source/en/api/pipelines/cogvideox.md
@@ -50,7 +50,7 @@ from diffusers.utils import export_to_video
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="torchao",
quant_kwargs={"quant_type": "int8wo"},
- components_to_quantize=["transformer"]
+ components_to_quantize="transformer"
)
# fp8 layerwise weight-casting
diff --git a/docs/source/en/api/pipelines/cogview3.md b/docs/source/en/api/pipelines/cogview3.md
index 0180fee300..5ee02e1a70 100644
--- a/docs/source/en/api/pipelines/cogview3.md
+++ b/docs/source/en/api/pipelines/cogview3.md
@@ -21,11 +21,8 @@ The abstract from the paper is:
*Recent advancements in text-to-image generative systems have been largely driven by diffusion models. However, single-stage text-to-image diffusion models still face challenges, in terms of computational efficiency and the refinement of image details. To tackle the issue, we propose CogView3, an innovative cascaded framework that enhances the performance of text-to-image diffusion. CogView3 is the first model implementing relay diffusion in the realm of text-to-image generation, executing the task by first creating low-resolution images and subsequently applying relay-based super-resolution. This methodology not only results in competitive text-to-image outputs but also greatly reduces both training and inference costs. Our experimental results demonstrate that CogView3 outperforms SDXL, the current state-of-the-art open-source text-to-image diffusion model, by 77.0% in human evaluations, all while requiring only about 1/2 of the inference time. The distilled variant of CogView3 achieves comparable performance while only utilizing 1/10 of the inference time by SDXL.*
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
diff --git a/docs/source/en/api/pipelines/cogview4.md b/docs/source/en/api/pipelines/cogview4.md
index 50ba5baa62..7857dc8c94 100644
--- a/docs/source/en/api/pipelines/cogview4.md
+++ b/docs/source/en/api/pipelines/cogview4.md
@@ -15,11 +15,8 @@
# CogView4
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
diff --git a/docs/source/en/api/pipelines/consisid.md b/docs/source/en/api/pipelines/consisid.md
index db6b5e59ac..bba0472924 100644
--- a/docs/source/en/api/pipelines/consisid.md
+++ b/docs/source/en/api/pipelines/consisid.md
@@ -25,11 +25,8 @@ The abstract from the paper is:
*Identity-preserving text-to-video (IPT2V) generation aims to create high-fidelity videos with consistent human identity. It is an important task in video generation but remains an open problem for generative models. This paper pushes the technical frontier of IPT2V in two directions that have not been resolved in the literature: (1) A tuning-free pipeline without tedious case-by-case finetuning, and (2) A frequency-aware heuristic identity-preserving Diffusion Transformer (DiT)-based control scheme. To achieve these goals, we propose **ConsisID**, a tuning-free DiT-based controllable IPT2V model to keep human-**id**entity **consis**tent in the generated video. Inspired by prior findings in frequency analysis of vision/diffusion transformers, it employs identity-control signals in the frequency domain, where facial features can be decomposed into low-frequency global features (e.g., profile, proportions) and high-frequency intrinsic features (e.g., identity markers that remain unaffected by pose changes). First, from a low-frequency perspective, we introduce a global facial extractor, which encodes the reference image and facial key points into a latent space, generating features enriched with low-frequency information. These features are then integrated into the shallow layers of the network to alleviate training challenges associated with DiT. Second, from a high-frequency perspective, we design a local facial extractor to capture high-frequency details and inject them into the transformer blocks, enhancing the model's ability to preserve fine-grained features. To leverage the frequency information for identity preservation, we propose a hierarchical training strategy, transforming a vanilla pre-trained video generation model into an IPT2V model. Extensive experiments demonstrate that our frequency-aware heuristic scheme provides an optimal control solution for DiT-based models. Thanks to this scheme, our **ConsisID** achieves excellent results in generating high-quality, identity-preserving videos, making strides towards more effective IPT2V. The model weight of ConsID is publicly available at https://github.com/PKU-YuanGroup/ConsisID.*
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
This pipeline was contributed by [SHYuanBest](https://github.com/SHYuanBest). The original codebase can be found [here](https://github.com/PKU-YuanGroup/ConsisID). The original weights can be found under [hf.co/BestWishYsh](https://huggingface.co/BestWishYsh).
diff --git a/docs/source/en/api/pipelines/control_flux_inpaint.md b/docs/source/en/api/pipelines/control_flux_inpaint.md
index 03a4fbebb8..4b087f20ef 100644
--- a/docs/source/en/api/pipelines/control_flux_inpaint.md
+++ b/docs/source/en/api/pipelines/control_flux_inpaint.md
@@ -26,11 +26,8 @@ FLUX.1 Depth and Canny [dev] is a 12 billion parameter rectified flow transforme
| Canny | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) |
-
-
-Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c).
-
-
+> [!TIP]
+> Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c).
```python
import torch
diff --git a/docs/source/en/api/pipelines/controlnet.md b/docs/source/en/api/pipelines/controlnet.md
index eea3473d36..afc0a4653e 100644
--- a/docs/source/en/api/pipelines/controlnet.md
+++ b/docs/source/en/api/pipelines/controlnet.md
@@ -28,11 +28,8 @@ This model was contributed by [takuma104](https://huggingface.co/takuma104). ❤
The original codebase can be found at [lllyasviel/ControlNet](https://github.com/lllyasviel/ControlNet), and you can find official ControlNet checkpoints on [lllyasviel's](https://huggingface.co/lllyasviel) Hub profile.
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## StableDiffusionControlNetPipeline
[[autodoc]] StableDiffusionControlNetPipeline
@@ -72,11 +69,3 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
## StableDiffusionPipelineOutput
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
-
-## FlaxStableDiffusionControlNetPipeline
-[[autodoc]] FlaxStableDiffusionControlNetPipeline
- - all
- - __call__
-
-## FlaxStableDiffusionControlNetPipelineOutput
-[[autodoc]] pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput
diff --git a/docs/source/en/api/pipelines/controlnet_flux.md b/docs/source/en/api/pipelines/controlnet_flux.md
index 9feb736523..ff38ca3f2c 100644
--- a/docs/source/en/api/pipelines/controlnet_flux.md
+++ b/docs/source/en/api/pipelines/controlnet_flux.md
@@ -44,11 +44,8 @@ XLabs ControlNets are also supported, which was contributed by the [XLabs team](
| HED | [The XLabs Team](https://huggingface.co/XLabs-AI) | [Link](https://huggingface.co/XLabs-AI/flux-controlnet-hed-diffusers) |
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## FluxControlNetPipeline
[[autodoc]] FluxControlNetPipeline
diff --git a/docs/source/en/api/pipelines/controlnet_hunyuandit.md b/docs/source/en/api/pipelines/controlnet_hunyuandit.md
index c79b2dbf65..88dc2de10a 100644
--- a/docs/source/en/api/pipelines/controlnet_hunyuandit.md
+++ b/docs/source/en/api/pipelines/controlnet_hunyuandit.md
@@ -24,11 +24,8 @@ The abstract from the paper is:
This code is implemented by Tencent Hunyuan Team. You can find pre-trained checkpoints for Hunyuan-DiT ControlNets on [Tencent Hunyuan](https://huggingface.co/Tencent-Hunyuan).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## HunyuanDiTControlNetPipeline
[[autodoc]] HunyuanDiTControlNetPipeline
diff --git a/docs/source/en/api/pipelines/controlnet_sd3.md b/docs/source/en/api/pipelines/controlnet_sd3.md
index 067c1c6b01..8cdada9edf 100644
--- a/docs/source/en/api/pipelines/controlnet_sd3.md
+++ b/docs/source/en/api/pipelines/controlnet_sd3.md
@@ -38,11 +38,8 @@ This controlnet code is mainly implemented by [The InstantX Team](https://huggin
| Inpainting | [The AlimamaCreative Team](https://huggingface.co/alimama-creative) | [link](https://huggingface.co/alimama-creative/SD3-Controlnet-Inpainting) |
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## StableDiffusion3ControlNetPipeline
[[autodoc]] StableDiffusion3ControlNetPipeline
diff --git a/docs/source/en/api/pipelines/controlnet_sdxl.md b/docs/source/en/api/pipelines/controlnet_sdxl.md
index cb0554a1cc..89fc1c3897 100644
--- a/docs/source/en/api/pipelines/controlnet_sdxl.md
+++ b/docs/source/en/api/pipelines/controlnet_sdxl.md
@@ -26,19 +26,13 @@ The abstract from the paper is:
You can find additional smaller Stable Diffusion XL (SDXL) ControlNet checkpoints from the 🤗 [Diffusers](https://huggingface.co/diffusers) Hub organization, and browse [community-trained](https://huggingface.co/models?other=stable-diffusion-xl&other=controlnet) checkpoints on the Hub.
-
-
-🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve!
-
-
+> [!WARNING]
+> 🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve!
If you don't see a checkpoint you're interested in, you can train your own SDXL ControlNet with our [training script](../../../../../examples/controlnet/README_sdxl).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## StableDiffusionXLControlNetPipeline
[[autodoc]] StableDiffusionXLControlNetPipeline
diff --git a/docs/source/en/api/pipelines/controlnetxs.md b/docs/source/en/api/pipelines/controlnetxs.md
index aea8cb2e86..d44fb0cf0f 100644
--- a/docs/source/en/api/pipelines/controlnetxs.md
+++ b/docs/source/en/api/pipelines/controlnetxs.md
@@ -31,11 +31,8 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## StableDiffusionControlNetXSPipeline
[[autodoc]] StableDiffusionControlNetXSPipeline
diff --git a/docs/source/en/api/pipelines/controlnetxs_sdxl.md b/docs/source/en/api/pipelines/controlnetxs_sdxl.md
index 76937b16c5..7ae0e2a2a1 100644
--- a/docs/source/en/api/pipelines/controlnetxs_sdxl.md
+++ b/docs/source/en/api/pipelines/controlnetxs_sdxl.md
@@ -27,17 +27,11 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
-
+> [!WARNING]
+> 🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve!
-🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve!
-
-
-
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## StableDiffusionXLControlNetXSPipeline
[[autodoc]] StableDiffusionXLControlNetXSPipeline
diff --git a/docs/source/en/api/pipelines/cosmos.md b/docs/source/en/api/pipelines/cosmos.md
index dba807c5ce..fb9453480e 100644
--- a/docs/source/en/api/pipelines/cosmos.md
+++ b/docs/source/en/api/pipelines/cosmos.md
@@ -18,11 +18,8 @@
*Physical AI needs to be trained digitally first. It needs a digital twin of itself, the policy model, and a digital twin of the world, the world model. In this paper, we present the Cosmos World Foundation Model Platform to help developers build customized world models for their Physical AI setups. We position a world foundation model as a general-purpose world model that can be fine-tuned into customized world models for downstream applications. Our platform covers a video curation pipeline, pre-trained world foundation models, examples of post-training of pre-trained world foundation models, and video tokenizers. To help Physical AI builders solve the most critical problems of our society, we make our platform open-source and our models open-weight with permissive licenses available via https://github.com/NVIDIA/Cosmos.*
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## Loading original format checkpoints
diff --git a/docs/source/en/api/pipelines/dance_diffusion.md b/docs/source/en/api/pipelines/dance_diffusion.md
index 5805561e49..0434f63195 100644
--- a/docs/source/en/api/pipelines/dance_diffusion.md
+++ b/docs/source/en/api/pipelines/dance_diffusion.md
@@ -20,11 +20,8 @@ specific language governing permissions and limitations under the License.
Dance Diffusion is the first in a suite of generative audio tools for producers and musicians released by [Harmonai](https://github.com/Harmonai-org).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## DanceDiffusionPipeline
[[autodoc]] DanceDiffusionPipeline
diff --git a/docs/source/en/api/pipelines/ddpm.md b/docs/source/en/api/pipelines/ddpm.md
index 716cf73275..63c2fcaf89 100644
--- a/docs/source/en/api/pipelines/ddpm.md
+++ b/docs/source/en/api/pipelines/ddpm.md
@@ -20,11 +20,8 @@ The abstract from the paper is:
The original codebase can be found at [hohonathanho/diffusion](https://github.com/hojonathanho/diffusion).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
# DDPMPipeline
[[autodoc]] DDPMPipeline
diff --git a/docs/source/en/api/pipelines/dit.md b/docs/source/en/api/pipelines/dit.md
index e87058899b..16d0c99961 100644
--- a/docs/source/en/api/pipelines/dit.md
+++ b/docs/source/en/api/pipelines/dit.md
@@ -20,11 +20,8 @@ The abstract from the paper is:
The original codebase can be found at [facebookresearch/dit](https://github.com/facebookresearch/dit).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## DiTPipeline
[[autodoc]] DiTPipeline
diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md
index ca39d71814..358b8139c7 100644
--- a/docs/source/en/api/pipelines/flux.md
+++ b/docs/source/en/api/pipelines/flux.md
@@ -21,11 +21,10 @@ Flux is a series of text-to-image generation models based on diffusion transform
Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](https://github.com/black-forest-labs/flux).
-
-
-Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c).
-
-
+> [!TIP]
+> Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c).
+>
+> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
Flux comes in the following variants:
@@ -314,6 +313,67 @@ if integrity_checker.test_image(image_):
raise ValueError("Your image has been flagged. Choose another prompt/image or try again.")
```
+### Kontext Inpainting
+`FluxKontextInpaintPipeline` enables image modification within a fixed mask region. It currently supports both text-based conditioning and image-reference conditioning.
+
+
+
+
+```python
+import torch
+from diffusers import FluxKontextInpaintPipeline
+from diffusers.utils import load_image
+
+prompt = "Change the yellow dinosaur to green one"
+img_url = (
+ "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_input.jpeg?raw=true"
+)
+mask_url = (
+ "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_mask.png?raw=true"
+)
+
+source = load_image(img_url)
+mask = load_image(mask_url)
+
+pipe = FluxKontextInpaintPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+image = pipe(prompt=prompt, image=source, mask_image=mask, strength=1.0).images[0]
+image.save("kontext_inpainting_normal.png")
+```
+
+
+
+```python
+import torch
+from diffusers import FluxKontextInpaintPipeline
+from diffusers.utils import load_image
+
+pipe = FluxKontextInpaintPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+prompt = "Replace this ball"
+img_url = "https://images.pexels.com/photos/39362/the-ball-stadion-football-the-pitch-39362.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
+mask_url = "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/ball_mask.png?raw=true"
+image_reference_url = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTah3x6OL_ECMBaZ5ZlJJhNsyC-OSMLWAI-xw&s"
+
+source = load_image(img_url)
+mask = load_image(mask_url)
+image_reference = load_image(image_reference_url)
+
+mask = pipe.mask_processor.blur(mask, blur_factor=12)
+image = pipe(
+ prompt=prompt, image=source, mask_image=mask, image_reference=image_reference, strength=1.0
+).images[0]
+image.save("kontext_inpainting_ref.png")
+```
+
+
+
## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux
We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD).
@@ -357,11 +417,8 @@ When unloading the Control LoRA weights, call `pipe.unload_lora_weights(reset_to
## IP-Adapter
-
-
-Check out [IP-Adapter](../../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work.
-
-
+> [!TIP]
+> Check out [IP-Adapter](../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work.
An IP-Adapter lets you prompt Flux with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images.
@@ -541,9 +598,8 @@ image.save("flux.png")
The `FluxTransformer2DModel` supports loading checkpoints in the original format shipped by Black Forest Labs. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
-
-`FP8` inference can be brittle depending on the GPU type, CUDA version, and `torch` version that you are using. It is recommended that you use the `optimum-quanto` library in order to run FP8 inference on your machine.
-
+> [!TIP]
+> `FP8` inference can be brittle depending on the GPU type, CUDA version, and `torch` version that you are using. It is recommended that you use the `optimum-quanto` library in order to run FP8 inference on your machine.
The following example demonstrates how to run Flux with less than 16GB of VRAM.
@@ -644,3 +700,15 @@ image.save("flux-fp8-dev.png")
[[autodoc]] FluxFillPipeline
- all
- __call__
+
+## FluxKontextPipeline
+
+[[autodoc]] FluxKontextPipeline
+ - all
+ - __call__
+
+## FluxKontextInpaintPipeline
+
+[[autodoc]] FluxKontextInpaintPipeline
+ - all
+ - __call__
\ No newline at end of file
diff --git a/docs/source/en/api/pipelines/framepack.md b/docs/source/en/api/pipelines/framepack.md
index ba7b2d0dc0..a25cfe24a4 100644
--- a/docs/source/en/api/pipelines/framepack.md
+++ b/docs/source/en/api/pipelines/framepack.md
@@ -22,11 +22,8 @@
*We present a neural network structure, FramePack, to train next-frame (or next-frame-section) prediction models for video generation. The FramePack compresses input frames to make the transformer context length a fixed number regardless of the video length. As a result, we are able to process a large number of frames using video diffusion with computation bottleneck similar to image diffusion. This also makes the training video batch sizes significantly higher (batch sizes become comparable to image diffusion training). We also propose an anti-drifting sampling method that generates frames in inverted temporal order with early-established endpoints to avoid exposure bias (error accumulation over iterations). Finally, we show that existing video diffusion models can be finetuned with FramePack, and their visual quality may be improved because the next-frame prediction supports more balanced diffusion schedulers with less extreme flow shift timesteps.*
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## Available models
diff --git a/docs/source/en/api/pipelines/hidream.md b/docs/source/en/api/pipelines/hidream.md
index 57814a309b..add4ad3132 100644
--- a/docs/source/en/api/pipelines/hidream.md
+++ b/docs/source/en/api/pipelines/hidream.md
@@ -16,15 +16,12 @@
[HiDream-I1](https://huggingface.co/HiDream-ai) by HiDream.ai
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
## Available models
-The following models are available for the [`HiDreamImagePipeline`](text-to-image) pipeline:
+The following models are available for the [`HiDreamImagePipeline`] pipeline:
| Model name | Description |
|:---|:---|
diff --git a/docs/source/en/api/pipelines/hunyuan_video.md b/docs/source/en/api/pipelines/hunyuan_video.md
index df52c49b36..cdd81495b6 100644
--- a/docs/source/en/api/pipelines/hunyuan_video.md
+++ b/docs/source/en/api/pipelines/hunyuan_video.md
@@ -54,7 +54,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16
},
- components_to_quantize=["transformer"]
+ components_to_quantize="transformer"
)
pipeline = HunyuanVideoPipeline.from_pretrained(
@@ -91,7 +91,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16
},
- components_to_quantize=["transformer"]
+ components_to_quantize="transformer"
)
pipeline = HunyuanVideoPipeline.from_pretrained(
@@ -139,7 +139,7 @@ export_to_video(video, "output.mp4", fps=15)
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16
},
- components_to_quantize=["transformer"]
+ components_to_quantize="transformer"
)
pipeline = HunyuanVideoPipeline.from_pretrained(
diff --git a/docs/source/en/api/pipelines/hunyuandit.md b/docs/source/en/api/pipelines/hunyuandit.md
index 07e869ba95..3f4db66c6c 100644
--- a/docs/source/en/api/pipelines/hunyuandit.md
+++ b/docs/source/en/api/pipelines/hunyuandit.md
@@ -28,17 +28,11 @@ HunyuanDiT has the following components:
* It uses a diffusion transformer as the backbone
* It combines two text encoders, a bilingual CLIP and a multilingual T5 encoder
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
-
-
-
-You can further improve generation quality by passing the generated image from [`HungyuanDiTPipeline`] to the [SDXL refiner](../../using-diffusers/sdxl#base-to-refiner-model) model.
-
-
+> [!TIP]
+> You can further improve generation quality by passing the generated image from [`HungyuanDiTPipeline`] to the [SDXL refiner](../../using-diffusers/sdxl#base-to-refiner-model) model.
## Optimization
diff --git a/docs/source/en/api/pipelines/i2vgenxl.md b/docs/source/en/api/pipelines/i2vgenxl.md
index 76a51a6cd5..711a5625f9 100644
--- a/docs/source/en/api/pipelines/i2vgenxl.md
+++ b/docs/source/en/api/pipelines/i2vgenxl.md
@@ -23,11 +23,8 @@ The abstract from the paper is:
The original codebase can be found [here](https://github.com/ali-vilab/i2vgen-xl/). The model checkpoints can be found [here](https://huggingface.co/ali-vilab/).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. Also, to know more about reducing the memory usage of this pipeline, refer to the ["Reduce memory usage"] section [here](../../using-diffusers/svd#reduce-memory-usage).
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. Also, to know more about reducing the memory usage of this pipeline, refer to the ["Reduce memory usage"] section [here](../../using-diffusers/svd#reduce-memory-usage).
Sample output with I2VGenXL:
diff --git a/docs/source/en/api/pipelines/kandinsky.md b/docs/source/en/api/pipelines/kandinsky.md
index 90c76954ab..7717f2db69 100644
--- a/docs/source/en/api/pipelines/kandinsky.md
+++ b/docs/source/en/api/pipelines/kandinsky.md
@@ -17,17 +17,11 @@ The description from it's GitHub page is:
The original codebase can be found at [ai-forever/Kandinsky-2](https://github.com/ai-forever/Kandinsky-2).
-
+> [!TIP]
+> Check out the [Kandinsky Community](https://huggingface.co/kandinsky-community) organization on the Hub for the official model checkpoints for tasks like text-to-image, image-to-image, and inpainting.
-Check out the [Kandinsky Community](https://huggingface.co/kandinsky-community) organization on the Hub for the official model checkpoints for tasks like text-to-image, image-to-image, and inpainting.
-
-
-
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## KandinskyPriorPipeline
diff --git a/docs/source/en/api/pipelines/kandinsky3.md b/docs/source/en/api/pipelines/kandinsky3.md
index 1727387c4a..f08afa8879 100644
--- a/docs/source/en/api/pipelines/kandinsky3.md
+++ b/docs/source/en/api/pipelines/kandinsky3.md
@@ -28,17 +28,11 @@ Its architecture includes 3 main components:
The original codebase can be found at [ai-forever/Kandinsky-3](https://github.com/ai-forever/Kandinsky-3).
-
+> [!TIP]
+> Check out the [Kandinsky Community](https://huggingface.co/kandinsky-community) organization on the Hub for the official model checkpoints for tasks like text-to-image, image-to-image, and inpainting.
-Check out the [Kandinsky Community](https://huggingface.co/kandinsky-community) organization on the Hub for the official model checkpoints for tasks like text-to-image, image-to-image, and inpainting.
-
-
-
-
-
-Make sure to check out the schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## Kandinsky3Pipeline
diff --git a/docs/source/en/api/pipelines/kandinsky_v22.md b/docs/source/en/api/pipelines/kandinsky_v22.md
index e68c094e23..0e0ed80db6 100644
--- a/docs/source/en/api/pipelines/kandinsky_v22.md
+++ b/docs/source/en/api/pipelines/kandinsky_v22.md
@@ -17,17 +17,11 @@ The description from it's GitHub page is:
The original codebase can be found at [ai-forever/Kandinsky-2](https://github.com/ai-forever/Kandinsky-2).
-
+> [!TIP]
+> Check out the [Kandinsky Community](https://huggingface.co/kandinsky-community) organization on the Hub for the official model checkpoints for tasks like text-to-image, image-to-image, and inpainting.
-Check out the [Kandinsky Community](https://huggingface.co/kandinsky-community) organization on the Hub for the official model checkpoints for tasks like text-to-image, image-to-image, and inpainting.
-
-
-
-
-
-Make sure to check out the schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## KandinskyV22PriorPipeline
diff --git a/docs/source/en/api/pipelines/kolors.md b/docs/source/en/api/pipelines/kolors.md
index 048f6c1de9..b4c83fe134 100644
--- a/docs/source/en/api/pipelines/kolors.md
+++ b/docs/source/en/api/pipelines/kolors.md
@@ -50,17 +50,11 @@ image.save("kolors_sample.png")
Kolors needs a different IP Adapter to work, and it uses [Openai-CLIP-336](https://huggingface.co/openai/clip-vit-large-patch14-336) as an image encoder.
-
+> [!TIP]
+> Using an IP Adapter with Kolors requires more than 24GB of VRAM. To use it, we recommend using [`~DiffusionPipeline.enable_model_cpu_offload`] on consumer GPUs.
-Using an IP Adapter with Kolors requires more than 24GB of VRAM. To use it, we recommend using [`~DiffusionPipeline.enable_model_cpu_offload`] on consumer GPUs.
-
-
-
-
-
-While Kolors is integrated in Diffusers, you need to load the image encoder from a revision to use the safetensor files. You can still use the main branch of the original repository if you're comfortable loading pickle checkpoints.
-
-
+> [!TIP]
+> While Kolors is integrated in Diffusers, you need to load the image encoder from a revision to use the safetensor files. You can still use the main branch of the original repository if you're comfortable loading pickle checkpoints.
```python
import torch
diff --git a/docs/source/en/api/pipelines/latent_diffusion.md b/docs/source/en/api/pipelines/latent_diffusion.md
index 5489d673f5..cefed90e86 100644
--- a/docs/source/en/api/pipelines/latent_diffusion.md
+++ b/docs/source/en/api/pipelines/latent_diffusion.md
@@ -20,11 +20,8 @@ The abstract from the paper is:
The original codebase can be found at [CompVis/latent-diffusion](https://github.com/CompVis/latent-diffusion).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## LDMTextToImagePipeline
[[autodoc]] LDMTextToImagePipeline
diff --git a/docs/source/en/api/pipelines/latte.md b/docs/source/en/api/pipelines/latte.md
index 9d4d12dd4e..c8438c668a 100644
--- a/docs/source/en/api/pipelines/latte.md
+++ b/docs/source/en/api/pipelines/latte.md
@@ -26,11 +26,8 @@ The abstract from the paper is:
This pipeline was contributed by [maxin-cn](https://github.com/maxin-cn). The original codebase can be found [here](https://github.com/Vchitect/Latte). The original weights can be found under [hf.co/maxin-cn](https://huggingface.co/maxin-cn).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
### Inference
diff --git a/docs/source/en/api/pipelines/ledits_pp.md b/docs/source/en/api/pipelines/ledits_pp.md
index 7c08971aa8..103bcf3798 100644
--- a/docs/source/en/api/pipelines/ledits_pp.md
+++ b/docs/source/en/api/pipelines/ledits_pp.md
@@ -22,16 +22,12 @@ The abstract from the paper is:
*Text-to-image diffusion models have recently received increasing interest for their astonishing ability to produce high-fidelity images from solely text inputs. Subsequent research efforts aim to exploit and apply their capabilities to real image editing. However, existing image-to-image methods are often inefficient, imprecise, and of limited versatility. They either require time-consuming fine-tuning, deviate unnecessarily strongly from the input image, and/or lack support for multiple, simultaneous edits. To address these issues, we introduce LEDITS++, an efficient yet versatile and precise textual image manipulation technique. LEDITS++'s novel inversion approach requires no tuning nor optimization and produces high-fidelity results with a few diffusion steps. Second, our methodology supports multiple simultaneous edits and is architecture-agnostic. Third, we use a novel implicit masking technique that limits changes to relevant image regions. We propose the novel TEdBench++ benchmark as part of our exhaustive evaluation. Our results demonstrate the capabilities of LEDITS++ and its improvements over previous methods. The project page is available at https://leditsplusplus-project.static.hf.space .*
-
+> [!TIP]
+> You can find additional information about LEDITS++ on the [project page](https://leditsplusplus-project.static.hf.space/index.html) and try it out in a [demo](https://huggingface.co/spaces/editing-images/leditsplusplus).
-You can find additional information about LEDITS++ on the [project page](https://leditsplusplus-project.static.hf.space/index.html) and try it out in a [demo](https://huggingface.co/spaces/editing-images/leditsplusplus).
-
-
-
-
-Due to some backward compatibility issues with the current diffusers implementation of [`~schedulers.DPMSolverMultistepScheduler`] this implementation of LEdits++ can no longer guarantee perfect inversion.
-This issue is unlikely to have any noticeable effects on applied use-cases. However, we provide an alternative implementation that guarantees perfect inversion in a dedicated [GitHub repo](https://github.com/ml-research/ledits_pp).
-
+> [!WARNING]
+> Due to some backward compatibility issues with the current diffusers implementation of [`~schedulers.DPMSolverMultistepScheduler`] this implementation of LEdits++ can no longer guarantee perfect inversion.
+> This issue is unlikely to have any noticeable effects on applied use-cases. However, we provide an alternative implementation that guarantees perfect inversion in a dedicated [GitHub repo](https://github.com/ml-research/ledits_pp).
We provide two distinct pipelines based on different pre-trained models.
diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md
index 867dd97259..940144538a 100644
--- a/docs/source/en/api/pipelines/ltx_video.md
+++ b/docs/source/en/api/pipelines/ltx_video.md
@@ -88,7 +88,7 @@ export_to_video(video, "output.mp4", fps=24)
-[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster.
+[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
```py
import torch
diff --git a/docs/source/en/api/pipelines/lumina.md b/docs/source/en/api/pipelines/lumina.md
index 3bd3d9f8e0..0a236d213d 100644
--- a/docs/source/en/api/pipelines/lumina.md
+++ b/docs/source/en/api/pipelines/lumina.md
@@ -45,11 +45,8 @@ Lumina-T2X has the following components:
This pipeline was contributed by [PommesPeter](https://github.com/PommesPeter). The original codebase can be found [here](https://github.com/Alpha-VLLM/Lumina-T2X). The original weights can be found under [hf.co/Alpha-VLLM](https://huggingface.co/Alpha-VLLM).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
### Inference (Text-to-Image)
diff --git a/docs/source/en/api/pipelines/lumina2.md b/docs/source/en/api/pipelines/lumina2.md
index 092d7cde2e..0c4e793404 100644
--- a/docs/source/en/api/pipelines/lumina2.md
+++ b/docs/source/en/api/pipelines/lumina2.md
@@ -24,11 +24,8 @@ The abstract from the paper is:
*We introduce Lumina-Image 2.0, an advanced text-to-image model that surpasses previous state-of-the-art methods across multiple benchmarks, while also shedding light on its potential to evolve into a generalist vision intelligence model. Lumina-Image 2.0 exhibits three key properties: (1) Unification – it adopts a unified architecture that treats text and image tokens as a joint sequence, enabling natural cross-modal interactions and facilitating task expansion. Besides, since high-quality captioners can provide semantically better-aligned text-image training pairs, we introduce a unified captioning system, UniCaptioner, which generates comprehensive and precise captions for the model. This not only accelerates model convergence but also enhances prompt adherence, variable-length prompt handling, and task generalization via prompt templates. (2) Efficiency – to improve the efficiency of the unified architecture, we develop a set of optimization techniques that improve semantic learning and fine-grained texture generation during training while incorporating inference-time acceleration strategies without compromising image quality. (3) Transparency – we open-source all training details, code, and models to ensure full reproducibility, aiming to bridge the gap between well-resourced closed-source research teams and independent developers.*
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## Using Single File loading with Lumina Image 2.0
diff --git a/docs/source/en/api/pipelines/marigold.md b/docs/source/en/api/pipelines/marigold.md
index e9ca0df067..bb6e94de33 100644
--- a/docs/source/en/api/pipelines/marigold.md
+++ b/docs/source/en/api/pipelines/marigold.md
@@ -45,14 +45,11 @@ This work expanded Marigold to support new modalities such as **Surface Normals*
(IID), introduced a training protocol for **Latent Consistency Models** (LCM), and demonstrated **High-Resolution** (HR)
processing capability.
-
-
-The early Marigold models (`v1-0` and earlier) were optimized for best results with at least 10 inference steps.
-LCM models were later developed to enable high-quality inference in just 1 to 4 steps.
-Marigold models `v1-1` and later use the DDIM scheduler to achieve optimal
-results in as few as 1 to 4 steps.
-
-
+> [!TIP]
+> The early Marigold models (`v1-0` and earlier) were optimized for best results with at least 10 inference steps.
+> LCM models were later developed to enable high-quality inference in just 1 to 4 steps.
+> Marigold models `v1-1` and later use the DDIM scheduler to achieve optimal
+> results in as few as 1 to 4 steps.
## Available Pipelines
@@ -78,29 +75,23 @@ The following is a summary of the recommended checkpoints, all of which produce
| [prs-eth/marigold-depth-v1-1](https://huggingface.co/prs-eth/marigold-depth-v1-1) | Depth | Affine-invariant depth prediction assigns each pixel a value between 0 (near plane) and 1 (far plane), with both planes determined by the model during inference. |
| [prs-eth/marigold-normals-v0-1](https://huggingface.co/prs-eth/marigold-normals-v0-1) | Normals | The surface normals predictions are unit-length 3D vectors in the screen space camera, with values in the range from -1 to 1. |
| [prs-eth/marigold-iid-appearance-v1-1](https://huggingface.co/prs-eth/marigold-iid-appearance-v1-1) | Intrinsics | InteriorVerse decomposition is comprised of Albedo and two BRDF material properties: Roughness and Metallicity. |
-| [prs-eth/marigold-iid-lighting-v1-1](https://huggingface.co/prs-eth/marigold-iid-lighting-v1-1) | Intrinsics | HyperSim decomposition of an image  \\(I\\)  is comprised of Albedo  \\(A\\), Diffuse shading  \\(S\\), and Non-diffuse residual  \\(R\\):  \\(I = A*S+R\\). |
+| [prs-eth/marigold-iid-lighting-v1-1](https://huggingface.co/prs-eth/marigold-iid-lighting-v1-1) | Intrinsics | HyperSim decomposition of an image $I$ is comprised of Albedo $A$, Diffuse shading $S$, and Non-diffuse residual $R$: $I = A*S+R$. |
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff
+> between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to
+> efficiently load the same components into multiple pipelines.
+> Also, to know more about reducing the memory usage of this pipeline, refer to the ["Reduce memory usage"] section
+> [here](../../using-diffusers/svd#reduce-memory-usage).
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff
-between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to
-efficiently load the same components into multiple pipelines.
-Also, to know more about reducing the memory usage of this pipeline, refer to the ["Reduce memory usage"] section
-[here](../../using-diffusers/svd#reduce-memory-usage).
-
-
-
-
-
-Marigold pipelines were designed and tested with the scheduler embedded in the model checkpoint.
-The optimal number of inference steps varies by scheduler, with no universal value that works best across all cases.
-To accommodate this, the `num_inference_steps` parameter in the pipeline's `__call__` method defaults to `None` (see the
-API reference).
-Unless set explicitly, it inherits the value from the `default_denoising_steps` field in the checkpoint configuration
-file (`model_index.json`).
-This ensures high-quality predictions when invoking the pipeline with only the `image` argument.
-
-
+> [!WARNING]
+> Marigold pipelines were designed and tested with the scheduler embedded in the model checkpoint.
+> The optimal number of inference steps varies by scheduler, with no universal value that works best across all cases.
+> To accommodate this, the `num_inference_steps` parameter in the pipeline's `__call__` method defaults to `None` (see the
+> API reference).
+> Unless set explicitly, it inherits the value from the `default_denoising_steps` field in the checkpoint configuration
+> file (`model_index.json`).
+> This ensures high-quality predictions when invoking the pipeline with only the `image` argument.
See also Marigold [usage examples](../../using-diffusers/marigold_usage).
diff --git a/docs/source/en/api/pipelines/mochi.md b/docs/source/en/api/pipelines/mochi.md
index f1260b07b0..f19a9bd575 100644
--- a/docs/source/en/api/pipelines/mochi.md
+++ b/docs/source/en/api/pipelines/mochi.md
@@ -121,15 +121,13 @@ export_to_video(frames, "mochi.mp4", fps=30)
The [Genmo Mochi implementation](https://github.com/genmoai/mochi/tree/main) uses different precision values for each stage in the inference process. The text encoder and VAE use `torch.float32`, while the DiT uses `torch.bfloat16` with the [attention kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html#torch.nn.attention.sdpa_kernel) set to `EFFICIENT_ATTENTION`. Diffusers pipelines currently do not support setting different `dtypes` for different stages of the pipeline. In order to run inference in the same way as the original implementation, please refer to the following example.
-
-The original Mochi implementation zeros out empty prompts. However, enabling this option and placing the entire pipeline under autocast can lead to numerical overflows with the T5 text encoder.
+> [!TIP]
+> The original Mochi implementation zeros out empty prompts. However, enabling this option and placing the entire pipeline under autocast can lead to numerical overflows with the T5 text encoder.
+>
+> When enabling `force_zeros_for_empty_prompt`, it is recommended to run the text encoding step outside the autocast context in full precision.
-When enabling `force_zeros_for_empty_prompt`, it is recommended to run the text encoding step outside the autocast context in full precision.
-
-
-
-Decoding the latents in full precision is very memory intensive. You will need at least 70GB VRAM to generate the 163 frames in this example. To reduce memory, either reduce the number of frames or run the decoding step in `torch.bfloat16`.
-
+> [!TIP]
+> Decoding the latents in full precision is very memory intensive. You will need at least 70GB VRAM to generate the 163 frames in this example. To reduce memory, either reduce the number of frames or run the decoding step in `torch.bfloat16`.
```python
import torch
@@ -231,9 +229,8 @@ export_to_video(frames, "output.mp4", fps=30)
You can use `from_single_file` to load the Mochi transformer in its original format.
-
-Diffusers currently doesn't support using the FP8 scaled versions of the Mochi single file checkpoints.
-
+> [!TIP]
+> Diffusers currently doesn't support using the FP8 scaled versions of the Mochi single file checkpoints.
```python
import torch
diff --git a/docs/source/en/api/pipelines/musicldm.md b/docs/source/en/api/pipelines/musicldm.md
index c2297162f7..1a83e5932e 100644
--- a/docs/source/en/api/pipelines/musicldm.md
+++ b/docs/source/en/api/pipelines/musicldm.md
@@ -43,11 +43,8 @@ During inference:
* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1 to enable. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.
* The _length_ of the generated audio sample can be controlled by varying the `audio_length_in_s` argument.
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## MusicLDMPipeline
[[autodoc]] MusicLDMPipeline
diff --git a/docs/source/en/api/pipelines/omnigen.md b/docs/source/en/api/pipelines/omnigen.md
index 074e7b8f01..4fac5c789a 100644
--- a/docs/source/en/api/pipelines/omnigen.md
+++ b/docs/source/en/api/pipelines/omnigen.md
@@ -21,11 +21,8 @@ The abstract from the paper is:
*The emergence of Large Language Models (LLMs) has unified language generation tasks and revolutionized human-machine interaction. However, in the realm of image generation, a unified model capable of handling various tasks within a single framework remains largely unexplored. In this work, we introduce OmniGen, a new diffusion model for unified image generation. OmniGen is characterized by the following features: 1) Unification: OmniGen not only demonstrates text-to-image generation capabilities but also inherently supports various downstream tasks, such as image editing, subject-driven generation, and visual conditional generation. 2) Simplicity: The architecture of OmniGen is highly simplified, eliminating the need for additional plugins. Moreover, compared to existing diffusion models, it is more user-friendly and can complete complex tasks end-to-end through instructions without the need for extra intermediate steps, greatly simplifying the image generation workflow. 3) Knowledge Transfer: Benefit from learning in a unified format, OmniGen effectively transfers knowledge across different tasks, manages unseen tasks and domains, and exhibits novel capabilities. We also explore the model’s reasoning capabilities and potential applications of the chain-of-thought mechanism. This work represents the first attempt at a general-purpose image generation model, and we will release our resources at https://github.com/VectorSpaceLab/OmniGen to foster future advancements.*
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
This pipeline was contributed by [staoxiao](https://github.com/staoxiao). The original codebase can be found [here](https://github.com/VectorSpaceLab/OmniGen). The original weights can be found under [hf.co/shitao](https://huggingface.co/Shitao/OmniGen-v1).
diff --git a/docs/source/en/api/pipelines/overview.md b/docs/source/en/api/pipelines/overview.md
index 4e7a4e5e8d..22fcf560ea 100644
--- a/docs/source/en/api/pipelines/overview.md
+++ b/docs/source/en/api/pipelines/overview.md
@@ -16,15 +16,12 @@ Pipelines provide a simple way to run state-of-the-art diffusion models in infer
All pipelines are built from the base [`DiffusionPipeline`] class which provides basic functionality for loading, downloading, and saving all the components. Specific pipeline types (for example [`StableDiffusionPipeline`]) loaded with [`~DiffusionPipeline.from_pretrained`] are automatically detected and the pipeline components are loaded and passed to the `__init__` function of the pipeline.
-
-
-You shouldn't use the [`DiffusionPipeline`] class for training. Individual components (for example, [`UNet2DModel`] and [`UNet2DConditionModel`]) of diffusion pipelines are usually trained individually, so we suggest directly working with them instead.
-
-
-
-Pipelines do not offer any training functionality. You'll notice PyTorch's autograd is disabled by decorating the [`~DiffusionPipeline.__call__`] method with a [`torch.no_grad`](https://pytorch.org/docs/stable/generated/torch.no_grad.html) decorator because pipelines should not be used for training. If you're interested in training, please take a look at the [Training](../../training/overview) guides instead!
-
-
+> [!WARNING]
+> You shouldn't use the [`DiffusionPipeline`] class for training. Individual components (for example, [`UNet2DModel`] and [`UNet2DConditionModel`]) of diffusion pipelines are usually trained individually, so we suggest directly working with them instead.
+>
+>
+>
+> Pipelines do not offer any training functionality. You'll notice PyTorch's autograd is disabled by decorating the [`~DiffusionPipeline.__call__`] method with a [`torch.no_grad`](https://pytorch.org/docs/stable/generated/torch.no_grad.html) decorator because pipelines should not be used for training. If you're interested in training, please take a look at the [Training](../../training/overview) guides instead!
The table below lists all the pipelines currently available in 🤗 Diffusers and the tasks they support. Click on a pipeline to view its abstract and published paper.
@@ -35,8 +32,9 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [Attend-and-Excite](attend_and_excite) | text2image |
| [AudioLDM](audioldm) | text2audio |
| [AudioLDM2](audioldm2) | text2audio |
-| [AuraFlow](auraflow) | text2image |
+| [AuraFlow](aura_flow) | text2image |
| [BLIP Diffusion](blip_diffusion) | text2image |
+| [Bria 3.2](bria_3_2) | text2image |
| [CogVideoX](cogvideox) | text2video |
| [Consistency Models](consistency_models) | unconditional image generation |
| [ControlNet](controlnet) | text2image, image2image, inpainting |
@@ -105,10 +103,20 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
[[autodoc]] pipelines.StableDiffusionMixin.disable_freeu
-## FlaxDiffusionPipeline
-
-[[autodoc]] pipelines.pipeline_flax_utils.FlaxDiffusionPipeline
-
## PushToHubMixin
[[autodoc]] utils.PushToHubMixin
+
+## Callbacks
+
+[[autodoc]] callbacks.PipelineCallback
+
+[[autodoc]] callbacks.SDCFGCutoffCallback
+
+[[autodoc]] callbacks.SDXLCFGCutoffCallback
+
+[[autodoc]] callbacks.SDXLControlnetCFGCutoffCallback
+
+[[autodoc]] callbacks.IPAdapterScaleCutoffCallback
+
+[[autodoc]] callbacks.SD3CFGCutoffCallback
diff --git a/docs/source/en/api/pipelines/pag.md b/docs/source/en/api/pipelines/pag.md
index 7b87e58a87..35004b6ad3 100644
--- a/docs/source/en/api/pipelines/pag.md
+++ b/docs/source/en/api/pipelines/pag.md
@@ -31,11 +31,8 @@ PAG can be used by specifying the `pag_applied_layers` as a parameter when insta
- Partial identifier as a RegEx: `down_blocks.2`, or `attn1`
- List of identifiers (can be combo of strings and ReGex): `["blocks.1", "blocks.(14|20)", r"down_blocks\.(2,3)"]`
-
-
-Since RegEx is supported as a way for matching layer identifiers, it is crucial to use it correctly otherwise there might be unexpected behaviour. The recommended way to use PAG is by specifying layers as `blocks.{layer_index}` and `blocks.({layer_index_1|layer_index_2|...})`. Using it in any other way, while doable, may bypass our basic validation checks and give you unexpected results.
-
-
+> [!WARNING]
+> Since RegEx is supported as a way for matching layer identifiers, it is crucial to use it correctly otherwise there might be unexpected behaviour. The recommended way to use PAG is by specifying layers as `blocks.{layer_index}` and `blocks.({layer_index_1|layer_index_2|...})`. Using it in any other way, while doable, may bypass our basic validation checks and give you unexpected results.
## AnimateDiffPAGPipeline
[[autodoc]] AnimateDiffPAGPipeline
diff --git a/docs/source/en/api/pipelines/paint_by_example.md b/docs/source/en/api/pipelines/paint_by_example.md
index 362c26de68..02bf6db726 100644
--- a/docs/source/en/api/pipelines/paint_by_example.md
+++ b/docs/source/en/api/pipelines/paint_by_example.md
@@ -27,11 +27,8 @@ The original codebase can be found at [Fantasy-Studio/Paint-by-Example](https://
Paint by Example is supported by the official [Fantasy-Studio/Paint-by-Example](https://huggingface.co/Fantasy-Studio/Paint-by-Example) checkpoint. The checkpoint is warm-started from [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) to inpaint partly masked images conditioned on example and reference images.
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## PaintByExamplePipeline
[[autodoc]] PaintByExamplePipeline
diff --git a/docs/source/en/api/pipelines/panorama.md b/docs/source/en/api/pipelines/panorama.md
index 9f61388dd5..b65e05dd0b 100644
--- a/docs/source/en/api/pipelines/panorama.md
+++ b/docs/source/en/api/pipelines/panorama.md
@@ -42,11 +42,8 @@ For example, without circular padding, there is a stitching artifact (default):
But with circular padding, the right and the left parts are matching (`circular_padding=True`):

-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## StableDiffusionPanoramaPipeline
[[autodoc]] StableDiffusionPanoramaPipeline
diff --git a/docs/source/en/api/pipelines/pia.md b/docs/source/en/api/pipelines/pia.md
index 7bd480b49a..eebfa4d4f8 100644
--- a/docs/source/en/api/pipelines/pia.md
+++ b/docs/source/en/api/pipelines/pia.md
@@ -87,11 +87,8 @@ Here are some sample outputs:
-
-
-If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples. Additionally, the PIA checkpoints can be sensitive to the beta schedule of the scheduler. We recommend setting this to `linear`.
-
-
+> [!TIP]
+> If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples. Additionally, the PIA checkpoints can be sensitive to the beta schedule of the scheduler. We recommend setting this to `linear`.
## Using FreeInit
@@ -149,11 +146,8 @@ export_to_gif(frames, "pia-freeinit-animation.gif")
-
-
-FreeInit is not really free - the improved quality comes at the cost of extra computation. It requires sampling a few extra times depending on the `num_iters` parameter that is set when enabling it. Setting the `use_fast_sampling` parameter to `True` can improve the overall performance (at the cost of lower quality compared to when `use_fast_sampling=False` but still better results than vanilla video generation models).
-
-
+> [!WARNING]
+> FreeInit is not really free - the improved quality comes at the cost of extra computation. It requires sampling a few extra times depending on the `num_iters` parameter that is set when enabling it. Setting the `use_fast_sampling` parameter to `True` can improve the overall performance (at the cost of lower quality compared to when `use_fast_sampling=False` but still better results than vanilla video generation models).
## PIAPipeline
diff --git a/docs/source/en/api/pipelines/pix2pix.md b/docs/source/en/api/pipelines/pix2pix.md
index 20a74577c1..84eb0cb5e5 100644
--- a/docs/source/en/api/pipelines/pix2pix.md
+++ b/docs/source/en/api/pipelines/pix2pix.md
@@ -24,11 +24,8 @@ The abstract from the paper is:
You can find additional information about InstructPix2Pix on the [project page](https://www.timothybrooks.com/instruct-pix2pix), [original codebase](https://github.com/timothybrooks/instruct-pix2pix), and try it out in a [demo](https://huggingface.co/spaces/timbrooks/instruct-pix2pix).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## StableDiffusionInstructPix2PixPipeline
[[autodoc]] StableDiffusionInstructPix2PixPipeline
diff --git a/docs/source/en/api/pipelines/pixart.md b/docs/source/en/api/pipelines/pixart.md
index a36a2a4b7a..dbdc89857e 100644
--- a/docs/source/en/api/pipelines/pixart.md
+++ b/docs/source/en/api/pipelines/pixart.md
@@ -29,11 +29,8 @@ Some notes about this pipeline:
* It is good at producing high-resolution images at different aspect ratios. To get the best results, the authors recommend some size brackets which can be found [here](https://github.com/PixArt-alpha/PixArt-alpha/blob/08fbbd281ec96866109bdd2cdb75f2f58fb17610/diffusion/data/datasets/utils.py).
* It rivals the quality of state-of-the-art text-to-image generation systems (as of this writing) such as Stable Diffusion XL, Imagen, and DALL-E 2, while being more efficient than them.
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## Inference with under 8GB GPU VRAM
@@ -112,11 +109,8 @@ del pipe.transformer
flush()
```
-
-
-Notice that while initializing `pipe`, you're setting `text_encoder` to `None` so that it's not loaded.
-
-
+> [!TIP]
+> Notice that while initializing `pipe`, you're setting `text_encoder` to `None` so that it's not loaded.
Once the latents are computed, pass it off to the VAE to decode into a real image:
@@ -133,11 +127,8 @@ By deleting components you aren't using and flushing the GPU VRAM, you should be
If you want a report of your memory-usage, run this [script](https://gist.github.com/sayakpaul/3ae0f847001d342af27018a96f467e4e).
-
-
-Text embeddings computed in 8-bit can impact the quality of the generated images because of the information loss in the representation space caused by the reduced precision. It's recommended to compare the outputs with and without 8-bit.
-
-
+> [!WARNING]
+> Text embeddings computed in 8-bit can impact the quality of the generated images because of the information loss in the representation space caused by the reduced precision. It's recommended to compare the outputs with and without 8-bit.
While loading the `text_encoder`, you set `load_in_8bit` to `True`. You could also specify `load_in_4bit` to bring your memory requirements down even further to under 7GB.
diff --git a/docs/source/en/api/pipelines/pixart_sigma.md b/docs/source/en/api/pipelines/pixart_sigma.md
index dded4ea2d7..06b54de43b 100644
--- a/docs/source/en/api/pipelines/pixart_sigma.md
+++ b/docs/source/en/api/pipelines/pixart_sigma.md
@@ -31,17 +31,11 @@ Some notes about this pipeline:
* It shows the ability of generating super high resolution images, such as 2048px or even 4K.
* It shows that text-to-image models can grow from a weak model to a stronger one through several improvements (VAEs, datasets, and so on.)
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
-
-
-
-You can further improve generation quality by passing the generated image from [`PixArtSigmaPipeline`] to the [SDXL refiner](../../using-diffusers/sdxl#base-to-refiner-model) model.
-
-
+> [!TIP]
+> You can further improve generation quality by passing the generated image from [`PixArtSigmaPipeline`] to the [SDXL refiner](../../using-diffusers/sdxl#base-to-refiner-model) model.
## Inference with under 8GB GPU VRAM
@@ -119,11 +113,8 @@ del pipe.transformer
flush()
```
-
-
-Notice that while initializing `pipe`, you're setting `text_encoder` to `None` so that it's not loaded.
-
-
+> [!TIP]
+> Notice that while initializing `pipe`, you're setting `text_encoder` to `None` so that it's not loaded.
Once the latents are computed, pass it off to the VAE to decode into a real image:
@@ -140,11 +131,8 @@ By deleting components you aren't using and flushing the GPU VRAM, you should be
If you want a report of your memory-usage, run this [script](https://gist.github.com/sayakpaul/3ae0f847001d342af27018a96f467e4e).
-
-
-Text embeddings computed in 8-bit can impact the quality of the generated images because of the information loss in the representation space caused by the reduced precision. It's recommended to compare the outputs with and without 8-bit.
-
-
+> [!WARNING]
+> Text embeddings computed in 8-bit can impact the quality of the generated images because of the information loss in the representation space caused by the reduced precision. It's recommended to compare the outputs with and without 8-bit.
While loading the `text_encoder`, you set `load_in_8bit` to `True`. You could also specify `load_in_4bit` to bring your memory requirements down even further to under 7GB.
diff --git a/docs/source/en/api/pipelines/qwenimage.md b/docs/source/en/api/pipelines/qwenimage.md
index 8f9529fef7..b3dd3dd936 100644
--- a/docs/source/en/api/pipelines/qwenimage.md
+++ b/docs/source/en/api/pipelines/qwenimage.md
@@ -14,15 +14,105 @@
# QwenImage
+
+
+
+
Qwen-Image from the Qwen team is an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing. Experiments show strong general capabilities in both image generation and editing, with exceptional performance in text rendering, especially for Chinese.
-Check out the model card [here](https://huggingface.co/Qwen/Qwen-Image) to learn more.
+Qwen-Image comes in the following variants:
-
+| model type | model id |
+|:----------:|:--------:|
+| Qwen-Image | [`Qwen/Qwen-Image`](https://huggingface.co/Qwen/Qwen-Image) |
+| Qwen-Image-Edit | [`Qwen/Qwen-Image-Edit`](https://huggingface.co/Qwen/Qwen-Image-Edit) |
+| Qwen-Image-Edit Plus | [Qwen/Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) |
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+> [!TIP]
+> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
-
+## LoRA for faster inference
+
+Use a LoRA from `lightx2v/Qwen-Image-Lightning` to speed up inference by reducing the
+number of steps. Refer to the code snippet below:
+
+
+Code
+
+```py
+from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
+import torch
+import math
+
+ckpt_id = "Qwen/Qwen-Image"
+
+# From
+# https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10
+scheduler_config = {
+ "base_image_seq_len": 256,
+ "base_shift": math.log(3), # We use shift=3 in distillation
+ "invert_sigmas": False,
+ "max_image_seq_len": 8192,
+ "max_shift": math.log(3), # We use shift=3 in distillation
+ "num_train_timesteps": 1000,
+ "shift": 1.0,
+ "shift_terminal": None, # set shift_terminal to None
+ "stochastic_sampling": False,
+ "time_shift_type": "exponential",
+ "use_beta_sigmas": False,
+ "use_dynamic_shifting": True,
+ "use_exponential_sigmas": False,
+ "use_karras_sigmas": False,
+}
+scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
+pipe = DiffusionPipeline.from_pretrained(
+ ckpt_id, scheduler=scheduler, torch_dtype=torch.bfloat16
+).to("cuda")
+pipe.load_lora_weights(
+ "lightx2v/Qwen-Image-Lightning", weight_name="Qwen-Image-Lightning-8steps-V1.0.safetensors"
+)
+
+prompt = "a tiny astronaut hatching from an egg on the moon, Ultra HD, 4K, cinematic composition."
+negative_prompt = " "
+image = pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=1024,
+ height=1024,
+ num_inference_steps=8,
+ true_cfg_scale=1.0,
+ generator=torch.manual_seed(0),
+).images[0]
+image.save("qwen_fewsteps.png")
+```
+
+
+
+> [!TIP]
+> The `guidance_scale` parameter in the pipeline is there to support future guidance-distilled models when they come up. Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance, please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should enable classifier-free guidance computations.
+
+## Multi-image reference with QwenImageEditPlusPipeline
+
+With [`QwenImageEditPlusPipeline`], one can provide multiple images as input reference.
+
+```
+import torch
+from PIL import Image
+from diffusers import QwenImageEditPlusPipeline
+from diffusers.utils import load_image
+
+pipe = QwenImageEditPlusPipeline.from_pretrained(
+ "Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16
+).to("cuda")
+
+image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg")
+image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png")
+image = pipe(
+ image=[image_1, image_2],
+ prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
+ num_inference_steps=50
+).images[0]
+```
## QwenImagePipeline
@@ -30,6 +120,42 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
- all
- __call__
+## QwenImageImg2ImgPipeline
+
+[[autodoc]] QwenImageImg2ImgPipeline
+ - all
+ - __call__
+
+## QwenImageInpaintPipeline
+
+[[autodoc]] QwenImageInpaintPipeline
+ - all
+ - __call__
+
+## QwenImageEditPipeline
+
+[[autodoc]] QwenImageEditPipeline
+ - all
+ - __call__
+
+## QwenImageEditInpaintPipeline
+
+[[autodoc]] QwenImageEditInpaintPipeline
+ - all
+ - __call__
+
+## QwenImageControlNetPipeline
+
+[[autodoc]] QwenImageControlNetPipeline
+ - all
+ - __call__
+
+## QwenImageEditPlusPipeline
+
+[[autodoc]] QwenImageEditPlusPipeline
+ - all
+ - __call__
+
## QwenImagePipelineOutput
-[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput
+[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput
\ No newline at end of file
diff --git a/docs/source/en/api/pipelines/sana.md b/docs/source/en/api/pipelines/sana.md
index 7491689fd8..a948620f96 100644
--- a/docs/source/en/api/pipelines/sana.md
+++ b/docs/source/en/api/pipelines/sana.md
@@ -25,11 +25,8 @@ The abstract from the paper is:
*We introduce Sana, a text-to-image framework that can efficiently generate images up to 4096×4096 resolution. Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU. Core designs include: (1) Deep compression autoencoder: unlike traditional AEs, which compress images only 8×, we trained an AE that can compress images 32×, effectively reducing the number of latent tokens. (2) Linear DiT: we replace all vanilla attention in DiT with linear attention, which is more efficient at high resolutions without sacrificing quality. (3) Decoder-only text encoder: we replaced T5 with modern decoder-only small LLM as the text encoder and designed complex human instruction with in-context learning to enhance the image-text alignment. (4) Efficient training and sampling: we propose Flow-DPM-Solver to reduce sampling steps, with efficient caption labeling and selection to accelerate convergence. As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.g. Flux-12B), being 20 times smaller and 100+ times faster in measured throughput. Moreover, Sana-0.6B can be deployed on a 16GB laptop GPU, taking less than 1 second to generate a 1024×1024 resolution image. Sana enables content creation at low cost. Code and model will be publicly released.*
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj) and [chenjy2003](https://github.com/chenjy2003). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model).
@@ -49,11 +46,8 @@ Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-67
Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
-
-
-Make sure to pass the `variant` argument for downloaded checkpoints to use lower disk space. Set it to `"fp16"` for models with recommended dtype as `torch.float16`, and `"bf16"` for models with recommended dtype as `torch.bfloat16`. By default, `torch.float32` weights are downloaded, which use twice the amount of disk storage. Additionally, `torch.float32` weights can be downcasted on-the-fly by specifying the `torch_dtype` argument. Read about it in the [docs](https://huggingface.co/docs/diffusers/v0.31.0/en/api/pipelines/overview#diffusers.DiffusionPipeline.from_pretrained).
-
-
+> [!TIP]
+> Make sure to pass the `variant` argument for downloaded checkpoints to use lower disk space. Set it to `"fp16"` for models with recommended dtype as `torch.float16`, and `"bf16"` for models with recommended dtype as `torch.bfloat16`. By default, `torch.float32` weights are downloaded, which use twice the amount of disk storage. Additionally, `torch.float32` weights can be downcasted on-the-fly by specifying the `torch_dtype` argument. Read about it in the [docs](https://huggingface.co/docs/diffusers/v0.31.0/en/api/pipelines/overview#diffusers.DiffusionPipeline.from_pretrained).
## Quantization
diff --git a/docs/source/en/api/pipelines/sana_sprint.md b/docs/source/en/api/pipelines/sana_sprint.md
index 93ab9fe418..357d7e406d 100644
--- a/docs/source/en/api/pipelines/sana_sprint.md
+++ b/docs/source/en/api/pipelines/sana_sprint.md
@@ -24,11 +24,8 @@ The abstract from the paper is:
*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.*
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/).
diff --git a/docs/source/en/api/pipelines/self_attention_guidance.md b/docs/source/en/api/pipelines/self_attention_guidance.md
index 5578fdfa63..8d411598ae 100644
--- a/docs/source/en/api/pipelines/self_attention_guidance.md
+++ b/docs/source/en/api/pipelines/self_attention_guidance.md
@@ -23,11 +23,8 @@ The abstract from the paper is:
You can find additional information about Self-Attention Guidance on the [project page](https://ku-cvlab.github.io/Self-Attention-Guidance), [original codebase](https://github.com/KU-CVLAB/Self-Attention-Guidance), and try it out in a [demo](https://huggingface.co/spaces/susunghong/Self-Attention-Guidance) or [notebook](https://colab.research.google.com/github/SusungHong/Self-Attention-Guidance/blob/main/SAG_Stable.ipynb).
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## StableDiffusionSAGPipeline
[[autodoc]] StableDiffusionSAGPipeline
diff --git a/docs/source/en/api/pipelines/semantic_stable_diffusion.md b/docs/source/en/api/pipelines/semantic_stable_diffusion.md
index 1ce44cf2de..dda428e80f 100644
--- a/docs/source/en/api/pipelines/semantic_stable_diffusion.md
+++ b/docs/source/en/api/pipelines/semantic_stable_diffusion.md
@@ -22,11 +22,8 @@ The abstract from the paper is:
*Text-to-image diffusion models have recently received a lot of interest for their astonishing ability to produce high-fidelity images from text only. However, achieving one-shot generation that aligns with the user's intent is nearly impossible, yet small changes to the input prompt often result in very different images. This leaves the user with little semantic control. To put the user in control, we show how to interact with the diffusion process to flexibly steer it along semantic directions. This semantic guidance (SEGA) generalizes to any generative architecture using classifier-free guidance. More importantly, it allows for subtle and extensive edits, changes in composition and style, as well as optimizing the overall artistic conception. We demonstrate SEGA's effectiveness on both latent and pixel-based diffusion models such as Stable Diffusion, Paella, and DeepFloyd-IF using a variety of tasks, thus providing strong evidence for its versatility, flexibility, and improvements over existing methods.*
-
-
-Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## SemanticStableDiffusionPipeline
[[autodoc]] SemanticStableDiffusionPipeline
diff --git a/docs/source/en/api/pipelines/shap_e.md b/docs/source/en/api/pipelines/shap_e.md
index 5e5af0656a..3e505894ca 100644
--- a/docs/source/en/api/pipelines/shap_e.md
+++ b/docs/source/en/api/pipelines/shap_e.md
@@ -17,11 +17,8 @@ The abstract from the paper is:
The original codebase can be found at [openai/shap-e](https://github.com/openai/shap-e).
-
-
-See the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
-
-
+> [!TIP]
+> See the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## ShapEPipeline
[[autodoc]] ShapEPipeline
diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md
index cd94f2a75c..6730f15516 100644
--- a/docs/source/en/api/pipelines/skyreels_v2.md
+++ b/docs/source/en/api/pipelines/skyreels_v2.md
@@ -1,4 +1,4 @@
-
-# Components Manager
+# ComponentsManager
-
+The [`ComponentsManager`] is a model registry and management system for Modular Diffusers. It adds and tracks models, stores useful metadata (model size, device placement, adapters), prevents duplicate model instances, and supports offloading.
-🧪 **Experimental Feature**: This is an experimental feature we are actively developing. The API may be subject to breaking changes.
+This guide will show you how to use [`ComponentsManager`] to manage components and device memory.
-
+## Add a component
-The Components Manager is a central model registry and management system in diffusers. It lets you add models then reuse them across multiple pipelines and workflows. It tracks all models in one place with useful metadata such as model size, device placement and loaded adapters (LoRA, IP-Adapter). It has mechanisms in place to prevent duplicate model instances, enables memory-efficient sharing. Most significantly, it offers offloading that works across pipelines — unlike regular DiffusionPipeline offloading (i.e. `enable_model_cpu_offload` and `enable_sequential_cpu_offload`) which is limited to one pipeline with predefined sequences, the Components Manager automatically manages your device memory across all your models and workflows.
+The [`ComponentsManager`] should be created alongside a [`ModularPipeline`] in either [`~ModularPipeline.from_pretrained`] or [`~ModularPipelineBlocks.init_pipeline`].
+> [!TIP]
+> The `collection` parameter is optional but makes it easier to organize and manage components.
-## Basic Operations
+
+
-Let's start with the most basic operations. First, create a Components Manager:
+```py
+from diffusers import ModularPipeline, ComponentsManager
+
+comp = ComponentsManager()
+pipe = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test1")
+```
+
+
+
```py
from diffusers import ComponentsManager
-comp = ComponentsManager()
+from diffusers.modular_pipelines import SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
+
+t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
+
+modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
+components = ComponentsManager()
+t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components)
```
-Use the `add(name, component)` method to register a component. It returns a unique ID that combines the component name with the object's unique identifier (using Python's `id()` function):
+
+
+
+Components are only loaded and registered when using [`~ModularPipeline.load_components`] or [`~ModularPipeline.load_components`]. The example below uses [`~ModularPipeline.load_components`] to create a second pipeline that reuses all the components from the first one, and assigns it to a different collection
+
+```py
+pipe.load_components()
+pipe2 = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test2")
+```
+
+Use the [`~ModularPipeline.null_component_names`] property to identify any components that need to be loaded, retrieve them with [`~ComponentsManager.get_components_by_names`], and then call [`~ModularPipeline.update_components`] to add the missing components.
+
+```py
+pipe2.null_component_names
+['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'image_encoder', 'unet', 'vae', 'scheduler', 'controlnet']
+
+comp_dict = comp.get_components_by_names(names=pipe2.null_component_names)
+pipe2.update_components(**comp_dict)
+```
+
+To add individual components, use the [`~ComponentsManager.add`] method. This registers a component with a unique id.
```py
from diffusers import AutoModel
+
text_encoder = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
-# Returns component_id like 'text_encoder_139917733042864'
component_id = comp.add("text_encoder", text_encoder)
+comp
```
-You can view all registered components and their metadata:
-
-```py
->>> comp
-Components:
-===============================================================================================================================================
-Models:
------------------------------------------------------------------------------------------------------------------------------------------------
-Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
------------------------------------------------------------------------------------------------------------------------------------------------
-text_encoder_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A
------------------------------------------------------------------------------------------------------------------------------------------------
-
-Additional Component Info:
-==================================================
-```
-
-And remove components using their unique ID:
+Use [`~ComponentsManager.remove`] to remove a component using their id.
```py
comp.remove("text_encoder_139917733042864")
```
-## Duplicate Detection
+## Retrieve a component
-The Components Manager automatically detects and prevents duplicate model instances to save memory and avoid confusion. Let's walk through how this works in practice.
+The [`ComponentsManager`] provides several methods to retrieve registered components.
-When you try to add the same object twice, the manager will warn you and return the existing ID:
+### get_one
+
+The [`~ComponentsManager.get_one`] method returns a single component and supports pattern matching for the `name` parameter. If multiple components match, [`~ComponentsManager.get_one`] returns an error.
+
+| Pattern | Example | Description |
+|-------------|----------------------------------|-------------------------------------------|
+| exact | `comp.get_one(name="unet")` | exact name match |
+| wildcard | `comp.get_one(name="unet*")` | names starting with "unet" |
+| exclusion | `comp.get_one(name="!unet")` | exclude components named "unet" |
+| or | `comp.get_one(name="unet|vae")` | name is "unet" or "vae" |
+
+[`~ComponentsManager.get_one`] also filters components by the `collection` argument or `load_id` argument.
```py
->>> comp.add("text_encoder", text_encoder)
-'text_encoder_139917733042864'
->>> comp.add("text_encoder", text_encoder)
-ComponentsManager: component 'text_encoder' already exists as 'text_encoder_139917733042864'
-'text_encoder_139917733042864'
+comp.get_one(name="unet", collection="sdxl")
```
-Even if you add the same object under a different name, it will still be detected as a duplicate:
+### get_components_by_names
+
+The [`~ComponentsManager.get_components_by_names`] method accepts a list of names and returns a dictionary mapping names to components. This is especially useful with [`ModularPipeline`] since they provide lists of required component names and the returned dictionary can be passed directly to [`~ModularPipeline.update_components`].
```py
->>> comp.add("clip", text_encoder)
-ComponentsManager: adding component 'clip' as 'clip_139917733042864', but it is duplicate of 'text_encoder_139917733042864'
-To remove a duplicate, call `components_manager.remove('')`.
-'clip_139917733042864'
+component_dict = comp.get_components_by_names(names=["text_encoder", "unet", "vae"])
+{"text_encoder": component1, "unet": component2, "vae": component3}
```
-However, there's a more subtle case where duplicate detection becomes tricky. When you load the same model into different objects, the manager can't detect duplicates unless you use `ComponentSpec`. For example:
+## Duplicate detection
-```py
->>> text_encoder_2 = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
->>> comp.add("text_encoder", text_encoder_2)
-'text_encoder_139917732983664'
-```
-
-This creates a problem - you now have two copies of the same model consuming double the memory:
-
-```py
->>> comp
-Components:
-===============================================================================================================================================
-Models:
------------------------------------------------------------------------------------------------------------------------------------------------
-Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
------------------------------------------------------------------------------------------------------------------------------------------------
-text_encoder_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A
-clip_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A
-text_encoder_139917732983664 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A
------------------------------------------------------------------------------------------------------------------------------------------------
-
-Additional Component Info:
-==================================================
-```
-
-We recommend using `ComponentSpec` to load your models. Models loaded with `ComponentSpec` get tagged with a unique ID that encodes their loading parameters, allowing the Components Manager to detect when different objects represent the same underlying checkpoint:
+It is recommended to load model components with [`ComponentSpec`] to assign components with a unique id that encodes their loading parameters. This allows [`ComponentsManager`] to automatically detect and prevent duplicate model instances even when different objects represent the same underlying checkpoint.
```py
from diffusers import ComponentSpec, ComponentsManager
from transformers import CLIPTextModel
+
comp = ComponentsManager()
# Create ComponentSpec for the first text encoder
spec = ComponentSpec(name="text_encoder", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=AutoModel)
-# Create ComponentSpec for a duplicate text encoder (it is same checkpoint, from same repo/subfolder)
+# Create ComponentSpec for a duplicate text encoder (it is same checkpoint, from the same repo/subfolder)
spec_duplicated = ComponentSpec(name="text_encoder_duplicated", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=CLIPTextModel)
# Load and add both components - the manager will detect they're the same model
@@ -129,42 +134,36 @@ comp.add("text_encoder", spec.load())
comp.add("text_encoder_duplicated", spec_duplicated.load())
```
-Now the manager detects the duplicate and warns you:
+This returns a warning with instructions for removing the duplicate.
-```out
+```py
ComponentsManager: adding component 'text_encoder_duplicated_139917580682672', but it has duplicate load_id 'stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null' with existing components: text_encoder_139918506246832. To remove a duplicate, call `components_manager.remove('')`.
'text_encoder_duplicated_139917580682672'
```
-Both models now show the same `load_id`, making it clear they're the same model:
+You could also add a component without using [`ComponentSpec`] and duplicate detection still works in most cases even if you're adding the same component under a different name.
+
+However, [`ComponentManager`] can't detect duplicates when you load the same component into different objects. In this case, you should load a model with [`ComponentSpec`].
```py
->>> comp
-Components:
-======================================================================================================================================================================================================
-Models:
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
-Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
-text_encoder_139918506246832 | CLIPTextModel | cpu | torch.float32 | 0.46 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | N/A
-text_encoder_duplicated_139917580682672 | CLIPTextModel | cpu | torch.float32 | 0.46 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | N/A
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
-
-Additional Component Info:
-==================================================
+text_encoder_2 = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
+comp.add("text_encoder", text_encoder_2)
+'text_encoder_139917732983664'
```
## Collections
-Collections are labels you can assign to components for better organization and management. You add a component under a collection by passing the `collection=` parameter when you add the component to the manager, i.e. `add(name, component, collection=...)`. Within each collection, only one component per name is allowed - if you add a second component with the same name, the first one is automatically removed.
+Collections are labels assigned to components for better organization and management. Add a component to a collection with the `collection` argument in [`~ComponentsManager.add`].
-Here's how collections work in practice:
+Only one component per name is allowed in each collection. Adding a second component with the same name automatically removes the first component.
```py
+from diffusers import ComponentSpec, ComponentsManager
+
comp = ComponentsManager()
-# Create ComponentSpec for the first UNet (SDXL base)
+# Create ComponentSpec for the first UNet
spec = ComponentSpec(name="unet", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", type_hint=AutoModel)
-# Create ComponentSpec for a different UNet (Juggernaut-XL)
+# Create ComponentSpec for a different UNet
spec2 = ComponentSpec(name="unet", repo="RunDiffusion/Juggernaut-XL-v9", subfolder="unet", type_hint=AutoModel, variant="fp16")
# Add both UNets to the same collection - the second one will replace the first
@@ -172,343 +171,20 @@ comp.add("unet", spec.load(), collection="sdxl")
comp.add("unet", spec2.load(), collection="sdxl")
```
-The manager automatically removes the old UNet and adds the new one:
+This makes it convenient to work with node-based systems because you can:
-```out
-ComponentsManager: removing existing unet from collection 'sdxl': unet_139917723891888
-'unet_139917723893136'
-```
+- Mark all models as loaded from one node with the `collection` label.
+- Automatically replace models when new checkpoints are loaded under the same name.
+- Batch delete all models in a collection when a node is removed.
-Only one UNet remains in the collection:
+## Offloading
-```py
->>> comp
-Components:
-====================================================================================================================================================================
-Models:
---------------------------------------------------------------------------------------------------------------------------------------------------------------------
-Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
---------------------------------------------------------------------------------------------------------------------------------------------------------------------
-unet_139917723893136 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | RunDiffusion/Juggernaut-XL-v9|unet|fp16|null | sdxl
---------------------------------------------------------------------------------------------------------------------------------------------------------------------
-
-Additional Component Info:
-==================================================
-```
-
-For example, in node-based systems, you can mark all models loaded from one node with the same collection label, automatically replace models when user loads new checkpoints under same name, batch delete all models in a collection when a node is removed.
-
-## Retrieving Components
-
-The Components Manager provides several methods to retrieve registered components.
-
-The `get_one()` method returns a single component and supports pattern matching for the `name` parameter. You can use:
-- exact matches like `comp.get_one(name="unet")`
-- wildcards like `comp.get_one(name="unet*")` for components starting with "unet"
-- exclusion patterns like `comp.get_one(name="!unet")` to exclude components named "unet"
-- OR patterns like `comp.get_one(name="unet|vae")` to match either "unet" OR "vae".
-
-Optionally, You can add collection and load_id as filters e.g. `comp.get_one(name="unet", collection="sdxl")`. If multiple components match, `get_one()` throws an error.
-
-Another useful method is `get_components_by_names()`, which takes a list of names and returns a dictionary mapping names to components. This is particularly helpful with modular pipelines since they provide lists of required component names, and the returned dictionary can be directly passed to `pipeline.update_components()`.
-
-```py
-# Get components by name list
-component_dict = comp.get_components_by_names(names=["text_encoder", "unet", "vae"])
-# Returns: {"text_encoder": component1, "unet": component2, "vae": component3}
-```
-
-## Using Components Manager with Modular Pipelines
-
-The Components Manager integrates seamlessly with Modular Pipelines. All you need to do is pass a Components Manager instance to `from_pretrained()` or `init_pipeline()` with an optional `collection` parameter:
-
-```py
-from diffusers import ModularPipeline, ComponentsManager
-comp = ComponentsManager()
-pipe = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test1")
-```
-
-By default, modular pipelines don't load components immediately, so both the pipeline and Components Manager start empty:
-
-```py
->>> comp
-Components:
-==================================================
-No components registered.
-==================================================
-```
-
-When you load components on the pipeline, they are automatically registered in the Components Manager:
-
-```py
->>> pipe.load_components(names="unet")
->>> comp
-Components:
-==============================================================================================================================================================
-Models:
---------------------------------------------------------------------------------------------------------------------------------------------------------------
-Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
---------------------------------------------------------------------------------------------------------------------------------------------------------------
-unet_139917726686304 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | SG161222/RealVisXL_V4.0|unet|null|null | test1
---------------------------------------------------------------------------------------------------------------------------------------------------------------
-
-Additional Component Info:
-==================================================
-```
-
-Now let's load all default components and then create a second pipeline that reuses all components from the first one. We pass the same Components Manager to the second pipeline but with a different collection:
-
-```py
-# Load all default components
->>> pipe.load_default_components()
-
-# Create a second pipeline using the same Components Manager but with a different collection
->>> pipe2 = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test2")
-```
-
-As mentioned earlier, `ModularPipeline` has a property `null_component_names` that returns a list of component names it needs to load. We can conveniently use this list with the `get_components_by_names` method on the Components Manager:
-
-```py
-# Get the list of components that pipe2 needs to load
->>> pipe2.null_component_names
-['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'image_encoder', 'unet', 'vae', 'scheduler', 'controlnet']
-
-# Retrieve all required components from the Components Manager
->>> comp_dict = comp.get_components_by_names(names=pipe2.null_component_names)
-
-# Update the pipeline with the retrieved components
->>> pipe2.update_components(**comp_dict)
-```
-
-The warnings that follow are expected and indicate that the Components Manager is correctly identifying that these components already exist and will be reused rather than creating duplicates:
-
-```out
-ComponentsManager: component 'text_encoder' already exists as 'text_encoder_139917586016400'
-ComponentsManager: component 'text_encoder_2' already exists as 'text_encoder_2_139917699973424'
-ComponentsManager: component 'tokenizer' already exists as 'tokenizer_139917580599504'
-ComponentsManager: component 'tokenizer_2' already exists as 'tokenizer_2_139915763443904'
-ComponentsManager: component 'image_encoder' already exists as 'image_encoder_139917722468304'
-ComponentsManager: component 'unet' already exists as 'unet_139917580609632'
-ComponentsManager: component 'vae' already exists as 'vae_139917722459040'
-ComponentsManager: component 'scheduler' already exists as 'scheduler_139916266559408'
-ComponentsManager: component 'controlnet' already exists as 'controlnet_139917722454432'
-```
-
-
-The pipeline is now fully loaded:
-
-```py
-# null_component_names return empty list, meaning everything are loaded
->>> pipe2.null_component_names
-[]
-```
-
-No new components were added to the Components Manager - we're reusing everything. All models are now associated with both `test1` and `test2` collections, showing that these components are shared across multiple pipelines:
-```py
->>> comp
-Components:
-========================================================================================================================================================================================
-Models:
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
-Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
-text_encoder_139917586016400 | CLIPTextModel | cpu | torch.float32 | 0.46 | SG161222/RealVisXL_V4.0|text_encoder|null|null | test1
- | | | | | | test2
-text_encoder_2_139917699973424 | CLIPTextModelWithProjection | cpu | torch.float32 | 2.59 | SG161222/RealVisXL_V4.0|text_encoder_2|null|null | test1
- | | | | | | test2
-unet_139917580609632 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | SG161222/RealVisXL_V4.0|unet|null|null | test1
- | | | | | | test2
-controlnet_139917722454432 | ControlNetModel | cpu | torch.float32 | 4.66 | diffusers/controlnet-canny-sdxl-1.0|null|null|null | test1
- | | | | | | test2
-vae_139917722459040 | AutoencoderKL | cpu | torch.float32 | 0.31 | SG161222/RealVisXL_V4.0|vae|null|null | test1
- | | | | | | test2
-image_encoder_139917722468304 | CLIPVisionModelWithProjection | cpu | torch.float32 | 6.87 | h94/IP-Adapter|sdxl_models/image_encoder|null|null | test1
- | | | | | | test2
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
-
-Other Components:
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
-ID | Class | Collection
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
-tokenizer_139917580599504 | CLIPTokenizer | test1
- | | test2
-scheduler_139916266559408 | EulerDiscreteScheduler | test1
- | | test2
-tokenizer_2_139915763443904 | CLIPTokenizer | test1
- | | test2
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
-
-Additional Component Info:
-==================================================
-```
-
-
-## Automatic Memory Management
-
-The Components Manager provides a global offloading strategy across all models, regardless of which pipeline is using them:
+The [`~ComponentsManager.enable_auto_cpu_offload`] method is a global offloading strategy that works across all models regardless of which pipeline is using them. Once enabled, you don't need to worry about device placement if you add or remove components.
```py
comp.enable_auto_cpu_offload(device="cuda")
```
-When enabled, all models start on CPU. The manager moves models to the device right before they're used and moves other models back to CPU when GPU memory runs low. You can set your own rules for which models to offload first. This works smoothly as you add or remove components. Once it's on, you don't need to worry about device placement - you can focus on your workflow.
-
-
-
-## Practical Example: Building Modular Workflows with Component Reuse
-
-Now that we've covered the basics of the Components Manager, let's walk through a practical example that shows how to build workflows in a modular setting and use the Components Manager to reuse components across multiple pipelines. This example demonstrates the true power of Modular Diffusers by working with multiple pipelines that can share components.
-
-In this example, we'll generate latents from a text-to-image pipeline, then refine them with an image-to-image pipeline.
-
-Let's create a modular text-to-image workflow by separating it into three workflows: `text_blocks` for encoding prompts, `t2i_blocks` for generating latents, and `decoder_blocks` for creating final images.
-
-```py
-import torch
-from diffusers.modular_pipelines import SequentialPipelineBlocks
-from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS
-
-# Create modular blocks and separate text encoding and decoding steps
-t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["text2img"])
-text_blocks = t2i_blocks.sub_blocks.pop("text_encoder")
-decoder_blocks = t2i_blocks.sub_blocks.pop("decode")
-```
-
-Now we will convert them into runnalbe pipelines and set up the Components Manager with auto offloading and organize components under a "t2i" collection
-
-Since we now have 3 different workflows that share components, we create a separate pipeline that serves as a dedicated loader to load all the components, register them to the component manager, and then reuse them across different workflows.
-
-```py
-from diffusers import ComponentsManager, ModularPipeline
-
-# Set up Components Manager with auto offloading
-components = ComponentsManager()
-components.enable_auto_cpu_offload(device="cuda")
-
-# Create a new pipeline to load the components
-t2i_repo = "YiYiXu/modular-demo-auto"
-t2i_loader_pipe = ModularPipeline.from_pretrained(t2i_repo, components_manager=components, collection="t2i")
-
-# convert the 3 blocks into pipelines and attach the same components manager to all 3
-text_node = text_blocks.init_pipeline(t2i_repo, components_manager=components)
-decoder_node = decoder_blocks.init_pipeline(t2i_repo, components_manager=components)
-t2i_pipe = t2i_blocks.init_pipeline(t2i_repo, components_manager=components)
-```
-
-Load all components into the loader pipeline, they should all be automatically registered to Components Manager under the "t2i" collection:
-
-```py
-# Load all components (including IP-Adapter and ControlNet for later use)
-t2i_loader_pipe.load_default_components(torch_dtype=torch.float16)
-```
-
-Now distribute the loaded components to each pipeline:
-
-```py
-# Get VAE for decoder (using get_one since there's only one)
-vae = components.get_one(load_id="SG161222/RealVisXL_V4.0|vae|null|null")
-decoder_node.update_components(vae=vae)
-
-# Get text components for text node (using get_components_by_names for multiple components)
-text_components = components.get_components_by_names(text_node.null_component_names)
-text_node.update_components(**text_components)
-
-# Get remaining components for t2i pipeline
-t2i_components = components.get_components_by_names(t2i_pipe.null_component_names)
-t2i_pipe.update_components(**t2i_components)
-```
-
-Now we can generate images using our modular workflow:
-
-```py
-# Generate text embeddings
-prompt = "an astronaut"
-text_embeddings = text_node(prompt=prompt, output=["prompt_embeds","negative_prompt_embeds", "pooled_prompt_embeds", "negative_pooled_prompt_embeds"])
-
-# Generate latents and decode to image
-generator = torch.Generator(device="cuda").manual_seed(0)
-latents_t2i = t2i_pipe(**text_embeddings, num_inference_steps=25, generator=generator, output="latents")
-image = decoder_node(latents=latents_t2i, output="images")[0]
-image.save("modular_part2_t2i.png")
-```
-
-Let's add a LoRA:
-
-```py
-# Load LoRA weights
->>> t2i_loader_pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy_face")
->>> components
-Components:
-============================================================================================================================================================
-...
-Additional Component Info:
-==================================================
-
-unet:
- Adapters: ['toy_face']
-```
-
-You can see that the Components Manager tracks adapters metadata for all models it manages, and in our case, only Unet has lora loaded. This means we can reuse existing text embeddings.
-
-```py
-# Generate with LoRA (reusing existing text embeddings)
-generator = torch.Generator(device="cuda").manual_seed(0)
-latents_lora = t2i_pipe(**text_embeddings, num_inference_steps=25, generator=generator, output="latents")
-image = decoder_node(latents=latents_lora, output="images")[0]
-image.save("modular_part2_lora.png")
-```
-
-
-Now let's create a refiner pipeline that reuses components from our text-to-image workflow:
-
-```py
-# Create refiner blocks (removing image_encoder and decode since we work with latents)
-refiner_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["img2img"])
-refiner_blocks.sub_blocks.pop("image_encoder")
-refiner_blocks.sub_blocks.pop("decode")
-
-# Create refiner pipeline with different repo and collection,
-# Attach the same component manager to it
-refiner_repo = "YiYiXu/modular_refiner"
-refiner_pipe = refiner_blocks.init_pipeline(refiner_repo, components_manager=components, collection="refiner")
-```
-
-We pass the **same Components Manager** (`components`) to the refiner pipeline, but with a **different collection** (`"refiner"`). This allows the refiner to access and reuse components from the "t2i" collection while organizing its own components (like the refiner UNet) under the "refiner" collection.
-
-```py
-# Load only the refiner UNet (different from t2i UNet)
-refiner_pipe.load_components(names="unet", torch_dtype=torch.float16)
-
-# Reuse components from t2i pipeline using pattern matching
-reuse_components = components.search_components("text_encoder_2|scheduler|vae|tokenizer_2")
-refiner_pipe.update_components(**reuse_components)
-```
-
-When we reuse components from the "t2i" collection, they automatically get added to the "refiner" collection as well. You can verify this by checking the Components Manager - you'll see components like `vae`, `scheduler`, etc. listed under both collections, indicating they're shared between workflows.
-
-Now we can refine any of our generated latents:
-
-```py
-# Refine all our different latents
-refined_latents = refiner_pipe(image_latents=latents_t2i, prompt=prompt, num_inference_steps=10, output="latents")
-refined_image = decoder_node(latents=refined_latents, output="images")[0]
-refined_image.save("modular_part2_t2i_refine_out.png")
-
-refined_latents = refiner_pipe(image_latents=latents_lora, prompt=prompt, num_inference_steps=10, output="latents")
-refined_image = decoder_node(latents=refined_latents, output="images")[0]
-refined_image.save("modular_part2_lora_refine_out.png")
-```
-
-
-Here are the results from our modular pipeline examples.
-
-#### Base Text-to-Image Generation
-| Base Text-to-Image | Base Text-to-Image (Refined) |
-|-------------------|------------------------------|
-|  |  |
-
-#### LoRA
-| LoRA | LoRA (Refined) |
-|-------------------|------------------------------|
-|  |  |
+All models begin on the CPU and [`ComponentsManager`] moves them to the appropriate device right before they're needed, and moves other models back to the CPU when GPU memory is low.
+You can set your own rules for which models to offload first.
diff --git a/docs/source/en/modular_diffusers/end_to_end_guide.md b/docs/source/en/modular_diffusers/end_to_end_guide.md
deleted file mode 100644
index cb7b87552a..0000000000
--- a/docs/source/en/modular_diffusers/end_to_end_guide.md
+++ /dev/null
@@ -1,648 +0,0 @@
-
-
-# End-to-End Developer Guide: Building with Modular Diffusers
-
-
-
-🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
-
-
-
-
-In this tutorial we will walk through the process of adding a new pipeline to the modular framework using differential diffusion as our example. We'll cover the complete workflow from implementation to deployment: implementing the new pipeline, ensuring compatibility with existing tools, sharing the code on Hugging Face Hub, and deploying it as a UI node.
-
-We'll also demonstrate the 4-step framework process we use for implementing new basic pipelines in the modular system.
-
-1. **Start with an existing pipeline as a base**
- - Identify which existing pipeline is most similar to the one you want to implement
- - Determine what part of the pipeline needs modification
-
-2. **Build a working pipeline structure first**
- - Assemble the complete pipeline structure
- - Use existing blocks wherever possible
- - For new blocks, create placeholders (e.g. you can copy from similar blocks and change the name) without implementing custom logic just yet
-
-3. **Set up an example**
- - Create a simple inference script with expected inputs/outputs
-
-4. **Implement your custom logic and test incrementally**
- - Add the custom logics the blocks you want to change
- - Test incrementally, and inspect pipeline states and debug as needed
-
-Let's see how this works with the Differential Diffusion example.
-
-
-## Differential Diffusion Pipeline
-
-### Start with an existing pipeline
-
-Differential diffusion (https://differential-diffusion.github.io/) is an image-to-image workflow, so it makes sense for us to start with the preset of pipeline blocks used to build img2img pipeline (`IMAGE2IMAGE_BLOCKS`) and see how we can build this new pipeline with them.
-
-```py
->>> from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
->>> IMAGE2IMAGE_BLOCKS = InsertableDict([
-... ("text_encoder", StableDiffusionXLTextEncoderStep),
-... ("image_encoder", StableDiffusionXLVaeEncoderStep),
-... ("input", StableDiffusionXLInputStep),
-... ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
-... ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
-... ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
-... ("denoise", StableDiffusionXLDenoiseStep),
-... ("decode", StableDiffusionXLDecodeStep)
-... ])
-```
-
-Note that "denoise" (`StableDiffusionXLDenoiseStep`) is a `LoopSequentialPipelineBlocks` that contains 3 loop blocks (more on LoopSequentialPipelineBlocks [here](https://huggingface.co/docs/diffusers/modular_diffusers/write_own_pipeline_block#loopsequentialpipelineblocks))
-
-```py
->>> denoise_blocks = IMAGE2IMAGE_BLOCKS["denoise"]()
->>> print(denoise_blocks)
-```
-
-```out
-StableDiffusionXLDenoiseStep(
- Class: StableDiffusionXLDenoiseLoopWrapper
-
- Description: Denoise step that iteratively denoise the latents.
- Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method
- At each iteration, it runs blocks defined in `sub_blocks` sequencially:
- - `StableDiffusionXLLoopBeforeDenoiser`
- - `StableDiffusionXLLoopDenoiser`
- - `StableDiffusionXLLoopAfterDenoiser`
- This block supports both text2img and img2img tasks.
-
-
- Components:
- scheduler (`EulerDiscreteScheduler`)
- guider (`ClassifierFreeGuidance`)
- unet (`UNet2DConditionModel`)
-
- Sub-Blocks:
- [0] before_denoiser (StableDiffusionXLLoopBeforeDenoiser)
- Description: step within the denoising loop that prepare the latent input for the denoiser. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)
-
- [1] denoiser (StableDiffusionXLLoopDenoiser)
- Description: Step within the denoising loop that denoise the latents with guidance. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)
-
- [2] after_denoiser (StableDiffusionXLLoopAfterDenoiser)
- Description: step within the denoising loop that update the latents. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)
-
-)
-```
-
-Let's compare standard image-to-image and differential diffusion! The key difference in algorithm is that standard image-to-image diffusion applies uniform noise across all pixels based on a single `strength` parameter, but differential diffusion uses a change map where each pixel value determines when that region starts denoising. Regions with lower values get "frozen" earlier by replacing them with noised original latents, preserving more of the original image.
-
-Therefore, the key differences when it comes to pipeline implementation would be:
-1. The `prepare_latents` step (which prepares the change map and pre-computes noised latents for all timesteps)
-2. The `denoise` step (which selectively applies denoising based on the change map)
-3. Since differential diffusion doesn't use the `strength` parameter, we'll use the text-to-image `set_timesteps` step instead of the image-to-image version
-
-To implement differntial diffusion, we can reuse most blocks from image-to-image and text-to-image workflows, only modifying the `prepare_latents` step and the first part of the `denoise` step (i.e. `before_denoiser (StableDiffusionXLLoopBeforeDenoiser)`).
-
-Here's a flowchart showing the pipeline structure and the changes we need to make:
-
-
-
-
-
-### Build a Working Pipeline Structure
-
-ok now we've identified the blocks to modify, let's build the pipeline skeleton first - at this stage, our goal is to get the pipeline struture working end-to-end (even though it's just doing the img2img behavior). I would simply create placeholder blocks by copying from existing ones:
-
-```py
->>> # Copy existing blocks as placeholders
->>> class SDXLDiffDiffPrepareLatentsStep(PipelineBlock):
-... """Copied from StableDiffusionXLImg2ImgPrepareLatentsStep - will modify later"""
-... # ... same implementation as StableDiffusionXLImg2ImgPrepareLatentsStep
-...
->>> class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock):
-... """Copied from StableDiffusionXLLoopBeforeDenoiser - will modify later"""
-... # ... same implementation as StableDiffusionXLLoopBeforeDenoiser
-```
-
-`SDXLDiffDiffLoopBeforeDenoiser` is the be part of the denoise loop we need to change. Let's use it to assemble a `SDXLDiffDiffDenoiseStep`.
-
-```py
->>> class SDXLDiffDiffDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
-... block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser]
-... block_names = ["before_denoiser", "denoiser", "after_denoiser"]
-```
-
-Now we can put together our differential diffusion pipeline.
-
-```py
->>> DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
->>> DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
->>> DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
->>> DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep
->>>
->>> dd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS)
->>> print(dd_blocks)
->>> # At this point, the pipeline works exactly like img2img since our blocks are just copies
-```
-
-### Set up an example
-
-ok, so now our blocks should be able to compile without an error, we can move on to the next step. Let's setup a simple example so we can run the pipeline as we build it. diff-diff use same model checkpoints as SDXL so we can fetch the models from a regular SDXL repo.
-
-```py
->>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
->>> dd_pipeline.load_default_componenets(torch_dtype=torch.float16)
->>> dd_pipeline.to("cuda")
-```
-
-We will use this example script:
-
-```py
->>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
->>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
->>>
->>> prompt = "a green pear"
->>> negative_prompt = "blurry"
->>>
->>> image = dd_pipeline(
-... prompt=prompt,
-... negative_prompt=negative_prompt,
-... num_inference_steps=25,
-... diffdiff_map=mask,
-... image=image,
-... output="images"
-... )[0]
->>>
->>> image.save("diffdiff_out.png")
-```
-
-If you run the script right now, you will get a complaint about unexpected input `diffdiff_map`.
-and you would get the same result as the original img2img pipeline.
-
-### implement your custom logic and test incrementally
-
-Let's modify the pipeline so that we can get expected result with this example script.
-
-We'll start with the `prepare_latents` step. The main changes are:
-- Requires a new user input `diffdiff_map`
-- Requires new component `mask_processor` to process the `diffdiff_map`
-- Requires new intermediate inputs:
- - Need `timestep` instead of `latent_timestep` to precompute all the latents
- - Need `num_inference_steps` to create the `diffdiff_masks`
-- create a new output `diffdiff_masks` and `original_latents`
-
-
-
-💡 use `print(dd_pipeline.doc)` to check compiled inputs and outputs of the built piepline.
-
-e.g. after we added `diffdiff_map` as an input in this step, we can run `print(dd_pipeline.doc)` to verify that it shows up in the docstring as a user input.
-
-
-
-Once we make sure all the variables we need are available in the block state, we can implement the diff-diff logic inside `__call__`. We created 2 new variables: the change map `diffdiff_mask` and the pre-computed noised latents for all timesteps `original_latents`.
-
-
-
-💡 Implement incrementally! Run the example script as you go, and insert `print(state)` and `print(block_state)` everywhere inside the `__call__` method to inspect the intermediate results. This helps you understand what's going on and what each line you just added does.
-
-
-
-Here are the key changes we made to implement differential diffusion:
-
-**1. Modified `prepare_latents` step:**
-```diff
-class SDXLDiffDiffPrepareLatentsStep(PipelineBlock):
- @property
- def expected_components(self) -> List[ComponentSpec]:
- return [
- ComponentSpec("vae", AutoencoderKL),
- ComponentSpec("scheduler", EulerDiscreteScheduler),
-+ ComponentSpec("mask_processor", VaeImageProcessor, config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True}))
- ]
-
- @property
- def inputs(self) -> List[Tuple[str, Any]]:
- return [
-+ InputParam("diffdiff_map", required=True),
- ]
-
- @property
- def intermediate_inputs(self) -> List[InputParam]:
- return [
- InputParam("generator"),
-- InputParam("latent_timestep", required=True, type_hint=torch.Tensor),
-+ InputParam("timesteps", type_hint=torch.Tensor),
-+ InputParam("num_inference_steps", type_hint=int),
- ]
-
- @property
- def intermediate_outputs(self) -> List[OutputParam]:
- return [
-+ OutputParam("original_latents", type_hint=torch.Tensor),
-+ OutputParam("diffdiff_masks", type_hint=torch.Tensor),
- ]
-
- def __call__(self, components, state: PipelineState):
- # ... existing logic ...
-+ # Process change map and create masks
-+ diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width)
-+ thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps
-+ block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0))
-+ block_state.original_latents = block_state.latents
-```
-
-**2. Modified `before_denoiser` step:**
-```diff
-class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock):
- @property
- def description(self) -> str:
- return (
- "Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser"
- )
-
-+ @property
-+ def inputs(self) -> List[Tuple[str, Any]]:
-+ return [
-+ InputParam("denoising_start"),
-+ ]
-
- @property
- def intermediate_inputs(self) -> List[str]:
- return [
- InputParam("latents", required=True, type_hint=torch.Tensor),
-+ InputParam("original_latents", type_hint=torch.Tensor),
-+ InputParam("diffdiff_masks", type_hint=torch.Tensor),
- ]
-
- def __call__(self, components, block_state, i, t):
-+ # Apply differential diffusion logic
-+ if i == 0 and block_state.denoising_start is None:
-+ block_state.latents = block_state.original_latents[:1]
-+ else:
-+ block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0).unsqueeze(1)
-+ block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask)
-
- # ... rest of existing logic ...
-```
-
-That's all there is to it! We've just created a simple sequential pipeline by mix-and-match some existing and new pipeline blocks.
-
-Now we use the process we've prepred in step2 to build the pipeline and inspect it.
-
-
-```py
->> dd_pipeline
-SequentialPipelineBlocks(
- Class: ModularPipelineBlocks
-
- Description:
-
-
- Components:
- text_encoder (`CLIPTextModel`)
- text_encoder_2 (`CLIPTextModelWithProjection`)
- tokenizer (`CLIPTokenizer`)
- tokenizer_2 (`CLIPTokenizer`)
- guider (`ClassifierFreeGuidance`)
- vae (`AutoencoderKL`)
- image_processor (`VaeImageProcessor`)
- scheduler (`EulerDiscreteScheduler`)
- mask_processor (`VaeImageProcessor`)
- unet (`UNet2DConditionModel`)
-
- Configs:
- force_zeros_for_empty_prompt (default: True)
- requires_aesthetics_score (default: False)
-
- Blocks:
- [0] text_encoder (StableDiffusionXLTextEncoderStep)
- Description: Text Encoder step that generate text_embeddings to guide the image generation
-
- [1] image_encoder (StableDiffusionXLVaeEncoderStep)
- Description: Vae Encoder step that encode the input image into a latent representation
-
- [2] input (StableDiffusionXLInputStep)
- Description: Input processing step that:
- 1. Determines `batch_size` and `dtype` based on `prompt_embeds`
- 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`
-
- All input tensors are expected to have either batch_size=1 or match the batch_size
- of prompt_embeds. The tensors will be duplicated across the batch dimension to
- have a final batch_size of batch_size * num_images_per_prompt.
-
- [3] set_timesteps (StableDiffusionXLSetTimestepsStep)
- Description: Step that sets the scheduler's timesteps for inference
-
- [4] prepare_latents (SDXLDiffDiffPrepareLatentsStep)
- Description: Step that prepares the latents for the differential diffusion generation process
-
- [5] prepare_add_cond (StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep)
- Description: Step that prepares the additional conditioning for the image-to-image/inpainting generation process
-
- [6] denoise (SDXLDiffDiffDenoiseStep)
- Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `sub_blocks` attributes
-
- [7] decode (StableDiffusionXLDecodeStep)
- Description: Step that decodes the denoised latents into images
-
-)
-```
-
-Run the example now, you should see an apple with its right half transformed into a green pear.
-
-
-
-
-## Adding IP-adapter
-
-We provide an auto IP-adapter block that you can plug-and-play into your modular workflow. It's an `AutoPipelineBlocks`, so it will only run when the user passes an IP adapter image. In this tutorial, we'll focus on how to package it into your differential diffusion workflow. To learn more about `AutoPipelineBlocks`, see [here](./auto_pipeline_blocks.md)
-
-We talked about how to add IP-adapter into your workflow in the [Modular Pipeline Guide](./modular_pipeline.md). Let's just go ahead to create the IP-adapter block.
-
-```py
->>> from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep
->>> ip_adapter_block = StableDiffusionXLAutoIPAdapterStep()
-```
-
-We can directly add the ip-adapter block instance to the `diffdiff_blocks` that we created before. The `sub_blocks` attribute is a `InsertableDict`, so we're able to insert the it at specific position (index `0` here).
-
-```py
->>> dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0)
-```
-
-Take a look at the new diff-diff pipeline with ip-adapter!
-
-```py
->>> print(dd_blocks)
-```
-
-The pipeline now lists ip-adapter as its first block, and tells you that it will run only if `ip_adapter_image` is provided. It also includes the two new components from ip-adpater: `image_encoder` and `feature_extractor`
-
-```out
-SequentialPipelineBlocks(
- Class: ModularPipelineBlocks
-
- ====================================================================================================
- This pipeline contains blocks that are selected at runtime based on inputs.
- Trigger Inputs: {'ip_adapter_image'}
- Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('ip_adapter_image')`).
- ====================================================================================================
-
-
- Description:
-
-
- Components:
- image_encoder (`CLIPVisionModelWithProjection`)
- feature_extractor (`CLIPImageProcessor`)
- unet (`UNet2DConditionModel`)
- guider (`ClassifierFreeGuidance`)
- text_encoder (`CLIPTextModel`)
- text_encoder_2 (`CLIPTextModelWithProjection`)
- tokenizer (`CLIPTokenizer`)
- tokenizer_2 (`CLIPTokenizer`)
- vae (`AutoencoderKL`)
- image_processor (`VaeImageProcessor`)
- scheduler (`EulerDiscreteScheduler`)
- mask_processor (`VaeImageProcessor`)
-
- Configs:
- force_zeros_for_empty_prompt (default: True)
- requires_aesthetics_score (default: False)
-
- Blocks:
- [0] ip_adapter (StableDiffusionXLAutoIPAdapterStep)
- Description: Run IP Adapter step if `ip_adapter_image` is provided.
-
- [1] text_encoder (StableDiffusionXLTextEncoderStep)
- Description: Text Encoder step that generate text_embeddings to guide the image generation
-
- [2] image_encoder (StableDiffusionXLVaeEncoderStep)
- Description: Vae Encoder step that encode the input image into a latent representation
-
- [3] input (StableDiffusionXLInputStep)
- Description: Input processing step that:
- 1. Determines `batch_size` and `dtype` based on `prompt_embeds`
- 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`
-
- All input tensors are expected to have either batch_size=1 or match the batch_size
- of prompt_embeds. The tensors will be duplicated across the batch dimension to
- have a final batch_size of batch_size * num_images_per_prompt.
-
- [4] set_timesteps (StableDiffusionXLSetTimestepsStep)
- Description: Step that sets the scheduler's timesteps for inference
-
- [5] prepare_latents (SDXLDiffDiffPrepareLatentsStep)
- Description: Step that prepares the latents for the differential diffusion generation process
-
- [6] prepare_add_cond (StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep)
- Description: Step that prepares the additional conditioning for the image-to-image/inpainting generation process
-
- [7] denoise (SDXLDiffDiffDenoiseStep)
- Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `sub_blocks` attributes
-
- [8] decode (StableDiffusionXLDecodeStep)
- Description: Step that decodes the denoised latents into images
-
-)
-```
-
-Let's test it out. We used an orange image to condition the generation via ip-addapter and we can see a slight orange color and texture in the final output.
-
-
-```py
->>> ip_adapter_block = StableDiffusionXLAutoIPAdapterStep()
->>> dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0)
->>>
->>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
->>> dd_pipeline.load_default_components(torch_dtype=torch.float16)
->>> dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
->>> dd_pipeline.loader.set_ip_adapter_scale(0.6)
->>> dd_pipeline = dd_pipeline.to(device)
->>>
->>> ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_orange.jpeg")
->>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
->>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
->>>
->>> prompt = "a green pear"
->>> negative_prompt = "blurry"
->>> generator = torch.Generator(device=device).manual_seed(42)
->>>
->>> image = dd_pipeline(
-... prompt=prompt,
-... negative_prompt=negative_prompt,
-... num_inference_steps=25,
-... generator=generator,
-... ip_adapter_image=ip_adapter_image,
-... diffdiff_map=mask,
-... image=image,
-... output="images"
-... )[0]
-```
-
-## Working with ControlNets
-
-What about controlnet? Can differential diffusion work with controlnet? The key differences between a regular pipeline and a ControlNet pipeline are:
-1. A ControlNet input step that prepares the control condition
-2. Inside the denoising loop, a modified denoiser step where the control image is first processed through ControlNet, then control information is injected into the UNet
-
-From looking at the code workflow: differential diffusion only modifies the "before denoiser" step, while ControlNet operates within the "denoiser" itself. Since they intervene at different points in the pipeline, they should work together without conflicts.
-
-Intuitively, these two techniques are orthogonal and should combine naturally: differential diffusion controls how much the inference process can deviate from the original in each region, while ControlNet controls in what direction that change occurs.
-
-With this understanding, let's assemble the diffdiff-controlnet loop by combining the diffdiff before-denoiser step and controlnet denoiser step.
-
-```py
->>> class SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
-... block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]
-... block_names = ["before_denoiser", "denoiser", "after_denoiser"]
->>>
->>> controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseStep()
->>> # print(controlnet_denoise)
-```
-
-We provide a auto controlnet input block that you can directly put into your workflow to proceess the `control_image`: similar to auto ip-adapter block, this step will only run if `control_image` input is passed from user. It work with both controlnet and controlnet union.
-
-
-```py
->>> from diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks import StableDiffusionXLAutoControlNetInputStep
->>> control_input_block = StableDiffusionXLAutoControlNetInputStep()
->>> print(control_input_block)
-```
-
-```out
-StableDiffusionXLAutoControlNetInputStep(
- Class: AutoPipelineBlocks
-
- ====================================================================================================
- This pipeline contains blocks that are selected at runtime based on inputs.
- Trigger Inputs: ['control_image', 'control_mode']
- ====================================================================================================
-
-
- Description: Controlnet Input step that prepare the controlnet input.
- This is an auto pipeline block that works for both controlnet and controlnet_union.
- (it should be called right before the denoise step) - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.
- - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided. - if neither `control_mode` nor `control_image` is provided, step will be skipped.
-
-
- Components:
- controlnet (`ControlNetUnionModel`)
- control_image_processor (`VaeImageProcessor`)
-
- Sub-Blocks:
- • controlnet_union [trigger: control_mode] (StableDiffusionXLControlNetUnionInputStep)
- Description: step that prepares inputs for the ControlNetUnion model
-
- • controlnet [trigger: control_image] (StableDiffusionXLControlNetInputStep)
- Description: step that prepare inputs for controlnet
-
-)
-
-```
-
-Let's assemble the blocks and run an example using controlnet + differential diffusion. We used a tomato as `control_image`, so you can see that in the output, the right half that transformed into a pear had a tomato-like shape.
-
-```py
->>> dd_blocks.sub_blocks.insert("controlnet_input", control_input_block, 7)
->>> dd_blocks.sub_blocks["denoise"] = controlnet_denoise_block
->>>
->>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
->>> dd_pipeline.load_default_components(torch_dtype=torch.float16)
->>> dd_pipeline = dd_pipeline.to(device)
->>>
->>> control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg")
->>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
->>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
->>>
->>> prompt = "a green pear"
->>> negative_prompt = "blurry"
->>> generator = torch.Generator(device=device).manual_seed(42)
->>>
->>> image = dd_pipeline(
-... prompt=prompt,
-... negative_prompt=negative_prompt,
-... num_inference_steps=25,
-... generator=generator,
-... control_image=control_image,
-... controlnet_conditioning_scale=0.5,
-... diffdiff_map=mask,
-... image=image,
-... output="images"
-... )[0]
-```
-
-Optionally, We can combine `SDXLDiffDiffControlNetDenoiseStep` and `SDXLDiffDiffDenoiseStep` into a `AutoPipelineBlocks` so that same workflow can work with or without controlnet.
-
-
-```py
->>> class SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks):
-... block_classes = [SDXLDiffDiffControlNetDenoiseStep, SDXLDiffDiffDenoiseStep]
-... block_names = ["controlnet_denoise", "denoise"]
-... block_trigger_inputs = ["controlnet_cond", None]
-```
-
-`SDXLDiffDiffAutoDenoiseStep` will run the ControlNet denoise step if `control_image` input is provided, otherwise it will run the regular denoise step.
-
-
-
- Note that it's perfectly fine not to use `AutoPipelineBlocks`. In fact, we recommend only using `AutoPipelineBlocks` to package your workflow at the end once you've verified all your pipelines work as expected.
-
-
-
-Now you can create the differential diffusion preset that works with ip-adapter & controlnet.
-
-```py
->>> DIFFDIFF_AUTO_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
->>> DIFFDIFF_AUTO_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
->>> DIFFDIFF_AUTO_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
->>> DIFFDIFF_AUTO_BLOCKS["denoise"] = SDXLDiffDiffAutoDenoiseStep
->>> DIFFDIFF_AUTO_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0)
->>> DIFFDIFF_AUTO_BLOCKS.insert("controlnet_input",StableDiffusionXLControlNetAutoInput, 7)
->>>
->>> print(DIFFDIFF_AUTO_BLOCKS)
-```
-
-to use
-
-```py
->>> dd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS)
->>> dd_pipeline = dd_auto_blocks.init_pipeline(...)
-```
-## Creating a Modular Repo
-
-You can easily share your differential diffusion workflow on the Hub by creating a modular repo. This is one created using the code we just wrote together: https://huggingface.co/YiYiXu/modular-diffdiff
-
-To create a Modular Repo and share on hub, you just need to run `save_pretrained()` along with the `push_to_hub=True` flag. Note that if your pipeline contains custom block, you need to manually upload the code to the hub. But we are working on a command line tool to help you upload it very easily.
-
-```py
-dd_pipeline.save_pretrained("YiYiXu/test_modular_doc", push_to_hub=True)
-```
-
-With a modular repo, it is very easy for the community to use the workflow you just created! Here is an example to use the differential-diffusion pipeline we just created and shared.
-
-```py
->>> from diffusers.modular_pipelines import ModularPipeline, ComponentsManager
->>> import torch
->>> from diffusers.utils import load_image
->>>
->>> repo_id = "YiYiXu/modular-diffdiff-0704"
->>>
->>> components = ComponentsManager()
->>>
->>> diffdiff_pipeline = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True, components_manager=components, collection="diffdiff")
->>> diffdiff_pipeline.load_default_components(torch_dtype=torch.float16)
->>> components.enable_auto_cpu_offload()
-```
-
-see more usage example on model card.
-
-## deploy a mellon node
-
-[YIYI TODO: for now, here is an example of mellon node https://huggingface.co/YiYiXu/diff-diff-mellon]
diff --git a/docs/source/en/modular_diffusers/guiders.md b/docs/source/en/modular_diffusers/guiders.md
new file mode 100644
index 0000000000..fd0d278442
--- /dev/null
+++ b/docs/source/en/modular_diffusers/guiders.md
@@ -0,0 +1,175 @@
+
+
+# Guiders
+
+[Classifier-free guidance](https://huggingface.co/papers/2207.12598) steers model generation that better match a prompt and is commonly used to improve generation quality, control, and adherence to prompts. There are different types of guidance methods, and in Diffusers, they are known as *guiders*. Like blocks, it is easy to switch and use different guiders for different use cases without rewriting the pipeline.
+
+This guide will show you how to switch guiders, adjust guider parameters, and load and share them to the Hub.
+
+## Switching guiders
+
+[`ClassifierFreeGuidance`] is the default guider and created when a pipeline is initialized with [`~ModularPipelineBlocks.init_pipeline`]. It is created by `from_config` which means it doesn't require loading specifications from a modular repository. A guider won't be listed in `modular_model_index.json`.
+
+Use [`~ModularPipeline.get_component_spec`] to inspect a guider.
+
+```py
+t2i_pipeline.get_component_spec("guider")
+ComponentSpec(name='guider', type_hint=, description=None, config=FrozenDict([('guidance_scale', 7.5), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['start', 'guidance_rescale', 'stop', 'use_original_formulation'])]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config')
+```
+
+Switch to a different guider by passing the new guider to [`~ModularPipeline.update_components`].
+
+> [!TIP]
+> Changing guiders will return text letting you know you're changing the guider type.
+> ```bash
+> ModularPipeline.update_components: adding guider with new type: PerturbedAttentionGuidance, previous type: ClassifierFreeGuidance
+> ```
+
+```py
+from diffusers import LayerSkipConfig, PerturbedAttentionGuidance
+
+config = LayerSkipConfig(indices=[2, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=False, skip_attention_scores=True, skip_ff=False)
+guider = PerturbedAttentionGuidance(
+ guidance_scale=5.0, perturbed_guidance_scale=2.5, perturbed_guidance_config=config
+)
+t2i_pipeline.update_components(guider=guider)
+```
+
+Use [`~ModularPipeline.get_component_spec`] again to verify the guider type is different.
+
+```py
+t2i_pipeline.get_component_spec("guider")
+ComponentSpec(name='guider', type_hint=, description=None, config=FrozenDict([('guidance_scale', 5.0), ('perturbed_guidance_scale', 2.5), ('perturbed_guidance_start', 0.01), ('perturbed_guidance_stop', 0.2), ('perturbed_guidance_layers', None), ('perturbed_guidance_config', LayerSkipConfig(indices=[2, 9], fqn='mid_block.attentions.0.transformer_blocks', skip_attention=False, skip_attention_scores=True, skip_ff=False, dropout=1.0)), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['perturbed_guidance_start', 'use_original_formulation', 'perturbed_guidance_layers', 'stop', 'start', 'guidance_rescale', 'perturbed_guidance_stop']), ('_class_name', 'PerturbedAttentionGuidance'), ('_diffusers_version', '0.35.0.dev0')]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config')
+```
+
+## Loading custom guiders
+
+Guiders that are already saved on the Hub with a `modular_model_index.json` file are considered a `from_pretrained` component now instead of a `from_config` component.
+
+```json
+{
+ "guider": [
+ null,
+ null,
+ {
+ "repo": "YiYiXu/modular-loader-t2i-guider",
+ "revision": null,
+ "subfolder": "pag_guider",
+ "type_hint": [
+ "diffusers",
+ "PerturbedAttentionGuidance"
+ ],
+ "variant": null
+ }
+ ]
+}
+```
+
+The guider is only created after calling [`~ModularPipeline.load_components`] based on the loading specification in `modular_model_index.json`.
+
+```py
+t2i_pipeline = t2i_blocks.init_pipeline("YiYiXu/modular-doc-guider")
+# not created during init
+assert t2i_pipeline.guider is None
+t2i_pipeline.load_components()
+# loaded as PAG guider
+t2i_pipeline.guider
+```
+
+
+## Changing guider parameters
+
+The guider parameters can be adjusted with either the [`~ComponentSpec.create`] method or with [`~ModularPipeline.update_components`]. The example below changes the `guidance_scale` value.
+
+
+
+
+```py
+guider_spec = t2i_pipeline.get_component_spec("guider")
+guider = guider_spec.create(guidance_scale=10)
+t2i_pipeline.update_components(guider=guider)
+```
+
+
+
+
+```py
+guider_spec = t2i_pipeline.get_component_spec("guider")
+guider_spec.config["guidance_scale"] = 10
+t2i_pipeline.update_components(guider=guider_spec)
+```
+
+
+
+
+## Uploading custom guiders
+
+Call the [`~utils.PushToHubMixin.push_to_hub`] method on a custom guider to share it to the Hub.
+
+```py
+guider.push_to_hub("YiYiXu/modular-loader-t2i-guider", subfolder="pag_guider")
+```
+
+To make this guider available to the pipeline, either modify the `modular_model_index.json` file or use the [`~ModularPipeline.update_components`] method.
+
+
+
+
+Edit the `modular_model_index.json` file and add a loading specification for the guider by pointing to a folder containing the guider config.
+
+```json
+{
+ "guider": [
+ "diffusers",
+ "PerturbedAttentionGuidance",
+ {
+ "repo": "YiYiXu/modular-loader-t2i-guider",
+ "revision": null,
+ "subfolder": "pag_guider",
+ "type_hint": [
+ "diffusers",
+ "PerturbedAttentionGuidance"
+ ],
+ "variant": null
+ }
+ ],
+```
+
+
+
+
+Change the [`~ComponentSpec.default_creation_method`] to `from_pretrained` and use [`~ModularPipeline.update_components`] to update the guider and component specifications as well as the pipeline config.
+
+> [!TIP]
+> Changing the creation method will return text letting you know you're changing the creation type to `from_pretrained`.
+> ```bash
+> ModularPipeline.update_components: changing the default_creation_method of guider from from_config to from_pretrained.
+> ```
+
+```py
+guider_spec = t2i_pipeline.get_component_spec("guider")
+guider_spec.default_creation_method="from_pretrained"
+guider_spec.repo="YiYiXu/modular-loader-t2i-guider"
+guider_spec.subfolder="pag_guider"
+pag_guider = guider_spec.load()
+t2i_pipeline.update_components(guider=pag_guider)
+```
+
+To make it the default guider for a pipeline, call [`~utils.PushToHubMixin.push_to_hub`]. This is an optional step and not necessary if you are only experimenting locally.
+
+```py
+t2i_pipeline.push_to_hub("YiYiXu/modular-doc-guider")
+```
+
+
+
diff --git a/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md b/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md
index e95cdc7163..86c82b5145 100644
--- a/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md
+++ b/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md
@@ -12,67 +12,22 @@ specific language governing permissions and limitations under the License.
# LoopSequentialPipelineBlocks
-
+[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `intermediate_inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.
-🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
+This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].
-
+## Loop wrapper
-`LoopSequentialPipelineBlocks` is a subclass of `ModularPipelineBlocks`. It is a multi-block that composes other blocks together in a loop, creating iterative workflows where blocks run multiple times with evolving state. It's particularly useful for denoising loops requiring repeated execution of the same blocks.
+[`~modular_pipelines.LoopSequentialPipelineBlocks`], is also known as the *loop wrapper* because it defines the loop structure, iteration variables, and configuration. Within the loop wrapper, you need the following variables.
-
-
-Other types of multi-blocks include [SequentialPipelineBlocks](./sequential_pipeline_blocks.md) (for linear workflows) and [AutoPipelineBlocks](./auto_pipeline_blocks.md) (for conditional block selection). For information on creating individual blocks, see the [PipelineBlock guide](./pipeline_block.md).
-
-Additionally, like all `ModularPipelineBlocks`, `LoopSequentialPipelineBlocks` are definitions/specifications, not runnable pipelines. You need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](modular_pipeline.md).
-
-
-
-You could create a loop using `PipelineBlock` like this:
-
-```python
-class DenoiseLoop(PipelineBlock):
- def __call__(self, components, state):
- block_state = self.get_block_state(state)
- for t in range(block_state.num_inference_steps):
- # ... loop logic here
- pass
- self.set_block_state(state, block_state)
- return components, state
-```
-
-But in this tutorial, we will focus on how to use `LoopSequentialPipelineBlocks` to create a "composable" denoising loop where you can add or remove blocks within the loop or reuse the same loop structure with different block combinations.
-
-It involves two parts: a **loop wrapper** and **loop blocks**
-
-* The **loop wrapper** (`LoopSequentialPipelineBlocks`) defines the loop structure, e.g. it defines the iteration variables, and loop configurations such as progress bar.
-
-* The **loop blocks** are basically standard pipeline blocks you add to the loop wrapper.
- - they run sequentially for each iteration of the loop
- - they receive the current iteration index as an additional parameter
- - they share the same block_state throughout the entire loop
-
-Unlike regular `SequentialPipelineBlocks` where each block gets its own state, loop blocks share a single state that persists and evolves across iterations.
-
-We will build a simple loop block to demonstrate these concepts. Creating a loop block involves three steps:
-1. defining the loop wrapper class
-2. creating the loop blocks
-3. adding the loop blocks to the loop wrapper class to create the loop wrapper instance
-
-**Step 1: Define the Loop Wrapper**
-
-To create a `LoopSequentialPipelineBlocks` class, you need to define:
-
-* `loop_inputs`: User input variables (equivalent to `PipelineBlock.inputs`)
-* `loop_intermediate_inputs`: Intermediate variables needed from the mutable pipeline state (equivalent to `PipelineBlock.intermediates_inputs`)
-* `loop_intermediate_outputs`: New intermediate variables this block will add to the mutable pipeline state (equivalent to `PipelineBlock.intermediates_outputs`)
-* `__call__` method: Defines the loop structure and iteration logic
-
-Here is an example of a loop wrapper:
+- `loop_inputs` are user provided values and equivalent to [`~modular_pipelines.ModularPipelineBlocks.inputs`].
+- `loop_intermediate_inputs` are intermediate variables from the [`~modular_pipelines.PipelineState`] and equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_inputs`].
+- `loop_intermediate_outputs` are new intermediate variables created by the block and added to the [`~modular_pipelines.PipelineState`]. It is equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_outputs`].
+- `__call__` method defines the loop structure and iteration logic.
```py
import torch
-from diffusers.modular_pipelines import LoopSequentialPipelineBlocks, PipelineBlock, InputParam, OutputParam
+from diffusers.modular_pipelines import LoopSequentialPipelineBlocks, ModularPipelineBlocks, InputParam, OutputParam
class LoopWrapper(LoopSequentialPipelineBlocks):
model_name = "test"
@@ -93,16 +48,20 @@ class LoopWrapper(LoopSequentialPipelineBlocks):
return components, state
```
-**Step 2: Create Loop Blocks**
+The loop wrapper can pass additional arguments, like current iteration index, to the loop blocks.
-Loop blocks are standard `PipelineBlock`s, but their `__call__` method works differently:
-* It receives the iteration variable (e.g., `i`) passed by the loop wrapper
-* It works directly with `block_state` instead of pipeline state
-* No need to call `self.get_block_state()` or `self.set_block_state()`
+## Loop blocks
+
+A loop block is a [`~modular_pipelines.ModularPipelineBlocks`], but the `__call__` method behaves differently.
+
+- It recieves the iteration variable from the loop wrapper.
+- It works directly with the [`~modular_pipelines.BlockState`] instead of the [`~modular_pipelines.PipelineState`].
+- It doesn't require retrieving or updating the [`~modular_pipelines.BlockState`].
+
+Loop blocks share the same [`~modular_pipelines.BlockState`] to allow values to accumulate and change for each iteration in the loop.
```py
-class LoopBlock(PipelineBlock):
- # this is used to identify the model family, we won't worry about it in this example
+class LoopBlock(ModularPipelineBlocks):
model_name = "test"
@property
def inputs(self):
@@ -119,76 +78,16 @@ class LoopBlock(PipelineBlock):
return components, block_state
```
-**Step 3: Combine Everything**
+## LoopSequentialPipelineBlocks
-Finally, assemble your loop by adding the block(s) to the wrapper:
+Use the [`~modular_pipelines.LoopSequentialPipelineBlocks.from_blocks_dict`] method to add the loop block to the loop wrapper to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].
```py
loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock})
```
-Now you've created a loop with one step:
-
-```py
->>> loop
-LoopWrapper(
- Class: LoopSequentialPipelineBlocks
-
- Description: I'm a loop!!
-
- Sub-Blocks:
- [0] block1 (LoopBlock)
- Description: I'm a block used inside the `LoopWrapper` class
-
-)
-```
-
-It has two inputs: `x` (used at each step within the loop) and `num_steps` used to define the loop.
-
-```py
->>> print(loop.doc)
-class LoopWrapper
-
- I'm a loop!!
-
- Inputs:
-
- x (`None`, *optional*):
-
- num_steps (`None`, *optional*):
-
- Outputs:
-
- x (`None`):
-```
-
-**Running the Loop:**
-
-```py
-# run the loop
-loop_pipeline = loop.init_pipeline()
-x = loop_pipeline(num_steps=10, x=0, output="x")
-assert x == 10
-```
-
-**Adding Multiple Blocks:**
-
-We can add multiple blocks to run within each iteration. Let's run the loop block twice within each iteration:
+Add more loop blocks to run within each iteration with [`~modular_pipelines.LoopSequentialPipelineBlocks.from_blocks_dict`]. This allows you to modify the blocks without changing the loop logic itself.
```py
loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock(), "block2": LoopBlock})
-loop_pipeline = loop.init_pipeline()
-x = loop_pipeline(num_steps=10, x=0, output="x")
-assert x == 20 # Each iteration runs 2 blocks, so 10 iterations * 2 = 20
-```
-
-**Key Differences from SequentialPipelineBlocks:**
-
-The main difference is that loop blocks share the same `block_state` across all iterations, allowing values to accumulate and evolve throughout the loop. Loop blocks could receive additional arguments (like the current iteration index) depending on the loop wrapper's implementation, since the wrapper defines how loop blocks are called. You can easily add, remove, or reorder blocks within the loop without changing the loop logic itself.
-
-The officially supported denoising loops in Modular Diffusers are implemented using `LoopSequentialPipelineBlocks`. You can explore the actual implementation to see how these concepts work in practice:
-
-```py
-from diffusers.modular_pipelines.stable_diffusion_xl.denoise import StableDiffusionXLDenoiseStep
-StableDiffusionXLDenoiseStep()
```
\ No newline at end of file
diff --git a/docs/source/en/modular_diffusers/modular_diffusers_states.md b/docs/source/en/modular_diffusers/modular_diffusers_states.md
index 744089fcf6..eb55b524e4 100644
--- a/docs/source/en/modular_diffusers/modular_diffusers_states.md
+++ b/docs/source/en/modular_diffusers/modular_diffusers_states.md
@@ -10,43 +10,42 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# PipelineState and BlockState
+# States
-
+Blocks rely on the [`~modular_pipelines.PipelineState`] and [`~modular_pipelines.BlockState`] data structures for communicating and sharing data.
-🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
+| State | Description |
+|-------|-------------|
+| [`~modular_pipelines.PipelineState`] | Maintains the overall data required for a pipeline's execution and allows blocks to read and update its data. |
+| [`~modular_pipelines.BlockState`] | Allows each block to perform its computation with the necessary data from `inputs`|
-
+This guide explains how states work and how they connect blocks.
-In Modular Diffusers, `PipelineState` and `BlockState` are the core data structures that enable blocks to communicate and share data. The concept is fundamental to understand how blocks interact with each other and the pipeline system.
+## PipelineState
-In the modular diffusers system, `PipelineState` acts as the global state container that all pipeline blocks operate on. It maintains the complete runtime state of the pipeline and provides a structured way for blocks to read from and write to shared data.
+The [`~modular_pipelines.PipelineState`] is a global state container for all blocks. It maintains the complete runtime state of the pipeline and provides a structured way for blocks to read from and write to shared data.
-A `PipelineState` consists of two distinct states:
+There are two dict's in [`~modular_pipelines.PipelineState`] for structuring data.
-- **The immutable state** (i.e. the `inputs` dict) contains a copy of values provided by users. Once a value is added to the immutable state, it cannot be changed. Blocks can read from the immutable state but cannot write to it.
-
-- **The mutable state** (i.e. the `intermediates` dict) contains variables that are passed between blocks and can be modified by them.
-
-Here's an example of what a `PipelineState` looks like:
+- The `values` dict is a **mutable** state containing a copy of user provided input values and intermediate output values generated by blocks. If a block modifies an `input`, it will be reflected in the `values` dict after calling `set_block_state`.
```py
PipelineState(
- inputs={
+ values={
'prompt': 'a cat'
'guidance_scale': 7.0
'num_inference_steps': 25
- },
- intermediates={
'prompt_embeds': Tensor(dtype=torch.float32, shape=torch.Size([1, 1, 1, 1]))
'negative_prompt_embeds': None
},
)
```
-Each pipeline blocks define what parts of that state they can read from and write to through their `inputs`, `intermediate_inputs`, and `intermediate_outputs` properties. At run time, they gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates `PipelineState` with any changes.
+## BlockState
-For example, if a block defines an input `image`, inside the block's `__call__` method, the `BlockState` would contain:
+The [`~modular_pipelines.BlockState`] is a local view of the relevant variables an individual block needs from [`~modular_pipelines.PipelineState`] for performing it's computations.
+
+Access these variables directly as attributes like `block_state.image`.
```py
BlockState(
@@ -54,6 +53,23 @@ BlockState(
)
```
-You can access the variables directly as attributes: `block_state.image`.
+When a block's `__call__` method is executed, it retrieves the [`BlockState`] with `self.get_block_state(state)`, performs it's operations, and updates [`~modular_pipelines.PipelineState`] with `self.set_block_state(state, block_state)`.
-We will explore more on how blocks interact with pipeline state through their `inputs`, `intermediate_inputs`, and `intermediate_outputs` properties, see the [PipelineBlock guide](./pipeline_block.md).
\ No newline at end of file
+```py
+def __call__(self, components, state):
+ # retrieve BlockState
+ block_state = self.get_block_state(state)
+
+ # computation logic on inputs
+
+ # update PipelineState
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+## State interaction
+
+[`~modular_pipelines.PipelineState`] and [`~modular_pipelines.BlockState`] interaction is defined by a block's `inputs`, and `intermediate_outputs`.
+
+- `inputs`, a block can modify an input - like `block_state.image` - and this change can be propagated globally to [`~modular_pipelines.PipelineState`] by calling `set_block_state`.
+- `intermediate_outputs`, is a new variable that a block creates. It is added to the [`~modular_pipelines.PipelineState`]'s `values` dict and is available as for subsequent blocks or accessed by users as a final output from the pipeline.
diff --git a/docs/source/en/modular_diffusers/modular_pipeline.md b/docs/source/en/modular_diffusers/modular_pipeline.md
index 55182b921f..0e0a7bd75d 100644
--- a/docs/source/en/modular_diffusers/modular_pipeline.md
+++ b/docs/source/en/modular_diffusers/modular_pipeline.md
@@ -12,963 +12,11 @@ specific language governing permissions and limitations under the License.
# ModularPipeline
-
+[`ModularPipeline`] converts [`~modular_pipelines.ModularPipelineBlocks`]'s into an executable pipeline that loads models and performs the computation steps defined in the block. It is the main interface for running a pipeline and it is very similar to the [`DiffusionPipeline`] API.
-🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
+The main difference is to include an expected `output` argument in the pipeline.
-
-
-`ModularPipeline` is the main interface for end users to run pipelines in Modular Diffusers. It takes pipeline blocks and converts them into a runnable pipeline that can load models and execute the computation steps.
-
-In this guide, we will focus on how to build pipelines using the blocks we officially support at diffusers 🧨. We'll cover how to use predefined blocks and convert them into a `ModularPipeline` for execution.
-
-
-
-This guide shows you how to use predefined blocks. If you want to learn how to create your own pipeline blocks, see the [PipelineBlock guide](pipeline_block.md) for creating individual blocks, and the multi-block guides for connecting them together:
-- [SequentialPipelineBlocks](sequential_pipeline_blocks.md) (for linear workflows)
-- [LoopSequentialPipelineBlocks](loop_sequential_pipeline_blocks.md) (for iterative workflows)
-- [AutoPipelineBlocks](auto_pipeline_blocks.md) (for conditional workflows)
-
-For information on how data flows through pipelines, see the [PipelineState and BlockState guide](modular_diffusers_states.md).
-
-
-
-
-## Create ModularPipelineBlocks
-
-In Modular Diffusers system, you build pipelines using Pipeline blocks. Pipeline Blocks are fundamental building blocks - they define what components, inputs/outputs, and computation logics are needed. They are designed to be assembled into workflows for tasks such as image generation, video creation, and inpainting. But they are just definitions and don't actually run anything. To execute blocks, you need to put them into a `ModularPipeline`. We'll first learn how to create predefined blocks here before talking about how to run them using `ModularPipeline`.
-
-All pipeline blocks inherit from the base class `ModularPipelineBlocks`, including:
-
-- [`PipelineBlock`]: The most granular block - you define the input/output/components requirements and computation logic.
-- [`SequentialPipelineBlocks`]: A multi-block composed of multiple blocks that run sequentially, passing outputs as inputs to the next block.
-- [`LoopSequentialPipelineBlocks`]: A special type of `SequentialPipelineBlocks` that runs the same sequence of blocks multiple times (loops), typically used for iterative processes like denoising steps in diffusion models.
-- [`AutoPipelineBlocks`]: A multi-block composed of multiple blocks that are selected at runtime based on the inputs.
-
-It is very easy to use a `ModularPipelineBlocks` officially supported in 🧨 Diffusers
-
-```py
-from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLTextEncoderStep
-
-text_encoder_block = StableDiffusionXLTextEncoderStep()
-```
-
-This is a single `PipelineBlock`. You'll see that this text encoder block uses 2 text_encoders, 2 tokenizers as well as a guider component. It takes user inputs such as `prompt` and `negative_prompt`, and return text embeddings outputs such as `prompt_embeds` and `negative_prompt_embeds`.
-
-```py
->>> text_encoder_block
-StableDiffusionXLTextEncoderStep(
- Class: PipelineBlock
- Description: Text Encoder step that generate text_embeddings to guide the image generation
- Components:
- text_encoder (`CLIPTextModel`)
- text_encoder_2 (`CLIPTextModelWithProjection`)
- tokenizer (`CLIPTokenizer`)
- tokenizer_2 (`CLIPTokenizer`)
- guider (`ClassifierFreeGuidance`)
- Configs:
- force_zeros_for_empty_prompt (default: True)
- Inputs:
- prompt=None, prompt_2=None, negative_prompt=None, negative_prompt_2=None, cross_attention_kwargs=None, clip_skip=None
- Intermediates:
- - outputs: prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
-)
-```
-
-More commonly, you need multiple blocks to build your workflow. You can create a `SequentialPipelineBlocks` using block class presets from 🧨 Diffusers. `TEXT2IMAGE_BLOCKS` is a dict containing all the blocks needed for text-to-image generation.
-
-```py
-from diffusers.modular_pipelines import SequentialPipelineBlocks
-from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
-t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
-```
-
-This creates a `SequentialPipelineBlocks`. Unlike the `text_encoder_block` we saw earlier, this is a multi-block and its `sub_blocks` attribute contains a list of other blocks (text_encoder, input, set_timesteps, prepare_latents, prepare_added_con, denoise, decode). Its requirements for components, inputs, and intermediate inputs are combined from these blocks that compose it. At runtime, it executes its sub-blocks sequentially and passes the pipeline state from one block to another.
-
-```py
->>> t2i_blocks
-SequentialPipelineBlocks(
- Class: ModularPipelineBlocks
-
- Description:
-
-
- Components:
- text_encoder (`CLIPTextModel`)
- text_encoder_2 (`CLIPTextModelWithProjection`)
- tokenizer (`CLIPTokenizer`)
- tokenizer_2 (`CLIPTokenizer`)
- guider (`ClassifierFreeGuidance`)
- scheduler (`EulerDiscreteScheduler`)
- unet (`UNet2DConditionModel`)
- vae (`AutoencoderKL`)
- image_processor (`VaeImageProcessor`)
-
- Configs:
- force_zeros_for_empty_prompt (default: True)
-
- Sub-Blocks:
- [0] text_encoder (StableDiffusionXLTextEncoderStep)
- Description: Text Encoder step that generate text_embeddings to guide the image generation
-
- [1] input (StableDiffusionXLInputStep)
- Description: Input processing step that:
- 1. Determines `batch_size` and `dtype` based on `prompt_embeds`
- 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`
-
- All input tensors are expected to have either batch_size=1 or match the batch_size
- of prompt_embeds. The tensors will be duplicated across the batch dimension to
- have a final batch_size of batch_size * num_images_per_prompt.
-
- [2] set_timesteps (StableDiffusionXLSetTimestepsStep)
- Description: Step that sets the scheduler's timesteps for inference
-
- [3] prepare_latents (StableDiffusionXLPrepareLatentsStep)
- Description: Prepare latents step that prepares the latents for the text-to-image generation process
-
- [4] prepare_add_cond (StableDiffusionXLPrepareAdditionalConditioningStep)
- Description: Step that prepares the additional conditioning for the text-to-image generation process
-
- [5] denoise (StableDiffusionXLDenoiseStep)
- Description: Denoise step that iteratively denoise the latents.
- Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method
- At each iteration, it runs blocks defined in `sub_blocks` sequencially:
- - `StableDiffusionXLLoopBeforeDenoiser`
- - `StableDiffusionXLLoopDenoiser`
- - `StableDiffusionXLLoopAfterDenoiser`
- This block supports both text2img and img2img tasks.
-
- [6] decode (StableDiffusionXLDecodeStep)
- Description: Step that decodes the denoised latents into images
-
-)
-```
-
-This is the block classes preset (`TEXT2IMAGE_BLOCKS`) we used: It is just a dictionary that maps names to ModularPipelineBlocks classes
-
-```py
->>> TEXT2IMAGE_BLOCKS
-InsertableDict([
- 0: ('text_encoder', ),
- 1: ('input', ),
- 2: ('set_timesteps', ),
- 3: ('prepare_latents', ),
- 4: ('prepare_add_cond', ),
- 5: ('denoise', ),
- 6: ('decode', )
-])
-```
-
-When we create a `SequentialPipelineBlocks` from this preset, it instantiates each block class into actual block objects. Its `sub_blocks` attribute now contains these instantiated objects:
-
-```py
->>> t2i_blocks.sub_blocks
-InsertableDict([
- 0: ('text_encoder', ),
- 1: ('input', ),
- 2: ('set_timesteps', ),
- 3: ('prepare_latents', ),
- 4: ('prepare_add_cond', ),
- 5: ('denoise', ),
- 6: ('decode', )
-])
-```
-
-Note that both the block classes preset and the `sub_blocks` attribute are `InsertableDict` objects. This is a custom dictionary that extends `OrderedDict` with the ability to insert items at specific positions. You can perform all standard dictionary operations (get, set, delete) plus insert items at any index, which is particularly useful for reordering or inserting blocks in the middle of a pipeline.
-
-**Add a block:**
-```py
-# BLOCKS is dict of block classes, you need to add class to it
-BLOCKS.insert("block_name", BlockClass, index)
-# sub_blocks attribute contains instance, add a block instance to the attribute
-t2i_blocks.sub_blocks.insert("block_name", block_instance, index)
-```
-
-**Remove a block:**
-```py
-# remove a block class from preset
-BLOCKS.pop("text_encoder")
-# split out a block instance on its own
-text_encoder_block = t2i_blocks.sub_blocks.pop("text_encoder")
-```
-
-**Swap block:**
-```py
-# Replace block class in preset
-BLOCKS["prepare_latents"] = CustomPrepareLatents
-# Replace in sub_blocks attribute using an block instance
-t2i_blocks.sub_blocks["prepare_latents"] = CustomPrepareLatents()
-```
-
-This means you can mix-and-match blocks in very flexible ways. Let's see some real examples:
-
-**Example 1: Adding IP-Adapter to the Block Classes Preset**
-Let's make a new block classes preset by insert IP-Adapter at index 0 (before the text_encoder block), and create a text-to-image pipeline with IP-Adapter support:
-
-```py
-from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLAutoIPAdapterStep
-CUSTOM_BLOCKS = TEXT2IMAGE_BLOCKS.copy()
-# CUSTOM_BLOCKS is now a preset including ip_adapter
-CUSTOM_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0)
-# create a blocks isntance from the preset
-custom_blocks = SequentialPipelineBlocks.from_blocks_dict(CUSTOM_BLOCKS)
-```
-
-**Example 2: Extracting a block from a multi-block**
-You can extract a block instance from the multi-block to use it independently. A common pattern is to use text_encoder to process prompts once, then reuse the text embeddings outputs to generate multiple images with different settings (schedulers, seeds, inference steps). We can do this by simply extracting the text_encoder block from the pipeline.
-
-```py
-# this gives you StableDiffusionXLTextEncoderStep()
->>> text_encoder_blocks = t2i_blocks.sub_blocks.pop("text_encoder")
->>> text_encoder_blocks
-```
-
-The multi-block now has fewer components and no longer has the `text_encoder` block. If you check its docstring `t2i_blocks.doc`, you will see that it no longer accepts `prompt` as input - you will need to pass the embeddings instead.
-
-```py
->>> t2i_blocks
-SequentialPipelineBlocks(
- Class: ModularPipelineBlocks
-
- Description:
-
- Components:
- scheduler (`EulerDiscreteScheduler`)
- guider (`ClassifierFreeGuidance`)
- unet (`UNet2DConditionModel`)
- vae (`AutoencoderKL`)
- image_processor (`VaeImageProcessor`)
-
- Blocks:
- [0] input (StableDiffusionXLInputStep)
- Description: Input processing step that:
- 1. Determines `batch_size` and `dtype` based on `prompt_embeds`
- 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`
-
- All input tensors are expected to have either batch_size=1 or match the batch_size
- of prompt_embeds. The tensors will be duplicated across the batch dimension to
- have a final batch_size of batch_size * num_images_per_prompt.
-
- [1] set_timesteps (StableDiffusionXLSetTimestepsStep)
- Description: Step that sets the scheduler's timesteps for inference
-
- [2] prepare_latents (StableDiffusionXLPrepareLatentsStep)
- Description: Prepare latents step that prepares the latents for the text-to-image generation process
-
- [3] prepare_add_cond (StableDiffusionXLPrepareAdditionalConditioningStep)
- Description: Step that prepares the additional conditioning for the text-to-image generation process
-
- [4] denoise (StableDiffusionXLDenoiseLoop)
- Description: Denoise step that iteratively denoise the latents.
- Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method
- At each iteration, it runs blocks defined in `blocks` sequencially:
- - `StableDiffusionXLLoopBeforeDenoiser`
- - `StableDiffusionXLLoopDenoiser`
- - `StableDiffusionXLLoopAfterDenoiser`
-
-
- [5] decode (StableDiffusionXLDecodeStep)
- Description: Step that decodes the denoised latents into images
-
-)
-```
-
-
-
-💡 You can find all the block classes presets we support for each model in `ALL_BLOCKS`.
-
-```py
-# For Stable Diffusion XL
-from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS
-ALL_BLOCKS
-# For other models...
-from diffusers.modular_pipelines. import ALL_BLOCKS
-```
-
-Each model provides a dictionary that maps all supported tasks/techniques to their corresponding block classes presets. For SDXL, it is
-
-```py
-ALL_BLOCKS = {
- "text2img": TEXT2IMAGE_BLOCKS,
- "img2img": IMAGE2IMAGE_BLOCKS,
- "inpaint": INPAINT_BLOCKS,
- "controlnet": CONTROLNET_BLOCKS,
- "ip_adapter": IP_ADAPTER_BLOCKS,
- "auto": AUTO_BLOCKS,
-}
-```
-
-
-
-This covers the essentials of pipeline blocks! Like we have already mentioned, **pipeline blocks are not runnable by themselves**. They are essentially **"definitions"** - they define the specifications and computational steps for a pipeline, but they do not contain any model states. To actually run them, you need to convert them into a `ModularPipeline` object.
-
-
-## Modular Repo
-
-To convert blocks into a runnable pipeline, you may need a repository if your blocks contain **pretrained components** (models with checkpoints that need to be loaded from the Hub). Pipeline blocks define what components they need (like a UNet, text encoder, etc.), as well as how to create them: components can be either created using **from_pretrained** method (with checkpoints) or **from_config** (initialized from scratch with default configuration, usually stateless like a guider or scheduler).
-
-If your pipeline contains **pretrained components**, you typically need to use a repository to provide the loading specifications and metadata.
-
-`ModularPipeline` works specifically with modular repositories, which offer more flexibility in component loading compared to traditional repositories. You can find an example modular repo [here](https://huggingface.co/YiYiXu/modular-diffdiff).
-
-A `DiffusionPipeline` defines `model_index.json` to configure its components. However, repositories for Modular Diffusers work with `modular_model_index.json`. Let's walk through the differences here.
-
-In standard `model_index.json`, each component entry is a `(library, class)` tuple:
-```py
-"text_encoder": [
- "transformers",
- "CLIPTextModel"
-],
-```
-
-In `modular_model_index.json`, each component entry contains 3 elements: `(library, class, loading_specs_dict)`
-
-- `library` and `class`: Information about the actual component loaded in the pipeline at the time of saving (will be `null` if not loaded)
-- `loading_specs_dict`: A dictionary containing all information required to load this component, including `repo`, `revision`, `subfolder`, `variant`, and `type_hint`.
-
-```py
-"text_encoder": [
- null, # library of actual loaded component (same as in model_index.json)
- null, # class of actual loaded componenet (same as in model_index.json)
- { # loading specs map (unique to modular_model_index.json)
- "repo": "stabilityai/stable-diffusion-xl-base-1.0", # can be a different repo
- "revision": null,
- "subfolder": "text_encoder",
- "type_hint": [ # (library, class) for the expected component
- "transformers",
- "CLIPTextModel"
- ],
- "variant": null
- }
-],
-```
-
-Unlike standard repositories where components must be in subfolders within the same repo, modular repositories can fetch components from different repositories based on the `loading_specs_dict`. e.g. the `text_encoder` component will be fetched from the "text_encoder" folder in `stabilityai/stable-diffusion-xl-base-1.0` while other components come from different repositories.
-
-
-## Creating a `ModularPipeline` from `ModularPipelineBlocks`
-
-Each `ModularPipelineBlocks` has an `init_pipeline` method that can initialize a `ModularPipeline` object based on its component and configuration specifications.
-
-Let's convert our `t2i_blocks` (which we created earlier) into a runnable `ModularPipeline`. We'll use a `ComponentsManager` to handle device placement, memory management, and component reuse automatically:
-
-```py
-# We already have this from earlier
-t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
-
-# Now convert it to a ModularPipeline
-from diffusers import ComponentsManager
-modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
-components = ComponentsManager()
-t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components)
-```
-
-
-
-💡 **ComponentsManager** is the model registry and management system in diffusers, it track all the models in one place and let you add, remove and reuse them across different workflows in most efficient way. Without it, you'd need to manually manage GPU memory, device placement, and component sharing between workflows. See the [Components Manager guide](components_manager.md) for detailed information.
-
-
-
-The `init_pipeline()` method creates a ModularPipeline and loads component specifications from the repository's `modular_model_index.json` file, but doesn't load the actual models yet.
-
-
-## Creating a `ModularPipeline` with `from_pretrained`
-
-You can create a `ModularPipeline` from a HuggingFace Hub repository with `from_pretrained` method, as long as it's a modular repo:
-
-```py
-from diffusers import ModularPipeline, ComponentsManager
-components = ComponentsManager()
-pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-loader-t2i-0704", components_manager=components)
-```
-
-Loading custom code is also supported:
-
-```py
-from diffusers import ModularPipeline, ComponentsManager
-components = ComponentsManager()
-modular_repo_id = "YiYiXu/modular-diffdiff-0704"
-diffdiff_pipeline = ModularPipeline.from_pretrained(modular_repo_id, trust_remote_code=True, components_manager=components)
-```
-
-This modular repository contains custom code. The folder contains these files:
-
-```
-modular-diffdiff-0704/
-├── block.py # Custom pipeline blocks implementation
-├── config.json # Pipeline configuration and auto_map
-└── modular_model_index.json # Component loading specifications
-```
-
-The [`config.json`](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/config.json) file defines a custom `DiffDiffBlocks` class and points to its implementation:
-
-```json
-{
- "_class_name": "DiffDiffBlocks",
- "auto_map": {
- "ModularPipelineBlocks": "block.DiffDiffBlocks"
- }
-}
-```
-
-The `auto_map` tells the pipeline where to find the custom blocks definition - in this case, it's looking for `DiffDiffBlocks` in the `block.py` file. The actual `DiffDiffBlocks` class is defined in [`block.py`](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/block.py) within the repository.
-
-When `diffdiff_pipeline.blocks` is created, it's based on the `DiffDiffBlocks` definition from the custom code in the repository, allowing you to use specialized blocks that aren't part of the standard diffusers library.
-
-## Loading components into a `ModularPipeline`
-
-Unlike `DiffusionPipeline`, when you create a `ModularPipeline` instance (whether using `from_pretrained` or converting from pipeline blocks), its components aren't loaded automatically. You need to explicitly load model components using `load_default_components` or `load_components(names=..,)`:
-
-```py
-# This will load ALL the expected components into pipeline
-import torch
-t2i_pipeline.load_default_components(torch_dtype=torch.float16)
-t2i_pipeline.to("cuda")
-```
-
-All expected components are now loaded into the pipeline. You can also partially load specific components using the `names` argument. For example, to only load unet and vae:
-
-```py
->>> t2i_pipeline.load_components(names=["unet", "vae"], torch_dtype=torch.float16)
-```
-
-You can inspect the pipeline's loading status by simply printing the pipeline itself. It helps you understand what components are expected to load, which ones are already loaded, how they were loaded, and what loading specs are available. Let's print out the `t2i_pipeline`:
-
-```py
->>> t2i_pipeline
-StableDiffusionXLModularPipeline {
- "_blocks_class_name": "SequentialPipelineBlocks",
- "_class_name": "StableDiffusionXLModularPipeline",
- "_diffusers_version": "0.35.0.dev0",
- "force_zeros_for_empty_prompt": true,
- "scheduler": [
- null,
- null,
- {
- "repo": "stabilityai/stable-diffusion-xl-base-1.0",
- "revision": null,
- "subfolder": "scheduler",
- "type_hint": [
- "diffusers",
- "EulerDiscreteScheduler"
- ],
- "variant": null
- }
- ],
- "text_encoder": [
- null,
- null,
- {
- "repo": "stabilityai/stable-diffusion-xl-base-1.0",
- "revision": null,
- "subfolder": "text_encoder",
- "type_hint": [
- "transformers",
- "CLIPTextModel"
- ],
- "variant": null
- }
- ],
- "text_encoder_2": [
- null,
- null,
- {
- "repo": "stabilityai/stable-diffusion-xl-base-1.0",
- "revision": null,
- "subfolder": "text_encoder_2",
- "type_hint": [
- "transformers",
- "CLIPTextModelWithProjection"
- ],
- "variant": null
- }
- ],
- "tokenizer": [
- null,
- null,
- {
- "repo": "stabilityai/stable-diffusion-xl-base-1.0",
- "revision": null,
- "subfolder": "tokenizer",
- "type_hint": [
- "transformers",
- "CLIPTokenizer"
- ],
- "variant": null
- }
- ],
- "tokenizer_2": [
- null,
- null,
- {
- "repo": "stabilityai/stable-diffusion-xl-base-1.0",
- "revision": null,
- "subfolder": "tokenizer_2",
- "type_hint": [
- "transformers",
- "CLIPTokenizer"
- ],
- "variant": null
- }
- ],
- "unet": [
- "diffusers",
- "UNet2DConditionModel",
- {
- "repo": "RunDiffusion/Juggernaut-XL-v9",
- "revision": null,
- "subfolder": "unet",
- "type_hint": [
- "diffusers",
- "UNet2DConditionModel"
- ],
- "variant": "fp16"
- }
- ],
- "vae": [
- "diffusers",
- "AutoencoderKL",
- {
- "repo": "madebyollin/sdxl-vae-fp16-fix",
- "revision": null,
- "subfolder": null,
- "type_hint": [
- "diffusers",
- "AutoencoderKL"
- ],
- "variant": null
- }
- ]
-}
-```
-
-You can see all the **pretrained components** that will be loaded using `from_pretrained` method are listed as entries. Each entry contains 3 elements: `(library, class, loading_specs_dict)`:
-
-- **`library` and `class`**: Show the actual loaded component info. If `null`, the component is not loaded yet.
-- **`loading_specs_dict`**: Contains all the information needed to load the component (repo, subfolder, variant, etc.)
-
-In this example:
-- **Loaded components**: `vae` and `unet` (their `library` and `class` fields show the actual loaded models)
-- **Not loaded yet**: `scheduler`, `text_encoder`, `text_encoder_2`, `tokenizer`, `tokenizer_2` (their `library` and `class` fields are `null`, but you can see their loading specs to know where they'll be loaded from when you call `load_components()`)
-
-You're looking at essentailly the pipeline's config dict that's synced with the `modular_model_index.json` from the repository you used during `init_pipeline()` - it takes the loading specs that match the pipeline's component requirements.
-
-For example, if your pipeline needs a `text_encoder` component, it will include the loading spec for `text_encoder` from the modular repo during the `init_pipeline`. If the pipeline doesn't need a component (like `controlnet` in a basic text-to-image pipeline), that component won't be included even if it exists in the modular repo.
-
-There are also a few properties that can provide a quick summary of component loading status:
-
-```py
-# All components expected by the pipeline
->>> t2i_pipeline.component_names
-['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'guider', 'scheduler', 'unet', 'vae', 'image_processor']
-
-# Components that are not loaded yet (will be loaded with from_pretrained)
->>> t2i_pipeline.null_component_names
-['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler']
-
-# Components that will be loaded from pretrained models
->>> t2i_pipeline.pretrained_component_names
-['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler', 'unet', 'vae']
-
-# Components that are created with default config (no repo needed)
->>> t2i_pipeline.config_component_names
-['guider', 'image_processor']
-```
-
-From config components (like `guider` and `image_processor`) are not included in the pipeline output above because they don't need loading specs - they're already initialized during pipeline creation. You can see this because they're not listed in `null_component_names`.
-
-## Modifying Loading Specs
-
-When you call `pipeline.load_components(names=)` or `pipeline.load_default_components()`, it uses the loading specs from the modular repository's `modular_model_index.json`. You can change where components are loaded from by modifying the `modular_model_index.json` in the repository. Just find the file on the Hub and click edit - you can change any field in the loading specs: `repo`, `subfolder`, `variant`, `revision`, etc.
-
-```py
-# Original spec in modular_model_index.json
-"unet": [
- null, null,
- {
- "repo": "stabilityai/stable-diffusion-xl-base-1.0",
- "subfolder": "unet",
- "variant": "fp16"
- }
-]
-
-# Modified spec - changed repo, subfolder, and variant
-"unet": [
- null, null,
- {
- "repo": "RunDiffusion/Juggernaut-XL-v9",
- "subfolder": "unet",
- "variant": "fp16"
- }
-]
-```
-
-Now if you create a pipeline using the same blocks and updated repository, it will by default load from the new repository.
-
-```py
-pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-loader-t2i-0704", components_manager=components)
-pipeline.load_components(names="unet")
-```
-
-
-## Updating components in a `ModularPipeline`
-
-Similar to `DiffusionPipeline`, you can load components separately to replace the default ones in the pipeline. In Modular Diffusers, the approach depends on the component type:
-
-- **Pretrained components** (`default_creation_method='from_pretrained'`): Must use `ComponentSpec` to load them to update the existing one.
-- **Config components** (`default_creation_method='from_config'`): These are components that don't need loading specs - they're created during pipeline initialization with default config. To update them, you can either pass the object directly or pass a ComponentSpec directly.
-
-
-
-💡 **Component Type Changes**: The component type (pretrained vs config-based) can change when you update components. These types are initially defined in pipeline blocks' `expected_components` field using `ComponentSpec` with `default_creation_method`. See the [Customizing Guidance Techniques](#customizing-guidance-techniques) section for examples of how this works in practice.
-
-
-
-`ComponentSpec` defines how to create or load components and can actually create them using its `create()` method (for ConfigMixin objects) or `load()` method (wrapper around `from_pretrained()`). When a component is loaded with a ComponentSpec, it gets tagged with a unique ID that encodes its creation parameters, allowing you to always extract the original specification using `ComponentSpec.from_component()`.
-
-Now let's look at how to update pretrained components in practice:
-
-So instead of
-
-```py
-from diffusers import UNet2DConditionModel
-import torch
-unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16", torch_dtype=torch.float16)
-```
-You should load your model like this
-
-```py
-from diffusers import ComponentSpec, UNet2DConditionModel
-unet_spec = ComponentSpec(name="unet",type_hint=UNet2DConditionModel, repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16")
-unet2 = unet_spec.load(torch_dtype=torch.float16)
-```
-
-The key difference is that the second unet retains its loading specs, so you can extract the spec and recreate the unet:
-
-```py
-# component -> spec
->>> spec = ComponentSpec.from_component("unet", unet2)
->>> spec
-ComponentSpec(name='unet', type_hint=, description=None, config=None, repo='stabilityai/stable-diffusion-xl-base-1.0', subfolder='unet', variant='fp16', revision=None, default_creation_method='from_pretrained')
-# spec -> component
->>> unet2_recreatd = spec.load(torch_dtype=torch.float16)
-```
-
-To replace the unet in the pipeline
-
-```
-t2i_pipeline.update_components(unet=unet2)
-```
-
-Not only is the `unet` component swapped, but its loading specs are also updated from "RunDiffusion/Juggernaut-XL-v9" to "stabilityai/stable-diffusion-xl-base-1.0" in pipeline config. This means that if you save the pipeline now and load it back with `from_pretrained`, the new pipeline will by default load the SDXL original unet.
-
-```
->>> t2i_pipeline
-StableDiffusionXLModularPipeline {
- ...
- "unet": [
- "diffusers",
- "UNet2DConditionModel",
- {
- "repo": "stabilityai/stable-diffusion-xl-base-1.0",
- "revision": null,
- "subfolder": "unet",
- "type_hint": [
- "diffusers",
- "UNet2DConditionModel"
- ],
- "variant": "fp16"
- }
- ],
- ...
-}
-```
-
-
-💡 **Modifying Component Specs**: You can get a copy of the current component spec from the pipeline using `get_component_spec()`. This makes it easy to modify the spec and updating components.
-
-```py
->>> unet_spec = t2i_pipeline.get_component_spec("unet")
->>> unet_spec
-ComponentSpec(
- name='unet',
- type_hint=,
- repo='RunDiffusion/Juggernaut-XL-v9',
- subfolder='unet',
- variant='fp16',
- default_creation_method='from_pretrained'
-)
-
-# Modify the spec to load from a different repository
->>> unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0"
-
-# Load the component with the modified spec
->>> unet = unet_spec.load(torch_dtype=torch.float16)
-```
-
-
-
-## Customizing Guidance Techniques
-
-Guiders are implementations of different [classifier-free guidance](https://huggingface.co/papers/2207.12598) techniques that can be applied during the denoising process to improve generation quality, control, and adherence to prompts. They work by steering the model predictions towards desired directions and away from undesired directions. In diffusers, guiders are implemented as subclasses of `BaseGuidance`. They can easily be integrated into modular pipelines and provide a flexible way to enhance generation quality without modifying the underlying diffusion models.
-
-**ClassifierFreeGuidance (CFG)** is the first and most common guidance technique, used in all our standard pipelines. We also offer many other guidance techniques from the latest research in this area - **PerturbedAttentionGuidance (PAG)**, **SkipLayerGuidance (SLG)**, **SmoothedEnergyGuidance (SEG)**, and others that can provide better results for specific use cases.
-
-This section demonstrates how to use guiders using the component updating methods we just learned. Since `BaseGuidance` components are stateless (similar to schedulers), they are typically created with default configurations during pipeline initialization using `default_creation_method='from_config'`. This means they don't require loading specs from the repository - you won't see guider listed in `modular_model_index.json` files.
-
-Let's take a look at the default guider configuration:
-
-```py
->>> t2i_pipeline.get_component_spec("guider")
-ComponentSpec(name='guider', type_hint=, description=None, config=FrozenDict([('guidance_scale', 7.5), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['start', 'guidance_rescale', 'stop', 'use_original_formulation'])]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config')
-```
-
-As you can see, the guider is configured to use `ClassifierFreeGuidance` with default parameters and `default_creation_method='from_config'`, meaning it's created during pipeline initialization rather than loaded from a repository. Let's verify this, here we run `init_pipeline()` without a modular repo, and there it is, a guider with the default configuration we just saw
-
-
-```py
->>> pipeline = t2i_blocks.init_pipeline()
->>> pipeline.guider
-ClassifierFreeGuidance {
- "_class_name": "ClassifierFreeGuidance",
- "_diffusers_version": "0.35.0.dev0",
- "guidance_rescale": 0.0,
- "guidance_scale": 7.5,
- "start": 0.0,
- "stop": 1.0,
- "use_original_formulation": false
-}
-```
-
-#### Modify Parameters of the Same Guider Type
-
-To change parameters of the same guider type (e.g., adjusting the `guidance_scale` for CFG), you have two options:
-
-**Option 1: Use ComponentSpec.create() method**
-
-You just need to pass the parameter with the new value to override the default one.
-
-```python
->>> guider_spec = t2i_pipeline.get_component_spec("guider")
->>> guider = guider_spec.create(guidance_scale=10)
->>> t2i_pipeline.update_components(guider=guider)
-```
-
-**Option 2: Pass ComponentSpec directly**
-
-Update the spec directly and pass it to `update_components()`.
-
-```python
->>> guider_spec = t2i_pipeline.get_component_spec("guider")
->>> guider_spec.config["guidance_scale"] = 10
->>> t2i_pipeline.update_components(guider=guider_spec)
-```
-
-Both approaches produce the same result:
-```python
->>> t2i_pipeline.guider
-ClassifierFreeGuidance {
- "_class_name": "ClassifierFreeGuidance",
- "_diffusers_version": "0.35.0.dev0",
- "guidance_rescale": 0.0,
- "guidance_scale": 10,
- "start": 0.0,
- "stop": 1.0,
- "use_original_formulation": false
-}
-```
-
-#### Switch to a Different Guider Type
-
-Switching between guidance techniques is as simple as passing a guider object of that technique:
-
-```py
-from diffusers import LayerSkipConfig, PerturbedAttentionGuidance
-config = LayerSkipConfig(indices=[2, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=False, skip_attention_scores=True, skip_ff=False)
-guider = PerturbedAttentionGuidance(
- guidance_scale=5.0, perturbed_guidance_scale=2.5, perturbed_guidance_config=config
-)
-t2i_pipeline.update_components(guider=guider)
-```
-
-Note that you will get a warning about changing the guider type, which is expected:
-
-```
-ModularPipeline.update_components: adding guider with new type: PerturbedAttentionGuidance, previous type: ClassifierFreeGuidance
-```
-
-
-
-- For `from_config` components (like guiders, schedulers): You can pass an object of required type OR pass a ComponentSpec directly (which calls `create()` under the hood)
-- For `from_pretrained` components (like models): You must use ComponentSpec to ensure proper tagging and loading
-
-
-
-Let's verify that the guider has been updated:
-
-```py
->>> t2i_pipeline.guider
-PerturbedAttentionGuidance {
- "_class_name": "PerturbedAttentionGuidance",
- "_diffusers_version": "0.35.0.dev0",
- "guidance_rescale": 0.0,
- "guidance_scale": 5.0,
- "perturbed_guidance_config": {
- "dropout": 1.0,
- "fqn": "mid_block.attentions.0.transformer_blocks",
- "indices": [
- 2,
- 9
- ],
- "skip_attention": false,
- "skip_attention_scores": true,
- "skip_ff": false
- },
- "perturbed_guidance_layers": null,
- "perturbed_guidance_scale": 2.5,
- "perturbed_guidance_start": 0.01,
- "perturbed_guidance_stop": 0.2,
- "start": 0.0,
- "stop": 1.0,
- "use_original_formulation": false
-}
-
-```
-
-The component spec has also been updated to reflect the new guider type:
-
-```py
->>> t2i_pipeline.get_component_spec("guider")
-ComponentSpec(name='guider', type_hint=, description=None, config=FrozenDict([('guidance_scale', 5.0), ('perturbed_guidance_scale', 2.5), ('perturbed_guidance_start', 0.01), ('perturbed_guidance_stop', 0.2), ('perturbed_guidance_layers', None), ('perturbed_guidance_config', LayerSkipConfig(indices=[2, 9], fqn='mid_block.attentions.0.transformer_blocks', skip_attention=False, skip_attention_scores=True, skip_ff=False, dropout=1.0)), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['perturbed_guidance_start', 'use_original_formulation', 'perturbed_guidance_layers', 'stop', 'start', 'guidance_rescale', 'perturbed_guidance_stop']), ('_class_name', 'PerturbedAttentionGuidance'), ('_diffusers_version', '0.35.0.dev0')]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config')
-```
-
-The "guider" is still a `from_config` component: is still not included in the pipeline config and will not be saved into the `modular_model_index.json`.
-
-```py
->>> assert "guider" not in t2i_pipeline.config
-```
-
-However, you can change it to a `from_pretrained` component, which allows you to upload your customized guider to the Hub and load it into your pipeline.
-
-#### Loading Custom Guiders from Hub
-
-If you already have a guider saved on the Hub and a `modular_model_index.json` with the loading spec for that guider, it will automatically be changed to a `from_pretrained` component during pipeline initialization.
-
-For example, this `modular_model_index.json` includes loading specs for the guider:
-
-```json
-{
- "guider": [
- null,
- null,
- {
- "repo": "YiYiXu/modular-loader-t2i-guider",
- "revision": null,
- "subfolder": "pag_guider",
- "type_hint": [
- "diffusers",
- "PerturbedAttentionGuidance"
- ],
- "variant": null
- }
- ]
-}
-```
-
-When you use this repository to create a pipeline with the same blocks (that originally configured guider as a `from_config` component), the guider becomes a `from_pretrained` component. This means it doesn't get created during initialization, and after you call `load_default_components()`, it loads based on the spec - resulting in the PAG guider instead of the default CFG.
-
-```py
-t2i_pipeline = t2i_blocks.init_pipeline("YiYiXu/modular-doc-guider")
-assert t2i_pipeline.guider is None # Not created during init
-t2i_pipeline.load_default_components()
-t2i_pipeline.guider # Now loaded as PAG guider
-```
-
-#### Upload Custom Guider to Hub for Easy Loading & Sharing
-
-Now let's see how we can share the guider on the Hub and change it to a `from_pretrained` component.
-
-```py
-guider.push_to_hub("YiYiXu/modular-loader-t2i-guider", subfolder="pag_guider")
-```
-
-Voilà! Now you have a subfolder called `pag_guider` on that repository.
-
-You have a few options to make this guider available in your pipeline:
-
-1. **Directly modify the `modular_model_index.json`** to add a loading spec for the guider by pointing to a folder containing the desired guider config.
-
-2. **Use the `update_components` method** to change it to a `from_pretrained` component for your pipeline. This is easier if you just want to try it out with different repositories.
-
-Let's use the second approach and change our guider_spec to use `from_pretrained` as the default creation method and update the loading spec to use this subfolder we just created:
-
-```python
-guider_spec = t2i_pipeline.get_component_spec("guider")
-guider_spec.default_creation_method="from_pretrained"
-guider_spec.repo="YiYiXu/modular-loader-t2i-guider"
-guider_spec.subfolder="pag_guider"
-pag_guider = guider_spec.load()
-t2i_pipeline.update_components(guider=pag_guider)
-```
-
-You will get a warning about changing the creation method:
-
-```
-ModularPipeline.update_components: changing the default_creation_method of guider from from_config to from_pretrained.
-```
-
-Now not only the `guider` component and its component_spec are updated, but so is the pipeline config.
-
-If you want to change the default behavior for future pipelines, you can push the updated pipeline to the Hub. This way, when others use your repository, they'll get the PAG guider by default. However, this is optional - you don't have to do this if you just want to experiment locally.
-
-```py
-t2i_pipeline.push_to_hub("YiYiXu/modular-doc-guider")
-```
-
-
-
-
-Experiment with different techniques and parameters to find what works best for your specific use case! You can find all the guider class we support [here](TODO: API doc)
-
-Additionally, you can write your own guider implementations, for example, CFG Zero* combined with Skip Layer Guidance, and they should be compatible out-of-the-box with modular diffusers!
-
-
-
-## Running a `ModularPipeline`
-
-The API to run the `ModularPipeline` is very similar to how you would run a regular `DiffusionPipeline`:
-
-```py
->>> image = pipeline(prompt="a cat", num_inference_steps=15, output="images")[0]
-```
-
-There are a few key differences though:
-1. You can also pass a `PipelineState` object directly to the pipeline instead of individual arguments
-2. If you do not specify the `output` argument, it returns the `PipelineState` object
-3. You can pass a list as `output`, e.g. `pipeline(... output=["images", "latents"])` will return a dictionary containing both the generated image and the final denoised latents
-
-Under the hood, `ModularPipeline`'s `__call__` method is a wrapper around the pipeline blocks' `__call__` method: it creates a `PipelineState` object and populates it with user inputs, then returns the output to the user based on the `output` argument. It also ensures that all pipeline-level config and components are exposed to all pipeline blocks by preparing and passing a `components` input.
-
-
-
-You can inspect the docstring of a `ModularPipeline` to check what arguments the pipeline accepts and how to specify the `output` you want. It will list all available outputs (basically everything in the intermediate pipeline state) so you can choose from the list.
-
-```py
-t2i_pipeline.doc
-```
-
-**Important**: It is important to always check the docstring because arguments can be different from standard pipelines that you're familar with. For example, in Modular Diffusers we standardized controlnet image input as `control_image`, but regular pipelines have inconsistencies over the names, e.g. controlnet text-to-image uses `image` while SDXL controlnet img2img uses `control_image`.
-
-**Note**: The `output` list might be longer than you expected - it includes everything in the intermediate state that you can choose to return. Most of the time, you'll just want `output="images"` or `output="latents"`.
-
-
-
-#### Text-to-Image, Image-to-Image, and Inpainting
-
-These are minimum inference examples for basic tasks: text-to-image, image-to-image, and inpainting. The process to create different pipelines is the same - only difference is the block classes presets. The inference is also more or less same to standard pipelines, but please always check `.doc` for correct input names and remember to pass `output="images"`.
-
-
-
+
```py
@@ -976,16 +24,14 @@ import torch
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
-# create pipeline from official blocks preset
blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
pipeline = blocks.init_pipeline(modular_repo_id)
-pipeline.load_default_components(torch_dtype=torch.float16)
+pipeline.load_components(torch_dtype=torch.float16)
pipeline.to("cuda")
-# run pipeline, need to pass a "output=images" argument
image = pipeline(prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", output="images")[0]
image.save("modular_t2i_out.png")
```
@@ -998,13 +44,12 @@ import torch
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
-# create pipeline from blocks preset
blocks = SequentialPipelineBlocks.from_blocks_dict(IMAGE2IMAGE_BLOCKS)
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
pipeline = blocks.init_pipeline(modular_repo_id)
-pipeline.load_default_components(torch_dtype=torch.float16)
+pipeline.load_components(torch_dtype=torch.float16)
pipeline.to("cuda")
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
@@ -1023,13 +68,12 @@ from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
from diffusers.utils import load_image
-# create pipeline from blocks preset
blocks = SequentialPipelineBlocks.from_blocks_dict(INPAINT_BLOCKS)
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
pipeline = blocks.init_pipeline(modular_repo_id)
-pipeline.load_default_components(torch_dtype=torch.float16)
+pipeline.load_components(torch_dtype=torch.float16)
pipeline.to("cuda")
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
@@ -1046,192 +90,269 @@ image.save("moduar_inpaint_out.png")
-#### ControlNet
+This guide will show you how to create a [`ModularPipeline`] and manage the components in it.
-For ControlNet, we provide one auto block you can place at the `denoise` step. Let's create it and inspect it to see what it tells us.
+## Adding blocks
-
+Blocks are [`InsertableDict`] objects that can be inserted at specific positions, providing a flexible way to mix-and-match blocks.
-💡 **How to explore new tasks**: When you want to figure out how to do a specific task in Modular Diffusers, it is a good idea to start by checking what block classes presets we offer in `ALL_BLOCKS`. Then create the block instance and inspect it - it will show you the required components, description, and sub-blocks. This is crucial for understanding what each block does and what it needs.
-
-
+Use [`~modular_pipelines.modular_pipeline_utils.InsertableDict.insert`] on either the block class or `sub_blocks` attribute to add a block.
```py
->>> from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS
->>> ALL_BLOCKS["controlnet"]
-InsertableDict([
- 0: ('denoise', )
-])
->>> controlnet_blocks = ALL_BLOCKS["controlnet"]["denoise"]()
->>> controlnet_blocks
-StableDiffusionXLAutoControlnetStep(
- Class: SequentialPipelineBlocks
-
- ====================================================================================================
- This pipeline contains blocks that are selected at runtime based on inputs.
- Trigger Inputs: {'mask', 'control_mode', 'control_image', 'controlnet_cond'}
- Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('mask')`).
- ====================================================================================================
-
-
- Description: Controlnet auto step that prepare the controlnet input and denoise the latents. It works for both controlnet and controlnet_union and supports text2img, img2img and inpainting tasks. (it should be replace at 'denoise' step)
-
-
- Components:
- controlnet (`ControlNetUnionModel`)
- control_image_processor (`VaeImageProcessor`)
- scheduler (`EulerDiscreteScheduler`)
- unet (`UNet2DConditionModel`)
- guider (`ClassifierFreeGuidance`)
-
- Sub-Blocks:
- [0] controlnet_input (StableDiffusionXLAutoControlNetInputStep)
- Description: Controlnet Input step that prepare the controlnet input.
- This is an auto pipeline block that works for both controlnet and controlnet_union.
- (it should be called right before the denoise step) - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.
- - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided. - if neither `control_mode` nor `control_image` is provided, step will be skipped.
-
- [1] controlnet_denoise (StableDiffusionXLAutoControlNetDenoiseStep)
- Description: Denoise step that iteratively denoise the latents with controlnet. This is a auto pipeline block that using controlnet for text2img, img2img and inpainting tasks.This block should not be used without a controlnet_cond input - `StableDiffusionXLInpaintControlNetDenoiseStep` (inpaint_controlnet_denoise) is used when mask is provided. - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when mask is not provided but controlnet_cond is provided. - If neither mask nor controlnet_cond are provided, step will be skipped.
-
-)
+# BLOCKS is dict of block classes, you need to add class to it
+BLOCKS.insert("block_name", BlockClass, index)
+# sub_blocks attribute contains instance, add a block instance to the attribute
+t2i_blocks.sub_blocks.insert("block_name", block_instance, index)
```
-
-
-💡 **Auto Blocks**: This is first time we meet a Auto Blocks! `AutoPipelineBlocks` automatically adapt to your inputs by combining multiple workflows with conditional logic. This is why one convenient block can work for all tasks and controlnet types. See the [Auto Blocks Guide](./auto_pipeline_blocks.md) for more details.
-
-
-
-The block shows us it has two steps (prepare inputs + denoise) and supports all tasks with both controlnet and controlnet union. Most importantly, it tells us to place it at the 'denoise' step. Let's do exactly that:
+Use [`~modular_pipelines.modular_pipeline_utils.InsertableDict.pop`] on either the block class or `sub_blocks` attribute to remove a block.
```py
-import torch
-from diffusers.modular_pipelines import SequentialPipelineBlocks
-from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS, StableDiffusionXLAutoControlnetStep
-from diffusers.utils import load_image
-
-# create pipeline from blocks preset
-blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
-
-# these two lines applies controlnet
-controlnet_blocks = StableDiffusionXLAutoControlnetStep()
-blocks.sub_blocks["denoise"] = controlnet_blocks
+# remove a block class from preset
+BLOCKS.pop("text_encoder")
+# split out a block instance on its own
+text_encoder_block = t2i_blocks.sub_blocks.pop("text_encoder")
```
-Before we convert the blocks into a pipeline and load its components, let's inspect the blocks and its docs again to make sure it was assembled correctly. You should be able to see that `controlnet` and `control_image_processor` are now listed as `Components`, so we should initialize the pipeline with a repo that contains desired loading specs for these 2 components.
+Swap blocks by setting the existing block to the new block.
```py
-# make sure to a modular_repo including controlnet
-modular_repo_id = "YiYiXu/modular-demo-auto"
-pipeline = blocks.init_pipeline(modular_repo_id)
-pipeline.load_default_components(torch_dtype=torch.float16)
-pipeline.to("cuda")
-
-# generate
-canny_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
-)
-image = pipeline(
- prompt="a bird", controlnet_conditioning_scale=0.5, control_image=canny_image, output="images"
-)[0]
-image.save("modular_control_out.png")
+# Replace block class in preset
+BLOCKS["prepare_latents"] = CustomPrepareLatents
+# Replace in sub_blocks attribute using an block instance
+t2i_blocks.sub_blocks["prepare_latents"] = CustomPrepareLatents()
```
-#### IP-Adapter
+## Creating a pipeline
-**Challenge time!** Before we show you how to apply IP-adapter, try doing it yourself! Use the same process we just walked you through with ControlNet: check the official blocks preset, inspect the block instance and docstring `.doc`, and adapt a regular IP-adapter example to modular.
+There are two ways to create a [`ModularPipeline`]. Assemble and create a pipeline from [`ModularPipelineBlocks`] or load an existing pipeline with [`~ModularPipeline.from_pretrained`].
-Let's walk through the steps:
+You should also initialize a [`ComponentsManager`] to handle device placement and memory and component management.
-1. Check blocks preset
+> [!TIP]
+> Refer to the [ComponentsManager](./components_manager) doc for more details about how it can help manage components across different workflows.
+
+
+
+
+Use the [`~ModularPipelineBlocks.init_pipeline`] method to create a [`ModularPipeline`] from the component and configuration specifications. This method loads the *specifications* from a `modular_model_index.json` file, but it doesn't load the *models* yet.
```py
->>> from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS
->>> ALL_BLOCKS["ip_adapter"]
-InsertableDict([
- 0: ('ip_adapter', )
-])
-```
-
-2. inspect the block & doc
-
-```
->>> from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLAutoIPAdapterStep
->>> ip_adapter_blocks = StableDiffusionXLAutoIPAdapterStep()
->>> ip_adapter_blocks
-StableDiffusionXLAutoIPAdapterStep(
- Class: AutoPipelineBlocks
-
- ====================================================================================================
- This pipeline contains blocks that are selected at runtime based on inputs.
- Trigger Inputs: {'ip_adapter_image'}
- Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('ip_adapter_image')`).
- ====================================================================================================
-
-
- Description: Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.
-
-
-
- Components:
- image_encoder (`CLIPVisionModelWithProjection`)
- feature_extractor (`CLIPImageProcessor`)
- unet (`UNet2DConditionModel`)
- guider (`ClassifierFreeGuidance`)
-
- Sub-Blocks:
- • ip_adapter [trigger: ip_adapter_image] (StableDiffusionXLIPAdapterStep)
- Description: IP Adapter step that prepares ip adapter image embeddings.
- Note that this step only prepares the embeddings - in order for it to work correctly, you need to load ip adapter weights into unet via ModularPipeline.load_ip_adapter() and pipeline.set_ip_adapter_scale().
- See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin) for more details
-
-)
-```
-3. follow the instruction to build
-
-```py
-import torch
+from diffusers import ComponentsManager
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
-# create pipeline from official blocks preset
-blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
+t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
-# insert ip_adapter_blocks before the input step as instructed
-blocks.sub_blocks.insert("ip_adapter", ip_adapter_blocks, 1)
-
-# inspec the blocks before you convert it into pipelines,
-# and make sure to use a repo that contains the loading spec for all components
-# for ip-adapter, you need image_encoder & feature_extractor
-modular_repo_id = "YiYiXu/modular-demo-auto"
-pipeline = blocks.init_pipeline(modular_repo_id)
-
-pipeline.load_default_components(torch_dtype=torch.float16)
-pipeline.load_ip_adapter(
- "h94/IP-Adapter",
- subfolder="sdxl_models",
- weight_name="ip-adapter_sdxl.bin"
-)
-pipeline.set_ip_adapter_scale(0.8)
-pipeline.to("cuda")
+modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
+components = ComponentsManager()
+t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components)
```
-4. adapt an example to modular
-
-We are using [this one](https://huggingface.co/docs/diffusers/using-diffusers/ip_adapter?ipadapter-variants=IP-Adapter+Plus#ip-adapter) from our IP-Adapter doc!
+
+
+The [`~ModularPipeline.from_pretrained`] method creates a [`ModularPipeline`] from a modular repository on the Hub.
```py
-from diffusers.utils import load_image
-image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png")
-image = pipeline(
- prompt="a polar bear sitting in a chair drinking a milkshake",
- ip_adapter_image=image,
- negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
- output="images"
-)[0]
-image.save("modular_ipa_out.png")
+from diffusers import ModularPipeline, ComponentsManager
+
+components = ComponentsManager()
+pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-loader-t2i-0704", components_manager=components)
```
+Add the `trust_remote_code` argument to load a custom [`ModularPipeline`].
+```py
+from diffusers import ModularPipeline, ComponentsManager
+
+components = ComponentsManager()
+modular_repo_id = "YiYiXu/modular-diffdiff-0704"
+diffdiff_pipeline = ModularPipeline.from_pretrained(modular_repo_id, trust_remote_code=True, components_manager=components)
+```
+
+
+
+
+## Loading components
+
+A [`ModularPipeline`] doesn't automatically instantiate with components. It only loads the configuration and component specifications. You can load all components with [`~ModularPipeline.load_components`] or only load specific components with [`~ModularPipeline.load_components`].
+
+
+
+
+```py
+import torch
+
+t2i_pipeline.load_components(torch_dtype=torch.float16)
+t2i_pipeline.to("cuda")
+```
+
+
+
+
+The example below only loads the UNet and VAE.
+
+```py
+import torch
+
+t2i_pipeline.load_components(names=["unet", "vae"], torch_dtype=torch.float16)
+```
+
+
+
+
+Print the pipeline to inspect the loaded pretrained components.
+
+```py
+t2i_pipeline
+```
+
+This should match the `modular_model_index.json` file from the modular repository a pipeline is initialized from. If a pipeline doesn't need a component, it won't be included even if it exists in the modular repository.
+
+To modify where components are loaded from, edit the `modular_model_index.json` file in the repository and change it to your desired loading path. The example below loads a UNet from a different repository.
+
+```json
+# original
+"unet": [
+ null, null,
+ {
+ "repo": "stabilityai/stable-diffusion-xl-base-1.0",
+ "subfolder": "unet",
+ "variant": "fp16"
+ }
+]
+
+# modified
+"unet": [
+ null, null,
+ {
+ "repo": "RunDiffusion/Juggernaut-XL-v9",
+ "subfolder": "unet",
+ "variant": "fp16"
+ }
+]
+```
+
+### Component loading status
+
+The pipeline properties below provide more information about which components are loaded.
+
+Use `component_names` to return all expected components.
+
+```py
+t2i_pipeline.component_names
+['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'guider', 'scheduler', 'unet', 'vae', 'image_processor']
+```
+
+Use `null_component_names` to return components that aren't loaded yet. Load these components with [`~ModularPipeline.from_pretrained`].
+
+```py
+t2i_pipeline.null_component_names
+['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler']
+```
+
+Use `pretrained_component_names` to return components that will be loaded from pretrained models.
+
+```py
+t2i_pipeline.pretrained_component_names
+['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler', 'unet', 'vae']
+```
+
+Use `config_component_names` to return components that are created with the default config (not loaded from a modular repository). Components from a config aren't included because they are already initialized during pipeline creation. This is why they aren't listed in `null_component_names`.
+
+```py
+t2i_pipeline.config_component_names
+['guider', 'image_processor']
+```
+
+## Updating components
+
+Components may be updated depending on whether it is a *pretrained component* or a *config component*.
+
+> [!WARNING]
+> A component may change from pretrained to config when updating a component. The component type is initially defined in a block's `expected_components` field.
+
+A pretrained component is updated with [`ComponentSpec`] whereas a config component is updated by eihter passing the object directly or with [`ComponentSpec`].
+
+The [`ComponentSpec`] shows `default_creation_method="from_pretrained"` for a pretrained component shows `default_creation_method="from_config` for a config component.
+
+To update a pretrained component, create a [`ComponentSpec`] with the name of the component and where to load it from. Use the [`~ComponentSpec.load`] method to load the component.
+
+```py
+from diffusers import ComponentSpec, UNet2DConditionModel
+
+unet_spec = ComponentSpec(name="unet",type_hint=UNet2DConditionModel, repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16")
+unet = unet_spec.load(torch_dtype=torch.float16)
+```
+
+The [`~ModularPipeline.update_components`] method replaces the component with a new one.
+
+```py
+t2i_pipeline.update_components(unet=unet2)
+```
+
+When a component is updated, the loading specifications are also updated in the pipeline config.
+
+### Component extraction and modification
+
+When you use [`~ComponentSpec.load`], the new component maintains its loading specifications. This makes it possible to extract the specification and recreate the component.
+
+```py
+spec = ComponentSpec.from_component("unet", unet2)
+spec
+ComponentSpec(name='unet', type_hint=, description=None, config=None, repo='stabilityai/stable-diffusion-xl-base-1.0', subfolder='unet', variant='fp16', revision=None, default_creation_method='from_pretrained')
+unet2_recreated = spec.load(torch_dtype=torch.float16)
+```
+
+The [`~ModularPipeline.get_component_spec`] method gets a copy of the current component specification to modify or update.
+
+```py
+unet_spec = t2i_pipeline.get_component_spec("unet")
+unet_spec
+ComponentSpec(
+ name='unet',
+ type_hint=,
+ repo='RunDiffusion/Juggernaut-XL-v9',
+ subfolder='unet',
+ variant='fp16',
+ default_creation_method='from_pretrained'
+)
+
+# modify to load from a different repository
+unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0"
+
+# load component with modified spec
+unet = unet_spec.load(torch_dtype=torch.float16)
+```
+
+## Modular repository
+
+A repository is required if the pipeline blocks use *pretrained components*. The repository supplies loading specifications and metadata.
+
+[`ModularPipeline`] specifically requires *modular repositories* (see [example repository](https://huggingface.co/YiYiXu/modular-diffdiff)) which are more flexible than a typical repository. It contains a `modular_model_index.json` file containing the following 3 elements.
+
+- `library` and `class` shows which library the component was loaded from and it's class. If `null`, the component hasn't been loaded yet.
+- `loading_specs_dict` contains the information required to load the component such as the repository and subfolder it is loaded from.
+
+Unlike standard repositories, a modular repository can fetch components from different repositories based on the `loading_specs_dict`. Components don't need to exist in the same repository.
+
+A modular repository may contain custom code for loading a [`ModularPipeline`]. This allows you to use specialized blocks that aren't native to Diffusers.
+
+```
+modular-diffdiff-0704/
+├── block.py # Custom pipeline blocks implementation
+├── config.json # Pipeline configuration and auto_map
+└── modular_model_index.json # Component loading specifications
+```
+
+The [config.json](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/config.json) file contains an `auto_map` key that points to where a custom block is defined in `block.py`.
+
+```json
+{
+ "_class_name": "DiffDiffBlocks",
+ "auto_map": {
+ "ModularPipelineBlocks": "block.DiffDiffBlocks"
+ }
+}
+```
diff --git a/docs/source/en/modular_diffusers/overview.md b/docs/source/en/modular_diffusers/overview.md
index 9702cea063..7d07c4b734 100644
--- a/docs/source/en/modular_diffusers/overview.md
+++ b/docs/source/en/modular_diffusers/overview.md
@@ -10,33 +10,32 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Getting Started with Modular Diffusers
+# Overview
-
+> [!WARNING]
+> Modular Diffusers is under active development and it's API may change.
-🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
+Modular Diffusers is a unified pipeline system that simplifies your workflow with *pipeline blocks*.
-
+- Blocks are reusable and you only need to create new blocks that are unique to your pipeline.
+- Blocks can be mixed and matched to adapt to or create a pipeline for a specific workflow or multiple workflows.
-With Modular Diffusers, we introduce a unified pipeline system that simplifies how you work with diffusion models. Instead of creating separate pipelines for each task, Modular Diffusers lets you:
+The Modular Diffusers docs are organized as shown below.
-**Write Only What's New**: You won't need to write an entire pipeline from scratch every time you have a new use case. You can create pipeline blocks just for your new workflow's unique aspects and reuse existing blocks for existing functionalities.
+## Quickstart
-**Assemble Like LEGO®**: You can mix and match between blocks in flexible ways. This allows you to write dedicated blocks unique to specific workflows, and then assemble different blocks into a pipeline that can be used more conveniently for multiple workflows.
+- A [quickstart](./quickstart) demonstrating how to implement an example workflow with Modular Diffusers.
+## ModularPipelineBlocks
-Here's how our guides are organized to help you navigate the Modular Diffusers documentation:
+- [States](./modular_diffusers_states) explains how data is shared and communicated between blocks and [`ModularPipeline`].
+- [ModularPipelineBlocks](./pipeline_block) is the most basic unit of a [`ModularPipeline`] and this guide shows you how to create one.
+- [SequentialPipelineBlocks](./sequential_pipeline_blocks) is a type of block that chains multiple blocks so they run one after another, passing data along the chain. This guide shows you how to create [`~modular_pipelines.SequentialPipelineBlocks`] and how they connect and work together.
+- [LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks) is a type of block that runs a series of blocks in a loop. This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].
+- [AutoPipelineBlocks](./auto_pipeline_blocks) is a type of block that automatically chooses which blocks to run based on the input. This guide shows you how to create [`~modular_pipelines.AutoPipelineBlocks`].
-### 🚀 Running Pipelines
-- **[Modular Pipeline Guide](./modular_pipeline.md)** - How to use predefined blocks to build a pipeline and run it
-- **[Components Manager Guide](./components_manager.md)** - How to manage and reuse components across multiple pipelines
+## ModularPipeline
-### 📚 Creating PipelineBlocks
-- **[Pipeline and Block States](./modular_diffusers_states.md)** - Understanding PipelineState and BlockState
-- **[Pipeline Block](./pipeline_block.md)** - How to write custom PipelineBlocks
-- **[SequentialPipelineBlocks](sequential_pipeline_blocks.md)** - Connecting blocks in sequence
-- **[LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks.md)** - Creating iterative workflows
-- **[AutoPipelineBlocks](./auto_pipeline_blocks.md)** - Conditional block selection
-
-### 🎯 Practical Examples
-- **[End-to-End Example](./end_to_end_guide.md)** - Complete end-to-end examples including sharing your workflow in huggingface hub and deplying UI nodes
+- [ModularPipeline](./modular_pipeline) shows you how to create and convert pipeline blocks into an executable [`ModularPipeline`].
+- [ComponentsManager](./components_manager) shows you how to manage and reuse components across multiple pipelines.
+- [Guiders](./guiders) shows you how to use different guidance methods in the pipeline.
\ No newline at end of file
diff --git a/docs/source/en/modular_diffusers/pipeline_block.md b/docs/source/en/modular_diffusers/pipeline_block.md
index 17a819732f..66d26b0214 100644
--- a/docs/source/en/modular_diffusers/pipeline_block.md
+++ b/docs/source/en/modular_diffusers/pipeline_block.md
@@ -10,126 +10,101 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# PipelineBlock
+# ModularPipelineBlocks
-
+[`~modular_pipelines.ModularPipelineBlocks`] is the basic block for building a [`ModularPipeline`]. It defines what components, inputs/outputs, and computation a block should perform for a specific step in a pipeline. A [`~modular_pipelines.ModularPipelineBlocks`] connects with other blocks, using [state](./modular_diffusers_states), to enable the modular construction of workflows.
-🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
+A [`~modular_pipelines.ModularPipelineBlocks`] on it's own can't be executed. It is a blueprint for what a step should do in a pipeline. To actually run and execute a pipeline, the [`~modular_pipelines.ModularPipelineBlocks`] needs to be converted into a [`ModularPipeline`].
-
+This guide will show you how to create a [`~modular_pipelines.ModularPipelineBlocks`].
-In Modular Diffusers, you build your workflow using `ModularPipelineBlocks`. We support 4 different types of blocks: `PipelineBlock`, `SequentialPipelineBlocks`, `LoopSequentialPipelineBlocks`, and `AutoPipelineBlocks`. Among them, `PipelineBlock` is the most fundamental building block of the whole system - it's like a brick in a Lego system. These blocks are designed to easily connect with each other, allowing for modular construction of creative and potentially very complex workflows.
+## Inputs and outputs
-
+> [!TIP]
+> Refer to the [States](./modular_diffusers_states) guide if you aren't familiar with how state works in Modular Diffusers.
-**Important**: `PipelineBlock`s are definitions/specifications, not runnable pipelines. They define what a block should do and what data it needs, but you need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](./modular_pipeline.md).
+A [`~modular_pipelines.ModularPipelineBlocks`] requires `inputs`, and `intermediate_outputs`.
-
+- `inputs` are values provided by a user and retrieved from the [`~modular_pipelines.PipelineState`]. This is useful because some workflows resize an image, but the original image is still required. The [`~modular_pipelines.PipelineState`] maintains the original image.
-In this tutorial, we will focus on how to write a basic `PipelineBlock` and how it interacts with the pipeline state.
+ Use `InputParam` to define `inputs`.
-## PipelineState
+ ```py
+ from diffusers.modular_pipelines import InputParam
-Before we dive into creating `PipelineBlock`s, make sure you have a basic understanding of `PipelineState`. It acts as the global state container that all blocks operate on - each block gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates `PipelineState` with any changes. See the [PipelineState and BlockState guide](./modular_diffusers_states.md) for more details.
+ user_inputs = [
+ InputParam(name="image", type_hint="PIL.Image", description="raw input image to process")
+ ]
+ ```
-## Define a `PipelineBlock`
+- `intermediate_inputs` are values typically created from a previous block but it can also be directly provided if no preceding block generates them. Unlike `inputs`, `intermediate_inputs` can be modified.
-To write a `PipelineBlock` class, you need to define a few properties that determine how your block interacts with the pipeline state. Understanding these properties is crucial - they define what data your block can access and what it can produce.
+ Use `InputParam` to define `intermediate_inputs`.
-The three main properties you need to define are:
-- `inputs`: Immutable values from the user that cannot be modified
-- `intermediate_inputs`: Mutable values from previous blocks that can be read and modified
-- `intermediate_outputs`: New values your block creates for subsequent blocks and user access
+ ```py
+ user_intermediate_inputs = [
+ InputParam(name="processed_image", type_hint="torch.Tensor", description="image that has been preprocessed and normalized"),
+ ]
+ ```
-Let's explore each one and understand how they work with the pipeline state.
+- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `intermediate_inputs` for subsequent blocks or available as the final output from running the pipeline.
-**Inputs: Immutable User Values**
+ Use `OutputParam` to define `intermediate_outputs`.
-Inputs are variables your block needs from the immutable pipeline state - these are user-provided values that cannot be modified by any block. You define them using `InputParam`:
+ ```py
+ from diffusers.modular_pipelines import OutputParam
-```py
-user_inputs = [
- InputParam(name="image", type_hint="PIL.Image", description="raw input image to process")
-]
-```
+ user_intermediate_outputs = [
+ OutputParam(name="image_latents", description="latents representing the image")
+ ]
+ ```
-When you list something as an input, you're saying "I need this value directly from the end user, and I will talk to them directly, telling them what I need in the 'description' field. They will provide it and it will come to me unchanged."
+The intermediate inputs and outputs share data to connect blocks. They are accessible at any point, allowing you to track the workflow's progress.
-This is especially useful for raw values that serve as the "source of truth" in your workflow. For example, with a raw image, many workflows require preprocessing steps like resizing that a previous block might have performed. But in many cases, you also want the raw PIL image. In some inpainting workflows, you need the original image to overlay with the generated result for better control and consistency.
+## Computation logic
-**Intermediate Inputs: Mutable Values from Previous Blocks, or Users**
+The computation a block performs is defined in the `__call__` method and it follows a specific structure.
-Intermediate inputs are variables your block needs from the mutable pipeline state - these are values that can be read and modified. They're typically created by previous blocks, but could also be directly provided by the user if not the case:
-
-```py
-user_intermediate_inputs = [
- InputParam(name="processed_image", type_hint="torch.Tensor", description="image that has been preprocessed and normalized"),
-]
-```
-
-When you list something as an intermediate input, you're saying "I need this value, but I want to work with a different block that has already created it. I already know for sure that I can get it from this other block, but it's okay if other developers want use something different."
-
-**Intermediate Outputs: New Values for Subsequent Blocks and User Access**
-
-Intermediate outputs are new variables your block creates and adds to the mutable pipeline state. They serve two purposes:
-
-1. **For subsequent blocks**: They can be used as intermediate inputs by other blocks in the pipeline
-2. **For users**: They become available as final outputs that users can access when running the pipeline
-
-```py
-user_intermediate_outputs = [
- OutputParam(name="image_latents", description="latents representing the image")
-]
-```
-
-Intermediate inputs and intermediate outputs work together like Lego studs and anti-studs - they're the connection points that make blocks modular. When one block produces an intermediate output, it becomes available as an intermediate input for subsequent blocks. This is where the "modular" nature of the system really shines - blocks can be connected and reconnected in different ways as long as their inputs and outputs match.
-
-Additionally, all intermediate outputs are accessible to users when they run the pipeline, typically you would only need the final images, but they are also able to access intermediate results like latents, embeddings, or other processing steps.
-
-**The `__call__` Method Structure**
-
-Your `PipelineBlock`'s `__call__` method should follow this structure:
+1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs` and `intermediate_inputs`.
+2. Implement the computation logic on the `inputs` and `intermediate_inputs`.
+3. Update [`~modular_pipelines.PipelineState`] to push changes from the local [`~modular_pipelines.BlockState`] back to the global [`~modular_pipelines.PipelineState`].
+4. Return the components and state which becomes available to the next block.
```py
def __call__(self, components, state):
# Get a local view of the state variables this block needs
block_state = self.get_block_state(state)
-
+
# Your computation logic here
# block_state contains all your inputs and intermediate_inputs
- # You can access them like: block_state.image, block_state.processed_image
-
+ # Access them like: block_state.image, block_state.processed_image
+
# Update the pipeline state with your updated block_states
self.set_block_state(state, block_state)
return components, state
```
-The `block_state` object contains all the variables you defined in `inputs` and `intermediate_inputs`, making them easily accessible for your computation.
+### Components and configs
-**Components and Configs**
+The components and pipeline-level configs a block needs are specified in [`ComponentSpec`] and [`~modular_pipelines.ConfigSpec`].
-You can define the components and pipeline-level configs your block needs using `ComponentSpec` and `ConfigSpec`:
+- [`ComponentSpec`] contains the expected components used by a block. You need the `name` of the component and ideally a `type_hint` that specifies exactly what the component is.
+- [`~modular_pipelines.ConfigSpec`] contains pipeline-level settings that control behavior across all blocks.
```py
from diffusers import ComponentSpec, ConfigSpec
-# Define components your block needs
expected_components = [
ComponentSpec(name="unet", type_hint=UNet2DConditionModel),
ComponentSpec(name="scheduler", type_hint=EulerDiscreteScheduler)
]
-# Define pipeline-level configs
expected_config = [
ConfigSpec("force_zeros_for_empty_prompt", True)
]
```
-**Components**: In the `ComponentSpec`, you must provide a `name` and ideally a `type_hint`. You can also specify a `default_creation_method` to indicate whether the component should be loaded from a pretrained model or created with default configurations. The actual loading details (`repo`, `subfolder`, `variant` and `revision` fields) are typically specified when creating the pipeline, as we covered in the [Modular Pipeline Guide](./modular_pipeline.md).
-
-**Configs**: Pipeline-level settings that control behavior across all blocks.
-
-When you convert your blocks into a pipeline using `blocks.init_pipeline()`, the pipeline collects all component requirements from the blocks and fetches the loading specs from the modular repository. The components are then made available to your block as the first argument of the `__call__` method. You can access any component you need using dot notation:
+When the blocks are converted into a pipeline, the components become available to the block as the first argument in `__call__`.
```py
def __call__(self, components, state):
@@ -137,156 +112,4 @@ def __call__(self, components, state):
unet = components.unet
vae = components.vae
scheduler = components.scheduler
-```
-
-That's all you need to define in order to create a `PipelineBlock`. There is no hidden complexity. In fact we are going to create a helper function that take exactly these variables as input and return a pipeline block. We will use this helper function through out the tutorial to create test blocks
-
-Note that for `__call__` method, the only part you should implement differently is the part between `self.get_block_state()` and `self.set_block_state()`, which can be abstracted into a simple function that takes `block_state` and returns the updated state. Our helper function accepts a `block_fn` that does exactly that.
-
-**Helper Function**
-
-```py
-from diffusers.modular_pipelines import PipelineBlock, InputParam, OutputParam
-import torch
-
-def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block_fn=None, description=None):
- class TestBlock(PipelineBlock):
- model_name = "test"
-
- @property
- def inputs(self):
- return inputs
-
- @property
- def intermediate_inputs(self):
- return intermediate_inputs
-
- @property
- def intermediate_outputs(self):
- return intermediate_outputs
-
- @property
- def description(self):
- return description if description is not None else ""
-
- def __call__(self, components, state):
- block_state = self.get_block_state(state)
- if block_fn is not None:
- block_state = block_fn(block_state, state)
- self.set_block_state(state, block_state)
- return components, state
-
- return TestBlock
-```
-
-## Example: Creating a Simple Pipeline Block
-
-Let's create a simple block to see how these definitions interact with the pipeline state. To better understand what's happening, we'll print out the states before and after updates to inspect them:
-
-```py
-inputs = [
- InputParam(name="image", type_hint="PIL.Image", description="raw input image to process")
-]
-
-intermediate_inputs = [InputParam(name="batch_size", type_hint=int)]
-
-intermediate_outputs = [
- OutputParam(name="image_latents", description="latents representing the image")
-]
-
-def image_encoder_block_fn(block_state, pipeline_state):
- print(f"pipeline_state (before update): {pipeline_state}")
- print(f"block_state (before update): {block_state}")
-
- # Simulate processing the image
- block_state.image = torch.randn(1, 3, 512, 512)
- block_state.batch_size = block_state.batch_size * 2
- block_state.processed_image = [torch.randn(1, 3, 512, 512)] * block_state.batch_size
- block_state.image_latents = torch.randn(1, 4, 64, 64)
-
- print(f"block_state (after update): {block_state}")
- return block_state
-
-# Create a block with our definitions
-image_encoder_block_cls = make_block(
- inputs=inputs,
- intermediate_inputs=intermediate_inputs,
- intermediate_outputs=intermediate_outputs,
- block_fn=image_encoder_block_fn,
- description="Encode raw image into its latent presentation"
-)
-image_encoder_block = image_encoder_block_cls()
-pipe = image_encoder_block.init_pipeline()
-```
-
-Let's check the pipeline's docstring to see what inputs it expects:
-```py
->>> print(pipe.doc)
-class TestBlock
-
- Encode raw image into its latent presentation
-
- Inputs:
-
- image (`PIL.Image`, *optional*):
- raw input image to process
-
- batch_size (`int`, *optional*):
-
- Outputs:
-
- image_latents (`None`):
- latents representing the image
-```
-
-Notice that `batch_size` appears as an input even though we defined it as an intermediate input. This happens because no previous block provided it, so the pipeline makes it available as a user input. However, unlike regular inputs, this value goes directly into the mutable intermediate state.
-
-Now let's run the pipeline:
-
-```py
-from diffusers.utils import load_image
-
-image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_of_squirrel_painting.png")
-state = pipe(image=image, batch_size=2)
-print(f"pipeline_state (after update): {state}")
-```
-```out
-pipeline_state (before update): PipelineState(
- inputs={
- image:
- },
- intermediates={
- batch_size: 2
- },
-)
-block_state (before update): BlockState(
- image:
- batch_size: 2
-)
-
-block_state (after update): BlockState(
- image: Tensor(dtype=torch.float32, shape=torch.Size([1, 3, 512, 512]))
- batch_size: 4
- processed_image: List[4] of Tensors with shapes [torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512])]
- image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64]))
-)
-pipeline_state (after update): PipelineState(
- inputs={
- image:
- },
- intermediates={
- batch_size: 4
- image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64]))
- },
-)
-```
-
-**Key Observations:**
-
-1. **Before the update**: `image` (the input) goes to the immutable inputs dict, while `batch_size` (the intermediate_input) goes to the mutable intermediates dict, and both are available in `block_state`.
-
-2. **After the update**:
- - **`image` (inputs)** changed in `block_state` but not in `pipeline_state` - this change is local to the block only.
- - **`batch_size (intermediate_inputs)`** was updated in both `block_state` and `pipeline_state` - this change affects subsequent blocks (we didn't need to declare it as an intermediate output since it was already in the intermediates dict)
- - **`image_latents (intermediate_outputs)`** was added to `pipeline_state` because it was declared as an intermediate output
- - **`processed_image`** was not added to `pipeline_state` because it wasn't declared as an intermediate output
\ No newline at end of file
+```
\ No newline at end of file
diff --git a/docs/source/en/modular_diffusers/quickstart.md b/docs/source/en/modular_diffusers/quickstart.md
new file mode 100644
index 0000000000..9d4eaa0c0c
--- /dev/null
+++ b/docs/source/en/modular_diffusers/quickstart.md
@@ -0,0 +1,344 @@
+
+
+# Quickstart
+
+Modular Diffusers is a framework for quickly building flexible and customizable pipelines. At the core of Modular Diffusers are [`ModularPipelineBlocks`] that can be combined with other blocks to adapt to new workflows. The blocks are converted into a [`ModularPipeline`], a friendly user-facing interface developers can use.
+
+This doc will show you how to implement a [Differential Diffusion](https://differential-diffusion.github.io/) pipeline with the modular framework.
+
+## ModularPipelineBlocks
+
+[`ModularPipelineBlocks`] are *definitions* that specify the components, inputs, outputs, and computation logic for a single step in a pipeline. There are four types of blocks.
+
+- [`ModularPipelineBlocks`] is the most basic block for a single step.
+- [`SequentialPipelineBlocks`] is a multi-block that composes other blocks linearly. The outputs of one block are the inputs to the next block.
+- [`LoopSequentialPipelineBlocks`] is a multi-block that runs iteratively and is designed for iterative workflows.
+- [`AutoPipelineBlocks`] is a collection of blocks for different workflows and it selects which block to run based on the input. It is designed to conveniently package multiple workflows into a single pipeline.
+
+[Differential Diffusion](https://differential-diffusion.github.io/) is an image-to-image workflow. Start with the `IMAGE2IMAGE_BLOCKS` preset, a collection of `ModularPipelineBlocks` for image-to-image generation.
+
+```py
+from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
+IMAGE2IMAGE_BLOCKS = InsertableDict([
+ ("text_encoder", StableDiffusionXLTextEncoderStep),
+ ("image_encoder", StableDiffusionXLVaeEncoderStep),
+ ("input", StableDiffusionXLInputStep),
+ ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
+ ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
+ ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
+ ("denoise", StableDiffusionXLDenoiseStep),
+ ("decode", StableDiffusionXLDecodeStep)
+])
+```
+
+## Pipeline and block states
+
+Modular Diffusers uses *state* to communicate data between blocks. There are two types of states.
+
+- [`PipelineState`] is a global state that can be used to track all inputs and outputs across all blocks.
+- [`BlockState`] is a local view of relevant variables from [`PipelineState`] for an individual block.
+
+## Customizing blocks
+
+[Differential Diffusion](https://differential-diffusion.github.io/) differs from standard image-to-image in its `prepare_latents` and `denoise` blocks. All the other blocks can be reused, but you'll need to modify these two.
+
+Create placeholder `ModularPipelineBlocks` for `prepare_latents` and `denoise` by copying and modifying the existing ones.
+
+Print the `denoise` block to see that it is composed of [`LoopSequentialPipelineBlocks`] with three sub-blocks, `before_denoiser`, `denoiser`, and `after_denoiser`. Only the `before_denoiser` sub-block needs to be modified to prepare the latent input for the denoiser based on the change map.
+
+```py
+denoise_blocks = IMAGE2IMAGE_BLOCKS["denoise"]()
+print(denoise_blocks)
+```
+
+Replace the `StableDiffusionXLLoopBeforeDenoiser` sub-block with the new `SDXLDiffDiffLoopBeforeDenoiser` block.
+
+```py
+# Copy existing blocks as placeholders
+class SDXLDiffDiffPrepareLatentsStep(ModularPipelineBlocks):
+ """Copied from StableDiffusionXLImg2ImgPrepareLatentsStep - will modify later"""
+ # ... same implementation as StableDiffusionXLImg2ImgPrepareLatentsStep
+
+class SDXLDiffDiffDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
+ block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+```
+
+### prepare_latents
+
+The `prepare_latents` block requires the following changes.
+
+- a processor to process the change map
+- a new `inputs` to accept the user-provided change map, `timestep` for precomputing all the latents and `num_inference_steps` to create the mask for updating the image regions
+- update the computation in the `__call__` method for processing the change map and creating the masks, and storing it in the [`BlockState`]
+
+```diff
+class SDXLDiffDiffPrepareLatentsStep(ModularPipelineBlocks):
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("vae", AutoencoderKL),
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
++ ComponentSpec("mask_processor", VaeImageProcessor, config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True}))
+ ]
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("generator"),
++ InputParam("diffdiff_map", required=True),
+- InputParam("latent_timestep", required=True, type_hint=torch.Tensor),
++ InputParam("timesteps", type_hint=torch.Tensor),
++ InputParam("num_inference_steps", type_hint=int),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
++ OutputParam("original_latents", type_hint=torch.Tensor),
++ OutputParam("diffdiff_masks", type_hint=torch.Tensor),
+ ]
+ def __call__(self, components, state: PipelineState):
+ # ... existing logic ...
++ # Process change map and create masks
++ diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width)
++ thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps
++ block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0))
++ block_state.original_latents = block_state.latents
+```
+
+### denoise
+
+The `before_denoiser` sub-block requires the following changes.
+
+- a new `inputs` to accept a `denoising_start` parameter, `original_latents` and `diffdiff_masks` from the `prepare_latents` block
+- update the computation in the `__call__` method for applying Differential Diffusion
+
+```diff
+class SDXLDiffDiffLoopBeforeDenoiser(ModularPipelineBlocks):
+ @property
+ def description(self) -> str:
+ return (
+ "Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser"
+ )
+
+ @property
+ def inputs(self) -> List[str]:
+ return [
+ InputParam("latents", required=True, type_hint=torch.Tensor),
++ InputParam("denoising_start"),
++ InputParam("original_latents", type_hint=torch.Tensor),
++ InputParam("diffdiff_masks", type_hint=torch.Tensor),
+ ]
+
+ def __call__(self, components, block_state, i, t):
++ # Apply differential diffusion logic
++ if i == 0 and block_state.denoising_start is None:
++ block_state.latents = block_state.original_latents[:1]
++ else:
++ block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0).unsqueeze(1)
++ block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask)
+
+ # ... rest of existing logic ...
+```
+
+## Assembling the blocks
+
+You should have all the blocks you need at this point to create a [`ModularPipeline`].
+
+Copy the existing `IMAGE2IMAGE_BLOCKS` preset and for the `set_timesteps` block, use the `set_timesteps` from the `TEXT2IMAGE_BLOCKS` because Differential Diffusion doesn't require a `strength` parameter.
+
+Set the `prepare_latents` and `denoise` blocks to the `SDXLDiffDiffPrepareLatentsStep` and `SDXLDiffDiffDenoiseStep` blocks you just modified.
+
+Call [`SequentialPipelineBlocks.from_blocks_dict`] on the blocks to create a `SequentialPipelineBlocks`.
+
+```py
+DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
+DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
+DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
+DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep
+
+dd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS)
+print(dd_blocks)
+```
+
+## ModularPipeline
+
+Convert the [`SequentialPipelineBlocks`] into a [`ModularPipeline`] with the [`ModularPipeline.init_pipeline`] method. This initializes the expected components to load from a `modular_model_index.json` file. Explicitly load the components by calling [`ModularPipeline.load_components`].
+
+It is a good idea to initialize the [`ComponentManager`] with the pipeline to help manage the different components. Once you call [`~ModularPipeline.load_components`], the components are registered to the [`ComponentManager`] and can be shared between workflows. The example below uses the `collection` argument to assign the components a `"diffdiff"` label for better organization.
+
+```py
+from diffusers.modular_pipelines import ComponentsManager
+
+components = ComponentManager()
+
+dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", components_manager=components, collection="diffdiff")
+dd_pipeline.load_default_componenets(torch_dtype=torch.float16)
+dd_pipeline.to("cuda")
+```
+
+## Adding workflows
+
+Other workflows can be added to the [`ModularPipeline`] to support additional features without rewriting the entire pipeline from scratch.
+
+This section demonstrates how to add an IP-Adapter or ControlNet.
+
+### IP-Adapter
+
+Stable Diffusion XL already has a preset IP-Adapter block that you can use and doesn't require any changes to the existing Differential Diffusion pipeline.
+
+```py
+from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep
+
+ip_adapter_block = StableDiffusionXLAutoIPAdapterStep()
+```
+
+Use the [`sub_blocks.insert`] method to insert it into the [`ModularPipeline`]. The example below inserts the `ip_adapter_block` at position `0`. Print the pipeline to see that the `ip_adapter_block` is added and it requires an `ip_adapter_image`. This also added two components to the pipeline, the `image_encoder` and `feature_extractor`.
+
+```py
+dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0)
+```
+
+Call [`~ModularPipeline.init_pipeline`] to initialize a [`ModularPipeline`] and use [`~ModularPipeline.load_components`] to load the model components. Load and set the IP-Adapter to run the pipeline.
+
+```py
+dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
+dd_pipeline.load_components(torch_dtype=torch.float16)
+dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
+dd_pipeline.loader.set_ip_adapter_scale(0.6)
+dd_pipeline = dd_pipeline.to(device)
+
+ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_orange.jpeg")
+image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
+mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
+
+prompt = "a green pear"
+negative_prompt = "blurry"
+generator = torch.Generator(device=device).manual_seed(42)
+
+image = dd_pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_inference_steps=25,
+ generator=generator,
+ ip_adapter_image=ip_adapter_image,
+ diffdiff_map=mask,
+ image=image,
+ output="images"
+)[0]
+```
+
+### ControlNet
+
+Stable Diffusion XL already has a preset ControlNet block that can readily be used.
+
+```py
+from diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks import StableDiffusionXLAutoControlNetInputStep
+
+control_input_block = StableDiffusionXLAutoControlNetInputStep()
+```
+
+However, it requires modifying the `denoise` block because that's where the ControlNet injects the control information into the UNet.
+
+Modify the `denoise` block by replacing the `StableDiffusionXLLoopDenoiser` sub-block with the `StableDiffusionXLControlNetLoopDenoiser`.
+
+```py
+class SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
+ block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseStep()
+```
+
+Insert the `controlnet_input` block and replace the `denoise` block with the new `controlnet_denoise_block`. Initialize a [`ModularPipeline`] and [`~ModularPipeline.load_components`] into it.
+
+```py
+dd_blocks.sub_blocks.insert("controlnet_input", control_input_block, 7)
+dd_blocks.sub_blocks["denoise"] = controlnet_denoise_block
+
+dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
+dd_pipeline.load_components(torch_dtype=torch.float16)
+dd_pipeline = dd_pipeline.to(device)
+
+control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg")
+image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
+mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
+
+prompt = "a green pear"
+negative_prompt = "blurry"
+generator = torch.Generator(device=device).manual_seed(42)
+
+image = dd_pipeline(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ num_inference_steps=25,
+ generator=generator,
+ control_image=control_image,
+ controlnet_conditioning_scale=0.5,
+ diffdiff_map=mask,
+ image=image,
+ output="images"
+)[0]
+```
+
+### AutoPipelineBlocks
+
+The Differential Diffusion, IP-Adapter, and ControlNet workflows can be bundled into a single [`ModularPipeline`] by using [`AutoPipelineBlocks`]. This allows automatically selecting which sub-blocks to run based on the inputs like `control_image` or `ip_adapter_image`. If none of these inputs are passed, then it defaults to the Differential Diffusion.
+
+Use `block_trigger_inputs` to only run the `SDXLDiffDiffControlNetDenoiseStep` block if a `control_image` input is provided. Otherwise, the `SDXLDiffDiffDenoiseStep` is used.
+
+```py
+class SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks):
+ block_classes = [SDXLDiffDiffControlNetDenoiseStep, SDXLDiffDiffDenoiseStep]
+ block_names = ["controlnet_denoise", "denoise"]
+ block_trigger_inputs = ["controlnet_cond", None]
+```
+
+Add the `ip_adapter` and `controlnet_input` blocks.
+
+```py
+DIFFDIFF_AUTO_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
+DIFFDIFF_AUTO_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
+DIFFDIFF_AUTO_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
+DIFFDIFF_AUTO_BLOCKS["denoise"] = SDXLDiffDiffAutoDenoiseStep
+DIFFDIFF_AUTO_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0)
+DIFFDIFF_AUTO_BLOCKS.insert("controlnet_input",StableDiffusionXLControlNetAutoInput, 7)
+```
+
+Call [`SequentialPipelineBlocks.from_blocks_dict`] to create a [`SequentialPipelineBlocks`] and create a [`ModularPipeline`] and load in the model components to run.
+
+```py
+dd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS)
+dd_pipeline = dd_auto_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
+dd_pipeline.load_components(torch_dtype=torch.float16)
+```
+
+## Share
+
+Add your [`ModularPipeline`] to the Hub with [`~ModularPipeline.save_pretrained`] and set `push_to_hub` argument to `True`.
+
+```py
+dd_pipeline.save_pretrained("YiYiXu/test_modular_doc", push_to_hub=True)
+```
+
+Other users can load the [`ModularPipeline`] with [`~ModularPipeline.from_pretrained`].
+
+```py
+import torch
+from diffusers.modular_pipelines import ModularPipeline, ComponentsManager
+
+components = ComponentsManager()
+
+diffdiff_pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-diffdiff-0704", trust_remote_code=True, components_manager=components, collection="diffdiff")
+diffdiff_pipeline.load_components(torch_dtype=torch.float16)
+```
diff --git a/docs/source/en/modular_diffusers/sequential_pipeline_blocks.md b/docs/source/en/modular_diffusers/sequential_pipeline_blocks.md
index a683f0d065..bbeb28aae5 100644
--- a/docs/source/en/modular_diffusers/sequential_pipeline_blocks.md
+++ b/docs/source/en/modular_diffusers/sequential_pipeline_blocks.md
@@ -12,178 +12,102 @@ specific language governing permissions and limitations under the License.
# SequentialPipelineBlocks
-
+[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `intermediate_inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.
-🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
+This guide shows you how to connect two blocks into a [`~modular_pipelines.SequentialPipelineBlocks`].
-
+Create two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `intermediate_inputs`.
-`SequentialPipelineBlocks` is a subclass of `ModularPipelineBlocks`. Unlike `PipelineBlock`, it is a multi-block that composes other blocks together in sequence, creating modular workflows where data flows from one block to the next. It's one of the most common ways to build complex pipelines by combining simpler building blocks.
-
-
-
-Other types of multi-blocks include [AutoPipelineBlocks](auto_pipeline_blocks.md) (for conditional block selection) and [LoopSequentialPipelineBlocks](loop_sequential_pipeline_blocks.md) (for iterative workflows). For information on creating individual blocks, see the [PipelineBlock guide](pipeline_block.md).
-
-Additionally, like all `ModularPipelineBlocks`, `SequentialPipelineBlocks` are definitions/specifications, not runnable pipelines. You need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](modular_pipeline.md).
-
-
-
-In this tutorial, we will focus on how to create `SequentialPipelineBlocks` and how blocks connect and work together.
-
-The key insight is that blocks connect through their intermediate inputs and outputs - the "studs and anti-studs" we discussed in the [PipelineBlock guide](pipeline_block.md). When one block produces an intermediate output, it becomes available as an intermediate input for subsequent blocks.
-
-Let's explore this through an example. We will use the same helper function from the PipelineBlock guide to create blocks.
+
+
+
+```py
+from diffusers.modular_pipelines import ModularPipelineBlocks, InputParam, OutputParam
+
+class InputBlock(ModularPipelineBlocks):
+
+ @property
+ def inputs(self):
+ return [
+ InputParam(name="prompt", type_hint=list, description="list of text prompts"),
+ InputParam(name="num_images_per_prompt", type_hint=int, description="number of images per prompt"),
+ ]
+
+ @property
+ def intermediate_outputs(self):
+ return [
+ OutputParam(name="batch_size", description="calculated batch size"),
+ ]
+
+ @property
+ def description(self):
+ return "A block that determines batch_size based on the number of prompts and num_images_per_prompt argument."
+
+ def __call__(self, components, state):
+ block_state = self.get_block_state(state)
+ batch_size = len(block_state.prompt)
+ block_state.batch_size = batch_size * block_state.num_images_per_prompt
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+
+
```py
-from diffusers.modular_pipelines import PipelineBlock, InputParam, OutputParam
import torch
+from diffusers.modular_pipelines import ModularPipelineBlocks, InputParam, OutputParam
-def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block_fn=None, description=None):
- class TestBlock(PipelineBlock):
- model_name = "test"
-
- @property
- def inputs(self):
- return inputs
-
- @property
- def intermediate_inputs(self):
- return intermediate_inputs
-
- @property
- def intermediate_outputs(self):
- return intermediate_outputs
-
- @property
- def description(self):
- return description if description is not None else ""
-
- def __call__(self, components, state):
- block_state = self.get_block_state(state)
- if block_fn is not None:
- block_state = block_fn(block_state, state)
- self.set_block_state(state, block_state)
- return components, state
-
- return TestBlock
+class ImageEncoderBlock(ModularPipelineBlocks):
+
+ @property
+ def inputs(self):
+ return [
+ InputParam(name="image", type_hint="PIL.Image", description="raw input image to process"),
+ InputParam(name="batch_size", type_hint=int),
+ ]
+
+ @property
+ def intermediate_outputs(self):
+ return [
+ OutputParam(name="image_latents", description="latents representing the image"),
+ ]
+
+ @property
+ def description(self):
+ return "Encode raw image into its latent presentation"
+
+ def __call__(self, components, state):
+ block_state = self.get_block_state(state)
+ # Simulate processing the image
+ # This will change the state of the image from a PIL image to a tensor for all blocks
+ block_state.image = torch.randn(1, 3, 512, 512)
+ block_state.batch_size = block_state.batch_size * 2
+ block_state.image_latents = torch.randn(1, 4, 64, 64)
+ self.set_block_state(state, block_state)
+ return components, state
```
-Let's create a block that produces `batch_size`, which we'll call "input_block":
+
+
-```py
-def input_block_fn(block_state, pipeline_state):
-
- batch_size = len(block_state.prompt)
- block_state.batch_size = batch_size * block_state.num_images_per_prompt
-
- return block_state
+Connect the two blocks by defining an [`InsertableDict`] to map the block names to the block instances. Blocks are executed in the order they're registered in `blocks_dict`.
-input_block_cls = make_block(
- inputs=[
- InputParam(name="prompt", type_hint=list, description="list of text prompts"),
- InputParam(name="num_images_per_prompt", type_hint=int, description="number of images per prompt")
- ],
- intermediate_outputs=[
- OutputParam(name="batch_size", description="calculated batch size")
- ],
- block_fn=input_block_fn,
- description="A block that determines batch_size based on the number of prompts and num_images_per_prompt argument."
-)
-input_block = input_block_cls()
-```
-
-Now let's create a second block that uses the `batch_size` from the first block:
-
-```py
-def image_encoder_block_fn(block_state, pipeline_state):
- # Simulate processing the image
- block_state.image = torch.randn(1, 3, 512, 512)
- block_state.batch_size = block_state.batch_size * 2
- block_state.image_latents = torch.randn(1, 4, 64, 64)
- return block_state
-
-image_encoder_block_cls = make_block(
- inputs=[
- InputParam(name="image", type_hint="PIL.Image", description="raw input image to process")
- ],
- intermediate_inputs=[
- InputParam(name="batch_size", type_hint=int)
- ],
- intermediate_outputs=[
- OutputParam(name="image_latents", description="latents representing the image")
- ],
- block_fn=image_encoder_block_fn,
- description="Encode raw image into its latent presentation"
-)
-image_encoder_block = image_encoder_block_cls()
-```
-
-Now let's connect these blocks to create a `SequentialPipelineBlocks`:
+Use [`~modular_pipelines.SequentialPipelineBlocks.from_blocks_dict`] to create a [`~modular_pipelines.SequentialPipelineBlocks`].
```py
from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict
-# Define a dict mapping block names to block instances
blocks_dict = InsertableDict()
blocks_dict["input"] = input_block
blocks_dict["image_encoder"] = image_encoder_block
-# Create the SequentialPipelineBlocks
blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict)
```
-Now you have a `SequentialPipelineBlocks` with 2 blocks:
+Inspect the sub-blocks in [`~modular_pipelines.SequentialPipelineBlocks`] by calling `blocks`, and for more details about the inputs and outputs, access the `docs` attribute.
```py
->>> blocks
-SequentialPipelineBlocks(
- Class: ModularPipelineBlocks
-
- Description:
-
-
- Sub-Blocks:
- [0] input (TestBlock)
- Description: A block that determines batch_size based on the number of prompts and num_images_per_prompt argument.
-
- [1] image_encoder (TestBlock)
- Description: Encode raw image into its latent presentation
-
-)
-```
-
-When you inspect `blocks.doc`, you can see that `batch_size` is not listed as an input. The pipeline automatically detects that the `input_block` can produce `batch_size` for the `image_encoder_block`, so it doesn't ask the user to provide it.
-
-```py
->>> print(blocks.doc)
-class SequentialPipelineBlocks
-
- Inputs:
-
- prompt (`None`, *optional*):
-
- num_images_per_prompt (`None`, *optional*):
-
- image (`PIL.Image`, *optional*):
- raw input image to process
-
- Outputs:
-
- batch_size (`None`):
-
- image_latents (`None`):
- latents representing the image
-```
-
-At runtime, you have data flow like this:
-
-
-
-**How SequentialPipelineBlocks Works:**
-
-1. Blocks are executed in the order they're registered in the `blocks_dict`
-2. Outputs from one block become available as intermediate inputs to all subsequent blocks
-3. The pipeline automatically figures out which values need to be provided by the user and which will be generated by previous blocks
-4. Each block maintains its own behavior and operates through its defined interface, while collectively these interfaces determine what the entire pipeline accepts and produces
-
-What happens within each block follows the same pattern we described earlier: each block gets its own `block_state` with the relevant inputs and intermediate inputs, performs its computation, and updates the pipeline state with its intermediate outputs.
\ No newline at end of file
+print(blocks)
+print(blocks.doc)
+```
\ No newline at end of file
diff --git a/docs/source/en/optimization/attention_backends.md b/docs/source/en/optimization/attention_backends.md
new file mode 100644
index 0000000000..e603878a63
--- /dev/null
+++ b/docs/source/en/optimization/attention_backends.md
@@ -0,0 +1,114 @@
+
+
+# Attention backends
+
+> [!NOTE]
+> The attention dispatcher is an experimental feature. Please open an issue if you have any feedback or encounter any problems.
+
+Diffusers provides several optimized attention algorithms that are more memory and computationally efficient through it's *attention dispatcher*. The dispatcher acts as a router for managing and switching between different attention implementations and provides a unified interface for interacting with them.
+
+Refer to the table below for an overview of the available attention families and to the [Available backends](#available-backends) section for a more complete list.
+
+| attention family | main feature |
+|---|---|
+| FlashAttention | minimizes memory reads/writes through tiling and recomputation |
+| SageAttention | quantizes attention to int8 |
+| PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) |
+| xFormers | memory-efficient attention with support for various attention kernels |
+
+This guide will show you how to set and use the different attention backends.
+
+## set_attention_backend
+
+The [`~ModelMixin.set_attention_backend`] method iterates through all the modules in the model and sets the appropriate attention backend to use. The attention backend setting persists until [`~ModelMixin.reset_attention_backend`] is called.
+
+The example below demonstrates how to enable the `_flash_3_hub` implementation for FlashAttention-3 from the [kernel](https://github.com/huggingface/kernels) library, which allows you to instantly use optimized compute kernels from the Hub without requiring any setup.
+
+> [!NOTE]
+> FlashAttention-3 is not supported for non-Hopper architectures, in which case, use FlashAttention with `set_attention_backend("flash")`.
+
+```py
+import torch
+from diffusers import QwenImagePipeline
+
+pipeline = QwenImagePipeline.from_pretrained(
+ "Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
+)
+pipeline.transformer.set_attention_backend("_flash_3_hub")
+
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+pipeline(prompt).images[0]
+```
+
+To restore the default attention backend, call [`~ModelMixin.reset_attention_backend`].
+
+```py
+pipeline.transformer.reset_attention_backend()
+```
+
+## attention_backend context manager
+
+The [attention_backend](https://github.com/huggingface/diffusers/blob/5e181eddfe7e44c1444a2511b0d8e21d177850a0/src/diffusers/models/attention_dispatch.py#L225) context manager temporarily sets an attention backend for a model within the context. Outside the context, the default attention (PyTorch's native scaled dot product attention) is used. This is useful if you want to use different backends for different parts of a pipeline or if you want to test the different backends.
+
+```py
+import torch
+from diffusers import QwenImagePipeline
+
+pipeline = QwenImagePipeline.from_pretrained(
+ "Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
+)
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+
+with attention_backend("_flash_3_hub"):
+ image = pipeline(prompt).images[0]
+```
+
+> [!TIP]
+> Most attention backends support `torch.compile` without graph breaks and can be used to further speed up inference.
+
+## Available backends
+
+Refer to the table below for a complete list of available attention backends and their variants.
+
+
+Expand
+
+| Backend Name | Family | Description |
+|--------------|--------|-------------|
+| `native` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Default backend using PyTorch's scaled_dot_product_attention |
+| `flex` | [FlexAttention](https://docs.pytorch.org/docs/stable/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention) | PyTorch FlexAttention implementation |
+| `_native_cudnn` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | CuDNN-optimized attention |
+| `_native_efficient` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Memory-efficient attention |
+| `_native_flash` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | PyTorch's FlashAttention |
+| `_native_math` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | Math-based attention (fallback) |
+| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |
+| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
+| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
+| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
+| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
+| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
+| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
+| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
+| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
+| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |
+| `_sage_qk_int8_pv_fp8_cuda_sm90` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (SM90) |
+| `_sage_qk_int8_pv_fp16_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (CUDA) |
+| `_sage_qk_int8_pv_fp16_triton` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP16 PV (Triton) |
+| `xformers` | [xFormers](https://github.com/facebookresearch/xformers) | Memory-efficient attention |
+
+
\ No newline at end of file
diff --git a/docs/source/en/optimization/cache_dit.md b/docs/source/en/optimization/cache_dit.md
new file mode 100644
index 0000000000..1261423212
--- /dev/null
+++ b/docs/source/en/optimization/cache_dit.md
@@ -0,0 +1,270 @@
+## CacheDiT
+
+CacheDiT is a unified, flexible, and training-free cache acceleration framework designed to support nearly all Diffusers' DiT-based pipelines. It provides a unified cache API that supports automatic block adapter, DBCache, and more.
+
+To learn more, refer to the [CacheDiT](https://github.com/vipshop/cache-dit) repository.
+
+Install a stable release of CacheDiT from PyPI or you can install the latest version from GitHub.
+
+
+
+
+```bash
+pip3 install -U cache-dit
+```
+
+
+
+
+```bash
+pip3 install git+https://github.com/vipshop/cache-dit.git
+```
+
+
+
+
+Run the command below to view supported DiT pipelines.
+
+```python
+>>> import cache_dit
+>>> cache_dit.supported_pipelines()
+(30, ['Flux*', 'Mochi*', 'CogVideoX*', 'Wan*', 'HunyuanVideo*', 'QwenImage*', 'LTX*', 'Allegro*',
+'CogView3Plus*', 'CogView4*', 'Cosmos*', 'EasyAnimate*', 'SkyReelsV2*', 'StableDiffusion3*',
+'ConsisID*', 'DiT*', 'Amused*', 'Bria*', 'Lumina*', 'OmniGen*', 'PixArt*', 'Sana*', 'StableAudio*',
+'VisualCloze*', 'AuraFlow*', 'Chroma*', 'ShapE*', 'HiDream*', 'HunyuanDiT*', 'HunyuanDiTPAG*'])
+```
+
+For a complete benchmark, please refer to [Benchmarks](https://github.com/vipshop/cache-dit/blob/main/bench/).
+
+
+## Unified Cache API
+
+CacheDiT works by matching specific input/output patterns as shown below.
+
+
+
+Call the `enable_cache()` function on a pipeline to enable cache acceleration. This function is the entry point to many of CacheDiT's features.
+
+```python
+import cache_dit
+from diffusers import DiffusionPipeline
+
+# Can be any diffusion pipeline
+pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image")
+
+# One-line code with default cache options.
+cache_dit.enable_cache(pipe)
+
+# Just call the pipe as normal.
+output = pipe(...)
+
+# Disable cache and run original pipe.
+cache_dit.disable_cache(pipe)
+```
+
+## Automatic Block Adapter
+
+For custom or modified pipelines or transformers not included in Diffusers, use the `BlockAdapter` in `auto` mode or via manual configuration. Please check the [BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md#automatic-block-adapter) docs for more details. Refer to [Qwen-Image w/ BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_qwen_image_adapter.py) as an example.
+
+
+```python
+from cache_dit import ForwardPattern, BlockAdapter
+
+# Use 🔥BlockAdapter with `auto` mode.
+cache_dit.enable_cache(
+ BlockAdapter(
+ # Any DiffusionPipeline, Qwen-Image, etc.
+ pipe=pipe, auto=True,
+ # Check `📚Forward Pattern Matching` documentation and hack the code of
+ # of Qwen-Image, you will find that it has satisfied `FORWARD_PATTERN_1`.
+ forward_pattern=ForwardPattern.Pattern_1,
+ ),
+)
+
+# Or, manually setup transformer configurations.
+cache_dit.enable_cache(
+ BlockAdapter(
+ pipe=pipe, # Qwen-Image, etc.
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_1,
+ ),
+)
+```
+
+Sometimes, a Transformer class will contain more than one transformer `blocks`. For example, FLUX.1 (HiDream, Chroma, etc) contains `transformer_blocks` and `single_transformer_blocks` (with different forward patterns). The BlockAdapter is able to detect this hybrid pattern type as well.
+Refer to [FLUX.1](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_flux_adapter.py) as an example.
+
+```python
+# For diffusers <= 0.34.0, FLUX.1 transformer_blocks and
+# single_transformer_blocks have different forward patterns.
+cache_dit.enable_cache(
+ BlockAdapter(
+ pipe=pipe, # FLUX.1, etc.
+ transformer=pipe.transformer,
+ blocks=[
+ pipe.transformer.transformer_blocks,
+ pipe.transformer.single_transformer_blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_1,
+ ForwardPattern.Pattern_3,
+ ],
+ ),
+)
+```
+
+This also works if there is more than one transformer (namely `transformer` and `transformer_2`) in its structure. Refer to [Wan 2.2 MoE](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_wan_2.2.py) as an example.
+
+## Patch Functor
+
+For any pattern not included in CacheDiT, use the Patch Functor to convert the pattern into a known pattern. You need to subclass the Patch Functor and may also need to fuse the operations within the blocks for loop into block `forward`. After implementing a Patch Functor, set the `patch_functor` property in `BlockAdapter`.
+
+
+
+Some Patch Functors are already provided in CacheDiT, [HiDreamPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_hidream.py), [ChromaPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_chroma.py), etc.
+
+```python
+@BlockAdapterRegistry.register("HiDream")
+def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import HiDreamImageTransformer2DModel
+ from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor
+
+ assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=[
+ pipe.transformer.double_stream_blocks,
+ pipe.transformer.single_stream_blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_0,
+ ForwardPattern.Pattern_3,
+ ],
+ # NOTE: Setup your custom patch functor here.
+ patch_functor=HiDreamPatchFunctor(),
+ **kwargs,
+ )
+```
+
+Finally, you can call the `cache_dit.summary()` function on a pipeline after its completed inference to get the cache acceleration details.
+
+```python
+stats = cache_dit.summary(pipe)
+```
+
+```python
+⚡️Cache Steps and Residual Diffs Statistics: QwenImagePipeline
+
+| Cache Steps | Diffs Min | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Max |
+|-------------|-----------|-----------|-----------|-----------|-----------|-----------|
+| 23 | 0.045 | 0.084 | 0.114 | 0.147 | 0.241 | 0.297 |
+```
+
+## DBCache: Dual Block Cache
+
+
+
+DBCache (Dual Block Caching) supports different configurations of compute blocks (F8B12, etc.) to enable a balanced trade-off between performance and precision.
+- Fn_compute_blocks: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
+- Bn_compute_blocks: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
+
+
+```python
+import cache_dit
+from diffusers import FluxPipeline
+
+pipe_or_adapter = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+
+# Default options, F8B0, 8 warmup steps, and unlimited cached
+# steps for good balance between performance and precision
+cache_dit.enable_cache(pipe_or_adapter)
+
+# Custom options, F8B8, higher precision
+from cache_dit import BasicCacheConfig
+
+cache_dit.enable_cache(
+ pipe_or_adapter,
+ cache_config=BasicCacheConfig(
+ max_warmup_steps=8, # steps do not cache
+ max_cached_steps=-1, # -1 means no limit
+ Fn_compute_blocks=8, # Fn, F8, etc.
+ Bn_compute_blocks=8, # Bn, B8, etc.
+ residual_diff_threshold=0.12,
+ ),
+)
+```
+Check the [DBCache](https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md) and [User Guide](https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md#dbcache) docs for more design details.
+
+## TaylorSeer Calibrator
+
+The [TaylorSeers](https://huggingface.co/papers/2503.06923) algorithm further improves the precision of DBCache in cases where the cached steps are large (Hybrid TaylorSeer + DBCache). At timesteps with significant intervals, the feature similarity in diffusion models decreases substantially, significantly harming the generation quality.
+
+TaylorSeer employs a differential method to approximate the higher-order derivatives of features and predict features in future timesteps with Taylor series expansion. The TaylorSeer implemented in CacheDiT supports both hidden states and residual cache types. F_pred can be a residual cache or a hidden-state cache.
+
+```python
+from cache_dit import BasicCacheConfig, TaylorSeerCalibratorConfig
+
+cache_dit.enable_cache(
+ pipe_or_adapter,
+ # Basic DBCache w/ FnBn configurations
+ cache_config=BasicCacheConfig(
+ max_warmup_steps=8, # steps do not cache
+ max_cached_steps=-1, # -1 means no limit
+ Fn_compute_blocks=8, # Fn, F8, etc.
+ Bn_compute_blocks=8, # Bn, B8, etc.
+ residual_diff_threshold=0.12,
+ ),
+ # Then, you can use the TaylorSeer Calibrator to approximate
+ # the values in cached steps, taylorseer_order default is 1.
+ calibrator_config=TaylorSeerCalibratorConfig(
+ taylorseer_order=1,
+ ),
+)
+```
+
+> [!TIP]
+> The `Bn_compute_blocks` parameter of DBCache can be set to `0` if you use TaylorSeer as the calibrator for approximate hidden states. DBCache's `Bn_compute_blocks` also acts as a calibrator, so you can choose either `Bn_compute_blocks` > 0 or TaylorSeer. We recommend using the configuration scheme of TaylorSeer + DBCache FnB0.
+
+## Hybrid Cache CFG
+
+CacheDiT supports caching for CFG (classifier-free guidance). For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG in the forward step, please set `enable_separate_cfg` parameter to `False (default, None)`. Otherwise, set it to `True`.
+
+```python
+from cache_dit import BasicCacheConfig
+
+cache_dit.enable_cache(
+ pipe_or_adapter,
+ cache_config=BasicCacheConfig(
+ ...,
+ # For example, set it as True for Wan 2.1, Qwen-Image
+ # and set it as False for FLUX.1, HunyuanVideo, etc.
+ enable_separate_cfg=True,
+ ),
+)
+```
+
+## torch.compile
+
+CacheDiT is designed to work with torch.compile for even better performance. Call `torch.compile` after enabling the cache.
+
+
+```python
+cache_dit.enable_cache(pipe)
+
+# Compile the Transformer module
+pipe.transformer = torch.compile(pipe.transformer)
+```
+
+If you're using CacheDiT with dynamic input shapes, consider increasing the `recompile_limit` of `torch._dynamo`. Otherwise, the `recompile_limit` error may be triggered, causing the module to fall back to eager mode.
+
+```python
+torch._dynamo.config.recompile_limit = 96 # default is 8
+torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
+```
+
+Please check [perf.py](https://github.com/vipshop/cache-dit/blob/main/bench/perf.py) for more details.
diff --git a/docs/source/en/optimization/coreml.md b/docs/source/en/optimization/coreml.md
index cd0e662bb7..71da1e3dc1 100644
--- a/docs/source/en/optimization/coreml.md
+++ b/docs/source/en/optimization/coreml.md
@@ -16,11 +16,8 @@ specific language governing permissions and limitations under the License.
Core ML models can leverage all the compute engines available in Apple devices: the CPU, the GPU, and the Apple Neural Engine (or ANE, a tensor-optimized accelerator available in Apple Silicon Macs and modern iPhones/iPads). Depending on the model and the device it's running on, Core ML can mix and match compute engines too, so some portions of the model may run on the CPU while others run on GPU, for example.
-
-
-You can also run the `diffusers` Python codebase on Apple Silicon Macs using the `mps` accelerator built into PyTorch. This approach is explained in depth in [the mps guide](mps), but it is not compatible with native apps.
-
-
+> [!TIP]
+> You can also run the `diffusers` Python codebase on Apple Silicon Macs using the `mps` accelerator built into PyTorch. This approach is explained in depth in [the mps guide](mps), but it is not compatible with native apps.
## Stable Diffusion Core ML Checkpoints
diff --git a/docs/source/en/optimization/fp16.md b/docs/source/en/optimization/fp16.md
index e32cbec917..941f53604c 100644
--- a/docs/source/en/optimization/fp16.md
+++ b/docs/source/en/optimization/fp16.md
@@ -209,7 +209,7 @@ There is also a [compile_regions](https://github.com/huggingface/accelerate/blob
# pip install -U accelerate
import torch
from diffusers import StableDiffusionXLPipeline
-from accelerate.utils import compile regions
+from accelerate.utils import compile_regions
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
@@ -239,11 +239,8 @@ The `step()` function is [called](https://github.com/huggingface/diffusers/blob/
In general, the `sigmas` should [stay on the CPU](https://github.com/huggingface/diffusers/blob/35a969d297cba69110d175ee79c59312b9f49e1e/src/diffusers/schedulers/scheduling_euler_discrete.py#L240) to avoid the communication sync and latency.
-
-
-Refer to the [torch.compile and Diffusers: A Hands-On Guide to Peak Performance](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/) blog post for maximizing performance with `torch.compile` for diffusion models.
-
-
+> [!TIP]
+> Refer to the [torch.compile and Diffusers: A Hands-On Guide to Peak Performance](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/) blog post for maximizing performance with `torch.compile` for diffusion models.
### Benchmarks
diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md
index 78fd96e027..611e07ec76 100644
--- a/docs/source/en/optimization/memory.md
+++ b/docs/source/en/optimization/memory.md
@@ -291,13 +291,53 @@ Group offloading moves groups of internal layers ([torch.nn.ModuleList](https://
> [!WARNING]
> Group offloading may not work with all models if the forward implementation contains weight-dependent device casting of inputs because it may clash with group offloading's device casting mechanism.
-Call [`~ModelMixin.enable_group_offload`] to enable it for standard Diffusers model components that inherit from [`ModelMixin`]. For other model components that don't inherit from [`ModelMixin`], such as a generic [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html), use [`~hooks.apply_group_offloading`] instead.
-
-The `offload_type` parameter can be set to `block_level` or `leaf_level`.
+Enable group offloading by configuring the `offload_type` parameter to `block_level` or `leaf_level`.
- `block_level` offloads groups of layers based on the `num_blocks_per_group` parameter. For example, if `num_blocks_per_group=2` on a model with 40 layers, 2 layers are onloaded and offloaded at a time (20 total onloads/offloads). This drastically reduces memory requirements.
- `leaf_level` offloads individual layers at the lowest level and is equivalent to [CPU offloading](#cpu-offloading). But it can be made faster if you use streams without giving up inference speed.
+Group offloading is supported for entire pipelines or individual models. Applying group offloading to the entire pipeline is the easiest option while selectively applying it to individual models gives users more flexibility to use different offloading techniques for different models.
+
+
+
+
+Call [`~DiffusionPipeline.enable_group_offload`] on a pipeline.
+
+```py
+import torch
+from diffusers import CogVideoXPipeline
+from diffusers.hooks import apply_group_offloading
+from diffusers.utils import export_to_video
+
+onload_device = torch.device("cuda")
+offload_device = torch.device("cpu")
+
+pipeline = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
+pipeline.enable_group_offload(
+ onload_device=onload_device,
+ offload_device=offload_device,
+ offload_type="leaf_level",
+ use_stream=True
+)
+
+prompt = (
+ "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
+ "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
+ "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
+ "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
+ "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
+ "atmosphere of this unique musical performance."
+)
+video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
+print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
+export_to_video(video, "output.mp4", fps=8)
+```
+
+
+
+
+Call [`~ModelMixin.enable_group_offload`] on standard Diffusers model components that inherit from [`ModelMixin`]. For other model components that don't inherit from [`ModelMixin`], such as a generic [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html), use [`~hooks.apply_group_offloading`] instead.
+
```py
import torch
from diffusers import CogVideoXPipeline
@@ -328,6 +368,9 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
export_to_video(video, "output.mp4", fps=8)
```
+
+
+
#### CUDA stream
The `use_stream` parameter can be activated for CUDA devices that support asynchronous data transfer streams to reduce overall execution time compared to [CPU offloading](#cpu-offloading). It overlaps data transfer and computation by using layer prefetching. The next layer to be executed is loaded onto the GPU while the current layer is still being executed. It can increase CPU memory significantly so ensure you have 2x the amount of memory as the model size.
diff --git a/docs/source/en/optimization/mps.md b/docs/source/en/optimization/mps.md
index 7e4c2716ac..b5afa25b2f 100644
--- a/docs/source/en/optimization/mps.md
+++ b/docs/source/en/optimization/mps.md
@@ -38,11 +38,8 @@ image = pipe(prompt).images[0]
image
```
-
-
-The PyTorch [mps](https://pytorch.org/docs/stable/notes/mps.html) backend does not support NDArray sizes greater than `2**32`. Please open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) if you encounter this problem so we can investigate.
-
-
+> [!WARNING]
+> The PyTorch [mps](https://pytorch.org/docs/stable/notes/mps.html) backend does not support NDArray sizes greater than `2**32`. Please open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) if you encounter this problem so we can investigate.
If you're using **PyTorch 1.13**, you need to "prime" the pipeline with an additional one-time pass through it. This is a temporary workaround for an issue where the first inference pass produces slightly different results than subsequent ones. You only need to do this pass once, and after just one inference step you can discard the result.
diff --git a/docs/source/en/optimization/neuron.md b/docs/source/en/optimization/neuron.md
index fa933317b4..6a45bd0563 100644
--- a/docs/source/en/optimization/neuron.md
+++ b/docs/source/en/optimization/neuron.md
@@ -20,11 +20,8 @@ Diffusers functionalities are available on [AWS Inf2 instances](https://aws.amaz
python -m pip install --upgrade-strategy eager optimum[neuronx]
```
-
-
-We provide pre-built [Hugging Face Neuron Deep Learning AMI](https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2) (DLAMI) and Optimum Neuron containers for Amazon SageMaker. It's recommended to correctly set up your environment.
-
-
+> [!TIP]
+> We provide pre-built [Hugging Face Neuron Deep Learning AMI](https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2) (DLAMI) and Optimum Neuron containers for Amazon SageMaker. It's recommended to correctly set up your environment.
The example below demonstrates how to generate images with the Stable Diffusion XL model on an inf2.8xlarge instance (you can switch to cheaper inf2.xlarge instances once the model is compiled). To generate some images, use the [`~optimum.neuron.NeuronStableDiffusionXLPipeline`] class, which is similar to the [`StableDiffusionXLPipeline`] class in Diffusers.
diff --git a/docs/source/en/optimization/onnx.md b/docs/source/en/optimization/onnx.md
index d160dcffe8..620f2af994 100644
--- a/docs/source/en/optimization/onnx.md
+++ b/docs/source/en/optimization/onnx.md
@@ -34,11 +34,8 @@ image = pipeline(prompt).images[0]
pipeline.save_pretrained("./onnx-stable-diffusion-v1-5")
```
-
-
-Generating multiple prompts in a batch seems to take too much memory. While we look into it, you may need to iterate instead of batching.
-
-
+> [!WARNING]
+> Generating multiple prompts in a batch seems to take too much memory. While we look into it, you may need to iterate instead of batching.
To export the pipeline in the ONNX format offline and use it later for inference,
use the [`optimum-cli export`](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#exporting-a-model-to-onnx-using-the-cli) command:
diff --git a/docs/source/en/optimization/speed-memory-optims.md b/docs/source/en/optimization/speed-memory-optims.md
index f43e60bc74..80c6c79a3c 100644
--- a/docs/source/en/optimization/speed-memory-optims.md
+++ b/docs/source/en/optimization/speed-memory-optims.md
@@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Compile and offloading quantized models
+# Compiling and offloading quantized models
Optimizing models often involves trade-offs between [inference speed](./fp16) and [memory-usage](./memory). For instance, while [caching](./cache) can boost inference speed, it also increases memory consumption since it needs to store the outputs of intermediate attention layers. A more balanced optimization strategy combines quantizing a model, [torch.compile](./fp16#torchcompile) and various [offloading methods](./memory#offloading).
@@ -28,7 +28,8 @@ The table below provides a comparison of optimization strategy combinations and
| quantization | 32.602 | 14.9453 |
| quantization, torch.compile | 25.847 | 14.9448 |
| quantization, torch.compile, model CPU offloading | 32.312 | 12.2369 |
-These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the [benchmarking script](https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d) if you're interested in evaluating your own model.
+
+These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the benchmarking script if you're interested in evaluating your own model.
This guide will show you how to compile and offload a quantized model with [bitsandbytes](../quantization/bitsandbytes#torchcompile). Make sure you are using [PyTorch nightly](https://pytorch.org/get-started/locally/) and the latest version of bitsandbytes.
diff --git a/docs/source/en/optimization/xformers.md b/docs/source/en/optimization/xformers.md
index 3e2792fd5f..523e815595 100644
--- a/docs/source/en/optimization/xformers.md
+++ b/docs/source/en/optimization/xformers.md
@@ -20,16 +20,10 @@ Install xFormers from `pip`:
pip install xformers
```
-
-
-The xFormers `pip` package requires the latest version of PyTorch. If you need to use a previous version of PyTorch, then we recommend [installing xFormers from the source](https://github.com/facebookresearch/xformers#installing-xformers).
-
-
+> [!TIP]
+> The xFormers `pip` package requires the latest version of PyTorch. If you need to use a previous version of PyTorch, then we recommend [installing xFormers from the source](https://github.com/facebookresearch/xformers#installing-xformers).
After xFormers is installed, you can use `enable_xformers_memory_efficient_attention()` for faster inference and reduced memory consumption as shown in this [section](memory#memory-efficient-attention).
-
-
-According to this [issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training (fine-tune or DreamBooth) in some GPUs. If you observe this problem, please install a development version as indicated in the issue comments.
-
-
+> [!WARNING]
+> According to this [issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training (fine-tune or DreamBooth) in some GPUs. If you observe this problem, please install a development version as indicated in the issue comments.
diff --git a/docs/source/en/quantization/bitsandbytes.md b/docs/source/en/quantization/bitsandbytes.md
index f97119d5f4..0729472744 100644
--- a/docs/source/en/quantization/bitsandbytes.md
+++ b/docs/source/en/quantization/bitsandbytes.md
@@ -206,11 +206,8 @@ Once a model is quantized, you can push the model to the Hub with the [`~ModelMi
-
-
-Training with 8-bit and 4-bit weights are only supported for training *extra* parameters.
-
-
+> [!WARNING]
+> Training with 8-bit and 4-bit weights are only supported for training *extra* parameters.
Check your memory footprint with the `get_memory_footprint` method:
@@ -234,11 +231,8 @@ model_4bit = AutoModel.from_pretrained(
## 8-bit (LLM.int8() algorithm)
-
-
-Learn more about the details of 8-bit quantization in this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration)!
-
-
+> [!TIP]
+> Learn more about the details of 8-bit quantization in this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration)!
This section explores some of the specific features of 8-bit models, such as outlier thresholds and skipping module conversion.
@@ -283,11 +277,8 @@ model_8bit = SD3Transformer2DModel.from_pretrained(
## 4-bit (QLoRA algorithm)
-
-
-Learn more about its details in this [blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
-
-
+> [!TIP]
+> Learn more about its details in this [blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
This section explores some of the specific features of 4-bit models, such as changing the compute data type, using the Normal Float 4 (NF4) data type, and using nested quantization.
diff --git a/docs/source/en/quantization/gguf.md b/docs/source/en/quantization/gguf.md
index 71321d5568..47804c102d 100644
--- a/docs/source/en/quantization/gguf.md
+++ b/docs/source/en/quantization/gguf.md
@@ -77,3 +77,44 @@ Once installed, set `DIFFUSERS_GGUF_CUDA_KERNELS=true` to use optimized kernels
- Q5_K
- Q6_K
+## Convert to GGUF
+
+Use the Space below to convert a Diffusers checkpoint into the GGUF format for inference.
+run conversion:
+
+
+
+
+```py
+import torch
+
+from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig
+
+ckpt_path = (
+ "https://huggingface.co/sayakpaul/different-lora-from-civitai/blob/main/flux_dev_diffusers-q4_0.gguf"
+)
+transformer = FluxTransformer2DModel.from_single_file(
+ ckpt_path,
+ quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
+ config="black-forest-labs/FLUX.1-dev",
+ subfolder="transformer",
+ torch_dtype=torch.bfloat16,
+)
+pipe = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ transformer=transformer,
+ torch_dtype=torch.bfloat16,
+)
+pipe.enable_model_cpu_offload()
+prompt = "A cat holding a sign that says hello world"
+image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
+image.save("flux-gguf.png")
+```
+
+When using Diffusers format GGUF checkpoints, it's a must to provide the model `config` path. If the
+model config resides in a `subfolder`, that needs to be specified, too.
\ No newline at end of file
diff --git a/docs/source/en/quantization/modelopt.md b/docs/source/en/quantization/modelopt.md
new file mode 100644
index 0000000000..06933d47c2
--- /dev/null
+++ b/docs/source/en/quantization/modelopt.md
@@ -0,0 +1,141 @@
+
+
+# NVIDIA ModelOpt
+
+[NVIDIA-ModelOpt](https://github.com/NVIDIA/TensorRT-Model-Optimizer) is a unified library of state-of-the-art model optimization techniques like quantization, pruning, distillation, speculative decoding, etc. It compresses deep learning models for downstream deployment frameworks like TensorRT-LLM or TensorRT to optimize inference speed.
+
+Before you begin, make sure you have nvidia_modelopt installed.
+
+```bash
+pip install -U "nvidia_modelopt[hf]"
+```
+
+Quantize a model by passing [`NVIDIAModelOptConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
+
+The example below only quantizes the weights to FP8.
+
+```python
+import torch
+from diffusers import AutoModel, SanaPipeline, NVIDIAModelOptConfig
+
+model_id = "Efficient-Large-Model/Sana_600M_1024px_diffusers"
+dtype = torch.bfloat16
+
+quantization_config = NVIDIAModelOptConfig(quant_type="FP8", quant_method="modelopt")
+transformer = AutoModel.from_pretrained(
+ model_id,
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=dtype,
+)
+pipe = SanaPipeline.from_pretrained(
+ model_id,
+ transformer=transformer,
+ torch_dtype=dtype,
+)
+pipe.to("cuda")
+
+print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
+
+prompt = "A cat holding a sign that says hello world"
+image = pipe(
+ prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
+).images[0]
+image.save("output.png")
+```
+
+> **Note:**
+>
+> The quantization methods in NVIDIA-ModelOpt are designed to reduce the memory footprint of model weights using various QAT (Quantization-Aware Training) and PTQ (Post-Training Quantization) techniques while maintaining model performance. However, the actual performance gain during inference depends on the deployment framework (e.g., TRT-LLM, TensorRT) and the specific hardware configuration.
+>
+> More details can be found [here](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples).
+
+## NVIDIAModelOptConfig
+
+The `NVIDIAModelOptConfig` class accepts three parameters:
+- `quant_type`: A string value mentioning one of the quantization types below.
+- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`SD3Transformer2DModel`]'s pos_embed projection blocks, one would specify: `modules_to_not_convert=["pos_embed.proj.weight"]`.
+- `disable_conv_quantization`: A boolean value which when set to `True` disables quantization for all convolutional layers in the model. This is useful as channel and block quantization generally don't work well with convolutional layers (used with INT4, NF4, NVFP4). If you want to disable quantization for specific convolutional layers, use `modules_to_not_convert` instead.
+- `algorithm`: The algorithm to use for determining scale, defaults to `"max"`. You can check modelopt documentation for more algorithms and details.
+- `forward_loop`: The forward loop function to use for calibrating activation during quantization. If not provided, it relies on static scale values computed using the weights only.
+- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.
+
+## Supported quantization types
+
+ModelOpt supports weight-only, channel and block quantization int8, fp8, int4, nf4, and nvfp4. The quantization methods are designed to reduce the memory footprint of the model weights while maintaining the performance of the model during inference.
+
+Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation.
+
+The quantization methods supported are as follows:
+
+| **Quantization Type** | **Supported Schemes** | **Required Kwargs** | **Additional Notes** |
+|-----------------------|-----------------------|---------------------|----------------------|
+| **INT8** | `int8 weight only`, `int8 channel quantization`, `int8 block quantization` | `quant_type`, `quant_type + channel_quantize`, `quant_type + channel_quantize + block_quantize` |
+| **FP8** | `fp8 weight only`, `fp8 channel quantization`, `fp8 block quantization` | `quant_type`, `quant_type + channel_quantize`, `quant_type + channel_quantize + block_quantize` |
+| **INT4** | `int4 weight only`, `int4 block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize` | `channel_quantize = -1 is only supported for now`|
+| **NF4** | `nf4 weight only`, `nf4 double block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize + scale_channel_quantize` + `scale_block_quantize` | `channel_quantize = -1 and scale_channel_quantize = -1 are only supported for now` |
+| **NVFP4** | `nvfp4 weight only`, `nvfp4 block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize` | `channel_quantize = -1 is only supported for now`|
+
+
+Refer to the [official modelopt documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/) for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
+
+## Serializing and Deserializing quantized models
+
+To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method.
+
+```python
+import torch
+from diffusers import AutoModel, NVIDIAModelOptConfig
+from modelopt.torch.opt import enable_huggingface_checkpointing
+
+enable_huggingface_checkpointing()
+
+model_id = "Efficient-Large-Model/Sana_600M_1024px_diffusers"
+quant_config_fp8 = {"quant_type": "FP8", "quant_method": "modelopt"}
+quant_config_fp8 = NVIDIAModelOptConfig(**quant_config_fp8)
+model = AutoModel.from_pretrained(
+ model_id,
+ subfolder="transformer",
+ quantization_config=quant_config_fp8,
+ torch_dtype=torch.bfloat16,
+)
+model.save_pretrained('path/to/sana_fp8', safe_serialization=False)
+```
+
+To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method.
+
+```python
+import torch
+from diffusers import AutoModel, NVIDIAModelOptConfig, SanaPipeline
+from modelopt.torch.opt import enable_huggingface_checkpointing
+
+enable_huggingface_checkpointing()
+
+quantization_config = NVIDIAModelOptConfig(quant_type="FP8", quant_method="modelopt")
+transformer = AutoModel.from_pretrained(
+ "path/to/sana_fp8",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+)
+pipe = SanaPipeline.from_pretrained(
+ "Efficient-Large-Model/Sana_600M_1024px_diffusers",
+ transformer=transformer,
+ torch_dtype=torch.bfloat16,
+)
+pipe.to("cuda")
+prompt = "A cat holding a sign that says hello world"
+image = pipe(
+ prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
+).images[0]
+image.save("output.png")
+```
diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md
index 12c39f52e4..38abeeac6d 100644
--- a/docs/source/en/quantization/overview.md
+++ b/docs/source/en/quantization/overview.md
@@ -34,7 +34,9 @@ Initialize [`~quantizers.PipelineQuantizationConfig`] with the following paramet
> [!TIP]
> These `quant_kwargs` arguments are different for each backend. Refer to the [Quantization API](../api/quantization) docs to view the arguments for each backend.
-- `components_to_quantize` specifies which components of the pipeline to quantize. Typically, you should quantize the most compute intensive components like the transformer. The text encoder is another component to consider quantizing if a pipeline has more than one such as [`FluxPipeline`]. The example below quantizes the T5 text encoder in [`FluxPipeline`] while keeping the CLIP model intact.
+- `components_to_quantize` specifies which component(s) of the pipeline to quantize. Typically, you should quantize the most compute intensive components like the transformer. The text encoder is another component to consider quantizing if a pipeline has more than one such as [`FluxPipeline`]. The example below quantizes the T5 text encoder in [`FluxPipeline`] while keeping the CLIP model intact.
+
+ `components_to_quantize` accepts either a list for multiple models or a string for a single model.
The example below loads the bitsandbytes backend with the following arguments from [`~quantizers.quantization_config.BitsAndBytesConfig`], `load_in_4bit`, `bnb_4bit_quant_type`, and `bnb_4bit_compute_dtype`.
@@ -62,6 +64,7 @@ pipe = DiffusionPipeline.from_pretrained(
image = pipe("photo of a cute dog").images[0]
```
+
### Advanced quantization
The `quant_mapping` argument provides more options for how to quantize each individual component in a pipeline, like combining different quantization backends.
diff --git a/docs/source/en/quantization/torchao.md b/docs/source/en/quantization/torchao.md
index 5c7578dcbb..18cc109e07 100644
--- a/docs/source/en/quantization/torchao.md
+++ b/docs/source/en/quantization/torchao.md
@@ -11,69 +11,96 @@ specific language governing permissions and limitations under the License. -->
# torchao
-[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more.
+[torchao](https://github.com/pytorch/ao) provides high-performance dtypes and optimizations based on quantization and sparsity for inference and training PyTorch models. It is supported for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
-Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed.
+Make sure Pytorch 2.5+ and torchao are installed with the command below.
```bash
-pip install -U torch torchao
+uv pip install -U torch torchao
```
+Each quantization dtype is available as a separate instance of a [AOBaseConfig](https://docs.pytorch.org/ao/main/api_ref_quantization.html#inference-apis-for-quantize) class. This provides more flexible configuration options by exposing more available arguments.
-Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
+Pass the `AOBaseConfig` of a quantization dtype, like [Int4WeightOnlyConfig](https://docs.pytorch.org/ao/main/generated/torchao.quantization.Int4WeightOnlyConfig) to [`TorchAoConfig`] in [`~ModelMixin.from_pretrained`].
-The example below only quantizes the weights to int8.
+```py
+import torch
+from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
+from torchao.quantization import Int8WeightOnlyConfig
+
+pipeline_quant_config = PipelineQuantizationConfig(
+ quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig(group_size=128)))}
+)
+pipeline = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ quantzation_config=pipeline_quant_config,
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+)
+```
+
+For simple use cases, you could also provide a string identifier in [`TorchAo`] as shown below.
+
+```py
+import torch
+from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
+
+pipeline_quant_config = PipelineQuantizationConfig(
+ quant_mapping={"transformer": TorchAoConfig("int8wo")}
+)
+pipeline = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ quantzation_config=pipeline_quant_config,
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+)
+```
+
+## torch.compile
+
+torchao supports [torch.compile](../optimization/fp16#torchcompile) which can speed up inference with one line of code.
```python
import torch
-from diffusers import FluxPipeline, AutoModel, TorchAoConfig
+from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConfig
+from torchao.quantization import Int4WeightOnlyConfig
-model_id = "black-forest-labs/FLUX.1-dev"
-dtype = torch.bfloat16
-
-quantization_config = TorchAoConfig("int8wo")
-transformer = AutoModel.from_pretrained(
- model_id,
- subfolder="transformer",
- quantization_config=quantization_config,
- torch_dtype=dtype,
+pipeline_quant_config = PipelineQuantizationConfig(
+ quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128)))}
)
-pipe = FluxPipeline.from_pretrained(
- model_id,
- transformer=transformer,
- torch_dtype=dtype,
+pipeline = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ quantzation_config=pipeline_quant_config,
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
)
-pipe.to("cuda")
-# Without quantization: ~31.447 GB
-# With quantization: ~20.40 GB
-print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
-
-prompt = "A cat holding a sign that says hello world"
-image = pipe(
- prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
-).images[0]
-image.save("output.png")
+pipeline.transformer.compile(transformer, mode="max-autotune", fullgraph=True)
```
-TorchAO is fully compatible with [torch.compile](../optimization/fp16#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code.
-
-```python
-# In the above code, add the following after initializing the transformer
-transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
-```
-
-For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware.
+Refer to this [table](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450) for inference speed and memory usage benchmarks with Flux and CogVideoX. More benchmarks on various hardware are also available in the torchao [repository](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks).
> [!TIP]
> The FP8 post-training quantization schemes in torchao are effective for GPUs with compute capability of at least 8.9 (RTX-4090, Hopper, etc.). FP8 often provides the best speed, memory, and quality trade-off when generating images and videos. We recommend combining FP8 and torch.compile if your GPU is compatible.
-torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future.
+## autoquant
-The `TorchAoConfig` class accepts three parameters:
-- `quant_type`: A string value mentioning one of the quantization types below.
-- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`.
-- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.
+torchao provides [autoquant](https://docs.pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) an automatic quantization API. Autoquantization chooses the best quantization strategy by comparing the performance of each strategy on chosen input types and shapes. This is only supported in Diffusers for individual models at the moment.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+from torchao.quantization import autoquant
+
+# Load the pipeline
+pipeline = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-schnell",
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+)
+
+transformer = autoquant(pipeline.transformer)
+```
## Supported quantization types
diff --git a/docs/source/en/quicktour.md b/docs/source/en/quicktour.md
index 820b03c02a..1ccc8eeadc 100644
--- a/docs/source/en/quicktour.md
+++ b/docs/source/en/quicktour.md
@@ -10,314 +10,223 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-[[open-in-colab]]
+# Quickstart
-# Quicktour
+Diffusers is a library for developers and researchers that provides an easy inference API for generating images, videos and audio, as well as the building blocks for implementing new workflows.
-Diffusion models are trained to denoise random Gaussian noise step-by-step to generate a sample of interest, such as an image or audio. This has sparked a tremendous amount of interest in generative AI, and you have probably seen examples of diffusion generated images on the internet. 🧨 Diffusers is a library aimed at making diffusion models widely accessible to everyone.
+Diffusers provides many optimizations out-of-the-box that makes it possible to load and run large models on setups with limited memory or to accelerate inference.
-Whether you're a developer or an everyday user, this quicktour will introduce you to 🧨 Diffusers and help you get up and generating quickly! There are three main components of the library to know about:
+This Quickstart will give you an overview of Diffusers and get you up and generating quickly.
-* The [`DiffusionPipeline`] is a high-level end-to-end class designed to rapidly generate samples from pretrained diffusion models for inference.
-* Popular pretrained [model](./api/models) architectures and modules that can be used as building blocks for creating diffusion systems.
-* Many different [schedulers](./api/schedulers/overview) - algorithms that control how noise is added for training, and how to generate denoised images during inference.
+> [!TIP]
+> Before you begin, make sure you have a Hugging Face [account](https://huggingface.co/join) in order to use gated models like [Flux](https://huggingface.co/black-forest-labs/FLUX.1-dev).
-The quicktour will show you how to use the [`DiffusionPipeline`] for inference, and then walk you through how to combine a model and scheduler to replicate what's happening inside the [`DiffusionPipeline`].
-
-
-
-The quicktour is a simplified version of the introductory 🧨 Diffusers [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb) to help you get started quickly. If you want to learn more about 🧨 Diffusers' goal, design philosophy, and additional details about its core API, check out the notebook!
-
-
-
-Before you begin, make sure you have all the necessary libraries installed:
-
-```py
-# uncomment to install the necessary libraries in Colab
-#!pip install --upgrade diffusers accelerate transformers
-```
-
-- [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) speeds up model loading for inference and training.
-- [🤗 Transformers](https://huggingface.co/docs/transformers/index) is required to run the most popular diffusion models, such as [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview).
+Follow the [Installation](./installation) guide to install Diffusers if it's not already installed.
## DiffusionPipeline
-The [`DiffusionPipeline`] is the easiest way to use a pretrained diffusion system for inference. It is an end-to-end system containing the model and the scheduler. You can use the [`DiffusionPipeline`] out-of-the-box for many tasks. Take a look at the table below for some supported tasks, and for a complete list of supported tasks, check out the [🧨 Diffusers Summary](./api/pipelines/overview#diffusers-summary) table.
+A diffusion model combines multiple components to generate outputs in any modality based on an input, such as a text description, image or both.
-| **Task** | **Description** | **Pipeline**
-|------------------------------|--------------------------------------------------------------------------------------------------------------|-----------------|
-| Unconditional Image Generation | generate an image from Gaussian noise | [unconditional_image_generation](./using-diffusers/unconditional_image_generation) |
-| Text-Guided Image Generation | generate an image given a text prompt | [conditional_image_generation](./using-diffusers/conditional_image_generation) |
-| Text-Guided Image-to-Image Translation | adapt an image guided by a text prompt | [img2img](./using-diffusers/img2img) |
-| Text-Guided Image-Inpainting | fill the masked part of an image given the image, the mask and a text prompt | [inpaint](./using-diffusers/inpaint) |
-| Text-Guided Depth-to-Image Translation | adapt parts of an image guided by a text prompt while preserving structure via depth estimation | [depth2img](./using-diffusers/depth2img) |
+For a standard text-to-image model:
-Start by creating an instance of a [`DiffusionPipeline`] and specify which pipeline checkpoint you would like to download.
-You can use the [`DiffusionPipeline`] for any [checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads) stored on the Hugging Face Hub.
-In this quicktour, you'll load the [`stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint for text-to-image generation.
+1. A text encoder turns a prompt into embeddings that guide the denoising process. Some models have more than one text encoder.
+2. A scheduler contains the algorithmic specifics for gradually denoising initial random noise into clean outputs. Different schedulers affect generation speed and quality.
+3. A UNet or diffusion transformer (DiT) is the workhorse of a diffusion model.
-
+ At each step, it performs the denoising predictions, such as how much noise to remove or the general direction in which to steer the noise to generate better quality outputs.
-For [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) models, please carefully read the [license](https://huggingface.co/spaces/CompVis/stable-diffusion-license) first before running the model. 🧨 Diffusers implements a [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) to prevent offensive or harmful content, but the model's improved image generation capabilities can still produce potentially harmful content.
+ The UNet or DiT repeats this loop for a set amount of steps to generate the final output.
+
+4. A variational autoencoder (VAE) encodes and decodes pixels to a spatially compressed latent-space. *Latents* are compressed representations of an image and are more efficient to work with. The UNet or DiT operates on latents, and the clean latents at the end are decoded back into images.
-
+The [`DiffusionPipeline`] packages all these components into a single class for inference. There are several arguments in [`~DiffusionPipeline.__call__`] you can change, such as `num_inference_steps`, that affect the diffusion process. Try different values and arguments to see how they change generation quality or speed.
-Load the model with the [`~DiffusionPipeline.from_pretrained`] method:
+Load a model with [`~DiffusionPipeline.from_pretrained`] and describe what you'd like to generate. The example below uses the default argument values.
-```python
->>> from diffusers import DiffusionPipeline
+
+
->>> pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", use_safetensors=True)
-```
-
-The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components. You'll see that the Stable Diffusion pipeline is composed of the [`UNet2DConditionModel`] and [`PNDMScheduler`] among other things:
+Use `.images[0]` to access the generated image output.
```py
->>> pipeline
-StableDiffusionPipeline {
- "_class_name": "StableDiffusionPipeline",
- "_diffusers_version": "0.21.4",
- ...,
- "scheduler": [
- "diffusers",
- "PNDMScheduler"
- ],
- ...,
- "unet": [
- "diffusers",
- "UNet2DConditionModel"
- ],
- "vae": [
- "diffusers",
- "AutoencoderKL"
- ]
-}
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
+)
+
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+pipeline(prompt).images[0]
```
-We strongly recommend running the pipeline on a GPU because the model consists of roughly 1.4 billion parameters.
-You can move the generator object to a GPU, just like you would in PyTorch:
+
+
-```python
->>> pipeline.to("cuda")
-```
-
-Now you can pass a text prompt to the `pipeline` to generate an image, and then access the denoised image. By default, the image output is wrapped in a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object.
-
-```python
->>> image = pipeline("An image of a squirrel in Picasso style").images[0]
->>> image
-```
-
-
-
-
-
-Save the image by calling `save`:
-
-```python
->>> image.save("image_of_squirrel_painting.png")
-```
-
-### Local pipeline
-
-You can also use the pipeline locally. The only difference is you need to download the weights first:
-
-```bash
-!git lfs install
-!git clone https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5
-```
-
-Then load the saved weights into the pipeline:
-
-```python
->>> pipeline = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5", use_safetensors=True)
-```
-
-Now, you can run the pipeline as you would in the section above.
-
-### Swapping schedulers
-
-Different schedulers come with different denoising speeds and quality trade-offs. The best way to find out which one works best for you is to try them out! One of the main features of 🧨 Diffusers is to allow you to easily switch between schedulers. For example, to replace the default [`PNDMScheduler`] with the [`EulerDiscreteScheduler`], load it with the [`~diffusers.ConfigMixin.from_config`] method:
+Use `.frames[0]` to access the generated video output and [`~utils.export_to_video`] to save the video.
```py
->>> from diffusers import EulerDiscreteScheduler
+import torch
+from diffusers import AutoencoderKLWan, DiffusionPipeline
+from diffusers.quantizers import PipelineQuantizationConfig
+from diffusers.utils import export_to_video
->>> pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", use_safetensors=True)
->>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
+vae = AutoencoderKLWan.from_pretrained(
+ "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
+ subfolder="vae",
+ torch_dtype=torch.float32
+)
+pipeline = DiffusionPipeline.from_pretrained(
+ "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
+ vae=vae
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+)
+
+prompt = """
+Cinematic video of a sleek cat lounging on a colorful inflatable in a crystal-clear turquoise pool in Palm Springs,
+sipping a salt-rimmed margarita through a straw. Golden-hour sunlight glows over mid-century modern homes and swaying palms.
+Shot in rich Sony a7S III: with moody, glamorous color grading, subtle lens flares, and soft vintage film grain.
+Ripples shimmer as a warm desert breeze stirs the water, blending luxury and playful charm in an epic, gorgeously composed frame.
+"""
+video = pipeline(prompt=prompt, num_frames=81, num_inference_steps=40).frames[0]
+export_to_video(video, "output.mp4", fps=16)
```
-Try generating an image with the new scheduler and see if you notice a difference!
+
+
-In the next section, you'll take a closer look at the components - the model and scheduler - that make up the [`DiffusionPipeline`] and learn how to use these components to generate an image of a cat.
+## LoRA
-## Models
+Adapters insert a small number of trainable parameters to the original base model. Only the inserted parameters are fine-tuned while the rest of the model weights remain frozen. This makes it fast and cheap to fine-tune a model on a new style. Among adapters, [LoRA's](./tutorials/using_peft_for_inference) are the most popular.
-Most models take a noisy sample, and at each timestep it predicts the *noise residual* (other models learn to predict the previous sample directly or the velocity or [`v-prediction`](https://github.com/huggingface/diffusers/blob/5e5ce13e2f89ac45a0066cb3f369462a3cf1d9ef/src/diffusers/schedulers/scheduling_ddim.py#L110)), the difference between a less noisy image and the input image. You can mix and match models to create other diffusion systems.
-
-Models are initiated with the [`~ModelMixin.from_pretrained`] method which also locally caches the model weights so it is faster the next time you load the model. For the quicktour, you'll load the [`UNet2DModel`], a basic unconditional image generation model with a checkpoint trained on cat images:
+Add a LoRA to a pipeline with the [`~loaders.QwenImageLoraLoaderMixin.load_lora_weights`] method. Some LoRA's require a special word to trigger it, such as `Realism`, in the example below. Check a LoRA's model card to see if it requires a trigger word.
```py
->>> from diffusers import UNet2DModel
+import torch
+from diffusers import DiffusionPipeline
->>> repo_id = "google/ddpm-cat-256"
->>> model = UNet2DModel.from_pretrained(repo_id, use_safetensors=True)
+pipeline = DiffusionPipeline.from_pretrained(
+ "Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
+)
+pipeline.load_lora_weights(
+ "flymy-ai/qwen-image-realism-lora",
+)
+
+prompt = """
+super Realism cinematic film still of a cat sipping a margarita in a pool in Palm Springs in the style of umempart, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+pipeline(prompt).images[0]
```
+Check out the [LoRA](./tutorials/using_peft_for_inference) docs or Adapters section to learn more.
+
+## Quantization
+
+[Quantization](./quantization/overview) stores data in fewer bits to reduce memory usage. It may also speed up inference because it takes less time to perform calculations with fewer bits.
+
+Diffusers provides several quantization backends and picking one depends on your use case. For example, [bitsandbytes](./quantization/bitsandbytes) and [torchao](./quantization/torchao) are both simple and easy to use for inference, but torchao supports more [quantization types](./quantization/torchao#supported-quantization-types) like fp8.
+
+Configure [`PipelineQuantizationConfig`] with the backend to use, the specific arguments (refer to the [API](./api/quantization) reference for available arguments) for that backend, and which components to quantize. The example below quantizes the model to 4-bits and only uses 14.93GB of memory.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+from diffusers.quantizers import PipelineQuantizationConfig
+
+quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_4bit",
+ quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
+ components_to_quantize=["transformer", "text_encoder"],
+)
+pipeline = DiffusionPipeline.from_pretrained(
+ "Qwen/Qwen-Image",
+ torch_dtype=torch.bfloat16,
+ quantization_config=quant_config,
+ device_map="cuda"
+)
+
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+pipeline(prompt).images[0]
+print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
+```
+
+Take a look at the [Quantization](./quantization/overview) section for more details.
+
+## Optimizations
+
> [!TIP]
-> Use the [`AutoModel`] API to automatically select a model class if you're unsure of which one to use.
+> Optimization is dependent on hardware specs such as memory. Use this [Space](https://huggingface.co/spaces/diffusers/optimized-diffusers-code) to generate code examples that include all of Diffusers' available memory and speed optimization techniques for any model you're using.
-To access the model parameters, call `model.config`:
+Modern diffusion models are very large and have billions of parameters. The iterative denoising process is also computationally intensive and slow. Diffusers provides techniques for reducing memory usage and boosting inference speed. These techniques can be combined with quantization to optimize for both memory usage and inference speed.
+
+### Memory usage
+
+The text encoders and UNet or DiT can use up as much as ~30GB of memory, exceeding the amount available on many free-tier or consumer GPUs.
+
+Offloading stores weights that aren't currently used on the CPU and only moves them to the GPU when they're needed. There are a few offloading types and the example below uses [model offloading](./optimization/memory#model-offloading). This moves an entire model, like a text encoder or transformer, to the CPU when it isn't actively being used.
+
+Call [`~DiffusionPipeline.enable_model_cpu_offload`] to activate it. By combining quantization and offloading, the following example only requires ~12.54GB of memory.
```py
->>> model.config
+import torch
+from diffusers import DiffusionPipeline
+from diffusers.quantizers import PipelineQuantizationConfig
+
+quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_4bit",
+ quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
+ components_to_quantize=["transformer", "text_encoder"],
+)
+pipeline = DiffusionPipeline.from_pretrained(
+ "Qwen/Qwen-Image",
+ torch_dtype=torch.bfloat16,
+ quantization_config=quant_config,
+ device_map="cuda"
+)
+pipeline.enable_model_cpu_offload()
+
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+pipeline(prompt).images[0]
+print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
```
-The model configuration is a 🧊 frozen 🧊 dictionary, which means those parameters can't be changed after the model is created. This is intentional and ensures that the parameters used to define the model architecture at the start remain the same, while other parameters can still be adjusted during inference.
+Refer to the [Reduce memory usage](./optimization/memory) docs to learn more about other memory reducing techniques.
-Some of the most important parameters are:
+### Inference speed
-* `sample_size`: the height and width dimension of the input sample.
-* `in_channels`: the number of input channels of the input sample.
-* `down_block_types` and `up_block_types`: the type of down- and upsampling blocks used to create the UNet architecture.
-* `block_out_channels`: the number of output channels of the downsampling blocks; also used in reverse order for the number of input channels of the upsampling blocks.
-* `layers_per_block`: the number of ResNet blocks present in each UNet block.
+The denoising loop performs a lot of computations and can be slow. Methods like [torch.compile](./optimization/fp16#torchcompile) increases inference speed by compiling the computations into an optimized kernel. Compilation is slow for the first generation but successive generations should be much faster.
-To use the model for inference, create the image shape with random Gaussian noise. It should have a `batch` axis because the model can receive multiple random noises, a `channel` axis corresponding to the number of input channels, and a `sample_size` axis for the height and width of the image:
+The example below uses [regional compilation](./optimization/fp16#regional-compilation) to only compile small regions of a model. It reduces cold-start latency while also providing a runtime speed up.
+
+Call [`~ModelMixin.compile_repeated_blocks`] on the model to activate it.
```py
->>> import torch
+import torch
+from diffusers import DiffusionPipeline
->>> torch.manual_seed(0)
+pipeline = DiffusionPipeline.from_pretrained(
+ "Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
+)
->>> noisy_sample = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
->>> noisy_sample.shape
-torch.Size([1, 3, 256, 256])
+pipeline.transformer.compile_repeated_blocks(
+ fullgraph=True,
+)
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+pipeline(prompt).images[0]
```
-For inference, pass the noisy image and a `timestep` to the model. The `timestep` indicates how noisy the input image is, with more noise at the beginning and less at the end. This helps the model determine its position in the diffusion process, whether it is closer to the start or the end. Use the `sample` method to get the model output:
-
-```py
->>> with torch.no_grad():
-... noisy_residual = model(sample=noisy_sample, timestep=2).sample
-```
-
-To generate actual examples though, you'll need a scheduler to guide the denoising process. In the next section, you'll learn how to couple a model with a scheduler.
-
-## Schedulers
-
-Schedulers manage going from a noisy sample to a less noisy sample given the model output - in this case, it is the `noisy_residual`.
-
-
-
-🧨 Diffusers is a toolbox for building diffusion systems. While the [`DiffusionPipeline`] is a convenient way to get started with a pre-built diffusion system, you can also choose your own model and scheduler components separately to build a custom diffusion system.
-
-
-
-For the quicktour, you'll instantiate the [`DDPMScheduler`] with its [`~diffusers.ConfigMixin.from_config`] method:
-
-```py
->>> from diffusers import DDPMScheduler
-
->>> scheduler = DDPMScheduler.from_pretrained(repo_id)
->>> scheduler
-DDPMScheduler {
- "_class_name": "DDPMScheduler",
- "_diffusers_version": "0.21.4",
- "beta_end": 0.02,
- "beta_schedule": "linear",
- "beta_start": 0.0001,
- "clip_sample": true,
- "clip_sample_range": 1.0,
- "dynamic_thresholding_ratio": 0.995,
- "num_train_timesteps": 1000,
- "prediction_type": "epsilon",
- "sample_max_value": 1.0,
- "steps_offset": 0,
- "thresholding": false,
- "timestep_spacing": "leading",
- "trained_betas": null,
- "variance_type": "fixed_small"
-}
-```
-
-
-
-💡 Unlike a model, a scheduler does not have trainable weights and is parameter-free!
-
-
-
-Some of the most important parameters are:
-
-* `num_train_timesteps`: the length of the denoising process or, in other words, the number of timesteps required to process random Gaussian noise into a data sample.
-* `beta_schedule`: the type of noise schedule to use for inference and training.
-* `beta_start` and `beta_end`: the start and end noise values for the noise schedule.
-
-To predict a slightly less noisy image, pass the following to the scheduler's [`~diffusers.DDPMScheduler.step`] method: model output, `timestep`, and current `sample`.
-
-```py
->>> less_noisy_sample = scheduler.step(model_output=noisy_residual, timestep=2, sample=noisy_sample).prev_sample
->>> less_noisy_sample.shape
-torch.Size([1, 3, 256, 256])
-```
-
-The `less_noisy_sample` can be passed to the next `timestep` where it'll get even less noisy! Let's bring it all together now and visualize the entire denoising process.
-
-First, create a function that postprocesses and displays the denoised image as a `PIL.Image`:
-
-```py
->>> import PIL.Image
->>> import numpy as np
-
-
->>> def display_sample(sample, i):
-... image_processed = sample.cpu().permute(0, 2, 3, 1)
-... image_processed = (image_processed + 1.0) * 127.5
-... image_processed = image_processed.numpy().astype(np.uint8)
-
-... image_pil = PIL.Image.fromarray(image_processed[0])
-... display(f"Image at step {i}")
-... display(image_pil)
-```
-
-To speed up the denoising process, move the input and model to a GPU:
-
-```py
->>> model.to("cuda")
->>> noisy_sample = noisy_sample.to("cuda")
-```
-
-Now create a denoising loop that predicts the residual of the less noisy sample, and computes the less noisy sample with the scheduler:
-
-```py
->>> import tqdm
-
->>> sample = noisy_sample
-
->>> for i, t in enumerate(tqdm.tqdm(scheduler.timesteps)):
-... # 1. predict noise residual
-... with torch.no_grad():
-... residual = model(sample, t).sample
-
-... # 2. compute less noisy image and set x_t -> x_t-1
-... sample = scheduler.step(residual, t, sample).prev_sample
-
-... # 3. optionally look at image
-... if (i + 1) % 50 == 0:
-... display_sample(sample, i + 1)
-```
-
-Sit back and watch as a cat is generated from nothing but noise! 😻
-
-
-
-
-
-## Next steps
-
-Hopefully, you generated some cool images with 🧨 Diffusers in this quicktour! For your next steps, you can:
-
-* Train or finetune a model to generate your own images in the [training](./tutorials/basic_training) tutorial.
-* See example official and community [training or finetuning scripts](https://github.com/huggingface/diffusers/tree/main/examples#-diffusers-examples) for a variety of use cases.
-* Learn more about loading, accessing, changing, and comparing schedulers in the [Using different Schedulers](./using-diffusers/schedulers) guide.
-* Explore prompt engineering, speed and memory optimizations, and tips and tricks for generating higher-quality images with the [Stable Diffusion](./stable_diffusion) guide.
-* Dive deeper into speeding up 🧨 Diffusers with guides on [optimized PyTorch on a GPU](./optimization/fp16), and inference guides for running [Stable Diffusion on Apple Silicon (M1/M2)](./optimization/mps) and [ONNX Runtime](./optimization/onnx).
+Check out the [Accelerate inference](./optimization/fp16) or [Caching](./optimization/cache) docs for more methods that speed up inference.
\ No newline at end of file
diff --git a/docs/source/en/stable_diffusion.md b/docs/source/en/stable_diffusion.md
index e43bcf3eaa..93e399d3db 100644
--- a/docs/source/en/stable_diffusion.md
+++ b/docs/source/en/stable_diffusion.md
@@ -10,252 +10,123 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Effective and efficient diffusion
-
[[open-in-colab]]
-Getting the [`DiffusionPipeline`] to generate images in a certain style or include what you want can be tricky. Often times, you have to run the [`DiffusionPipeline`] several times before you end up with an image you're happy with. But generating something out of nothing is a computationally intensive process, especially if you're running inference over and over again.
+# Basic performance
-This is why it's important to get the most *computational* (speed) and *memory* (GPU vRAM) efficiency from the pipeline to reduce the time between inference cycles so you can iterate faster.
+Diffusion is a random process that is computationally demanding. You may need to run the [`DiffusionPipeline`] several times before getting a desired output. That's why it's important to carefully balance generation speed and memory usage in order to iterate faster,
-This tutorial walks you through how to generate faster and better with the [`DiffusionPipeline`].
+This guide recommends some basic performance tips for using the [`DiffusionPipeline`]. Refer to the Inference Optimization section docs such as [Accelerate inference](./optimization/fp16) or [Reduce memory usage](./optimization/memory) for more detailed performance guides.
-Begin by loading the [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) model:
+## Memory usage
-```python
+Reducing the amount of memory used indirectly speeds up generation and can help a model fit on device.
+
+The [`~DiffusionPipeline.enable_model_cpu_offload`] method moves a model to the CPU when it is not in use to save GPU memory.
+
+```py
+import torch
from diffusers import DiffusionPipeline
-model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
-pipeline = DiffusionPipeline.from_pretrained(model_id, use_safetensors=True)
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+)
+pipeline.enable_model_cpu_offload()
+
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+pipeline(prompt).images[0]
+print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
```
-The example prompt you'll use is a portrait of an old warrior chief, but feel free to use your own prompt:
+## Inference speed
-```python
-prompt = "portrait photo of a old warrior chief"
-```
+Denoising is the most computationally demanding process during diffusion. Methods that optimizes this process accelerates inference speed. Try the following methods for a speed up.
-## Speed
+- Add `device_map="cuda"` to place the pipeline on a GPU. Placing a model on an accelerator, like a GPU, increases speed because it performs computations in parallel.
+- Set `torch_dtype=torch.bfloat16` to execute the pipeline in half-precision. Reducing the data type precision increases speed because it takes less time to perform computations in a lower precision.
-
-
-💡 If you don't have access to a GPU, you can use one for free from a GPU provider like [Colab](https://colab.research.google.com/)!
-
-
-
-One of the simplest ways to speed up inference is to place the pipeline on a GPU the same way you would with any PyTorch module:
-
-```python
-pipeline = pipeline.to("cuda")
-```
-
-To make sure you can use the same image and improve on it, use a [`Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) and set a seed for [reproducibility](./using-diffusers/reusing_seeds):
-
-```python
+```py
import torch
+import time
+from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
-generator = torch.Generator("cuda").manual_seed(0)
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.bfloat16,
+ device_map="cuda
+)
```
-Now you can generate an image:
-
-```python
-image = pipeline(prompt, generator=generator).images[0]
-image
-```
-
-
-
-
-
-This process took ~30 seconds on a T4 GPU (it might be faster if your allocated GPU is better than a T4). By default, the [`DiffusionPipeline`] runs inference with full `float32` precision for 50 inference steps. You can speed this up by switching to a lower precision like `float16` or running fewer inference steps.
-
-Let's start by loading the model in `float16` and generate an image:
-
-```python
-import torch
-
-pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, use_safetensors=True)
-pipeline = pipeline.to("cuda")
-generator = torch.Generator("cuda").manual_seed(0)
-image = pipeline(prompt, generator=generator).images[0]
-image
-```
-
-
-
-
-
-This time, it only took ~11 seconds to generate the image, which is almost 3x faster than before!
-
-
-
-💡 We strongly suggest always running your pipelines in `float16`, and so far, we've rarely seen any degradation in output quality.
-
-
-
-Another option is to reduce the number of inference steps. Choosing a more efficient scheduler could help decrease the number of steps without sacrificing output quality. You can find which schedulers are compatible with the current model in the [`DiffusionPipeline`] by calling the `compatibles` method:
-
-```python
-pipeline.scheduler.compatibles
-[
- diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
- diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler,
- diffusers.schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteScheduler,
- diffusers.schedulers.scheduling_deis_multistep.DEISMultistepScheduler,
- diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
- diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
- diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
- diffusers.schedulers.scheduling_dpmsolver_singlestep.DPMSolverSinglestepScheduler,
- diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteScheduler,
- diffusers.utils.dummy_torch_and_torchsde_objects.DPMSolverSDEScheduler,
- diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler,
- diffusers.schedulers.scheduling_pndm.PNDMScheduler,
- diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler,
- diffusers.schedulers.scheduling_ddim.DDIMScheduler,
-]
-```
-
-The Stable Diffusion model uses the [`PNDMScheduler`] by default which usually requires ~50 inference steps, but more performant schedulers like [`DPMSolverMultistepScheduler`], require only ~20 or 25 inference steps. Use the [`~ConfigMixin.from_config`] method to load a new scheduler:
-
-```python
-from diffusers import DPMSolverMultistepScheduler
+- Use a faster scheduler, such as [`DPMSolverMultistepScheduler`], which only requires ~20-25 steps.
+- Set `num_inference_steps` to a lower value. Reducing the number of inference steps reduces the overall number of computations. However, this can result in lower generation quality.
+```py
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+
+start_time = time.perf_counter()
+image = pipeline(prompt).images[0]
+end_time = time.perf_counter()
+
+print(f"Image generation took {end_time - start_time:.3f} seconds")
```
-Now set the `num_inference_steps` to 20:
+## Generation quality
-```python
-generator = torch.Generator("cuda").manual_seed(0)
-image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
-image
-```
+Many modern diffusion models deliver high-quality images out-of-the-box. However, you can still improve generation quality by trying the following.
-
-
-
+- Try a more detailed and descriptive prompt. Include details such as the image medium, subject, style, and aesthetic. A negative prompt may also help by guiding a model away from undesirable features by using words like low quality or blurry.
-Great, you've managed to cut the inference time to just 4 seconds! ⚡️
+ ```py
+ import torch
+ from diffusers import DiffusionPipeline
-## Memory
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+ )
-The other key to improving pipeline performance is consuming less memory, which indirectly implies more speed, since you're often trying to maximize the number of images generated per second. The easiest way to see how many images you can generate at once is to try out different batch sizes until you get an `OutOfMemoryError` (OOM).
+ prompt = """
+ cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+ highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+ """
+ negative_prompt = "low quality, blurry, ugly, poor details"
+ pipeline(prompt, negative_prompt=negative_prompt).images[0]
+ ```
-Create a function that'll generate a batch of images from a list of prompts and `Generators`. Make sure to assign each `Generator` a seed so you can reuse it if it produces a good result.
+ For more details about creating better prompts, take a look at the [Prompt techniques](./using-diffusers/weighted_prompts) doc.
-```python
-def get_inputs(batch_size=1):
- generator = [torch.Generator("cuda").manual_seed(i) for i in range(batch_size)]
- prompts = batch_size * [prompt]
- num_inference_steps = 20
+- Try a different scheduler, like [`HeunDiscreteScheduler`] or [`LMSDiscreteScheduler`], that gives up generation speed for quality.
- return {"prompt": prompts, "generator": generator, "num_inference_steps": num_inference_steps}
-```
+ ```py
+ import torch
+ from diffusers import DiffusionPipeline, HeunDiscreteScheduler
-Start with `batch_size=4` and see how much memory you've consumed:
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+ )
+ pipeline.scheduler = HeunDiscreteScheduler.from_config(pipeline.scheduler.config)
-```python
-from diffusers.utils import make_image_grid
-
-images = pipeline(**get_inputs(batch_size=4)).images
-make_image_grid(images, 2, 2)
-```
-
-Unless you have a GPU with more vRAM, the code above probably returned an `OOM` error! Most of the memory is taken up by the cross-attention layers. Instead of running this operation in a batch, you can run it sequentially to save a significant amount of memory. All you have to do is configure the pipeline to use the [`~DiffusionPipeline.enable_attention_slicing`] function:
-
-```python
-pipeline.enable_attention_slicing()
-```
-
-Now try increasing the `batch_size` to 8!
-
-```python
-images = pipeline(**get_inputs(batch_size=8)).images
-make_image_grid(images, rows=2, cols=4)
-```
-
-
-
-
-
-Whereas before you couldn't even generate a batch of 4 images, now you can generate a batch of 8 images at ~3.5 seconds per image! This is probably the fastest you can go on a T4 GPU without sacrificing quality.
-
-## Quality
-
-In the last two sections, you learned how to optimize the speed of your pipeline by using `fp16`, reducing the number of inference steps by using a more performant scheduler, and enabling attention slicing to reduce memory consumption. Now you're going to focus on how to improve the quality of generated images.
-
-### Better checkpoints
-
-The most obvious step is to use better checkpoints. The Stable Diffusion model is a good starting point, and since its official launch, several improved versions have also been released. However, using a newer version doesn't automatically mean you'll get better results. You'll still have to experiment with different checkpoints yourself, and do a little research (such as using [negative prompts](https://minimaxir.com/2022/11/stable-diffusion-negative-prompt/)) to get the best results.
-
-As the field grows, there are more and more high-quality checkpoints finetuned to produce certain styles. Try exploring the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) and [Diffusers Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery) to find one you're interested in!
-
-### Better pipeline components
-
-You can also try replacing the current pipeline components with a newer version. Let's try loading the latest [autoencoder](https://huggingface.co/stabilityai/stable-diffusion-2-1/tree/main/vae) from Stability AI into the pipeline, and generate some images:
-
-```python
-from diffusers import AutoencoderKL
-
-vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to("cuda")
-pipeline.vae = vae
-images = pipeline(**get_inputs(batch_size=8)).images
-make_image_grid(images, rows=2, cols=4)
-```
-
-
-
-
-
-### Better prompt engineering
-
-The text prompt you use to generate an image is super important, so much so that it is called *prompt engineering*. Some considerations to keep during prompt engineering are:
-
-- How is the image or similar images of the one I want to generate stored on the internet?
-- What additional detail can I give that steers the model towards the style I want?
-
-With this in mind, let's improve the prompt to include color and higher quality details:
-
-```python
-prompt += ", tribal panther make up, blue on red, side profile, looking away, serious eyes"
-prompt += " 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta"
-```
-
-Generate a batch of images with the new prompt:
-
-```python
-images = pipeline(**get_inputs(batch_size=8)).images
-make_image_grid(images, rows=2, cols=4)
-```
-
-
-
-
-
-Pretty impressive! Let's tweak the second image - corresponding to the `Generator` with a seed of `1` - a bit more by adding some text about the age of the subject:
-
-```python
-prompts = [
- "portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
- "portrait photo of an old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
- "portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
- "portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
-]
-
-generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))]
-images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images
-make_image_grid(images, 2, 2)
-```
-
-
-
-
+ prompt = """
+ cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+ highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+ """
+ negative_prompt = "low quality, blurry, ugly, poor details"
+ pipeline(prompt, negative_prompt=negative_prompt).images[0]
+ ```
## Next steps
-In this tutorial, you learned how to optimize a [`DiffusionPipeline`] for computational and memory efficiency as well as improving the quality of generated outputs. If you're interested in making your pipeline even faster, take a look at the following resources:
-
-- Learn how [PyTorch 2.0](./optimization/fp16) and [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) can yield 5 - 300% faster inference speed. On an A100 GPU, inference can be up to 50% faster!
-- If you can't use PyTorch 2, we recommend you install [xFormers](./optimization/xformers). Its memory-efficient attention mechanism works great with PyTorch 1.13.1 for faster speed and reduced memory consumption.
-- Other optimization techniques, such as model offloading, are covered in [this guide](./optimization/fp16).
+Diffusers offers more advanced and powerful optimizations such as [group-offloading](./optimization/memory#group-offloading) and [regional compilation](./optimization/fp16#regional-compilation). To learn more about how to maximize performance, take a look at the Inference Optimization section.
\ No newline at end of file
diff --git a/docs/source/en/training/controlnet.md b/docs/source/en/training/controlnet.md
index 0170ff3da9..840130d2b4 100644
--- a/docs/source/en/training/controlnet.md
+++ b/docs/source/en/training/controlnet.md
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
[ControlNet](https://hf.co/papers/2302.05543) models are adapters trained on top of another pretrained model. It allows for a greater degree of control over image generation by conditioning the model with an additional input image. The input image can be a canny edge, depth map, human pose, and many more.
-If you're training on a GPU with limited vRAM, you should try enabling the `gradient_checkpointing`, `gradient_accumulation_steps`, and `mixed_precision` parameters in the training command. You can also reduce your memory footprint by using memory-efficient attention with [xFormers](../optimization/xformers). JAX/Flax training is also supported for efficient training on TPUs and GPUs, but it doesn't support gradient checkpointing or xFormers. You should have a GPU with >30GB of memory if you want to train faster with Flax.
+If you're training on a GPU with limited vRAM, you should try enabling the `gradient_checkpointing`, `gradient_accumulation_steps`, and `mixed_precision` parameters in the training command. You can also reduce your memory footprint by using memory-efficient attention with [xFormers](../optimization/xformers).
This guide will explore the [train_controlnet.py](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py) training script to help you become familiar with it, and how you can adapt it for your own use-case.
@@ -28,51 +28,13 @@ pip install .
Then navigate to the example folder containing the training script and install the required dependencies for the script you're using:
-
-
```bash
cd examples/controlnet
pip install -r requirements.txt
```
-
-
-If you have access to a TPU, the Flax training script runs even faster! Let's run the training script on the [Google Cloud TPU VM](https://cloud.google.com/tpu/docs/run-calculation-jax). Create a single TPU v4-8 VM and connect to it:
-
-```bash
-ZONE=us-central2-b
-TPU_TYPE=v4-8
-VM_NAME=hg_flax
-
-gcloud alpha compute tpus tpu-vm create $VM_NAME \
- --zone $ZONE \
- --accelerator-type $TPU_TYPE \
- --version tpu-vm-v4-base
-
-gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \
-```
-
-Install JAX 0.4.5:
-
-```bash
-pip install "jax[tpu]==0.4.5" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
-```
-
-Then install the required dependencies for the Flax script:
-
-```bash
-cd examples/controlnet
-pip install -r requirements_flax.txt
-```
-
-
-
-
-
-
-🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
-
-
+> [!TIP]
+> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
Initialize an 🤗 Accelerate environment:
@@ -96,11 +58,8 @@ write_basic_config()
Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
-
-
-The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py) and let us know if you have any questions or concerns.
-
-
+> [!TIP]
+> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py) and let us know if you have any questions or concerns.
## Script parameters
@@ -120,7 +79,7 @@ Many of the basic and important parameters are described in the [Text-to-image](
### Min-SNR weighting
-The [Min-SNR](https://huggingface.co/papers/2303.09556) weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch and is unavailable in the Flax training script.
+The [Min-SNR](https://huggingface.co/papers/2303.09556) weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch.
Add the `--snr_gamma` parameter and set it to the recommended value of 5.0:
@@ -135,11 +94,8 @@ As with the script parameters, a general walkthrough of the training script is p
The training script has a [`make_train_dataset`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L582) function for preprocessing the dataset with image transforms and caption tokenization. You'll see that in addition to the usual caption tokenization and image transforms, the script also includes transforms for the conditioning image.
-
-
-If you're streaming a dataset on a TPU, performance may be bottlenecked by the 🤗 Datasets library which is not optimized for images. To ensure maximum throughput, you're encouraged to explore other dataset formats like [WebDataset](https://webdataset.github.io/webdataset/), [TorchData](https://github.com/pytorch/data), and [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds).
-
-
+> [!TIP]
+> If you're streaming a dataset on a TPU, performance may be bottlenecked by the 🤗 Datasets library which is not optimized for images. To ensure maximum throughput, you're encouraged to explore other dataset formats like [WebDataset](https://webdataset.github.io/webdataset/), [TorchData](https://github.com/pytorch/data), and [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds).
```py
conditioning_image_transforms = transforms.Compose(
@@ -272,9 +228,6 @@ That's it! You don't need to add any additional parameters to your training comm
-
-
-
```bash
export MODEL_DIR="stable-diffusion-v1-5/stable-diffusion-v1-5"
export OUTPUT_DIR="path/to/save/model"
@@ -292,47 +245,6 @@ accelerate launch train_controlnet.py \
--push_to_hub
```
-
-
-
-With Flax, you can [profile your code](https://jax.readthedocs.io/en/latest/profiling.html) by adding the `--profile_steps==5` parameter to your training command. Install the Tensorboard profile plugin:
-
-```bash
-pip install tensorflow tensorboard-plugin-profile
-tensorboard --logdir runs/fill-circle-100steps-20230411_165612/
-```
-
-Then you can inspect the profile at [http://localhost:6006/#profile](http://localhost:6006/#profile).
-
-
-
-If you run into version conflicts with the plugin, try uninstalling and reinstalling all versions of TensorFlow and Tensorboard. The debugging functionality of the profile plugin is still experimental, and not all views are fully functional. The `trace_viewer` cuts off events after 1M, which can result in all your device traces getting lost if for example, you profile the compilation step by accident.
-
-
-
-```bash
-python3 train_controlnet_flax.py \
- --pretrained_model_name_or_path=$MODEL_DIR \
- --output_dir=$OUTPUT_DIR \
- --dataset_name=fusing/fill50k \
- --resolution=512 \
- --learning_rate=1e-5 \
- --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
- --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
- --validation_steps=1000 \
- --train_batch_size=2 \
- --revision="non-ema" \
- --from_pt \
- --report_to="wandb" \
- --tracker_project_name=$HUB_MODEL_ID \
- --num_train_epochs=11 \
- --push_to_hub \
- --hub_model_id=$HUB_MODEL_ID
-```
-
-
-
-
Once training is complete, you can use your newly trained model for inference!
```py
diff --git a/docs/source/en/training/create_dataset.md b/docs/source/en/training/create_dataset.md
index 8e0d6f9200..725f143bba 100644
--- a/docs/source/en/training/create_dataset.md
+++ b/docs/source/en/training/create_dataset.md
@@ -7,11 +7,8 @@ This guide will show you two ways to create a dataset to finetune on:
- provide a folder of images to the `--train_data_dir` argument
- upload a dataset to the Hub and pass the dataset repository id to the `--dataset_name` argument
-
-
-💡 Learn more about how to create an image dataset for training in the [Create an image dataset](https://huggingface.co/docs/datasets/image_dataset) guide.
-
-
+> [!TIP]
+> 💡 Learn more about how to create an image dataset for training in the [Create an image dataset](https://huggingface.co/docs/datasets/image_dataset) guide.
## Provide a dataset as a folder
@@ -33,11 +30,8 @@ accelerate launch train_unconditional.py \
## Upload your data to the Hub
-
-
-💡 For more details and context about creating and uploading a dataset to the Hub, take a look at the [Image search with 🤗 Datasets](https://huggingface.co/blog/image-search-datasets) post.
-
-
+> [!TIP]
+> 💡 For more details and context about creating and uploading a dataset to the Hub, take a look at the [Image search with 🤗 Datasets](https://huggingface.co/blog/image-search-datasets) post.
Start by creating a dataset with the [`ImageFolder`](https://huggingface.co/docs/datasets/image_load#imagefolder) feature, which creates an `image` column containing the PIL-encoded images.
diff --git a/docs/source/en/training/custom_diffusion.md b/docs/source/en/training/custom_diffusion.md
index e803448b5f..bfa4fe6f9e 100644
--- a/docs/source/en/training/custom_diffusion.md
+++ b/docs/source/en/training/custom_diffusion.md
@@ -34,11 +34,8 @@ pip install -r requirements.txt
pip install clip-retrieval
```
-
-
-🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
-
-
+> [!TIP]
+> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
Initialize an 🤗 Accelerate environment:
@@ -62,11 +59,8 @@ write_basic_config()
Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
-
-
-The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/custom_diffusion/train_custom_diffusion.py) and let us know if you have any questions or concerns.
-
-
+> [!TIP]
+> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/custom_diffusion/train_custom_diffusion.py) and let us know if you have any questions or concerns.
## Script parameters
@@ -117,11 +111,8 @@ accelerate launch train_custom_diffusion.py \
## Training script
-
-
-A lot of the code in the Custom Diffusion training script is similar to the [DreamBooth](dreambooth#training-script) script. This guide instead focuses on the code that is relevant to Custom Diffusion.
-
-
+> [!TIP]
+> A lot of the code in the Custom Diffusion training script is similar to the [DreamBooth](dreambooth#training-script) script. This guide instead focuses on the code that is relevant to Custom Diffusion.
The Custom Diffusion training script has two dataset classes:
@@ -224,16 +215,13 @@ Set the environment variable `MODEL_NAME` to a model id on the Hub or a path to
To monitor training progress with Weights and Biases, add the `--report_to=wandb` parameter to the training command and specify a validation prompt with `--validation_prompt`. This is useful for debugging and saving intermediate results.
-
-
-If you're training on human faces, the Custom Diffusion team has found the following parameters to work well:
-
-- `--learning_rate=5e-6`
-- `--max_train_steps` can be anywhere between 1000 and 2000
-- `--freeze_model=crossattn`
-- use at least 15-20 images to train with
-
-
+> [!TIP]
+> If you're training on human faces, the Custom Diffusion team has found the following parameters to work well:
+>
+> - `--learning_rate=5e-6`
+> - `--max_train_steps` can be anywhere between 1000 and 2000
+> - `--freeze_model=crossattn`
+> - use at least 15-20 images to train with
diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md
index 64b1ea9f04..f9756e1a67 100644
--- a/docs/source/en/training/distributed_inference.md
+++ b/docs/source/en/training/distributed_inference.md
@@ -12,17 +12,23 @@ specific language governing permissions and limitations under the License.
# Distributed inference
-On distributed setups, you can run inference across multiple GPUs with 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) or [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html), which is useful for generating with multiple prompts in parallel.
+Distributed inference splits the workload across multiple GPUs. It a useful technique for fitting larger models in memory and can process multiple prompts for higher throughput.
-This guide will show you how to use 🤗 Accelerate and PyTorch Distributed for distributed inference.
+This guide will show you how to use [Accelerate](https://huggingface.co/docs/accelerate/index) and [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html) for distributed inference.
-## 🤗 Accelerate
+## Accelerate
-🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) is a library designed to make it easy to train or run inference across distributed setups. It simplifies the process of setting up the distributed environment, allowing you to focus on your PyTorch code.
+Accelerate is a library designed to simplify inference and training on multiple accelerators by handling the setup, allowing users to focus on their PyTorch code.
-To begin, create a Python file and initialize an [`accelerate.PartialState`] to create a distributed environment; your setup is automatically detected so you don't need to explicitly define the `rank` or `world_size`. Move the [`DiffusionPipeline`] to `distributed_state.device` to assign a GPU to each process.
+Install Accelerate with the following command.
-Now use the [`~accelerate.PartialState.split_between_processes`] utility as a context manager to automatically distribute the prompts between the number of processes.
+```bash
+uv pip install accelerate
+```
+
+Initialize a [`accelerate.PartialState`] class in a Python file to create a distributed environment. The [`accelerate.PartialState`] class manages process management, device control and distribution, and process coordination.
+
+Move the [`DiffusionPipeline`] to [`accelerate.PartialState.device`] to assign a GPU to each process.
```py
import torch
@@ -30,33 +36,34 @@ from accelerate import PartialState
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
+ "Qwen/Qwen-Image", torch_dtype=torch.float16
)
distributed_state = PartialState()
pipeline.to(distributed_state.device)
+```
+Use the [`~accelerate.PartialState.split_between_processes`] utility as a context manager to automatically distribute the prompts between the number of processes.
+
+```py
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
result = pipeline(prompt).images[0]
result.save(f"result_{distributed_state.process_index}.png")
```
-Use the `--num_processes` argument to specify the number of GPUs to use, and call `accelerate launch` to run the script:
+Call `accelerate launch` to run the script and use the `--num_processes` argument to set the number of GPUs to use.
```bash
accelerate launch run_distributed.py --num_processes=2
```
-
-
-Refer to this minimal example [script](https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba) for running inference across multiple GPUs. To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide.
-
-
+> [!TIP]
+> Refer to this minimal example [script](https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba) for running inference across multiple GPUs. To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide.
## PyTorch Distributed
-PyTorch supports [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) which enables data parallelism.
+PyTorch [DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) enables [data parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=data_parallelism), which replicates the same model on each device, to process different batches of data in parallel.
-To start, create a Python file and import `torch.distributed` and `torch.multiprocessing` to set up the distributed process group and to spawn the processes for inference on each GPU. You should also initialize a [`DiffusionPipeline`]:
+Import `torch.distributed` and `torch.multiprocessing` into a Python file to set up the distributed process group and to spawn the processes for inference on each GPU.
```py
import torch
@@ -65,20 +72,20 @@ import torch.multiprocessing as mp
from diffusers import DiffusionPipeline
-sd = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
+pipeline = DiffusionPipeline.from_pretrained(
+ "Qwen/Qwen-Image", torch_dtype=torch.float16,
)
```
-You'll want to create a function to run inference; [`init_process_group`](https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group) handles creating a distributed environment with the type of backend to use, the `rank` of the current process, and the `world_size` or the number of processes participating. If you're running inference in parallel over 2 GPUs, then the `world_size` is 2.
+Create a function for inference with [init_process_group](https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group). This method creates a distributed environment with the backend type, the `rank` of the current process, and the `world_size` or number of processes participating (for example, 2 GPUs would be `world_size=2`).
-Move the [`DiffusionPipeline`] to `rank` and use `get_rank` to assign a GPU to each process, where each process handles a different prompt:
+Move the pipeline to `rank` and use `get_rank` to assign a GPU to each process. Each process handles a different prompt.
```py
def run_inference(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
- sd.to(rank)
+ pipeline.to(rank)
if torch.distributed.get_rank() == 0:
prompt = "a dog"
@@ -89,7 +96,7 @@ def run_inference(rank, world_size):
image.save(f"./{'_'.join(prompt)}.png")
```
-To run the distributed inference, call [`mp.spawn`](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn) to run the `run_inference` function on the number of GPUs defined in `world_size`:
+Use [mp.spawn](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn) to create the number of processes defined in `world_size`.
```py
def main():
@@ -101,31 +108,26 @@ if __name__ == "__main__":
main()
```
-Once you've completed the inference script, use the `--nproc_per_node` argument to specify the number of GPUs to use and call `torchrun` to run the script:
+Call `torchrun` to run the inference script and use the `--nproc_per_node` argument to set the number of GPUs to use.
```bash
torchrun run_distributed.py --nproc_per_node=2
```
-> [!TIP]
-> You can use `device_map` within a [`DiffusionPipeline`] to distribute its model-level components on multiple devices. Refer to the [Device placement](../tutorials/inference_with_big_models#device-placement) guide to learn more.
+## device_map
-## Model sharding
+The `device_map` argument enables distributed inference by automatically placing model components on separate GPUs. This is especially useful when a model doesn't fit on a single GPU. You can use `device_map` to selectively load and unload the required model components at a given stage as shown in the example below (assumes two GPUs are available).
-Modern diffusion systems such as [Flux](../api/pipelines/flux) are very large and have multiple models. For example, [Flux.1-Dev](https://hf.co/black-forest-labs/FLUX.1-dev) is made up of two text encoders - [T5-XXL](https://hf.co/google/t5-v1_1-xxl) and [CLIP-L](https://hf.co/openai/clip-vit-large-patch14) - a [diffusion transformer](../api/models/flux_transformer), and a [VAE](../api/models/autoencoderkl). With a model this size, it can be challenging to run inference on consumer GPUs.
-
-Model sharding is a technique that distributes models across GPUs when the models don't fit on a single GPU. The example below assumes two 16GB GPUs are available for inference.
-
-Start by computing the text embeddings with the text encoders. Keep the text encoders on two GPUs by setting `device_map="balanced"`. The `balanced` strategy evenly distributes the model on all available GPUs. Use the `max_memory` parameter to allocate the maximum amount of memory for each text encoder on each GPU.
-
-> [!TIP]
-> **Only** load the text encoders for this step! The diffusion transformer and VAE are loaded in a later step to preserve memory.
+Set `device_map="balanced"` to evenly distributes the text encoders on all available GPUs. You can use the `max_memory` argument to allocate a maximum amount of memory for each text encoder. Don't load any other pipeline components to avoid memory usage.
```py
from diffusers import FluxPipeline
import torch
-prompt = "a photo of a dog with cat-like look"
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
@@ -142,7 +144,7 @@ with torch.no_grad():
)
```
-Once the text embeddings are computed, remove them from the GPU to make space for the diffusion transformer.
+After the text embeddings are computed, remove them from the GPU to make space for the diffusion transformer.
```py
import gc
@@ -162,7 +164,7 @@ del pipeline
flush()
```
-Load the diffusion transformer next which has 12.5B parameters. This time, set `device_map="auto"` to automatically distribute the model across two 16GB GPUs. The `auto` strategy is backed by [Accelerate](https://hf.co/docs/accelerate/index) and available as a part of the [Big Model Inference](https://hf.co/docs/accelerate/concept_guides/big_model_inference) feature. It starts by distributing a model across the fastest device first (GPU) before moving to slower devices like the CPU and hard drive if needed. The trade-off of storing model parameters on slower devices is slower inference latency.
+Set `device_map="auto"` to automatically distribute the model on the two GPUs. This strategy places a model on the fastest device first before placing a model on a slower device like a CPU or hard drive if needed. The trade-off of storing model parameters on slower devices is slower inference latency.
```py
from diffusers import AutoModel
@@ -177,9 +179,9 @@ transformer = AutoModel.from_pretrained(
```
> [!TIP]
-> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models. You can also try `print(transformer.hf_device_map)` to see how the transformer model is sharded across devices.
+> Run `pipeline.hf_device_map` to see how the various models are distributed across devices. This is useful for tracking model device placement. You can also call `hf_device_map` on the transformer model to see how it is distributed.
-Add the transformer model to the pipeline for denoising, but set the other model-level components like the text encoders and VAE to `None` because you don't need them yet.
+Add the transformer model to the pipeline and set the `output_type="latent"` to generate the latents.
```py
pipeline = FluxPipeline.from_pretrained(
@@ -206,24 +208,15 @@ latents = pipeline(
).images
```
-Remove the pipeline and transformer from memory as they're no longer needed.
-
-```py
-del pipeline.transformer
-del pipeline
-
-flush()
-```
-
-Finally, decode the latents with the VAE into an image. The VAE is typically small enough to be loaded on a single GPU.
+Remove the pipeline and transformer from memory and load a VAE to decode the latents. The VAE is typically small enough to be loaded on a single device.
```py
+import torch
from diffusers import AutoencoderKL
from diffusers.image_processor import VaeImageProcessor
-import torch
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16).to("cuda")
-vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
+vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
with torch.no_grad():
@@ -237,3 +230,63 @@ with torch.no_grad():
```
By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs.
+
+## Context parallelism
+
+[Context parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=context_parallelism) splits input sequences across multiple GPUs to reduce memory usage. Each GPU processes its own slice of the sequence.
+
+Use [`~ModelMixin.set_attention_backend`] to switch to a more optimized attention backend. Refer to this [table](../optimization/attention_backends#available-backends) for a complete list of available backends.
+
+### Ring Attention
+
+Key (K) and value (V) representations communicate between devices using [Ring Attention](https://huggingface.co/papers/2310.01889). This ensures each split sees every other token's K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency.
+
+Pass a [`ContextParallelConfig`] to the `parallel_config` argument of the transformer model. The config supports the `ring_degree` argument that determines how many devices to use for Ring Attention.
+
+```py
+import torch
+from diffusers import AutoModel, QwenImagePipeline, ContextParallelConfig
+
+try:
+ torch.distributed.init_process_group("nccl")
+ rank = torch.distributed.get_rank()
+ device = torch.device("cuda", rank % torch.cuda.device_count())
+ torch.cuda.set_device(device)
+
+ transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2))
+ pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16, device_map="cuda")
+ pipeline.transformer.set_attention_backend("flash")
+
+ prompt = """
+ cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+ highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+ """
+
+ # Must specify generator so all ranks start with same latents (or pass your own)
+ generator = torch.Generator().manual_seed(42)
+ image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0]
+
+ if rank == 0:
+ image.save("output.png")
+
+except Exception as e:
+ print(f"An error occurred: {e}")
+ torch.distributed.breakpoint()
+ raise
+
+finally:
+ if torch.distributed.is_initialized():
+ torch.distributed.destroy_process_group()
+```
+
+### Ulysses Attention
+
+[Ulysses Attention](https://huggingface.co/papers/2309.14509) splits a sequence across GPUs and performs an *all-to-all* communication (every device sends/receives data to every other device). Each GPU ends up with all tokens for only a subset of attention heads. Each GPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer.
+
+[`ContextParallelConfig`] supports Ulysses Attention through the `ulysses_degree` argument. This determines how many devices to use for Ulysses Attention.
+
+Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
+
+```py
+pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
+```
\ No newline at end of file
diff --git a/docs/source/en/training/dreambooth.md b/docs/source/en/training/dreambooth.md
index cff2bb500d..2302739a0e 100644
--- a/docs/source/en/training/dreambooth.md
+++ b/docs/source/en/training/dreambooth.md
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
[DreamBooth](https://huggingface.co/papers/2208.12242) is a training technique that updates the entire diffusion model by training on just a few images of a subject or style. It works by associating a special word in the prompt with the example images.
-If you're training on a GPU with limited vRAM, you should try enabling the `gradient_checkpointing` and `mixed_precision` parameters in the training command. You can also reduce your memory footprint by using memory-efficient attention with [xFormers](../optimization/xformers). JAX/Flax training is also supported for efficient training on TPUs and GPUs, but it doesn't support gradient checkpointing or xFormers. You should have a GPU with >30GB of memory if you want to train faster with Flax.
+If you're training on a GPU with limited vRAM, you should try enabling the `gradient_checkpointing` and `mixed_precision` parameters in the training command. You can also reduce your memory footprint by using memory-efficient attention with [xFormers](../optimization/xformers).
This guide will explore the [train_dreambooth.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) script to help you become more familiar with it, and how you can adapt it for your own use-case.
@@ -28,30 +28,13 @@ pip install .
Navigate to the example folder with the training script and install the required dependencies for the script you're using:
-
-
-
```bash
cd examples/dreambooth
pip install -r requirements.txt
```
-
-
-
-```bash
-cd examples/dreambooth
-pip install -r requirements_flax.txt
-```
-
-
-
-
-
-
-🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
-
-
+> [!TIP]
+> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
Initialize an 🤗 Accelerate environment:
@@ -75,19 +58,13 @@ write_basic_config()
Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
-
-
-The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) and let us know if you have any questions or concerns.
-
-
+> [!TIP]
+> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) and let us know if you have any questions or concerns.
## Script parameters
-
-
-DreamBooth is very sensitive to training hyperparameters, and it is easy to overfit. Read the [Training Stable Diffusion with Dreambooth using 🧨 Diffusers](https://huggingface.co/blog/dreambooth) blog post for recommended settings for different subjects to help you choose the appropriate hyperparameters.
-
-
+> [!WARNING]
+> DreamBooth is very sensitive to training hyperparameters, and it is easy to overfit. Read the [Training Stable Diffusion with Dreambooth using 🧨 Diffusers](https://huggingface.co/blog/dreambooth) blog post for recommended settings for different subjects to help you choose the appropriate hyperparameters.
The training script offers many parameters for customizing your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L228) function. The parameters are set with default values that should work pretty well out-of-the-box, but you can also set your own values in the training command if you'd like.
@@ -110,7 +87,7 @@ Some basic and important parameters to know and specify are:
### Min-SNR weighting
-The [Min-SNR](https://huggingface.co/papers/2303.09556) weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch and is unavailable in the Flax training script.
+The [Min-SNR](https://huggingface.co/papers/2303.09556) weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch.
Add the `--snr_gamma` parameter and set it to the recommended value of 5.0:
@@ -311,9 +288,6 @@ That's it! You don't need to add any additional parameters to your training comm
-
-
-
```bash
export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
export INSTANCE_DIR="./dog"
@@ -334,57 +308,28 @@ accelerate launch train_dreambooth.py \
--push_to_hub
```
-
-
-
-```bash
-export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
-export INSTANCE_DIR="./dog"
-export OUTPUT_DIR="path-to-save-model"
-
-python train_dreambooth_flax.py \
- --pretrained_model_name_or_path=$MODEL_NAME \
- --instance_data_dir=$INSTANCE_DIR \
- --output_dir=$OUTPUT_DIR \
- --instance_prompt="a photo of sks dog" \
- --resolution=512 \
- --train_batch_size=1 \
- --learning_rate=5e-6 \
- --max_train_steps=400 \
- --push_to_hub
-```
-
-
-
-
Once training is complete, you can use your newly trained model for inference!
-
-
-Can't wait to try your model for inference before training is complete? 🤭 Make sure you have the latest version of 🤗 Accelerate installed.
-
-```py
-from diffusers import DiffusionPipeline, UNet2DConditionModel
-from transformers import CLIPTextModel
-import torch
-
-unet = UNet2DConditionModel.from_pretrained("path/to/model/checkpoint-100/unet")
-
-# if you have trained with `--args.train_text_encoder` make sure to also load the text encoder
-text_encoder = CLIPTextModel.from_pretrained("path/to/model/checkpoint-100/checkpoint-100/text_encoder")
-
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", unet=unet, text_encoder=text_encoder, dtype=torch.float16,
-).to("cuda")
-
-image = pipeline("A photo of sks dog in a bucket", num_inference_steps=50, guidance_scale=7.5).images[0]
-image.save("dog-bucket.png")
-```
-
-
-
-
-
+> [!TIP]
+> Can't wait to try your model for inference before training is complete? 🤭 Make sure you have the latest version of 🤗 Accelerate installed.
+>
+> ```py
+> from diffusers import DiffusionPipeline, UNet2DConditionModel
+> from transformers import CLIPTextModel
+> import torch
+>
+> unet = UNet2DConditionModel.from_pretrained("path/to/model/checkpoint-100/unet")
+>
+> # if you have trained with `--args.train_text_encoder` make sure to also load the text encoder
+> text_encoder = CLIPTextModel.from_pretrained("path/to/model/checkpoint-100/checkpoint-100/text_encoder")
+>
+> pipeline = DiffusionPipeline.from_pretrained(
+> "stable-diffusion-v1-5/stable-diffusion-v1-5", unet=unet, text_encoder=text_encoder, dtype=torch.float16,
+> ).to("cuda")
+>
+> image = pipeline("A photo of sks dog in a bucket", num_inference_steps=50, guidance_scale=7.5).images[0]
+> image.save("dog-bucket.png")
+> ```
```py
from diffusers import DiffusionPipeline
@@ -395,39 +340,6 @@ image = pipeline("A photo of sks dog in a bucket", num_inference_steps=50, guida
image.save("dog-bucket.png")
```
-
-
-
-```py
-import jax
-import numpy as np
-from flax.jax_utils import replicate
-from flax.training.common_utils import shard
-from diffusers import FlaxStableDiffusionPipeline
-
-pipeline, params = FlaxStableDiffusionPipeline.from_pretrained("path-to-your-trained-model", dtype=jax.numpy.bfloat16)
-
-prompt = "A photo of sks dog in a bucket"
-prng_seed = jax.random.PRNGKey(0)
-num_inference_steps = 50
-
-num_samples = jax.device_count()
-prompt = num_samples * [prompt]
-prompt_ids = pipeline.prepare_inputs(prompt)
-
-# shard inputs and rng
-params = replicate(params)
-prng_seed = jax.random.split(prng_seed, jax.device_count())
-prompt_ids = shard(prompt_ids)
-
-images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
-images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
-image.save("dog-bucket.png")
-```
-
-
-
-
## LoRA
LoRA is a training technique for significantly reducing the number of trainable parameters. As a result, training is faster and it is easier to store the resulting weights because they are a lot smaller (~100MBs). Use the [train_dreambooth_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py) script to train with LoRA.
@@ -636,4 +548,4 @@ Training the DeepFloyd IF model can be challenging, but here are some tips that
Congratulations on training your DreamBooth model! To learn more about how to use your new model, the following guide may be helpful:
-- Learn how to [load a DreamBooth](../using-diffusers/loading_adapters) model for inference if you trained your model with LoRA.
\ No newline at end of file
+- Learn how to [load a DreamBooth](../using-diffusers/dreambooth) model for inference if you trained your model with LoRA.
\ No newline at end of file
diff --git a/docs/source/en/training/instructpix2pix.md b/docs/source/en/training/instructpix2pix.md
index c1ba5d870a..a1c94bb33f 100644
--- a/docs/source/en/training/instructpix2pix.md
+++ b/docs/source/en/training/instructpix2pix.md
@@ -31,11 +31,8 @@ cd examples/instruct_pix2pix
pip install -r requirements.txt
```
-
-
-🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
-
-
+> [!TIP]
+> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
Initialize an 🤗 Accelerate environment:
@@ -59,11 +56,8 @@ write_basic_config()
Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
-
-
-The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py) and let us know if you have any questions or concerns.
-
-
+> [!TIP]
+> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py) and let us know if you have any questions or concerns.
## Script parameters
@@ -174,15 +168,12 @@ This guide uses the [fusing/instructpix2pix-1000-samples](https://huggingface.co
Set the `MODEL_NAME` environment variable to the name of the model (can be a model id on the Hub or a path to a local model), and the `DATASET_ID` to the name of the dataset on the Hub. The script creates and saves all the components (feature extractor, scheduler, text encoder, UNet, etc.) to a subfolder in your repository.
-
-
-For better results, try longer training runs with a larger dataset. We've only tested this training script on a smaller-scale dataset.
-
-
-
-To monitor training progress with Weights and Biases, add the `--report_to=wandb` parameter to the training command and specify a validation image with `--val_image_url` and a validation prompt with `--validation_prompt`. This can be really useful for debugging the model.
-
-
+> [!TIP]
+> For better results, try longer training runs with a larger dataset. We've only tested this training script on a smaller-scale dataset.
+>
+>
+>
+> To monitor training progress with Weights and Biases, add the `--report_to=wandb` parameter to the training command and specify a validation image with `--val_image_url` and a validation prompt with `--validation_prompt`. This can be really useful for debugging the model.
If you’re training on more than one GPU, add the `--multi_gpu` parameter to the `accelerate launch` command.
diff --git a/docs/source/en/training/kandinsky.md b/docs/source/en/training/kandinsky.md
index 77f7af03b8..6cfd9f8d60 100644
--- a/docs/source/en/training/kandinsky.md
+++ b/docs/source/en/training/kandinsky.md
@@ -12,11 +12,8 @@ specific language governing permissions and limitations under the License.
# Kandinsky 2.2
-
-
-This script is experimental, and it's easy to overfit and run into issues like catastrophic forgetting. Try exploring different hyperparameters to get the best results on your dataset.
-
-
+> [!WARNING]
+> This script is experimental, and it's easy to overfit and run into issues like catastrophic forgetting. Try exploring different hyperparameters to get the best results on your dataset.
Kandinsky 2.2 is a multilingual text-to-image model capable of producing more photorealistic images. The model includes an image prior model for creating image embeddings from text prompts, and a decoder model that generates images based on the prior model's embeddings. That's why you'll find two separate scripts in Diffusers for Kandinsky 2.2, one for training the prior model and one for training the decoder model. You can train both models separately, but to get the best results, you should train both the prior and decoder models.
@@ -39,11 +36,8 @@ cd examples/kandinsky2_2/text_to_image
pip install -r requirements.txt
```
-
-
-🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
-
-
+> [!TIP]
+> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
Initialize an 🤗 Accelerate environment:
@@ -67,11 +61,8 @@ write_basic_config()
Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
-
-
-The following sections highlight parts of the training scripts that are important for understanding how to modify it, but it doesn't cover every aspect of the scripts in detail. If you're interested in learning more, feel free to read through the scripts and let us know if you have any questions or concerns.
-
-
+> [!TIP]
+> The following sections highlight parts of the training scripts that are important for understanding how to modify it, but it doesn't cover every aspect of the scripts in detail. If you're interested in learning more, feel free to read through the scripts and let us know if you have any questions or concerns.
## Script parameters
@@ -88,7 +79,7 @@ Most of the parameters are identical to the parameters in the [Text-to-image](te
### Min-SNR weighting
-The [Min-SNR](https://huggingface.co/papers/2303.09556) weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch and is unavailable in the Flax training script.
+The [Min-SNR](https://huggingface.co/papers/2303.09556) weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch.
Add the `--snr_gamma` parameter and set it to the recommended value of 5.0:
@@ -209,11 +200,8 @@ You'll train on the [Naruto BLIP captions](https://huggingface.co/datasets/lambd
If you’re training on more than one GPU, add the `--multi_gpu` parameter to the `accelerate launch` command.
-
-
-To monitor training progress with Weights & Biases, add the `--report_to=wandb` parameter to the training command. You’ll also need to add the `--validation_prompt` to the training command to keep track of results. This can be really useful for debugging the model and viewing intermediate results.
-
-
+> [!TIP]
+> To monitor training progress with Weights & Biases, add the `--report_to=wandb` parameter to the training command. You’ll also need to add the `--validation_prompt` to the training command to keep track of results. This can be really useful for debugging the model and viewing intermediate results.
@@ -283,11 +271,8 @@ prompt="A robot naruto, 4k photo"
image = pipeline(prompt=prompt, negative_prompt=negative_prompt).images[0]
```
-
-
-Feel free to replace `kandinsky-community/kandinsky-2-2-decoder` with your own trained decoder checkpoint!
-
-
+> [!TIP]
+> Feel free to replace `kandinsky-community/kandinsky-2-2-decoder` with your own trained decoder checkpoint!
diff --git a/docs/source/en/training/lcm_distill.md b/docs/source/en/training/lcm_distill.md
index 280b6469f6..4750f15036 100644
--- a/docs/source/en/training/lcm_distill.md
+++ b/docs/source/en/training/lcm_distill.md
@@ -33,11 +33,8 @@ cd examples/consistency_distillation
pip install -r requirements.txt
```
-
-
-🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
-
-
+> [!TIP]
+> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
Initialize an 🤗 Accelerate environment (try enabling `torch.compile` to significantly speedup training):
@@ -63,11 +60,8 @@ Lastly, if you want to train a model on your own dataset, take a look at the [Cr
## Script parameters
-
-
-The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_sd_wds.py) and let us know if you have any questions or concerns.
-
-
+> [!TIP]
+> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_sd_wds.py) and let us know if you have any questions or concerns.
The training script provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L419) function. This function provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.
@@ -81,7 +75,7 @@ accelerate launch train_lcm_distill_sd_wds.py \
Most of the parameters are identical to the parameters in the [Text-to-image](text2image#script-parameters) training guide, so you'll focus on the parameters that are relevant to latent consistency distillation in this guide.
- `--pretrained_teacher_model`: the path to a pretrained latent diffusion model to use as the teacher model
-- `--pretrained_vae_model_name_or_path`: path to a pretrained VAE; the SDXL VAE is known to suffer from numerical instability, so this parameter allows you to specify an alternative VAE (like this [VAE]((https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)) by madebyollin which works in fp16)
+- `--pretrained_vae_model_name_or_path`: path to a pretrained VAE; the SDXL VAE is known to suffer from numerical instability, so this parameter allows you to specify an alternative VAE (like this [VAE](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)) by madebyollin which works in fp16)
- `--w_min` and `--w_max`: the minimum and maximum guidance scale values for guidance scale sampling
- `--num_ddim_timesteps`: the number of timesteps for DDIM sampling
- `--loss_type`: the type of loss (L2 or Huber) to calculate for latent consistency distillation; Huber loss is generally preferred because it's more robust to outliers
@@ -251,5 +245,5 @@ The SDXL training script is discussed in more detail in the [SDXL training](sdxl
Congratulations on distilling a LCM model! To learn more about LCM, the following may be helpful:
-- Learn how to use [LCMs for inference](../using-diffusers/lcm) for text-to-image, image-to-image, and with LoRA checkpoints.
+- Learn how to use [LCMs for inference](../using-diffusers/inference_with_lcm) for text-to-image, image-to-image, and with LoRA checkpoints.
- Read the [SDXL in 4 steps with Latent Consistency LoRAs](https://huggingface.co/blog/lcm_lora) blog post to learn more about SDXL LCM-LoRA's for super fast inference, quality comparisons, benchmarks, and more.
diff --git a/docs/source/en/training/lora.md b/docs/source/en/training/lora.md
index 9a3512dd76..efb170e329 100644
--- a/docs/source/en/training/lora.md
+++ b/docs/source/en/training/lora.md
@@ -12,19 +12,13 @@ specific language governing permissions and limitations under the License.
# LoRA
-
-
-This is experimental and the API may change in the future.
-
-
+> [!WARNING]
+> This is experimental and the API may change in the future.
[LoRA (Low-Rank Adaptation of Large Language Models)](https://hf.co/papers/2106.09685) is a popular and lightweight training technique that significantly reduces the number of trainable parameters. It works by inserting a smaller number of new weights into the model and only these are trained. This makes training with LoRA much faster, memory-efficient, and produces smaller model weights (a few hundred MBs), which are easier to store and share. LoRA can also be combined with other training techniques like DreamBooth to speedup training.
-
-
-LoRA is very versatile and supported for [DreamBooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py), [Kandinsky 2.2](https://github.com/huggingface/diffusers/blob/main/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py), [Stable Diffusion XL](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora_sdxl.py), [text-to-image](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py), and [Wuerstchen](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py).
-
-
+> [!TIP]
+> LoRA is very versatile and supported for [DreamBooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py), [Kandinsky 2.2](https://github.com/huggingface/diffusers/blob/main/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py), [Stable Diffusion XL](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora_sdxl.py), [text-to-image](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py), and [Wuerstchen](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py).
This guide will explore the [train_text_to_image_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) script to help you become more familiar with it, and how you can adapt it for your own use-case.
@@ -38,30 +32,13 @@ pip install .
Navigate to the example folder with the training script and install the required dependencies for the script you're using:
-
-
-
```bash
cd examples/text_to_image
pip install -r requirements.txt
```
-
-
-
-```bash
-cd examples/text_to_image
-pip install -r requirements_flax.txt
-```
-
-
-
-
-
-
-🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
-
-
+> [!TIP]
+> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
Initialize an 🤗 Accelerate environment:
@@ -85,11 +62,8 @@ write_basic_config()
Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
-
-
-The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) and let us know if you have any questions or concerns.
-
-
+> [!TIP]
+> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) and let us know if you have any questions or concerns.
## Script parameters
@@ -177,11 +151,8 @@ Let's train on the [Naruto BLIP captions](https://huggingface.co/datasets/lambda
If you're training on more than one GPU, add the `--multi_gpu` parameter to the `accelerate launch` command.
-
-
-A full training run takes ~5 hours on a 2080 Ti GPU with 11GB of VRAM.
-
-
+> [!WARNING]
+> A full training run takes ~5 hours on a 2080 Ti GPU with 11GB of VRAM.
```bash
export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
@@ -227,5 +198,5 @@ image = pipeline("A naruto with blue eyes").images[0]
Congratulations on training a new model with LoRA! To learn more about how to use your new model, the following guides may be helpful:
-- Learn how to [load different LoRA formats](../using-diffusers/loading_adapters#LoRA) trained using community trainers like Kohya and TheLastBen.
+- Learn how to [load different LoRA formats](../tutorials/using_peft_for_inference) trained using community trainers like Kohya and TheLastBen.
- Learn how to use and [combine multiple LoRA's](../tutorials/using_peft_for_inference) with PEFT for inference.
diff --git a/docs/source/en/training/overview.md b/docs/source/en/training/overview.md
index 032900d9ac..55d6b19661 100644
--- a/docs/source/en/training/overview.md
+++ b/docs/source/en/training/overview.md
@@ -23,18 +23,18 @@ Each training script is:
Our current collection of training scripts include:
-| Training | SDXL-support | LoRA-support | Flax-support |
-|---|---|---|---|
-| [unconditional image generation](https://github.com/huggingface/diffusers/tree/main/examples/unconditional_image_generation) [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) | | | |
-| [text-to-image](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) | 👍 | 👍 | 👍 |
-| [textual inversion](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb) | | | 👍 |
-| [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb) | 👍 | 👍 | 👍 |
-| [ControlNet](https://github.com/huggingface/diffusers/tree/main/examples/controlnet) | 👍 | | 👍 |
-| [InstructPix2Pix](https://github.com/huggingface/diffusers/tree/main/examples/instruct_pix2pix) | 👍 | | |
-| [Custom Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/custom_diffusion) | | | |
-| [T2I-Adapters](https://github.com/huggingface/diffusers/tree/main/examples/t2i_adapter) | 👍 | | |
-| [Kandinsky 2.2](https://github.com/huggingface/diffusers/tree/main/examples/kandinsky2_2/text_to_image) | | 👍 | |
-| [Wuerstchen](https://github.com/huggingface/diffusers/tree/main/examples/wuerstchen/text_to_image) | | 👍 | |
+| Training | SDXL-support | LoRA-support |
+|---|---|---|
+| [unconditional image generation](https://github.com/huggingface/diffusers/tree/main/examples/unconditional_image_generation) [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) | | |
+| [text-to-image](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) | 👍 | 👍 |
+| [textual inversion](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb) | | |
+| [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb) | 👍 | 👍 |
+| [ControlNet](https://github.com/huggingface/diffusers/tree/main/examples/controlnet) | 👍 | |
+| [InstructPix2Pix](https://github.com/huggingface/diffusers/tree/main/examples/instruct_pix2pix) | 👍 | |
+| [Custom Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/custom_diffusion) | | |
+| [T2I-Adapters](https://github.com/huggingface/diffusers/tree/main/examples/t2i_adapter) | 👍 | |
+| [Kandinsky 2.2](https://github.com/huggingface/diffusers/tree/main/examples/kandinsky2_2/text_to_image) | | 👍 |
+| [Wuerstchen](https://github.com/huggingface/diffusers/tree/main/examples/wuerstchen/text_to_image) | | 👍 |
These examples are **actively** maintained, so please feel free to open an issue if they aren't working as expected. If you feel like another training example should be included, you're more than welcome to start a [Feature Request](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=) to discuss your feature idea with us and whether it meets our criteria of being self-contained, easy-to-tweak, beginner-friendly, and single-purpose.
@@ -48,7 +48,7 @@ cd diffusers
pip install .
```
-Then navigate to the folder of the training script (for example, [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth)) and install the `requirements.txt` file. Some training scripts have a specific requirement file for SDXL, LoRA or Flax. If you're using one of these scripts, make sure you install its corresponding requirements file.
+Then navigate to the folder of the training script (for example, [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth)) and install the `requirements.txt` file. Some training scripts have a specific requirement file for SDXL or LoRA. If you're using one of these scripts, make sure you install its corresponding requirements file.
```bash
cd examples/dreambooth
diff --git a/docs/source/en/training/sdxl.md b/docs/source/en/training/sdxl.md
index da8b93b6d6..266bbc7d61 100644
--- a/docs/source/en/training/sdxl.md
+++ b/docs/source/en/training/sdxl.md
@@ -12,11 +12,8 @@ specific language governing permissions and limitations under the License.
# Stable Diffusion XL
-
-
-This script is experimental, and it's easy to overfit and run into issues like catastrophic forgetting. Try exploring different hyperparameters to get the best results on your dataset.
-
-
+> [!WARNING]
+> This script is experimental, and it's easy to overfit and run into issues like catastrophic forgetting. Try exploring different hyperparameters to get the best results on your dataset.
[Stable Diffusion XL (SDXL)](https://hf.co/papers/2307.01952) is a larger and more powerful iteration of the Stable Diffusion model, capable of producing higher resolution images.
@@ -39,11 +36,8 @@ cd examples/text_to_image
pip install -r requirements_sdxl.txt
```
-
-
-🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
-
-
+> [!TIP]
+> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
Initialize an 🤗 Accelerate environment:
@@ -69,11 +63,8 @@ Lastly, if you want to train a model on your own dataset, take a look at the [Cr
## Script parameters
-
-
-The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_sdxl.py) and let us know if you have any questions or concerns.
-
-
+> [!TIP]
+> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_sdxl.py) and let us know if you have any questions or concerns.
The training script provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L129) function. This function provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.
@@ -96,7 +87,7 @@ Most of the parameters are identical to the parameters in the [Text-to-image](te
### Min-SNR weighting
-The [Min-SNR](https://huggingface.co/papers/2303.09556) weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting either `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch and is unavailable in the Flax training script.
+The [Min-SNR](https://huggingface.co/papers/2303.09556) weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting either `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch.
Add the `--snr_gamma` parameter and set it to the recommended value of 5.0:
@@ -178,11 +169,8 @@ Once you’ve made all your changes or you’re okay with the default configurat
Let’s train on the [Naruto BLIP captions](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) dataset to generate your own Naruto characters. Set the environment variables `MODEL_NAME` and `DATASET_NAME` to the model and the dataset (either from the Hub or a local path). You should also specify a VAE other than the SDXL VAE (either from the Hub or a local path) with `VAE_NAME` to avoid numerical instabilities.
-
-
-To monitor training progress with Weights & Biases, add the `--report_to=wandb` parameter to the training command. You’ll also need to add the `--validation_prompt` and `--validation_epochs` to the training command to keep track of results. This can be really useful for debugging the model and viewing intermediate results.
-
-
+> [!TIP]
+> To monitor training progress with Weights & Biases, add the `--report_to=wandb` parameter to the training command. You’ll also need to add the `--validation_prompt` and `--validation_epochs` to the training command to keep track of results. This can be really useful for debugging the model and viewing intermediate results.
```bash
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
diff --git a/docs/source/en/training/t2i_adapters.md b/docs/source/en/training/t2i_adapters.md
index 243c591bea..6d76004073 100644
--- a/docs/source/en/training/t2i_adapters.md
+++ b/docs/source/en/training/t2i_adapters.md
@@ -33,11 +33,8 @@ cd examples/t2i_adapter
pip install -r requirements.txt
```
-
-
-🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
-
-
+> [!TIP]
+> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
Initialize an 🤗 Accelerate environment:
@@ -61,11 +58,8 @@ write_basic_config()
Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
-
-
-The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/t2i_adapter/train_t2i_adapter_sdxl.py) and let us know if you have any questions or concerns.
-
-
+> [!TIP]
+> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/t2i_adapter/train_t2i_adapter_sdxl.py) and let us know if you have any questions or concerns.
## Script parameters
@@ -166,11 +160,8 @@ wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/ma
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
```
-
-
-To monitor training progress with Weights & Biases, add the `--report_to=wandb` parameter to the training command. You'll also need to add the `--validation_image`, `--validation_prompt`, and `--validation_steps` to the training command to keep track of results. This can be really useful for debugging the model and viewing intermediate results.
-
-
+> [!TIP]
+> To monitor training progress with Weights & Biases, add the `--report_to=wandb` parameter to the training command. You'll also need to add the `--validation_image`, `--validation_prompt`, and `--validation_steps` to the training command to keep track of results. This can be really useful for debugging the model and viewing intermediate results.
```bash
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
diff --git a/docs/source/en/training/text2image.md b/docs/source/en/training/text2image.md
index 182621e89b..d11e55e910 100644
--- a/docs/source/en/training/text2image.md
+++ b/docs/source/en/training/text2image.md
@@ -12,15 +12,12 @@ specific language governing permissions and limitations under the License.
# Text-to-image
-
-
-The text-to-image script is experimental, and it's easy to overfit and run into issues like catastrophic forgetting. Try exploring different hyperparameters to get the best results on your dataset.
-
-
+> [!WARNING]
+> The text-to-image script is experimental, and it's easy to overfit and run into issues like catastrophic forgetting. Try exploring different hyperparameters to get the best results on your dataset.
Text-to-image models like Stable Diffusion are conditioned to generate images given a text prompt.
-Training a model can be taxing on your hardware, but if you enable `gradient_checkpointing` and `mixed_precision`, it is possible to train a model on a single 24GB GPU. If you're training with larger batch sizes or want to train faster, it's better to use GPUs with more than 30GB of memory. You can reduce your memory footprint by enabling memory-efficient attention with [xFormers](../optimization/xformers). JAX/Flax training is also supported for efficient training on TPUs and GPUs, but it doesn't support gradient checkpointing, gradient accumulation or xFormers. A GPU with at least 30GB of memory or a TPU v3 is recommended for training with Flax.
+Training a model can be taxing on your hardware, but if you enable `gradient_checkpointing` and `mixed_precision`, it is possible to train a model on a single 24GB GPU. If you're training with larger batch sizes or want to train faster, it's better to use GPUs with more than 30GB of memory. You can reduce your memory footprint by enabling memory-efficient attention with [xFormers](../optimization/xformers).
This guide will explore the [train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) training script to help you become familiar with it, and how you can adapt it for your own use-case.
@@ -34,26 +31,13 @@ pip install .
Then navigate to the example folder containing the training script and install the required dependencies for the script you're using:
-
-
```bash
cd examples/text_to_image
pip install -r requirements.txt
```
-
-
-```bash
-cd examples/text_to_image
-pip install -r requirements_flax.txt
-```
-
-
-
-
-🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
-
-
+> [!TIP]
+> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
Initialize an 🤗 Accelerate environment:
@@ -79,11 +63,8 @@ Lastly, if you want to train a model on your own dataset, take a look at the [Cr
## Script parameters
-
-
-The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) and let us know if you have any questions or concerns.
-
-
+> [!TIP]
+> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) and let us know if you have any questions or concerns.
The training script provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L193) function. This function provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.
@@ -106,7 +87,7 @@ Some basic and important parameters include:
### Min-SNR weighting
-The [Min-SNR](https://huggingface.co/papers/2303.09556) weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch and is unavailable in the Flax training script.
+The [Min-SNR](https://huggingface.co/papers/2303.09556) weighting strategy can help with training by rebalancing the loss to achieve faster convergence. The training script supports predicting `epsilon` (noise) or `v_prediction`, but Min-SNR is compatible with both prediction types. This weighting strategy is only supported by PyTorch.
Add the `--snr_gamma` parameter and set it to the recommended value of 5.0:
@@ -155,16 +136,10 @@ Lastly, the [training loop](https://github.com/huggingface/diffusers/blob/8959c5
Once you've made all your changes or you're okay with the default configuration, you're ready to launch the training script! 🚀
-
-
-
Let's train on the [Naruto BLIP captions](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) dataset to generate your own Naruto characters. Set the environment variables `MODEL_NAME` and `dataset_name` to the model and the dataset (either from the Hub or a local path). If you're training on more than one GPU, add the `--multi_gpu` parameter to the `accelerate launch` command.
-
-
-To train on a local dataset, set the `TRAIN_DIR` and `OUTPUT_DIR` environment variables to the path of the dataset and where to save the model to.
-
-
+> [!TIP]
+> To train on a local dataset, set the `TRAIN_DIR` and `OUTPUT_DIR` environment variables to the path of the dataset and where to save the model to.
```bash
export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
@@ -187,43 +162,8 @@ accelerate launch --mixed_precision="fp16" train_text_to_image.py \
--push_to_hub
```
-
-
-
-Training with Flax can be faster on TPUs and GPUs thanks to [@duongna211](https://github.com/duongna21). Flax is more efficient on a TPU, but GPU performance is also great.
-
-Set the environment variables `MODEL_NAME` and `dataset_name` to the model and the dataset (either from the Hub or a local path).
-
-
-
-To train on a local dataset, set the `TRAIN_DIR` and `OUTPUT_DIR` environment variables to the path of the dataset and where to save the model to.
-
-
-
-```bash
-export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
-export dataset_name="lambdalabs/naruto-blip-captions"
-
-python train_text_to_image_flax.py \
- --pretrained_model_name_or_path=$MODEL_NAME \
- --dataset_name=$dataset_name \
- --resolution=512 --center_crop --random_flip \
- --train_batch_size=1 \
- --max_train_steps=15000 \
- --learning_rate=1e-05 \
- --max_grad_norm=1 \
- --output_dir="sd-naruto-model" \
- --push_to_hub
-```
-
-
-
-
Once training is complete, you can use your newly trained model for inference:
-
-
-
```py
from diffusers import StableDiffusionPipeline
import torch
@@ -234,42 +174,9 @@ image = pipeline(prompt="yoda").images[0]
image.save("yoda-naruto.png")
```
-
-
-
-```py
-import jax
-import numpy as np
-from flax.jax_utils import replicate
-from flax.training.common_utils import shard
-from diffusers import FlaxStableDiffusionPipeline
-
-pipeline, params = FlaxStableDiffusionPipeline.from_pretrained("path/to/saved_model", dtype=jax.numpy.bfloat16)
-
-prompt = "yoda naruto"
-prng_seed = jax.random.PRNGKey(0)
-num_inference_steps = 50
-
-num_samples = jax.device_count()
-prompt = num_samples * [prompt]
-prompt_ids = pipeline.prepare_inputs(prompt)
-
-# shard inputs and rng
-params = replicate(params)
-prng_seed = jax.random.split(prng_seed, jax.device_count())
-prompt_ids = shard(prompt_ids)
-
-images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
-images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
-image.save("yoda-naruto.png")
-```
-
-
-
-
## Next steps
Congratulations on training your own text-to-image model! To learn more about how to use your new model, the following guides may be helpful:
-- Learn how to [load LoRA weights](../using-diffusers/loading_adapters#LoRA) for inference if you trained your model with LoRA.
+- Learn how to [load LoRA weights](../tutorials/using_peft_for_inference) for inference if you trained your model with LoRA.
- Learn more about how certain parameters like guidance scale or techniques such as prompt weighting can help you control inference in the [Text-to-image](../using-diffusers/conditional_image_generation) task guide.
diff --git a/docs/source/en/training/text_inversion.md b/docs/source/en/training/text_inversion.md
index b7083ae589..4912b6730a 100644
--- a/docs/source/en/training/text_inversion.md
+++ b/docs/source/en/training/text_inversion.md
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
[Textual Inversion](https://hf.co/papers/2208.01618) is a training technique for personalizing image generation models with just a few example images of what you want it to learn. This technique works by learning and updating the text embeddings (the new embeddings are tied to a special word you must use in the prompt) to match the example images you provide.
-If you're training on a GPU with limited vRAM, you should try enabling the `gradient_checkpointing` and `mixed_precision` parameters in the training command. You can also reduce your memory footprint by using memory-efficient attention with [xFormers](../optimization/xformers). JAX/Flax training is also supported for efficient training on TPUs and GPUs, but it doesn't support gradient checkpointing or xFormers. With the same configuration and setup as PyTorch, the Flax training script should be at least ~70% faster!
+If you're training on a GPU with limited vRAM, you should try enabling the `gradient_checkpointing` and `mixed_precision` parameters in the training command. You can also reduce your memory footprint by using memory-efficient attention with [xFormers](../optimization/xformers).
This guide will explore the [textual_inversion.py](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py) script to help you become more familiar with it, and how you can adapt it for your own use-case.
@@ -28,30 +28,12 @@ pip install .
Navigate to the example folder with the training script and install the required dependencies for the script you're using:
-
-
-
```bash
cd examples/textual_inversion
pip install -r requirements.txt
```
-
-
-
-
-```bash
-cd examples/textual_inversion
-pip install -r requirements_flax.txt
-```
-
-
-
-
-
-
-🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
-
-
+> [!TIP]
+> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
Initialize an 🤗 Accelerate environment:
@@ -75,11 +57,8 @@ write_basic_config()
Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
-
-
-The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py) and let us know if you have any questions or concerns.
-
-
+> [!TIP]
+> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py) and let us know if you have any questions or concerns.
## Script parameters
@@ -175,11 +154,8 @@ Set the environment variable `MODEL_NAME` to a model id on the Hub or a path to
- `token_identifier.txt`: the special placeholder token
- `type_of_concept.txt`: the type of concept you're training on (either "object" or "style")
-
-
-A full training run takes ~1 hour on a single V100 GPU.
-
-
+> [!WARNING]
+> A full training run takes ~1 hour on a single V100 GPU.
One more thing before you launch the script. If you're interested in following along with the training process, you can periodically save generated images as training progresses. Add the following parameters to the training command:
@@ -189,9 +165,6 @@ One more thing before you launch the script. If you're interested in following a
--validation_steps=100
```
-
-
-
```bash
export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
export DATA_DIR="./cat"
@@ -214,36 +187,8 @@ accelerate launch textual_inversion.py \
--push_to_hub
```
-
-
-
-```bash
-export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
-export DATA_DIR="./cat"
-
-python textual_inversion_flax.py \
- --pretrained_model_name_or_path=$MODEL_NAME \
- --train_data_dir=$DATA_DIR \
- --learnable_property="object" \
- --placeholder_token="" \
- --initializer_token="toy" \
- --resolution=512 \
- --train_batch_size=1 \
- --max_train_steps=3000 \
- --learning_rate=5.0e-04 \
- --scale_lr \
- --output_dir="textual_inversion_cat" \
- --push_to_hub
-```
-
-
-
-
After training is complete, you can use your newly trained model for inference like:
-
-
-
```py
from diffusers import StableDiffusionPipeline
import torch
@@ -254,45 +199,8 @@ image = pipeline("A train", num_inference_steps=50).images[0]
image.save("cat-train.png")
```
-
-
-
-Flax doesn't support the [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] method, but the textual_inversion_flax.py script [saves](https://github.com/huggingface/diffusers/blob/c0f058265161178f2a88849e92b37ffdc81f1dcc/examples/textual_inversion/textual_inversion_flax.py#L636C2-L636C2) the learned embeddings as a part of the model after training. This means you can use the model for inference like any other Flax model:
-
-```py
-import jax
-import numpy as np
-from flax.jax_utils import replicate
-from flax.training.common_utils import shard
-from diffusers import FlaxStableDiffusionPipeline
-
-model_path = "path-to-your-trained-model"
-pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)
-
-prompt = "A train"
-prng_seed = jax.random.PRNGKey(0)
-num_inference_steps = 50
-
-num_samples = jax.device_count()
-prompt = num_samples * [prompt]
-prompt_ids = pipeline.prepare_inputs(prompt)
-
-# shard inputs and rng
-params = replicate(params)
-prng_seed = jax.random.split(prng_seed, jax.device_count())
-prompt_ids = shard(prompt_ids)
-
-images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
-images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
-image.save("cat-train.png")
-```
-
-
-
-
## Next steps
Congratulations on training your own Textual Inversion model! 🎉 To learn more about how to use your new model, the following guides may be helpful:
-- Learn how to [load Textual Inversion embeddings](../using-diffusers/loading_adapters) and also use them as negative embeddings.
-- Learn how to use [Textual Inversion](textual_inversion_inference) for inference with Stable Diffusion 1/2 and Stable Diffusion XL.
\ No newline at end of file
+- Learn how to [load Textual Inversion embeddings](../using-diffusers/textual_inversion_inference) and also use them as negative embeddings.
\ No newline at end of file
diff --git a/docs/source/en/training/unconditional_training.md b/docs/source/en/training/unconditional_training.md
index d2facc7852..ab3bdd6416 100644
--- a/docs/source/en/training/unconditional_training.md
+++ b/docs/source/en/training/unconditional_training.md
@@ -31,11 +31,8 @@ cd examples/unconditional_image_generation
pip install -r requirements.txt
```
-
-
-🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
-
-
+> [!TIP]
+> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
Initialize an 🤗 Accelerate environment:
@@ -61,11 +58,8 @@ Lastly, if you want to train a model on your own dataset, take a look at the [Cr
## Script parameters
-
-
-The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py) and let us know if you have any questions or concerns.
-
-
+> [!TIP]
+> The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py) and let us know if you have any questions or concerns.
The training script provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/096f84b05f9514fae9f185cbec0a4d38fbad9919/examples/unconditional_image_generation/train_unconditional.py#L55) function. It provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.
@@ -163,11 +157,8 @@ Finally, the [training loop](https://github.com/huggingface/diffusers/blob/096f8
Once you've made all your changes or you're okay with the default configuration, you're ready to launch the training script! 🚀
-
-
-A full training run takes 2 hours on 4xV100 GPUs.
-
-
+> [!WARNING]
+> A full training run takes 2 hours on 4xV100 GPUs.
diff --git a/docs/source/en/training/wuerstchen.md b/docs/source/en/training/wuerstchen.md
index 38a1387dd3..1c362879a6 100644
--- a/docs/source/en/training/wuerstchen.md
+++ b/docs/source/en/training/wuerstchen.md
@@ -33,11 +33,8 @@ cd examples/wuerstchen/text_to_image
pip install -r requirements.txt
```
-
-
-🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
-
-
+> [!TIP]
+> 🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
Initialize an 🤗 Accelerate environment:
@@ -61,11 +58,8 @@ write_basic_config()
Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
-
-
-The following sections highlight parts of the training scripts that are important for understanding how to modify it, but it doesn't cover every aspect of the [script](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_prior.py) in detail. If you're interested in learning more, feel free to read through the scripts and let us know if you have any questions or concerns.
-
-
+> [!TIP]
+> The following sections highlight parts of the training scripts that are important for understanding how to modify it, but it doesn't cover every aspect of the [script](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_prior.py) in detail. If you're interested in learning more, feel free to read through the scripts and let us know if you have any questions or concerns.
## Script parameters
@@ -133,11 +127,8 @@ Once you’ve made all your changes or you’re okay with the default configurat
Set the `DATASET_NAME` environment variable to the dataset name from the Hub. This guide uses the [Naruto BLIP captions](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) dataset, but you can create and train on your own datasets as well (see the [Create a dataset for training](create_dataset) guide).
-
-
-To monitor training progress with Weights & Biases, add the `--report_to=wandb` parameter to the training command. You’ll also need to add the `--validation_prompt` to the training command to keep track of results. This can be really useful for debugging the model and viewing intermediate results.
-
-
+> [!TIP]
+> To monitor training progress with Weights & Biases, add the `--report_to=wandb` parameter to the training command. You’ll also need to add the `--validation_prompt` to the training command to keep track of results. This can be really useful for debugging the model and viewing intermediate results.
```bash
export DATASET_NAME="lambdalabs/naruto-blip-captions"
diff --git a/docs/source/en/tutorials/autopipeline.md b/docs/source/en/tutorials/autopipeline.md
index 44bf00398f..f0aa298b23 100644
--- a/docs/source/en/tutorials/autopipeline.md
+++ b/docs/source/en/tutorials/autopipeline.md
@@ -12,112 +12,56 @@ specific language governing permissions and limitations under the License.
# AutoPipeline
-Diffusers provides many pipelines for basic tasks like generating images, videos, audio, and inpainting. On top of these, there are specialized pipelines for adapters and features like upscaling, super-resolution, and more. Different pipeline classes can even use the same checkpoint because they share the same pretrained model! With so many different pipelines, it can be overwhelming to know which pipeline class to use.
+[AutoPipeline](../api/models/auto_model) is a *task-and-model* pipeline that automatically selects the correct pipeline subclass based on the task. It handles the complexity of loading different pipeline subclasses without needing to know the specific pipeline subclass name.
-The [AutoPipeline](../api/pipelines/auto_pipeline) class is designed to simplify the variety of pipelines in Diffusers. It is a generic *task-first* pipeline that lets you focus on a task ([`AutoPipelineForText2Image`], [`AutoPipelineForImage2Image`], and [`AutoPipelineForInpainting`]) without needing to know the specific pipeline class. The [AutoPipeline](../api/pipelines/auto_pipeline) automatically detects the correct pipeline class to use.
+This is unlike [`DiffusionPipeline`], a *model-only* pipeline that automatically selects the pipeline subclass based on the model.
-For example, let's use the [dreamlike-art/dreamlike-photoreal-2.0](https://hf.co/dreamlike-art/dreamlike-photoreal-2.0) checkpoint.
-
-Under the hood, [AutoPipeline](../api/pipelines/auto_pipeline):
-
-1. Detects a `"stable-diffusion"` class from the [model_index.json](https://hf.co/dreamlike-art/dreamlike-photoreal-2.0/blob/main/model_index.json) file.
-2. Depending on the task you're interested in, it loads the [`StableDiffusionPipeline`], [`StableDiffusionImg2ImgPipeline`], or [`StableDiffusionInpaintPipeline`]. Any parameter (`strength`, `num_inference_steps`, etc.) you would pass to these specific pipelines can also be passed to the [AutoPipeline](../api/pipelines/auto_pipeline).
-
-
-
+[`AutoPipelineForImage2Image`] returns a specific pipeline subclass, (for example, [`StableDiffusionXLImg2ImgPipeline`]), which can only be used for image-to-image tasks.
```py
-from diffusers import AutoPipelineForText2Image
import torch
-
-pipe_txt2img = AutoPipelineForText2Image.from_pretrained(
- "dreamlike-art/dreamlike-photoreal-2.0", torch_dtype=torch.float16, use_safetensors=True
-).to("cuda")
-
-prompt = "cinematic photo of Godzilla eating sushi with a cat in a izakaya, 35mm photograph, film, professional, 4k, highly detailed"
-generator = torch.Generator(device="cpu").manual_seed(37)
-image = pipe_txt2img(prompt, generator=generator).images[0]
-image
-```
-
-
-
-
-
-
-
-
-```py
from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import load_image
-import torch
-
-pipe_img2img = AutoPipelineForImage2Image.from_pretrained(
- "dreamlike-art/dreamlike-photoreal-2.0", torch_dtype=torch.float16, use_safetensors=True
-).to("cuda")
-
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-text2img.png")
-
-prompt = "cinematic photo of Godzilla eating burgers with a cat in a fast food restaurant, 35mm photograph, film, professional, 4k, highly detailed"
-generator = torch.Generator(device="cpu").manual_seed(53)
-image = pipe_img2img(prompt, image=init_image, generator=generator).images[0]
-image
-```
-
-Notice how the [dreamlike-art/dreamlike-photoreal-2.0](https://hf.co/dreamlike-art/dreamlike-photoreal-2.0) checkpoint is used for both text-to-image and image-to-image tasks? To save memory and avoid loading the checkpoint twice, use the [`~DiffusionPipeline.from_pipe`] method.
-
-```py
-pipe_img2img = AutoPipelineForImage2Image.from_pipe(pipe_txt2img).to("cuda")
-image = pipeline(prompt, image=init_image, generator=generator).images[0]
-image
-```
-
-You can learn more about the [`~DiffusionPipeline.from_pipe`] method in the [Reuse a pipeline](../using-diffusers/loading#reuse-a-pipeline) guide.
-
-
-
-
-
-
-## Unsupported checkpoints
-
-The [AutoPipeline](../api/pipelines/auto_pipeline) supports [Stable Diffusion](../api/pipelines/stable_diffusion/overview), [Stable Diffusion XL](../api/pipelines/stable_diffusion/stable_diffusion_xl), [ControlNet](../api/pipelines/controlnet), [Kandinsky 2.1](../api/pipelines/kandinsky.md), [Kandinsky 2.2](../api/pipelines/kandinsky_v22), and [DeepFloyd IF](../api/pipelines/deepfloyd_if) checkpoints.
-
-If you try to load an unsupported checkpoint, you'll get an error.
-
-```py
-from diffusers import AutoPipelineForImage2Image
-import torch
pipeline = AutoPipelineForImage2Image.from_pretrained(
- "openai/shap-e-img2img", torch_dtype=torch.float16, use_safetensors=True
+ "RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.bfloat16, device_map="cuda",
+)
+print(pipeline)
+"StableDiffusionXLImg2ImgPipeline {
+ "_class_name": "StableDiffusionXLImg2ImgPipeline",
+ ...
+"
+```
+
+Loading the same model with [`DiffusionPipeline`] returns the [`StableDiffusionXLPipeline`] subclass. It can be used for text-to-image, image-to-image, or inpainting tasks depending on the inputs.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.bfloat16, device_map="cuda",
+)
+print(pipeline)
+"StableDiffusionXLPipeline {
+ "_class_name": "StableDiffusionXLPipeline",
+ ...
+"
+```
+
+Check the [mappings](https://github.com/huggingface/diffusers/blob/130fd8df54f24ffb006d84787b598d8adc899f23/src/diffusers/pipelines/auto_pipeline.py#L114) to see whether a model is supported or not.
+
+Trying to load an unsupported model returns an error.
+
+```py
+import torch
+from diffusers import AutoPipelineForImage2Image
+
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "openai/shap-e-img2img", torch_dtype=torch.float16,
)
"ValueError: AutoPipeline can't find a pipeline linked to ShapEImg2ImgPipeline for None"
```
+
+There are three types of [AutoPipeline](../api/models/auto_model) classes, [`AutoPipelineForText2Image`], [`AutoPipelineForImage2Image`] and [`AutoPipelineForInpainting`]. Each of these classes have a predefined mapping, linking a pipeline to their task-specific subclass.
+
+When [`~AutoPipelineForText2Image.from_pretrained`] is called, it extracts the class name from the `model_index.json` file and selects the appropriate pipeline subclass for the task based on the mapping.
\ No newline at end of file
diff --git a/docs/source/en/tutorials/basic_training.md b/docs/source/en/tutorials/basic_training.md
index 9a35b3438f..3aa2ae429b 100644
--- a/docs/source/en/tutorials/basic_training.md
+++ b/docs/source/en/tutorials/basic_training.md
@@ -18,11 +18,8 @@ Unconditional image generation is a popular application of diffusion models that
This tutorial will teach you how to train a [`UNet2DModel`] from scratch on a subset of the [Smithsonian Butterflies](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) dataset to generate your own 🦋 butterflies 🦋.
-
-
-💡 This training tutorial is based on the [Training with 🧨 Diffusers](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) notebook. For additional details and context about diffusion models like how they work, check out the notebook!
-
-
+> [!TIP]
+> 💡 This training tutorial is based on the [Training with 🧨 Diffusers](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) notebook. For additional details and context about diffusion models like how they work, check out the notebook!
Before you begin, make sure you have 🤗 Datasets installed to load and preprocess image datasets, and 🤗 Accelerate, to simplify training on any number of GPUs. The following command will also install [TensorBoard](https://www.tensorflow.org/tensorboard) to visualize training metrics (you can also use [Weights & Biases](https://docs.wandb.ai/) to track your training).
@@ -94,11 +91,8 @@ You can easily load the [Smithsonian Butterflies](https://huggingface.co/dataset
>>> dataset = load_dataset(config.dataset_name, split="train")
```
-
-
-💡 You can find additional datasets from the [HugGan Community Event](https://huggingface.co/huggan) or you can use your own dataset by creating a local [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder). Set `config.dataset_name` to the repository id of the dataset if it is from the HugGan Community Event, or `imagefolder` if you're using your own images.
-
-
+> [!TIP]
+> 💡 You can find additional datasets from the [HugGan Community Event](https://huggingface.co/huggan) or you can use your own dataset by creating a local [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder). Set `config.dataset_name` to the repository id of the dataset if it is from the HugGan Community Event, or `imagefolder` if you're using your own images.
🤗 Datasets uses the [`~datasets.Image`] feature to automatically decode the image data and load it as a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html) which we can visualize:
@@ -274,11 +268,8 @@ Then, you'll need a way to evaluate the model. For evaluation, you can use the [
Now you can wrap all these components together in a training loop with 🤗 Accelerate for easy TensorBoard logging, gradient accumulation, and mixed precision training. To upload the model to the Hub, write a function to get your repository name and information and then push it to the Hub.
-
-
-💡 The training loop below may look intimidating and long, but it'll be worth it later when you launch your training in just one line of code! If you can't wait and want to start generating images, feel free to copy and run the code below. You can always come back and examine the training loop more closely later, like when you're waiting for your model to finish training. 🤗
-
-
+> [!TIP]
+> 💡 The training loop below may look intimidating and long, but it'll be worth it later when you launch your training in just one line of code! If you can't wait and want to start generating images, feel free to copy and run the code below. You can always come back and examine the training loop more closely later, like when you're waiting for your model to finish training. 🤗
```py
>>> from accelerate import Accelerator
diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md
index 5cd47f8674..7bdd2a1ee9 100644
--- a/docs/source/en/tutorials/using_peft_for_inference.md
+++ b/docs/source/en/tutorials/using_peft_for_inference.md
@@ -94,7 +94,7 @@ pipeline = AutoPipelineForText2Image.from_pretrained(
pipeline.unet.load_lora_adapter(
"jbilcke-hf/sdxl-cinematic-1",
weight_name="pytorch_lora_weights.safetensors",
- adapter_name="cinematic"
+ adapter_name="cinematic",
prefix="unet"
)
# use cnmt in the prompt to trigger the LoRA
@@ -688,4 +688,4 @@ Browse the [LoRA Studio](https://lorastudio.co/models) for different LoRAs to us
You can find additional LoRAs in the [FLUX LoRA the Explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer) and [LoRA the Explorer](https://huggingface.co/spaces/multimodalart/LoraTheExplorer) Spaces.
-Check out the [Fast LoRA inference for Flux with Diffusers and PEFT](https://huggingface.co/blog/lora-fast) blog post to learn how to optimize LoRA inference with methods like FlashAttention-3 and fp8 quantization.
\ No newline at end of file
+Check out the [Fast LoRA inference for Flux with Diffusers and PEFT](https://huggingface.co/blog/lora-fast) blog post to learn how to optimize LoRA inference with methods like FlashAttention-3 and fp8 quantization.
diff --git a/docs/source/en/using-diffusers/batched_inference.md b/docs/source/en/using-diffusers/batched_inference.md
index b5e55c27ca..cdb16ac121 100644
--- a/docs/source/en/using-diffusers/batched_inference.md
+++ b/docs/source/en/using-diffusers/batched_inference.md
@@ -16,24 +16,24 @@ Batch inference processes multiple prompts at a time to increase throughput. It
The downside is increased latency because you must wait for the entire batch to complete, and more GPU memory is required for large batches.
-
-
-
-For text-to-image, pass a list of prompts to the pipeline.
+For text-to-image, pass a list of prompts to the pipeline and for image-to-image, pass a list of images and prompts to the pipeline. The example below demonstrates batched text-to-image inference.
```py
import torch
+import matplotlib.pyplot as plt
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
- torch_dtype=torch.float16
-).to("cuda")
+ torch_dtype=torch.float16,
+ device_map="cuda"
+)
prompts = [
- "cinematic photo of A beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
- "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
- "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
+ "Cinematic shot of a cozy coffee shop interior, warm pastel light streaming through a window where a cat rests. Shallow depth of field, glowing cups in soft focus, dreamy lofi-inspired mood, nostalgic tones, framed like a quiet film scene.",
+ "Polaroid-style photograph of a cozy coffee shop interior, bathed in warm pastel light. A cat sits on the windowsill near steaming mugs. Soft, slightly faded tones and dreamy blur evoke nostalgia, a lofi mood, and the intimate, imperfect charm of instant film.",
+ "Soft watercolor illustration of a cozy coffee shop interior, pastel washes of color filling the space. A cat rests peacefully on the windowsill as warm light glows through. Gentle brushstrokes create a dreamy, lofi-inspired atmosphere with whimsical textures and nostalgic calm.",
+ "Isometric pixel-art illustration of a cozy coffee shop interior in detailed 8-bit style. Warm pastel light fills the space as a cat rests on the windowsill. Blocky furniture and tiny mugs add charm, low-res retro graphics enhance the nostalgic, lofi-inspired game aesthetic."
]
images = pipeline(
@@ -52,6 +52,10 @@ plt.tight_layout()
plt.show()
```
+
+
+
+
To generate multiple variations of one prompt, use the `num_images_per_prompt` argument.
```py
@@ -61,11 +65,18 @@ from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
- torch_dtype=torch.float16
-).to("cuda")
+ torch_dtype=torch.float16,
+ device_map="cuda"
+)
+
+prompt="""
+Isometric pixel-art illustration of a cozy coffee shop interior in detailed 8-bit style. Warm pastel light fills the
+space as a cat rests on the windowsill. Blocky furniture and tiny mugs add charm, low-res retro graphics enhance the
+nostalgic, lofi-inspired game aesthetic.
+"""
images = pipeline(
- prompt="pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics",
+ prompt=prompt,
num_images_per_prompt=4
).images
@@ -81,6 +92,10 @@ plt.tight_layout()
plt.show()
```
+
+
+
+
Combine both approaches to generate different variations of different prompts.
```py
@@ -89,7 +104,7 @@ images = pipeline(
num_images_per_prompt=2,
).images
-fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+fig, axes = plt.subplots(2, 4, figsize=(12, 12))
axes = axes.flatten()
for i, image in enumerate(images):
@@ -101,126 +116,18 @@ plt.tight_layout()
plt.show()
```
-
-
-
-For image-to-image, pass a list of input images and prompts to the pipeline.
-
-```py
-import torch
-from diffusers.utils import load_image
-from diffusers import DiffusionPipeline
-
-pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- torch_dtype=torch.float16
-).to("cuda")
-
-input_images = [
- load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"),
- load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"),
- load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
-]
-
-prompts = [
- "cinematic photo of a beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
- "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
- "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
-]
-
-images = pipeline(
- prompt=prompts,
- image=input_images,
- guidance_scale=8.0,
- strength=0.5
-).images
-
-fig, axes = plt.subplots(2, 2, figsize=(12, 12))
-axes = axes.flatten()
-
-for i, image in enumerate(images):
- axes[i].imshow(image)
- axes[i].set_title(f"Image {i+1}")
- axes[i].axis('off')
-
-plt.tight_layout()
-plt.show()
-```
-
-To generate multiple variations of one prompt, use the `num_images_per_prompt` argument.
-
-```py
-import torch
-import matplotlib.pyplot as plt
-from diffusers.utils import load_image
-from diffusers import DiffusionPipeline
-
-pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- torch_dtype=torch.float16
-).to("cuda")
-
-input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
-
-images = pipeline(
- prompt="pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics",
- image=input_image,
- num_images_per_prompt=4
-).images
-
-fig, axes = plt.subplots(2, 2, figsize=(12, 12))
-axes = axes.flatten()
-
-for i, image in enumerate(images):
- axes[i].imshow(image)
- axes[i].set_title(f"Image {i+1}")
- axes[i].axis('off')
-
-plt.tight_layout()
-plt.show()
-```
-
-Combine both approaches to generate different variations of different prompts.
-
-```py
-input_images = [
- load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"),
- load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
-]
-
-prompts = [
- "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
- "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
-]
-
-images = pipeline(
- prompt=prompts,
- image=input_images,
- num_images_per_prompt=2,
-).images
-
-fig, axes = plt.subplots(2, 2, figsize=(12, 12))
-axes = axes.flatten()
-
-for i, image in enumerate(images):
- axes[i].imshow(image)
- axes[i].set_title(f"Image {i+1}")
- axes[i].axis('off')
-
-plt.tight_layout()
-plt.show()
-```
-
-
-
+
+
+
## Deterministic generation
Enable reproducible batch generation by passing a list of [Generator’s](https://pytorch.org/docs/stable/generated/torch.Generator.html) to the pipeline and tie each `Generator` to a seed to reuse it.
-Use a list comprehension to iterate over the batch size specified in `range()` to create a unique `Generator` object for each image in the batch.
+> [!TIP]
+> Refer to the [Reproducibility](./reusing_seeds) docs to learn more about deterministic algorithms and the `Generator` object.
-Don't multiply the `Generator` by the batch size because that only creates one `Generator` object that is used sequentially for each image in the batch.
+Use a list comprehension to iterate over the batch size specified in `range()` to create a unique `Generator` object for each image in the batch. Don't multiply the `Generator` by the batch size because that only creates one `Generator` object that is used sequentially for each image in the batch.
```py
generator = [torch.Generator(device="cuda").manual_seed(0)] * 3
@@ -234,14 +141,16 @@ from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
- torch_dtype=torch.float16
-).to("cuda")
+ torch_dtype=torch.float16,
+ device_map="cuda"
+)
generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(3)]
prompts = [
- "cinematic photo of A beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
- "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
- "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
+ "Cinematic shot of a cozy coffee shop interior, warm pastel light streaming through a window where a cat rests. Shallow depth of field, glowing cups in soft focus, dreamy lofi-inspired mood, nostalgic tones, framed like a quiet film scene.",
+ "Polaroid-style photograph of a cozy coffee shop interior, bathed in warm pastel light. A cat sits on the windowsill near steaming mugs. Soft, slightly faded tones and dreamy blur evoke nostalgia, a lofi mood, and the intimate, imperfect charm of instant film.",
+ "Soft watercolor illustration of a cozy coffee shop interior, pastel washes of color filling the space. A cat rests peacefully on the windowsill as warm light glows through. Gentle brushstrokes create a dreamy, lofi-inspired atmosphere with whimsical textures and nostalgic calm.",
+ "Isometric pixel-art illustration of a cozy coffee shop interior in detailed 8-bit style. Warm pastel light fills the space as a cat rests on the windowsill. Blocky furniture and tiny mugs add charm, low-res retro graphics enhance the nostalgic, lofi-inspired game aesthetic."
]
images = pipeline(
@@ -261,4 +170,4 @@ plt.tight_layout()
plt.show()
```
-You can use this to iteratively select an image associated with a seed and then improve on it by crafting a more detailed prompt.
\ No newline at end of file
+You can use this to select an image associated with a seed and iteratively improve on it by crafting a more detailed prompt.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/callback.md b/docs/source/en/using-diffusers/callback.md
index e0fa885784..60b839805f 100644
--- a/docs/source/en/using-diffusers/callback.md
+++ b/docs/source/en/using-diffusers/callback.md
@@ -12,52 +12,37 @@ specific language governing permissions and limitations under the License.
# Pipeline callbacks
-The denoising loop of a pipeline can be modified with custom defined functions using the `callback_on_step_end` parameter. The callback function is executed at the end of each step, and modifies the pipeline attributes and variables for the next step. This is really useful for *dynamically* adjusting certain pipeline attributes or modifying tensor variables. This versatility allows for interesting use cases such as changing the prompt embeddings at each timestep, assigning different weights to the prompt embeddings, and editing the guidance scale. With callbacks, you can implement new features without modifying the underlying code!
+A callback is a function that modifies [`DiffusionPipeline`] behavior and it is executed at the end of a denoising step. The changes are propagated to subsequent steps in the denoising process. It is useful for adjusting pipeline attributes or tensor variables to support new features without rewriting the underlying pipeline code.
-> [!TIP]
-> 🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point!
+Diffusers provides several callbacks in the pipeline [overview](../api/pipelines/overview#callbacks).
-This guide will demonstrate how callbacks work by a few features you can implement with them.
+To enable a callback, configure when the callback is executed after a certain number of denoising steps with one of the following arguments.
-## Official callbacks
+- `cutoff_step_ratio` specifies when a callback is activated as a percentage of the total denoising steps.
+- `cutoff_step_index` specifies the exact step number a callback is activated.
-We provide a list of callbacks you can plug into an existing pipeline and modify the denoising loop. This is the current list of official callbacks:
+The example below uses `cutoff_step_ratio=0.4`, which means the callback is activated once denoising reaches 40% of the total inference steps. [`~callbacks.SDXLCFGCutoffCallback`] disables classifier-free guidance (CFG) after a certain number of steps, which can help save compute without significantly affecting performance.
-- `SDCFGCutoffCallback`: Disables the CFG after a certain number of steps for all SD 1.5 pipelines, including text-to-image, image-to-image, inpaint, and controlnet.
-- `SDXLCFGCutoffCallback`: Disables the CFG after a certain number of steps for all SDXL pipelines, including text-to-image, image-to-image, inpaint, and controlnet.
-- `IPAdapterScaleCutoffCallback`: Disables the IP Adapter after a certain number of steps for all pipelines supporting IP-Adapter.
+Define a callback with either of the `cutoff` arguments and pass it to the `callback_on_step_end` parameter in the pipeline.
-> [!TIP]
-> If you want to add a new official callback, feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) or [submit a PR](https://huggingface.co/docs/diffusers/main/en/conceptual/contribution#how-to-open-a-pr).
-
-To set up a callback, you need to specify the number of denoising steps after which the callback comes into effect. You can do so by using either one of these two arguments
-
-- `cutoff_step_ratio`: Float number with the ratio of the steps.
-- `cutoff_step_index`: Integer number with the exact number of the step.
-
-```python
+```py
import torch
-
from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLPipeline
from diffusers.callbacks import SDXLCFGCutoffCallback
-
callback = SDXLCFGCutoffCallback(cutoff_step_ratio=0.4)
-# can also be used with cutoff_step_index
+# if using cutoff_step_index
# callback = SDXLCFGCutoffCallback(cutoff_step_ratio=None, cutoff_step_index=10)
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
- variant="fp16",
-).to("cuda")
+ device_map="cuda"
+)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True)
prompt = "a sports car at the road, best quality, high quality, high detail, 8k resolution"
-
-generator = torch.Generator(device="cpu").manual_seed(2628670641)
-
-out = pipeline(
+output = pipeline(
prompt=prompt,
negative_prompt="",
guidance_scale=6.5,
@@ -65,83 +50,16 @@ out = pipeline(
generator=generator,
callback_on_step_end=callback,
)
-
-out.images[0].save("official_callback.png")
```
-
-
-
- without SDXLCFGCutoffCallback
-
-
-
- with SDXLCFGCutoffCallback
-
-
+If you want to add a new official callback, feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) or [submit a PR](https://huggingface.co/docs/diffusers/main/en/conceptual/contribution#how-to-open-a-pr). Otherwise, you can also create your own callback as shown below.
-## Dynamic classifier-free guidance
+## Early stopping
-Dynamic classifier-free guidance (CFG) is a feature that allows you to disable CFG after a certain number of inference steps which can help you save compute with minimal cost to performance. The callback function for this should have the following arguments:
-
-- `pipeline` (or the pipeline instance) provides access to important properties such as `num_timesteps` and `guidance_scale`. You can modify these properties by updating the underlying attributes. For this example, you'll disable CFG by setting `pipeline._guidance_scale=0.0`.
-- `step_index` and `timestep` tell you where you are in the denoising loop. Use `step_index` to turn off CFG after reaching 40% of `num_timesteps`.
-- `callback_kwargs` is a dict that contains tensor variables you can modify during the denoising loop. It only includes variables specified in the `callback_on_step_end_tensor_inputs` argument, which is passed to the pipeline's `__call__` method. Different pipelines may use different sets of variables, so please check a pipeline's `_callback_tensor_inputs` attribute for the list of variables you can modify. Some common variables include `latents` and `prompt_embeds`. For this function, change the batch size of `prompt_embeds` after setting `guidance_scale=0.0` in order for it to work properly.
-
-Your callback function should look something like this:
-
-```python
-def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
- # adjust the batch_size of prompt_embeds according to guidance_scale
- if step_index == int(pipeline.num_timesteps * 0.4):
- prompt_embeds = callback_kwargs["prompt_embeds"]
- prompt_embeds = prompt_embeds.chunk(2)[-1]
-
- # update guidance_scale and prompt_embeds
- pipeline._guidance_scale = 0.0
- callback_kwargs["prompt_embeds"] = prompt_embeds
- return callback_kwargs
-```
-
-Now, you can pass the callback function to the `callback_on_step_end` parameter and the `prompt_embeds` to `callback_on_step_end_tensor_inputs`.
+Early stopping is useful if you aren't happy with the intermediate results during generation. This callback sets a hardcoded stop point after which the pipeline terminates by setting the `_interrupt` attribute to `True`.
```py
-import torch
-from diffusers import StableDiffusionPipeline
-
-pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16)
-pipeline = pipeline.to("cuda")
-
-prompt = "a photo of an astronaut riding a horse on mars"
-
-generator = torch.Generator(device="cuda").manual_seed(1)
-out = pipeline(
- prompt,
- generator=generator,
- callback_on_step_end=callback_dynamic_cfg,
- callback_on_step_end_tensor_inputs=['prompt_embeds']
-)
-
-out.images[0].save("out_custom_cfg.png")
-```
-
-## Interrupt the diffusion process
-
-> [!TIP]
-> The interruption callback is supported for text-to-image, image-to-image, and inpainting for the [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview) and [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl).
-
-Stopping the diffusion process early is useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.
-
-This callback function should take the following arguments: `pipeline`, `i`, `t`, and `callback_kwargs` (this must be returned). Set the pipeline's `_interrupt` attribute to `True` to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback.
-
-In this example, the diffusion process is stopped after 10 steps even though `num_inference_steps` is set to 50.
-
-```python
-from diffusers import StableDiffusionPipeline
-
-pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
-pipeline.enable_model_cpu_offload()
-num_inference_steps = 50
+from diffusers import StableDiffusionXLPipeline
def interrupt_callback(pipeline, i, t, callback_kwargs):
stop_idx = 10
@@ -150,6 +68,11 @@ def interrupt_callback(pipeline, i, t, callback_kwargs):
return callback_kwargs
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5"
+)
+num_inference_steps = 50
+
pipeline(
"A photo of a cat",
num_inference_steps=num_inference_steps,
@@ -157,92 +80,11 @@ pipeline(
)
```
-## IP Adapter Cutoff
+## Display intermediate images
-IP Adapter is an image prompt adapter that can be used for diffusion models without any changes to the underlying model. We can use the IP Adapter Cutoff Callback to disable the IP Adapter after a certain number of steps. To set up the callback, you need to specify the number of denoising steps after which the callback comes into effect. You can do so by using either one of these two arguments:
+Visualizing the intermediate images is useful for progress monitoring and assessing the quality of the generated content. This callback decodes the latent tensors at each step and converts them to images.
-- `cutoff_step_ratio`: Float number with the ratio of the steps.
-- `cutoff_step_index`: Integer number with the exact number of the step.
-
-We need to download the diffusion model and load the ip_adapter for it as follows:
-
-```py
-from diffusers import AutoPipelineForText2Image
-from diffusers.utils import load_image
-import torch
-
-pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
-pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
-pipeline.set_ip_adapter_scale(0.6)
-```
-The setup for the callback should look something like this:
-
-```py
-
-from diffusers import AutoPipelineForText2Image
-from diffusers.callbacks import IPAdapterScaleCutoffCallback
-from diffusers.utils import load_image
-import torch
-
-
-pipeline = AutoPipelineForText2Image.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- torch_dtype=torch.float16
-).to("cuda")
-
-
-pipeline.load_ip_adapter(
- "h94/IP-Adapter",
- subfolder="sdxl_models",
- weight_name="ip-adapter_sdxl.bin"
-)
-
-pipeline.set_ip_adapter_scale(0.6)
-
-
-callback = IPAdapterScaleCutoffCallback(
- cutoff_step_ratio=None,
- cutoff_step_index=5
-)
-
-image = load_image(
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png"
-)
-
-generator = torch.Generator(device="cuda").manual_seed(2628670641)
-
-images = pipeline(
- prompt="a tiger sitting in a chair drinking orange juice",
- ip_adapter_image=image,
- negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
- generator=generator,
- num_inference_steps=50,
- callback_on_step_end=callback,
-).images
-
-images[0].save("custom_callback_img.png")
-```
-
-
-
-
- without IPAdapterScaleCutoffCallback
-
-
-
- with IPAdapterScaleCutoffCallback
-
-
-
-
-## Display image after each generation step
-
-> [!TIP]
-> This tip was contributed by [asomoza](https://github.com/asomoza).
-
-Display an image after each generation step by accessing and converting the latents after each step into an image. The latent space is compressed to 128x128, so the images are also 128x128 which is useful for a quick preview.
-
-1. Use the function below to convert the SDXL latents (4 channels) to RGB tensors (3 channels) as explained in the [Explaining the SDXL latent space](https://huggingface.co/blog/TimothyAlexisVass/explaining-the-sdxl-latent-space) blog post.
+[Convert](https://huggingface.co/blog/TimothyAlexisVass/explaining-the-sdxl-latent-space) the Stable Diffusion XL latents from latents (4 channels) to RGB tensors (3 tensors).
```py
def latents_to_rgb(latents):
@@ -260,7 +102,7 @@ def latents_to_rgb(latents):
return Image.fromarray(image_array)
```
-2. Create a function to decode and save the latents into an image.
+Extract the latents and convert the first image in the batch to RGB. Save the image as a PNG file with the step number.
```py
def decode_tensors(pipe, step, timestep, callback_kwargs):
@@ -272,19 +114,18 @@ def decode_tensors(pipe, step, timestep, callback_kwargs):
return callback_kwargs
```
-3. Pass the `decode_tensors` function to the `callback_on_step_end` parameter to decode the tensors after each step. You also need to specify what you want to modify in the `callback_on_step_end_tensor_inputs` parameter, which in this case are the latents.
+Use the `callback_on_step_end_tensor_inputs` parameter to specify what input type to modify, which in this case, are the latents.
```py
-from diffusers import AutoPipelineForText2Image
import torch
from PIL import Image
+from diffusers import AutoPipelineForText2Image
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
- variant="fp16",
- use_safetensors=True
-).to("cuda")
+ device_map="cuda"
+)
image = pipeline(
prompt="A croissant shaped like a cute bear.",
@@ -293,27 +134,3 @@ image = pipeline(
callback_on_step_end_tensor_inputs=["latents"],
).images[0]
```
-
-
-
-
- step 0
-
-
-
- step 19
-
-
-
-
- step 29
-
-
-
- step 39
-
-
-
- step 49
-
-
diff --git a/docs/source/en/using-diffusers/conditional_image_generation.md b/docs/source/en/using-diffusers/conditional_image_generation.md
index 7efc0c653e..eb75b6b8a8 100644
--- a/docs/source/en/using-diffusers/conditional_image_generation.md
+++ b/docs/source/en/using-diffusers/conditional_image_generation.md
@@ -18,11 +18,8 @@ When you think of diffusion models, text-to-image is usually one of the first th
From a very high level, a diffusion model takes a prompt and some random initial noise, and iteratively removes the noise to construct an image. The *denoising* process is guided by the prompt, and once the denoising process ends after a predetermined number of time steps, the image representation is decoded into an image.
-
-
-Read the [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) blog post to learn more about how a latent diffusion model works.
-
-
+> [!TIP]
+> Read the [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) blog post to learn more about how a latent diffusion model works.
You can generate images from a prompt in 🤗 Diffusers in two steps:
@@ -176,11 +173,8 @@ image
-
-
-Other models may have different default image sizes depending on the image sizes in the training dataset. For example, SDXL's default image size is 1024x1024 and using lower `height` and `width` values may result in lower quality images. Make sure you check the model's API reference first!
-
-
+> [!WARNING]
+> Other models may have different default image sizes depending on the image sizes in the training dataset. For example, SDXL's default image size is 1024x1024 and using lower `height` and `width` values may result in lower quality images. Make sure you check the model's API reference first!
### Guidance scale
@@ -272,11 +266,8 @@ There are several ways to exert more control over how an image is generated outs
Prompt weighting is a technique for increasing or decreasing the importance of concepts in a prompt to emphasize or minimize certain features in an image. We recommend using the [Compel](https://github.com/damian0815/compel) library to help you generate the weighted prompt embeddings.
-
-
-Learn how to create the prompt embeddings in the [Prompt weighting](weighted_prompts) guide. This example focuses on how to use the prompt embeddings in the pipeline.
-
-
+> [!TIP]
+> Learn how to create the prompt embeddings in the [Prompt weighting](weighted_prompts) guide. This example focuses on how to use the prompt embeddings in the pipeline.
Once you've created the embeddings, you can pass them to the `prompt_embeds` (and `negative_prompt_embeds` if you're using a negative prompt) parameter in the pipeline.
diff --git a/docs/source/en/using-diffusers/controlling_generation.md b/docs/source/en/using-diffusers/controlling_generation.md
index 8fd57a7cb8..b7b0ea4919 100644
--- a/docs/source/en/using-diffusers/controlling_generation.md
+++ b/docs/source/en/using-diffusers/controlling_generation.md
@@ -70,38 +70,6 @@ For convenience, we provide a table to denote which methods are inference-only a
[InstructPix2Pix](../api/pipelines/pix2pix) is fine-tuned from Stable Diffusion to support editing input images. It takes as inputs an image and a prompt describing an edit, and it outputs the edited image.
InstructPix2Pix has been explicitly trained to work well with [InstructGPT](https://openai.com/blog/instruction-following/)-like prompts.
-## Pix2Pix Zero
-
-[Paper](https://huggingface.co/papers/2302.03027)
-
-[Pix2Pix Zero](../api/pipelines/pix2pix_zero) allows modifying an image so that one concept or subject is translated to another one while preserving general image semantics.
-
-The denoising process is guided from one conceptual embedding towards another conceptual embedding. The intermediate latents are optimized during the denoising process to push the attention maps towards reference attention maps. The reference attention maps are from the denoising process of the input image and are used to encourage semantic preservation.
-
-Pix2Pix Zero can be used both to edit synthetic images as well as real images.
-
-- To edit synthetic images, one first generates an image given a caption.
- Next, we generate image captions for the concept that shall be edited and for the new target concept. We can use a model like [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) for this purpose. Then, "mean" prompt embeddings for both the source and target concepts are created via the text encoder. Finally, the pix2pix-zero algorithm is used to edit the synthetic image.
-- To edit a real image, one first generates an image caption using a model like [BLIP](https://huggingface.co/docs/transformers/model_doc/blip). Then one applies DDIM inversion on the prompt and image to generate "inverse" latents. Similar to before, "mean" prompt embeddings for both source and target concepts are created and finally the pix2pix-zero algorithm in combination with the "inverse" latents is used to edit the image.
-
-
-
-Pix2Pix Zero is the first model that allows "zero-shot" image editing. This means that the model
-can edit an image in less than a minute on a consumer GPU as shown [here](../api/pipelines/pix2pix_zero#usage-example).
-
-
-
-As mentioned above, Pix2Pix Zero includes optimizing the latents (and not any of the UNet, VAE, or the text encoder) to steer the generation toward a specific concept. This means that the overall
-pipeline might require more memory than a standard [StableDiffusionPipeline](../api/pipelines/stable_diffusion/text2img).
-
-
-
-An important distinction between methods like InstructPix2Pix and Pix2Pix Zero is that the former
-involves fine-tuning the pre-trained weights while the latter does not. This means that you can
-apply Pix2Pix Zero to any of the available Stable Diffusion models.
-
-
-
## Attend and Excite
[Paper](https://huggingface.co/papers/2301.13826)
@@ -184,14 +152,6 @@ multi-concept training by design. Like DreamBooth and Textual Inversion, Custom
teach a pre-trained text-to-image diffusion model about new concepts to generate outputs involving the
concept(s) of interest.
-## Model Editing
-
-[Paper](https://huggingface.co/papers/2303.08084)
-
-The [text-to-image model editing pipeline](../api/pipelines/model_editing) helps you mitigate some of the incorrect implicit assumptions a pre-trained text-to-image
-diffusion model might make about the subjects present in the input prompt. For example, if you prompt Stable Diffusion to generate images for "A pack of roses", the roses in the generated images
-are more likely to be red. This pipeline helps you change that assumption.
-
## DiffEdit
[Paper](https://huggingface.co/papers/2210.11427)
diff --git a/docs/source/en/using-diffusers/custom_pipeline_overview.md b/docs/source/en/using-diffusers/custom_pipeline_overview.md
index bfe48d28be..b087e57056 100644
--- a/docs/source/en/using-diffusers/custom_pipeline_overview.md
+++ b/docs/source/en/using-diffusers/custom_pipeline_overview.md
@@ -10,376 +10,163 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Load community pipelines and components
-
[[open-in-colab]]
-## Community pipelines
+# Community pipelines and components
-> [!TIP] Take a look at GitHub Issue [#841](https://github.com/huggingface/diffusers/issues/841) for more context about why we're adding community pipelines to help everyone easily share their work without being slowed down.
-
-Community pipelines are any [`DiffusionPipeline`] class that are different from the original paper implementation (for example, the [`StableDiffusionControlNetPipeline`] corresponds to the [Text-to-Image Generation with ControlNet Conditioning](https://huggingface.co/papers/2302.05543) paper). They provide additional functionality or extend the original implementation of a pipeline.
-
-There are many cool community pipelines like [Marigold Depth Estimation](https://github.com/huggingface/diffusers/tree/main/examples/community#marigold-depth-estimation) or [InstantID](https://github.com/huggingface/diffusers/tree/main/examples/community#instantid-pipeline), and you can find all the official community pipelines [here](https://github.com/huggingface/diffusers/tree/main/examples/community).
-
-There are two types of community pipelines, those stored on the Hugging Face Hub and those stored on Diffusers GitHub repository. Hub pipelines are completely customizable (scheduler, models, pipeline code, etc.) while Diffusers GitHub pipelines are only limited to custom pipeline code.
-
-| | GitHub community pipeline | HF Hub community pipeline |
-|----------------|------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------|
-| usage | same | same |
-| review process | open a Pull Request on GitHub and undergo a review process from the Diffusers team before merging; may be slower | upload directly to a Hub repository without any review; this is the fastest workflow |
-| visibility | included in the official Diffusers repository and documentation | included on your HF Hub profile and relies on your own usage/promotion to gain visibility |
-
-
-
-
-To load a Hugging Face Hub community pipeline, pass the repository id of the community pipeline to the `custom_pipeline` argument and the model repository where you'd like to load the pipeline weights and components from. For example, the example below loads a dummy pipeline from [hf-internal-testing/diffusers-dummy-pipeline](https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py) and the pipeline weights and components from [google/ddpm-cifar10-32](https://huggingface.co/google/ddpm-cifar10-32):
-
-> [!WARNING]
-> By loading a community pipeline from the Hugging Face Hub, you are trusting that the code you are loading is safe. Make sure to inspect the code online before loading and running it automatically!
-
-```py
-from diffusers import DiffusionPipeline
-
-pipeline = DiffusionPipeline.from_pretrained(
- "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline", use_safetensors=True
-)
-```
-
-
-
-
-To load a GitHub community pipeline, pass the repository id of the community pipeline to the `custom_pipeline` argument and the model repository where you you'd like to load the pipeline weights and components from. You can also load model components directly. The example below loads the community [CLIP Guided Stable Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/community#clip-guided-stable-diffusion) pipeline and the CLIP model components.
-
-```py
-from diffusers import DiffusionPipeline
-from transformers import CLIPImageProcessor, CLIPModel
-
-clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
-
-feature_extractor = CLIPImageProcessor.from_pretrained(clip_model_id)
-clip_model = CLIPModel.from_pretrained(clip_model_id)
-
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- custom_pipeline="clip_guided_stable_diffusion",
- clip_model=clip_model,
- feature_extractor=feature_extractor,
- use_safetensors=True,
-)
-```
-
-
-
-
-### Load from a local file
-
-Community pipelines can also be loaded from a local file if you pass a file path instead. The path to the passed directory must contain a pipeline.py file that contains the pipeline class.
-
-```py
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- custom_pipeline="./path/to/pipeline_directory/",
- clip_model=clip_model,
- feature_extractor=feature_extractor,
- use_safetensors=True,
-)
-```
-
-### Load from a specific version
-
-By default, community pipelines are loaded from the latest stable version of Diffusers. To load a community pipeline from another version, use the `custom_revision` parameter.
-
-
-
-
-For example, to load from the main branch:
-
-```py
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- custom_pipeline="clip_guided_stable_diffusion",
- custom_revision="main",
- clip_model=clip_model,
- feature_extractor=feature_extractor,
- use_safetensors=True,
-)
-```
-
-
-
-
-For example, to load from a previous version of Diffusers like v0.25.0:
-
-```py
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- custom_pipeline="clip_guided_stable_diffusion",
- custom_revision="v0.25.0",
- clip_model=clip_model,
- feature_extractor=feature_extractor,
- use_safetensors=True,
-)
-```
-
-
-
-
-### Load with from_pipe
-
-Community pipelines can also be loaded with the [`~DiffusionPipeline.from_pipe`] method which allows you to load and reuse multiple pipelines without any additional memory overhead (learn more in the [Reuse a pipeline](./loading#reuse-a-pipeline) guide). The memory requirement is determined by the largest single pipeline loaded.
-
-For example, let's load a community pipeline that supports [long prompts with weighting](https://github.com/huggingface/diffusers/tree/main/examples/community#long-prompt-weighting-stable-diffusion) from a Stable Diffusion pipeline.
-
-```py
-import torch
-from diffusers import DiffusionPipeline
-
-pipe_sd = DiffusionPipeline.from_pretrained("emilianJR/CyberRealistic_V3", torch_dtype=torch.float16)
-pipe_sd.to("cuda")
-# load long prompt weighting pipeline
-pipe_lpw = DiffusionPipeline.from_pipe(
- pipe_sd,
- custom_pipeline="lpw_stable_diffusion",
-).to("cuda")
-
-prompt = "cat, hiding in the leaves, ((rain)), zazie rainyday, beautiful eyes, macro shot, colorful details, natural lighting, amazing composition, subsurface scattering, amazing textures, filmic, soft light, ultra-detailed eyes, intricate details, detailed texture, light source contrast, dramatic shadows, cinematic light, depth of field, film grain, noise, dark background, hyperrealistic dslr film still, dim volumetric cinematic lighting"
-neg_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers:1.4), (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
-generator = torch.Generator(device="cpu").manual_seed(20)
-out_lpw = pipe_lpw(
- prompt,
- negative_prompt=neg_prompt,
- width=512,
- height=512,
- max_embeddings_multiples=3,
- num_inference_steps=50,
- generator=generator,
- ).images[0]
-out_lpw
-```
-
-
-
-
- Stable Diffusion with long prompt weighting
-
-
-
- Stable Diffusion
-
-
-
-## Example community pipelines
-
-Community pipelines are a really fun and creative way to extend the capabilities of the original pipeline with new and unique features. You can find all community pipelines in the [diffusers/examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) folder with inference and training examples for how to use them.
-
-This section showcases a couple of the community pipelines and hopefully it'll inspire you to create your own (feel free to open a PR for your community pipeline and ping us for a review)!
+Community pipelines are [`DiffusionPipeline`] classes that are different from the original paper implementation. They provide additional functionality or extend the original pipeline implementation.
> [!TIP]
-> The [`~DiffusionPipeline.from_pipe`] method is particularly useful for loading community pipelines because many of them don't have pretrained weights and add a feature on top of an existing pipeline like Stable Diffusion or Stable Diffusion XL. You can learn more about the [`~DiffusionPipeline.from_pipe`] method in the [Load with from_pipe](custom_pipeline_overview#load-with-from_pipe) section.
+> Check out the community pipelines in [diffusers/examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) with inference and training examples for how to use them.
-
-
+Community pipelines are either stored on the Hub or the Diffusers' GitHub repository. Hub pipelines are completely customizable (scheduler, models, pipeline code, etc.) while GitHub pipelines are limited to only the custom pipeline code. Further compare the two community pipeline types in the table below.
-[Marigold](https://marigoldmonodepth.github.io/) is a depth estimation diffusion pipeline that uses the rich existing and inherent visual knowledge in diffusion models. It takes an input image and denoises and decodes it into a depth map. Marigold performs well even on images it hasn't seen before.
+| | GitHub | Hub |
+|---|---|---|
+| Usage | Same. | Same. |
+| Review process | Open a Pull Request on GitHub and undergo a review process from the Diffusers team before merging. This option is slower. | Upload directly to a Hub repository without a review. This is the fastest option. |
+| Visibility | Included in the official Diffusers repository and docs. | Included on your Hub profile and relies on your own usage and promotion to gain visibility. |
+
+## custom_pipeline
+
+Load either community pipeline types by passing the `custom_pipeline` argument to [`~DiffusionPipeline.from_pretrained`].
```py
import torch
-from PIL import Image
from diffusers import DiffusionPipeline
-from diffusers.utils import load_image
pipeline = DiffusionPipeline.from_pretrained(
- "prs-eth/marigold-lcm-v1-0",
- custom_pipeline="marigold_depth_estimation",
+ "stabilityai/stable-diffusion-3-medium-diffusers",
+ custom_pipeline="pipeline_stable_diffusion_3_instruct_pix2pix",
torch_dtype=torch.float16,
- variant="fp16",
+ device_map="cuda"
)
-
-pipeline.to("cuda")
-image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/community-marigold.png")
-output = pipeline(
- image,
- denoising_steps=4,
- ensemble_size=5,
- processing_res=768,
- match_input_res=True,
- batch_size=0,
- seed=33,
- color_map="Spectral",
- show_progress_bar=True,
-)
-depth_colored: Image.Image = output.depth_colored
-depth_colored.save("./depth_colored.png")
```
-
-
-
- original image
-
-
-
- colorized depth image
-
-
-
-
-
-
-[HD-Painter](https://hf.co/papers/2312.14091) is a high-resolution inpainting pipeline. It introduces a *Prompt-Aware Introverted Attention (PAIntA)* layer to better align a prompt with the area to be inpainted, and *Reweighting Attention Score Guidance (RASG)* to keep the latents more prompt-aligned and within their trained domain to generate realistc images.
+Add the `custom_revision` argument to [`~DiffusionPipeline.from_pretrained`] to load a community pipeline from a specific version (for example, `v0.30.0` or `main`). By default, community pipelines are loaded from the latest stable version of Diffusers.
```py
import torch
-from diffusers import DiffusionPipeline, DDIMScheduler
-from diffusers.utils import load_image
+from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5-inpainting",
- custom_pipeline="hd_painter"
+ "stabilityai/stable-diffusion-3-medium-diffusers",
+ custom_pipeline="pipeline_stable_diffusion_3_instruct_pix2pix",
+ custom_revision="main"
+ torch_dtype=torch.float16,
+ device_map="cuda"
)
-pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hd-painter.jpg")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hd-painter-mask.png")
-prompt = "football"
-image = pipeline(prompt, init_image, mask_image, use_rasg=True, use_painta=True, generator=torch.manual_seed(0)).images[0]
-image
```
-
-
-
- original image
-
-
-
- generated image
-
-
+> [!WARNING]
+> While the Hugging Face Hub [scans](https://huggingface.co/docs/hub/security-malware) files, you should still inspect the Hub pipeline code and make sure it is safe.
-
-
+There are a few ways to load a community pipeline.
+
+- Pass a path to `custom_pipeline` to load a local community pipeline. The directory must contain a `pipeline.py` file containing the pipeline class.
+
+ ```py
+ import torch
+ from diffusers import DiffusionPipeline
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-3-medium-diffusers",
+ custom_pipeline="path/to/pipeline_directory",
+ torch_dtype=torch.float16,
+ device_map="cuda"
+ )
+ ```
+
+- The `custom_pipeline` argument is also supported by [`~DiffusionPipeline.from_pipe`], which is useful for [reusing pipelines](./loading#reuse-a-pipeline) without using additional memory. It limits the memory usage to only the largest pipeline loaded.
+
+ ```py
+ import torch
+ from diffusers import DiffusionPipeline
+
+ pipeline_sd = DiffusionPipeline.from_pretrained("emilianJR/CyberRealistic_V3", torch_dtype=torch.float16, device_map="cuda")
+ pipeline_lpw = DiffusionPipeline.from_pipe(
+ pipeline_sd, custom_pipeline="lpw_stable_diffusion", device_map="cuda"
+ )
+ ```
+
+ The [`~DiffusionPipeline.from_pipe`] method is especially useful for loading community pipelines because many of them don't have pretrained weights. Community pipelines generally add a feature on top of an existing pipeline.
## Community components
-Community components allow users to build pipelines that may have customized components that are not a part of Diffusers. If your pipeline has custom components that Diffusers doesn't already support, you need to provide their implementations as Python modules. These customized components could be a VAE, UNet, and scheduler. In most cases, the text encoder is imported from the Transformers library. The pipeline code itself can also be customized.
+Community components let users build pipelines with custom transformers, UNets, VAEs, and schedulers not supported by Diffusers. These components require Python module implementations.
-This section shows how users should use community components to build a community pipeline.
+This section shows how users can use community components to build a community pipeline using [showlab/show-1-base](https://huggingface.co/showlab/show-1-base) as an example.
-You'll use the [showlab/show-1-base](https://huggingface.co/showlab/show-1-base) pipeline checkpoint as an example.
-
-1. Import and load the text encoder from Transformers:
-
-```python
-from transformers import T5Tokenizer, T5EncoderModel
-
-pipe_id = "showlab/show-1-base"
-tokenizer = T5Tokenizer.from_pretrained(pipe_id, subfolder="tokenizer")
-text_encoder = T5EncoderModel.from_pretrained(pipe_id, subfolder="text_encoder")
-```
-
-2. Load a scheduler:
+1. Load the required components, the scheduler and image processor. The text encoder is generally imported from [Transformers](https://huggingface.co/docs/transformers/index).
```python
+from transformers import T5Tokenizer, T5EncoderModel, CLIPImageProcessor
from diffusers import DPMSolverMultistepScheduler
+pipeline_id = "showlab/show-1-base"
+tokenizer = T5Tokenizer.from_pretrained(pipeline_id, subfolder="tokenizer")
+text_encoder = T5EncoderModel.from_pretrained(pipeline_id, subfolder="text_encoder")
scheduler = DPMSolverMultistepScheduler.from_pretrained(pipe_id, subfolder="scheduler")
-```
-
-3. Load an image processor:
-
-```python
-from transformers import CLIPImageProcessor
-
feature_extractor = CLIPImageProcessor.from_pretrained(pipe_id, subfolder="feature_extractor")
```
-
+> [!WARNING]
+> In steps 2 and 3, the custom [UNet](https://github.com/showlab/Show-1/blob/main/showone/models/unet_3d_condition.py) and [pipeline](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py) implementation must match the format shown in their files for this example to work.
-In steps 4 and 5, the custom [UNet](https://github.com/showlab/Show-1/blob/main/showone/models/unet_3d_condition.py) and [pipeline](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py) implementation must match the format shown in their files for this example to work.
-
-
-
-4. Now you'll load a [custom UNet](https://github.com/showlab/Show-1/blob/main/showone/models/unet_3d_condition.py), which in this example, has already been implemented in [showone_unet_3d_condition.py](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py) for your convenience. You'll notice the [`UNet3DConditionModel`] class name is changed to `ShowOneUNet3DConditionModel` because [`UNet3DConditionModel`] already exists in Diffusers. Any components needed for the `ShowOneUNet3DConditionModel` class should be placed in showone_unet_3d_condition.py.
-
- Once this is done, you can initialize the UNet:
-
- ```python
- from showone_unet_3d_condition import ShowOneUNet3DConditionModel
-
- unet = ShowOneUNet3DConditionModel.from_pretrained(pipe_id, subfolder="unet")
- ```
-
-5. Finally, you'll load the custom pipeline code. For this example, it has already been created for you in [pipeline_t2v_base_pixel.py](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/pipeline_t2v_base_pixel.py). This script contains a custom `TextToVideoIFPipeline` class for generating videos from text. Just like the custom UNet, any code needed for the custom pipeline to work should go in pipeline_t2v_base_pixel.py.
-
-Once everything is in place, you can initialize the `TextToVideoIFPipeline` with the `ShowOneUNet3DConditionModel`:
+2. Load a [custom UNet](https://github.com/showlab/Show-1/blob/main/showone/models/unet_3d_condition.py) which is already implemented in [showone_unet_3d_condition.py](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py). The [`UNet3DConditionModel`] class name is renamed to the custom implementation, `ShowOneUNet3DConditionModel`, because [`UNet3DConditionModel`] already exists in Diffusers. Any components required for `ShowOneUNet3DConditionModel` class should be placed in `showone_unet_3d_condition.py`.
+
+```python
+from showone_unet_3d_condition import ShowOneUNet3DConditionModel
+
+unet = ShowOneUNet3DConditionModel.from_pretrained(pipeline_id, subfolder="unet")
+```
+
+3. Load the custom pipeline code (already implemented in [pipeline_t2v_base_pixel.py](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/pipeline_t2v_base_pixel.py)). This script contains a custom `TextToVideoIFPipeline` class for generating videos from text. Like the custom UNet, any code required for `TextToVideIFPipeline` should be placed in `pipeline_t2v_base_pixel.py`.
+
+Initialize `TextToVideoIFPipeline` with `ShowOneUNet3DConditionModel`.
```python
-from pipeline_t2v_base_pixel import TextToVideoIFPipeline
import torch
+from pipeline_t2v_base_pixel import TextToVideoIFPipeline
pipeline = TextToVideoIFPipeline(
unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
- feature_extractor=feature_extractor
+ feature_extractor=feature_extractor,
+ device_map="cuda",
+ torch_dtype=torch.float16
)
-pipeline = pipeline.to(device="cuda")
-pipeline.torch_dtype = torch.float16
```
-Push the pipeline to the Hub to share with the community!
+4. Push the pipeline to the Hub to share with the community.
```python
pipeline.push_to_hub("custom-t2v-pipeline")
```
-After the pipeline is successfully pushed, you need to make a few changes:
+After the pipeline is successfully pushed, make the following changes.
-1. Change the `_class_name` attribute in [model_index.json](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/model_index.json#L2) to `"pipeline_t2v_base_pixel"` and `"TextToVideoIFPipeline"`.
-2. Upload `showone_unet_3d_condition.py` to the [unet](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py) subfolder.
-3. Upload `pipeline_t2v_base_pixel.py` to the pipeline [repository](https://huggingface.co/sayakpaul/show-1-base-with-code/tree/main).
+- Change the `_class_name` attribute in [model_index.json](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/model_index.json#L2) to `"pipeline_t2v_base_pixel"` and `"TextToVideoIFPipeline"`.
+- Upload `showone_unet_3d_condition.py` to the [unet](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py) subfolder.
+- Upload `pipeline_t2v_base_pixel.py` to the pipeline [repository](https://huggingface.co/sayakpaul/show-1-base-with-code/tree/main).
To run inference, add the `trust_remote_code` argument while initializing the pipeline to handle all the "magic" behind the scenes.
-> [!WARNING]
-> As an additional precaution with `trust_remote_code=True`, we strongly encourage you to pass a commit hash to the `revision` parameter in [`~DiffusionPipeline.from_pretrained`] to make sure the code hasn't been updated with some malicious new lines of code (unless you fully trust the model owners).
-
```python
-from diffusers import DiffusionPipeline
import torch
+from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"/", trust_remote_code=True, torch_dtype=torch.float16
-).to("cuda")
-
-prompt = "hello"
-
-# Text embeds
-prompt_embeds, negative_embeds = pipeline.encode_prompt(prompt)
-
-# Keyframes generation (8x64x40, 2fps)
-video_frames = pipeline(
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_embeds,
- num_frames=8,
- height=40,
- width=64,
- num_inference_steps=2,
- guidance_scale=9.0,
- output_type="pt"
-).frames
-```
-
-As an additional reference, take a look at the repository structure of [stabilityai/japanese-stable-diffusion-xl](https://huggingface.co/stabilityai/japanese-stable-diffusion-xl/) which also uses the `trust_remote_code` feature.
-
-```python
-from diffusers import DiffusionPipeline
-import torch
-
-pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/japanese-stable-diffusion-xl", trust_remote_code=True
)
-pipeline.to("cuda")
```
+
+> [!WARNING]
+> As an additional precaution with `trust_remote_code=True`, we strongly encourage passing a commit hash to the `revision` argument in [`~DiffusionPipeline.from_pretrained`] to make sure the code hasn't been updated with new malicious code (unless you fully trust the model owners).
+
+## Resources
+
+- Take a look at Issue [#841](https://github.com/huggingface/diffusers/issues/841) for more context about why we're adding community pipelines to help everyone easily share their work without being slowed down.
+- Check out the [stabilityai/japanese-stable-diffusion-xl](https://huggingface.co/stabilityai/japanese-stable-diffusion-xl/) repository for an additional example of a community pipeline that also uses the `trust_remote_code` feature.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/diffedit.md b/docs/source/en/using-diffusers/diffedit.md
index bb1c234dd6..adea210263 100644
--- a/docs/source/en/using-diffusers/diffedit.md
+++ b/docs/source/en/using-diffusers/diffedit.md
@@ -156,11 +156,8 @@ print(source_prompts)
print(target_prompts)
```
-
-
-Check out the [generation strategy](https://huggingface.co/docs/transformers/main/en/generation_strategies) guide if you're interested in learning more about strategies for generating different quality text.
-
-
+> [!TIP]
+> Check out the [generation strategy](https://huggingface.co/docs/transformers/main/en/generation_strategies) guide if you're interested in learning more about strategies for generating different quality text.
Load the text encoder model used by the [`StableDiffusionDiffEditPipeline`] to encode the text. You'll use the text encoder to compute the text embeddings:
diff --git a/docs/source/en/using-diffusers/image_quality.md b/docs/source/en/using-diffusers/image_quality.md
index 517d985190..29ce483d5e 100644
--- a/docs/source/en/using-diffusers/image_quality.md
+++ b/docs/source/en/using-diffusers/image_quality.md
@@ -10,13 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Controlling image quality
-
-The components of a diffusion model, like the UNet and scheduler, can be optimized to improve the quality of generated images leading to better details. These techniques are especially useful if you don't have the resources to simply use a larger model for inference. You can enable these techniques during inference without any additional training.
-
-This guide will show you how to turn these techniques on in your pipeline and how to configure them to improve the quality of your generated images.
-
-## Details
+# FreeU
[FreeU](https://hf.co/papers/2309.11497) improves image details by rebalancing the UNet's backbone and skip connection weights. The skip connections can cause the model to overlook some of the backbone semantics which may lead to unnatural image details in the generated image. This technique does not require any additional training and can be applied on the fly during inference for tasks like image-to-image and text-to-video.
@@ -139,7 +133,7 @@ export_to_video(video_frames, "teddy_bear.mp4", fps=10)
-Call the [`pipelines.StableDiffusionMixin.disable_freeu`] method to disable FreeU.
+Call the [`~pipelines.StableDiffusionMixin.disable_freeu`] method to disable FreeU.
```py
pipeline.disable_freeu()
diff --git a/docs/source/en/using-diffusers/img2img.md b/docs/source/en/using-diffusers/img2img.md
index 3f42c9396d..ef00bf7f9b 100644
--- a/docs/source/en/using-diffusers/img2img.md
+++ b/docs/source/en/using-diffusers/img2img.md
@@ -33,11 +33,8 @@ pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
```
-
-
-You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, then you don't need to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](../optimization/fp16#scaled-dot-product-attention).
-
-
+> [!TIP]
+> You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, then you don't need to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](../optimization/fp16#scaled-dot-product-attention).
2. Load an image to pass to the pipeline:
@@ -386,11 +383,8 @@ prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipeline(prompt, image=init_image, output_type="latent").images[0]
```
-
-
-It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in latent space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE.
-
-
+> [!TIP]
+> It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in latent space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE.
Pass the latent output from this pipeline to the next pipeline to generate an image in a [comic book art style](https://huggingface.co/ogkalu/Comic-Diffusion):
@@ -449,11 +443,8 @@ prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image_1 = pipeline(prompt, image=init_image, output_type="latent").images[0]
```
-
-
-It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in *latent* space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE.
-
-
+> [!TIP]
+> It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in *latent* space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE.
Chain it to an upscaler pipeline to increase the image resolution:
diff --git a/docs/source/en/using-diffusers/inference_with_lcm.md b/docs/source/en/using-diffusers/inference_with_lcm.md
index d0a47449ad..cde4168d38 100644
--- a/docs/source/en/using-diffusers/inference_with_lcm.md
+++ b/docs/source/en/using-diffusers/inference_with_lcm.md
@@ -257,7 +257,7 @@ LCMs are compatible with adapters like LoRA, ControlNet, T2I-Adapter, and Animat
### LoRA
-[LoRA](../using-diffusers/loading_adapters#lora) adapters can be rapidly finetuned to learn a new style from just a few images and plugged into a pretrained model to generate images in that style.
+[LoRA](../tutorials/using_peft_for_inference) adapters can be rapidly finetuned to learn a new style from just a few images and plugged into a pretrained model to generate images in that style.
diff --git a/docs/source/en/using-diffusers/inference_with_tcd_lora.md b/docs/source/en/using-diffusers/inference_with_tcd_lora.md
index 88dd4733b5..2aaf9c8aa8 100644
--- a/docs/source/en/using-diffusers/inference_with_tcd_lora.md
+++ b/docs/source/en/using-diffusers/inference_with_tcd_lora.md
@@ -18,7 +18,7 @@ Trajectory Consistency Distillation (TCD) enables a model to generate higher qua
The major advantages of TCD are:
-- Better than Teacher: TCD demonstrates superior generative quality at both small and large inference steps and exceeds the performance of [DPM-Solver++(2S)](../../api/schedulers/multistep_dpm_solver) with Stable Diffusion XL (SDXL). There is no additional discriminator or LPIPS supervision included during TCD training.
+- Better than Teacher: TCD demonstrates superior generative quality at both small and large inference steps and exceeds the performance of [DPM-Solver++(2S)](../api/schedulers/multistep_dpm_solver) with Stable Diffusion XL (SDXL). There is no additional discriminator or LPIPS supervision included during TCD training.
- Flexible Inference Steps: The inference steps for TCD sampling can be freely adjusted without adversely affecting the image quality.
@@ -166,7 +166,7 @@ image = pipe(
TCD-LoRA also supports other LoRAs trained on different styles. For example, let's load the [TheLastBen/Papercut_SDXL](https://huggingface.co/TheLastBen/Papercut_SDXL) LoRA and fuse it with the TCD-LoRA with the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] method.
> [!TIP]
-> Check out the [Merge LoRAs](merge_loras) guide to learn more about efficient merging methods.
+> Check out the [Merge LoRAs](../tutorials/using_peft_for_inference#merge) guide to learn more about efficient merging methods.
```python
import torch
@@ -335,9 +335,8 @@ grid_image = make_image_grid([canny_image, image], rows=1, cols=2)
```

-
-The inference parameters in this example might not work for all examples, so we recommend you to try different values for `num_inference_steps`, `guidance_scale`, `controlnet_conditioning_scale` and `cross_attention_kwargs` parameters and choose the best one.
-
+> [!TIP]
+> The inference parameters in this example might not work for all examples, so we recommend you to try different values for `num_inference_steps`, `guidance_scale`, `controlnet_conditioning_scale` and `cross_attention_kwargs` parameters and choose the best one.
diff --git a/docs/source/en/using-diffusers/inpaint.md b/docs/source/en/using-diffusers/inpaint.md
index 695ec04088..28da3a68a5 100644
--- a/docs/source/en/using-diffusers/inpaint.md
+++ b/docs/source/en/using-diffusers/inpaint.md
@@ -33,11 +33,8 @@ pipeline.enable_model_cpu_offload()
pipeline.enable_xformers_memory_efficient_attention()
```
-
-
-You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, it's not necessary to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](../optimization/fp16#scaled-dot-product-attention).
-
-
+> [!TIP]
+> You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, it's not necessary to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](../optimization/fp16#scaled-dot-product-attention).
2. Load the base and mask images:
@@ -639,11 +636,8 @@ pipeline.enable_xformers_memory_efficient_attention()
image = pipeline(prompt=prompt, image=image_inpainting, mask_image=mask_image, output_type="latent").images[0]
```
-
-
-It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in latent space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE. For example, in the [Text-to-image-to-inpaint](#text-to-image-to-inpaint) section, Kandinsky 2.2 uses a different VAE class than the Stable Diffusion model so it won't work. But if you use Stable Diffusion v1.5 for both pipelines, then you can keep everything in latent space because they both use [`AutoencoderKL`].
-
-
+> [!TIP]
+> It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in latent space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE. For example, in the [Text-to-image-to-inpaint](#text-to-image-to-inpaint) section, Kandinsky 2.2 uses a different VAE class than the Stable Diffusion model so it won't work. But if you use Stable Diffusion v1.5 for both pipelines, then you can keep everything in latent space because they both use [`AutoencoderKL`].
Finally, you can pass this image to an image-to-image pipeline to put the finishing touches on it. It is more efficient to use the [`~AutoPipelineForImage2Image.from_pipe`] method to reuse the existing pipeline components, and avoid unnecessarily loading all the pipeline components into memory again.
diff --git a/docs/source/en/using-diffusers/kandinsky.md b/docs/source/en/using-diffusers/kandinsky.md
index a482380524..2671c108b3 100644
--- a/docs/source/en/using-diffusers/kandinsky.md
+++ b/docs/source/en/using-diffusers/kandinsky.md
@@ -31,15 +31,12 @@ Before you begin, make sure you have the following libraries installed:
#!pip install -q diffusers transformers accelerate
```
-
-
-Kandinsky 2.1 and 2.2 usage is very similar! The only difference is Kandinsky 2.2 doesn't accept `prompt` as an input when decoding the latents. Instead, Kandinsky 2.2 only accepts `image_embeds` during decoding.
-
-
-
-Kandinsky 3 has a more concise architecture and it doesn't require a prior model. This means it's usage is identical to other diffusion models like [Stable Diffusion XL](sdxl).
-
-
+> [!WARNING]
+> Kandinsky 2.1 and 2.2 usage is very similar! The only difference is Kandinsky 2.2 doesn't accept `prompt` as an input when decoding the latents. Instead, Kandinsky 2.2 only accepts `image_embeds` during decoding.
+>
+>
+>
+> Kandinsky 3 has a more concise architecture and it doesn't require a prior model. This means it's usage is identical to other diffusion models like [Stable Diffusion XL](sdxl).
## Text-to-image
@@ -321,20 +318,17 @@ make_image_grid([original_image.resize((512, 512)), image.resize((512, 512))], r
## Inpainting
-
-
-⚠️ The Kandinsky models use ⬜️ **white pixels** to represent the masked area now instead of black pixels. If you are using [`KandinskyInpaintPipeline`] in production, you need to change the mask to use white pixels:
-
-```py
-# For PIL input
-import PIL.ImageOps
-mask = PIL.ImageOps.invert(mask)
-
-# For PyTorch and NumPy input
-mask = 1 - mask
-```
-
-
+> [!WARNING]
+> ⚠️ The Kandinsky models use ⬜️ **white pixels** to represent the masked area now instead of black pixels. If you are using [`KandinskyInpaintPipeline`] in production, you need to change the mask to use white pixels:
+>
+> ```py
+> # For PIL input
+> import PIL.ImageOps
+> mask = PIL.ImageOps.invert(mask)
+>
+> # For PyTorch and NumPy input
+> mask = 1 - mask
+> ```
For inpainting, you'll need the original image, a mask of the area to replace in the original image, and a text prompt of what to inpaint. Load the prior pipeline:
@@ -565,11 +559,8 @@ image
## ControlNet
-
-
-⚠️ ControlNet is only supported for Kandinsky 2.2!
-
-
+> [!WARNING]
+> ⚠️ ControlNet is only supported for Kandinsky 2.2!
ControlNet enables conditioning large pretrained diffusion models with additional inputs such as a depth map or edge detection. For example, you can condition Kandinsky 2.2 with a depth map so the model understands and preserves the structure of the depth image.
diff --git a/docs/source/en/using-diffusers/loading.md b/docs/source/en/using-diffusers/loading.md
index 591a138296..3fb608b1c2 100644
--- a/docs/source/en/using-diffusers/loading.md
+++ b/docs/source/en/using-diffusers/loading.md
@@ -10,574 +10,243 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Load pipelines
-
[[open-in-colab]]
-Diffusion systems consist of multiple components like parameterized models and schedulers that interact in complex ways. That is why we designed the [`DiffusionPipeline`] to wrap the complexity of the entire diffusion system into an easy-to-use API. At the same time, the [`DiffusionPipeline`] is entirely customizable so you can modify each component to build a diffusion system for your use case.
+# DiffusionPipeline
-This guide will show you how to load:
+Diffusion models consists of multiple components like UNets or diffusion transformers (DiTs), text encoders, variational autoencoders (VAEs), and schedulers. The [`DiffusionPipeline`] wraps all of these components into a single easy-to-use API without giving up the flexibility to modify it's components.
-- pipelines from the Hub and locally
-- different components into a pipeline
-- multiple pipelines without increasing memory usage
-- checkpoint variants such as different floating point types or non-exponential mean averaged (EMA) weights
+This guide will show you how to load a [`DiffusionPipeline`].
-## Load a pipeline
+## Loading a pipeline
+
+[`DiffusionPipeline`] is a base pipeline class that automatically selects and returns an instance of a model's pipeline subclass, like [`QwenImagePipeline`], by scanning the `model_index.json` file for the class name.
+
+Pass a model id to [`~DiffusionPipeline.from_pretrained`] to load a pipeline.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
+)
+```
+
+Every model has a specific pipeline subclass that inherits from [`DiffusionPipeline`]. A subclass usually has a narrow focus and are task-specific. See the table below for an example.
+
+| pipeline subclass | task |
+|---|---|
+| [`QwenImagePipeline`] | text-to-image |
+| [`QwenImageImg2ImgPipeline`] | image-to-image |
+| [`QwenImageInpaintPipeline`] | inpaint |
+
+You could use the subclass directly by passing a model id to [`~QwenImagePipeline.from_pretrained`].
+
+```py
+import torch
+from diffusers import QwenImagePipeline
+
+pipeline = QwenImagePipeline.from_pretrained(
+ "Qwen/Qwen-Image", torch_dtype=torch.bfloat16, device_map="cuda"
+)
+```
> [!TIP]
-> Skip to the [DiffusionPipeline explained](#diffusionpipeline-explained) section if you're interested in an explanation about how the [`DiffusionPipeline`] class works.
+> Refer to the [Single file format](./other-formats#single-file-format) docs to learn how to load single file models.
-There are two ways to load a pipeline for a task:
+### Local pipelines
-1. Load the generic [`DiffusionPipeline`] class and allow it to automatically detect the correct pipeline class from the checkpoint.
-2. Load a specific pipeline class for a specific task.
-
-
-
-
-The [`DiffusionPipeline`] class is a simple and generic way to load the latest trending diffusion model from the [Hub](https://huggingface.co/models?library=diffusers&sort=trending). It uses the [`~DiffusionPipeline.from_pretrained`] method to automatically detect the correct pipeline class for a task from the checkpoint, downloads and caches all the required configuration and weight files, and returns a pipeline ready for inference.
-
-```python
-from diffusers import DiffusionPipeline
-
-pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", use_safetensors=True)
-```
-
-This same checkpoint can also be used for an image-to-image task. The [`DiffusionPipeline`] class can handle any task as long as you provide the appropriate inputs. For example, for an image-to-image task, you need to pass an initial image to the pipeline.
+Pipelines can also be run locally. Use [`~huggingface_hub.snapshot_download`] to download a model repository.
```py
-from diffusers import DiffusionPipeline
+from huggingface_hub import snapshot_download
-pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", use_safetensors=True)
-
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png")
-prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", image=init_image).images[0]
+snapshot_download(repo_id="Qwen/Qwen-Image")
```
-
-
-
-Checkpoints can be loaded by their specific pipeline class if you already know it. For example, to load a Stable Diffusion model, use the [`StableDiffusionPipeline`] class.
-
-```python
-from diffusers import StableDiffusionPipeline
-
-pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", use_safetensors=True)
-```
-
-This same checkpoint may also be used for another task like image-to-image. To differentiate what task you want to use the checkpoint for, you have to use the corresponding task-specific pipeline class. For example, to use the same checkpoint for image-to-image, use the [`StableDiffusionImg2ImgPipeline`] class.
+The model is downloaded to your [cache](../installation#cache). Pass the folder path to [`~QwenImagePipeline.from_pretrained`] to load it.
```py
-from diffusers import StableDiffusionImg2ImgPipeline
-
-pipeline = StableDiffusionImg2ImgPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", use_safetensors=True)
-```
-
-
-
-
-Use the Space below to gauge a pipeline's memory requirements before you download and load it to see if it runs on your hardware.
-
-
-
-
-
-
-
-
-### Specifying Component-Specific Data Types
-
-You can customize the data types for individual sub-models by passing a dictionary to the `torch_dtype` parameter. This allows you to load different components of a pipeline in different floating point precisions. For instance, if you want to load the transformer with `torch.bfloat16` and all other components with `torch.float16`, you can pass a dictionary mapping:
-
-```python
-from diffusers import HunyuanVideoPipeline
import torch
+from diffusers import QwenImagePipeline
-pipe = HunyuanVideoPipeline.from_pretrained(
- "hunyuanvideo-community/HunyuanVideo",
- torch_dtype={"transformer": torch.bfloat16, "default": torch.float16},
+pipeline = QwenImagePipeline.from_pretrained(
+ "path/to/your/cache", torch_dtype=torch.bfloat16, device_map="cuda"
)
-print(pipe.transformer.dtype, pipe.vae.dtype) # (torch.bfloat16, torch.float16)
```
-If a component is not explicitly specified in the dictionary and no `default` is provided, it will be loaded with `torch.float32`.
+The [`~QwenImagePipeline.from_pretrained`] method won't download files from the Hub when it detects a local path. But this also means it won't download and cache any updates that have been made to the model either.
-### Local pipeline
+## Pipeline data types
-To load a pipeline locally, use [git-lfs](https://git-lfs.github.com/) to manually download a checkpoint to your local disk.
+Use the `torch_dtype` argument in [`~DiffusionPipeline.from_pretrained`] to load a model with a specific data type. This allows you to load different models in different precisions. For example, loading a large transformer model in half-precision reduces the memory required.
-```bash
-git-lfs install
-git clone https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5
+Pass the data type for each model as a dictionary to `torch_dtype`. Use the `default` key to set the default data type. If a model isn't in the dictionary and `default` isn't provided, it is loaded in full precision (`torch.float32`).
+
+```py
+import torch
+from diffusers import QwenImagePipeline
+
+pipeline = QwenImagePipeline.from_pretrained(
+ "Qwen/Qwen-Image",
+ torch_dtype={"transformer": torch.bfloat16, "default": torch.float16},
+)
+print(pipeline.transformer.dtype, pipeline.vae.dtype)
```
-This creates a local folder, ./stable-diffusion-v1-5, on your disk and you should pass its path to [`~DiffusionPipeline.from_pretrained`].
+You don't need to use a dictionary if you're loading all the models in the same data type.
-```python
+```py
+import torch
+from diffusers import QwenImagePipeline
+
+pipeline = QwenImagePipeline.from_pretrained(
+ "Qwen/Qwen-Image", torch_dtype=torch.bfloat16
+)
+print(pipeline.transformer.dtype, pipeline.vae.dtype)
+```
+
+## Device placement
+
+The `device_map` argument determines individual model or pipeline placement on an accelerator like a GPU. It is especially helpful when there are multiple GPUs.
+
+A pipeline supports two options for `device_map`, `"cuda"` and `"balanced"`. Refer to the table below to compare the placement strategies.
+
+| parameter | description |
+|---|---|
+| `"cuda"` | places pipeline on a supported accelerator device like CUDA |
+| `"balanced"` | evenly distributes pipeline on all GPUs |
+
+Use the `max_memory` argument in [`~DiffusionPipeline.from_pretrained`] to allocate a maximum amount of memory to use on each device. By default, Diffusers uses the maximum amount available.
+
+```py
+import torch
from diffusers import DiffusionPipeline
-stable_diffusion = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5", use_safetensors=True)
+max_memory = {0: "16GB", 1: "16GB"}
+pipeline = DiffusionPipeline.from_pretrained(
+ "Qwen/Qwen-Image",
+ torch_dtype=torch.bfloat16,
+ device_map="cuda",
+)
```
-The [`~DiffusionPipeline.from_pretrained`] method won't download files from the Hub when it detects a local path, but this also means it won't download and cache the latest changes to a checkpoint.
-
-## Customize a pipeline
-
-You can customize a pipeline by loading different components into it. This is important because you can:
-
-- change to a scheduler with faster generation speed or higher generation quality depending on your needs (call the `scheduler.compatibles` method on your pipeline to see compatible schedulers)
-- change a default pipeline component to a newer and better performing one
-
-For example, let's customize the default [stabilityai/stable-diffusion-xl-base-1.0](https://hf.co/stabilityai/stable-diffusion-xl-base-1.0) checkpoint with:
-
-- The [`HeunDiscreteScheduler`] to generate higher quality images at the expense of slower generation speed. You must pass the `subfolder="scheduler"` parameter in [`~HeunDiscreteScheduler.from_pretrained`] to load the scheduler configuration into the correct [subfolder](https://hf.co/stabilityai/stable-diffusion-xl-base-1.0/tree/main/scheduler) of the pipeline repository.
-- A more stable VAE that runs in fp16.
+The `hf_device_map` attribute allows you to access and view the `device_map`.
```py
-from diffusers import StableDiffusionXLPipeline, HeunDiscreteScheduler, AutoencoderKL
+print(pipeline.hf_device_map)
+# {'unet': 1, 'vae': 1, 'safety_checker': 0, 'text_encoder': 0}
+```
+
+Reset a pipeline's `device_map` with the [`~DiffusionPipeline.reset_device_map`] method. This is necessary if you want to use methods such as `.to()`, [`~DiffusionPipeline.enable_sequential_cpu_offload`], and [`~DiffusionPipeline.enable_model_cpu_offload`].
+
+```py
+pipeline.reset_device_map()
+```
+
+## Parallel loading
+
+Large models are often [sharded](../training/distributed_inference#model-sharding) into smaller files so that they are easier to load. Diffusers supports loading shards in parallel to speed up the loading process.
+
+Set `HF_ENABLE_PARALLEL_LOADING` to `"YES"` to enable parallel loading of shards.
+
+The `device_map` argument should be set to `"cuda"` to pre-allocate a large chunk of memory based on the model size. This substantially reduces model load time because warming up the memory allocator now avoids many smaller calls to the allocator later.
+
+```py
+import os
import torch
+from diffusers import DiffusionPipeline
-scheduler = HeunDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")
-vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
+os.environ["HF_ENABLE_PARALLEL_LOADING"] = "YES"
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "Wan-AI/Wan2.2-I2V-A14B-Diffusers", torch_dtype=torch.bfloat16, device_map="cuda"
+)
```
-Now pass the new scheduler and VAE to the [`StableDiffusionXLPipeline`].
+## Replacing models in a pipeline
+
+[`DiffusionPipeline`] is flexible and accommodates loading different models or schedulers. You can experiment with different schedulers to optimize for generation speed or quality, and you can replace models with more performant ones.
+
+The example below uses a more stable VAE version.
```py
-pipeline = StableDiffusionXLPipeline.from_pretrained(
+import torch
+from diffusers import DiffusionPipeline, AutoModel
+
+vae = AutoModel.from_pretrained(
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
+)
+
+pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
- scheduler=scheduler,
vae=vae,
torch_dtype=torch.float16,
- variant="fp16",
- use_safetensors=True
-).to("cuda")
+ device_map="cuda"
+)
```
-## Reuse a pipeline
+## Reusing models in multiple pipelines
-When you load multiple pipelines that share the same model components, it makes sense to reuse the shared components instead of reloading everything into memory again, especially if your hardware is memory-constrained. For example:
+When working with multiple pipelines that use the same model, the [`~DiffusionPipeline.from_pipe`] method enables reusing a model instead of reloading it each time. This allows you to use multiple pipelines without increasing memory usage.
-1. You generated an image with the [`StableDiffusionPipeline`] but you want to improve its quality with the [`StableDiffusionSAGPipeline`]. Both of these pipelines share the same pretrained model, so it'd be a waste of memory to load the same model twice.
-2. You want to add a model component, like a [`MotionAdapter`](../api/pipelines/animatediff#animatediffpipeline), to [`AnimateDiffPipeline`] which was instantiated from an existing [`StableDiffusionPipeline`]. Again, both pipelines share the same pretrained model, so it'd be a waste of memory to load an entirely new pipeline again.
+Memory usage is determined by the pipeline with the highest memory requirement regardless of the number of pipelines.
-With the [`DiffusionPipeline.from_pipe`] API, you can switch between multiple pipelines to take advantage of their different features without increasing memory-usage. It is similar to turning on and off a feature in your pipeline.
-
-> [!TIP]
-> To switch between tasks (rather than features), use the [`~DiffusionPipeline.from_pipe`] method with the [AutoPipeline](../api/pipelines/auto_pipeline) class, which automatically identifies the pipeline class based on the task (learn more in the [AutoPipeline](../tutorials/autopipeline) tutorial).
-
-Let's start with a [`StableDiffusionPipeline`] and then reuse the loaded model components to create a [`StableDiffusionSAGPipeline`] to increase generation quality. You'll use the [`StableDiffusionPipeline`] with an [IP-Adapter](./ip_adapter) to generate a bear eating pizza.
-
-```python
-from diffusers import DiffusionPipeline, StableDiffusionSAGPipeline
-import torch
-import gc
-from diffusers.utils import load_image
-from accelerate.utils import compute_module_sizes
-
-image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")
-
-pipe_sd = DiffusionPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", torch_dtype=torch.float16)
-pipe_sd.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
-pipe_sd.set_ip_adapter_scale(0.6)
-pipe_sd.to("cuda")
-
-generator = torch.Generator(device="cpu").manual_seed(33)
-out_sd = pipe_sd(
- prompt="bear eats pizza",
- negative_prompt="wrong white balance, dark, sketches,worst quality,low quality",
- ip_adapter_image=image,
- num_inference_steps=50,
- generator=generator,
-).images[0]
-out_sd
-```
-
-
-
-
-
-For reference, you can check how much memory this process consumed.
-
-```python
-def bytes_to_giga_bytes(bytes):
- return bytes / 1024 / 1024 / 1024
-print(f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB")
-"Max memory allocated: 4.406213283538818 GB"
-```
-
-Now, reuse the same pipeline components from [`StableDiffusionPipeline`] in [`StableDiffusionSAGPipeline`] with the [`~DiffusionPipeline.from_pipe`] method.
+The example below loads a pipeline and then loads a second pipeline with [`~DiffusionPipeline.from_pipe`] to use [perturbed-attention guidance (PAG)](../api/pipelines/pag) to improve generation quality.
> [!WARNING]
-> Some pipeline methods may not function properly on new pipelines created with [`~DiffusionPipeline.from_pipe`]. For instance, the [`~DiffusionPipeline.enable_model_cpu_offload`] method installs hooks on the model components based on a unique offloading sequence for each pipeline. If the models are executed in a different order in the new pipeline, the CPU offloading may not work correctly.
->
-> To ensure everything works as expected, we recommend re-applying a pipeline method on a new pipeline created with [`~DiffusionPipeline.from_pipe`].
+> Use [`AutoPipelineForText2Image`] because [`DiffusionPipeline`] doesn't support PAG. Refer to the [AutoPipeline](../tutorials/autopipeline) docs to learn more.
-```python
-pipe_sag = StableDiffusionSAGPipeline.from_pipe(
- pipe_sd
+```py
+import torch
+from diffusers import AutoPipelineForText2Image
+
+pipeline_sdxl = AutoPipelineForText2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, device_map="cuda"
)
-
-generator = torch.Generator(device="cpu").manual_seed(33)
-out_sag = pipe_sag(
- prompt="bear eats pizza",
- negative_prompt="wrong white balance, dark, sketches,worst quality,low quality",
- ip_adapter_image=image,
- num_inference_steps=50,
- generator=generator,
- guidance_scale=1.0,
- sag_scale=0.75
-).images[0]
-out_sag
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+image = pipeline_sdxl(prompt).images[0]
+print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
+# Max memory reserved: 10.47 GB
```
-
-
-
-
-If you check the memory usage, you'll see it remains the same as before because [`StableDiffusionPipeline`] and [`StableDiffusionSAGPipeline`] are sharing the same pipeline components. This allows you to use them interchangeably without any additional memory overhead.
+Set `enable_pag=True` in the second pipeline to enable PAG. The second pipeline uses the same amount of memory because it shares model weights with the first one.
```py
-print(f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB")
-"Max memory allocated: 4.406213283538818 GB"
+pipeline = AutoPipelineForText2Image.from_pipe(
+ pipeline_sdxl, enable_pag=True
+)
+prompt = """
+cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
+highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
+"""
+image = pipeline(prompt).images[0]
+print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
+# Max memory reserved: 10.47 GB
```
-Let's animate the image with the [`AnimateDiffPipeline`] and also add a [`MotionAdapter`] module to the pipeline. For the [`AnimateDiffPipeline`], you need to unload the IP-Adapter first and reload it *after* you've created your new pipeline (this only applies to the [`AnimateDiffPipeline`]).
+> [!WARNING]
+> Pipelines created by [`~DiffusionPipeline.from_pipe`] share the same models and *state*. Modifying the state of a model in one pipeline affects all the other pipelines that share the same model.
-```py
-from diffusers import AnimateDiffPipeline, MotionAdapter, DDIMScheduler
-from diffusers.utils import export_to_gif
-
-pipe_sag.unload_ip_adapter()
-adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
-
-pipe_animate = AnimateDiffPipeline.from_pipe(pipe_sd, motion_adapter=adapter)
-pipe_animate.scheduler = DDIMScheduler.from_config(pipe_animate.scheduler.config, beta_schedule="linear")
-# load IP-Adapter and LoRA weights again
-pipe_animate.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
-pipe_animate.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
-pipe_animate.to("cuda")
-
-generator = torch.Generator(device="cpu").manual_seed(33)
-pipe_animate.set_adapters("zoom-out", adapter_weights=0.75)
-out = pipe_animate(
- prompt="bear eats pizza",
- num_frames=16,
- num_inference_steps=50,
- ip_adapter_image=image,
- generator=generator,
-).frames[0]
-export_to_gif(out, "out_animate.gif")
-```
-
-
-
-
-
-The [`AnimateDiffPipeline`] is more memory-intensive and consumes 15GB of memory (see the [Memory-usage of from_pipe](#memory-usage-of-from_pipe) section to learn what this means for your memory-usage).
-
-```py
-print(f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB")
-"Max memory allocated: 15.178664207458496 GB"
-```
-
-### Modify from_pipe components
-
-Pipelines loaded with [`~DiffusionPipeline.from_pipe`] can be customized with different model components or methods. However, whenever you modify the *state* of the model components, it affects all the other pipelines that share the same components. For example, if you call [`~diffusers.loaders.IPAdapterMixin.unload_ip_adapter`] on the [`StableDiffusionSAGPipeline`], you won't be able to use IP-Adapter with the [`StableDiffusionPipeline`] because it's been removed from their shared components.
-
-```py
-pipe.sag_unload_ip_adapter()
-
-generator = torch.Generator(device="cpu").manual_seed(33)
-out_sd = pipe_sd(
- prompt="bear eats pizza",
- negative_prompt="wrong white balance, dark, sketches,worst quality,low quality",
- ip_adapter_image=image,
- num_inference_steps=50,
- generator=generator,
-).images[0]
-"AttributeError: 'NoneType' object has no attribute 'image_projection_layers'"
-```
-
-### Memory usage of from_pipe
-
-The memory requirement of loading multiple pipelines with [`~DiffusionPipeline.from_pipe`] is determined by the pipeline with the highest memory-usage regardless of the number of pipelines you create.
-
-| Pipeline | Memory usage (GB) |
-|---|---|
-| StableDiffusionPipeline | 4.400 |
-| StableDiffusionSAGPipeline | 4.400 |
-| AnimateDiffPipeline | 15.178 |
-
-The [`AnimateDiffPipeline`] has the highest memory requirement, so the *total memory-usage* is based only on the [`AnimateDiffPipeline`]. Your memory-usage will not increase if you create additional pipelines as long as their memory requirements doesn't exceed that of the [`AnimateDiffPipeline`]. Each pipeline can be used interchangeably without any additional memory overhead.
+Some methods may not work correctly on pipelines created with [`~DiffusionPipeline.from_pipe`]. For example, [`~DiffusionPipeline.enable_model_cpu_offload`] relies on a unique model execution order, which may differ in the new pipeline. To ensure proper functionality, reapply these methods on the new pipeline.
## Safety checker
-Diffusers implements a [safety checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) for Stable Diffusion models which can generate harmful content. The safety checker screens the generated output against known hardcoded not-safe-for-work (NSFW) content. If for whatever reason you'd like to disable the safety checker, pass `safety_checker=None` to the [`~DiffusionPipeline.from_pretrained`] method.
+Diffusers provides a [safety checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) for older Stable Diffusion models to prevent generating harmful content. It screens the generated output against a set of hardcoded harmful concepts.
-```python
+If you want to disable the safety checker, pass `safety_checker=None` in [`~DiffusionPipeline.from_pretrained`] as shown below.
+
+```py
from diffusers import DiffusionPipeline
-pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, use_safetensors=True)
+pipeline = DiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None
+)
"""
You have disabled the safety checker for by passing `safety_checker=None`. Ensure that you abide by the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend keeping the safety filter enabled in all public-facing circumstances, disabling it only for use cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
"""
-```
-
-## Checkpoint variants
-
-A checkpoint variant is usually a checkpoint whose weights are:
-
-- Stored in a different floating point type, such as [torch.float16](https://pytorch.org/docs/stable/tensors.html#data-types), because it only requires half the bandwidth and storage to download. You can't use this variant if you're continuing training or using a CPU.
-- Non-exponential mean averaged (EMA) weights which shouldn't be used for inference. You should use this variant to continue finetuning a model.
-
-> [!TIP]
-> When the checkpoints have identical model structures, but they were trained on different datasets and with a different training setup, they should be stored in separate repositories. For example, [stabilityai/stable-diffusion-2](https://hf.co/stabilityai/stable-diffusion-2) and [stabilityai/stable-diffusion-2-1](https://hf.co/stabilityai/stable-diffusion-2-1) are stored in separate repositories.
-
-Otherwise, a variant is **identical** to the original checkpoint. They have exactly the same serialization format (like [safetensors](./using_safetensors)), model structure, and their weights have identical tensor shapes.
-
-| **checkpoint type** | **weight name** | **argument for loading weights** |
-|---------------------|---------------------------------------------|----------------------------------|
-| original | diffusion_pytorch_model.safetensors | |
-| floating point | diffusion_pytorch_model.fp16.safetensors | `variant`, `torch_dtype` |
-| non-EMA | diffusion_pytorch_model.non_ema.safetensors | `variant` |
-
-There are two important arguments for loading variants:
-
-- `torch_dtype` specifies the floating point precision of the loaded checkpoint. For example, if you want to save bandwidth by loading a fp16 variant, you should set `variant="fp16"` and `torch_dtype=torch.float16` to *convert the weights* to fp16. Otherwise, the fp16 weights are converted to the default fp32 precision.
-
- If you only set `torch_dtype=torch.float16`, the default fp32 weights are downloaded first and then converted to fp16.
-
-- `variant` specifies which files should be loaded from the repository. For example, if you want to load a non-EMA variant of a UNet from [stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main/unet), set `variant="non_ema"` to download the `non_ema` file.
-
-
-
-
-```py
-from diffusers import DiffusionPipeline
-import torch
-
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16, use_safetensors=True
-)
-```
-
-
-
-
-```py
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", variant="non_ema", use_safetensors=True
-)
-```
-
-
-
-
-Use the `variant` parameter in the [`DiffusionPipeline.save_pretrained`] method to save a checkpoint as a different floating point type or as a non-EMA variant. You should try save a variant to the same folder as the original checkpoint, so you have the option of loading both from the same folder.
-
-
-
-
-```python
-from diffusers import DiffusionPipeline
-
-pipeline.save_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", variant="fp16")
-```
-
-
-
-
-```py
-pipeline.save_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", variant="non_ema")
-```
-
-
-
-
-If you don't save the variant to an existing folder, you must specify the `variant` argument otherwise it'll throw an `Exception` because it can't find the original checkpoint.
-
-```python
-# 👎 this won't work
-pipeline = DiffusionPipeline.from_pretrained(
- "./stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
-)
-# 👍 this works
-pipeline = DiffusionPipeline.from_pretrained(
- "./stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16, use_safetensors=True
-)
-```
-
-## DiffusionPipeline explained
-
-As a class method, [`DiffusionPipeline.from_pretrained`] is responsible for two things:
-
-- Download the latest version of the folder structure required for inference and cache it. If the latest folder structure is available in the local cache, [`DiffusionPipeline.from_pretrained`] reuses the cache and won't redownload the files.
-- Load the cached weights into the correct pipeline [class](../api/pipelines/overview#diffusers-summary) - retrieved from the `model_index.json` file - and return an instance of it.
-
-The pipelines' underlying folder structure corresponds directly with their class instances. For example, the [`StableDiffusionPipeline`] corresponds to the folder structure in [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5).
-
-```python
-from diffusers import DiffusionPipeline
-
-repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
-pipeline = DiffusionPipeline.from_pretrained(repo_id, use_safetensors=True)
-print(pipeline)
-```
-
-You'll see pipeline is an instance of [`StableDiffusionPipeline`], which consists of seven components:
-
-- `"feature_extractor"`: a [`~transformers.CLIPImageProcessor`] from 🤗 Transformers.
-- `"safety_checker"`: a [component](https://github.com/huggingface/diffusers/blob/e55687e1e15407f60f32242027b7bb8170e58266/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L32) for screening against harmful content.
-- `"scheduler"`: an instance of [`PNDMScheduler`].
-- `"text_encoder"`: a [`~transformers.CLIPTextModel`] from 🤗 Transformers.
-- `"tokenizer"`: a [`~transformers.CLIPTokenizer`] from 🤗 Transformers.
-- `"unet"`: an instance of [`UNet2DConditionModel`].
-- `"vae"`: an instance of [`AutoencoderKL`].
-
-```json
-StableDiffusionPipeline {
- "feature_extractor": [
- "transformers",
- "CLIPImageProcessor"
- ],
- "safety_checker": [
- "stable_diffusion",
- "StableDiffusionSafetyChecker"
- ],
- "scheduler": [
- "diffusers",
- "PNDMScheduler"
- ],
- "text_encoder": [
- "transformers",
- "CLIPTextModel"
- ],
- "tokenizer": [
- "transformers",
- "CLIPTokenizer"
- ],
- "unet": [
- "diffusers",
- "UNet2DConditionModel"
- ],
- "vae": [
- "diffusers",
- "AutoencoderKL"
- ]
-}
-```
-
-Compare the components of the pipeline instance to the [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main) folder structure, and you'll see there is a separate folder for each of the components in the repository:
-
-```
-.
-├── feature_extractor
-│ └── preprocessor_config.json
-├── model_index.json
-├── safety_checker
-│ ├── config.json
-| ├── model.fp16.safetensors
-│ ├── model.safetensors
-│ ├── pytorch_model.bin
-| └── pytorch_model.fp16.bin
-├── scheduler
-│ └── scheduler_config.json
-├── text_encoder
-│ ├── config.json
-| ├── model.fp16.safetensors
-│ ├── model.safetensors
-│ |── pytorch_model.bin
-| └── pytorch_model.fp16.bin
-├── tokenizer
-│ ├── merges.txt
-│ ├── special_tokens_map.json
-│ ├── tokenizer_config.json
-│ └── vocab.json
-├── unet
-│ ├── config.json
-│ ├── diffusion_pytorch_model.bin
-| |── diffusion_pytorch_model.fp16.bin
-│ |── diffusion_pytorch_model.f16.safetensors
-│ |── diffusion_pytorch_model.non_ema.bin
-│ |── diffusion_pytorch_model.non_ema.safetensors
-│ └── diffusion_pytorch_model.safetensors
-|── vae
-. ├── config.json
-. ├── diffusion_pytorch_model.bin
- ├── diffusion_pytorch_model.fp16.bin
- ├── diffusion_pytorch_model.fp16.safetensors
- └── diffusion_pytorch_model.safetensors
-```
-
-You can access each of the components of the pipeline as an attribute to view its configuration:
-
-```py
-pipeline.tokenizer
-CLIPTokenizer(
- name_or_path="/root/.cache/huggingface/hub/models--runwayml--stable-diffusion-v1-5/snapshots/39593d5650112b4cc580433f6b0435385882d819/tokenizer",
- vocab_size=49408,
- model_max_length=77,
- is_fast=False,
- padding_side="right",
- truncation_side="right",
- special_tokens={
- "bos_token": AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True),
- "eos_token": AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True),
- "unk_token": AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True),
- "pad_token": "<|endoftext|>",
- },
- clean_up_tokenization_spaces=True
-)
-```
-
-Every pipeline expects a [`model_index.json`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/model_index.json) file that tells the [`DiffusionPipeline`]:
-
-- which pipeline class to load from `_class_name`
-- which version of 🧨 Diffusers was used to create the model in `_diffusers_version`
-- what components from which library are stored in the subfolders (`name` corresponds to the component and subfolder name, `library` corresponds to the name of the library to load the class from, and `class` corresponds to the class name)
-
-```json
-{
- "_class_name": "StableDiffusionPipeline",
- "_diffusers_version": "0.6.0",
- "feature_extractor": [
- "transformers",
- "CLIPImageProcessor"
- ],
- "safety_checker": [
- "stable_diffusion",
- "StableDiffusionSafetyChecker"
- ],
- "scheduler": [
- "diffusers",
- "PNDMScheduler"
- ],
- "text_encoder": [
- "transformers",
- "CLIPTextModel"
- ],
- "tokenizer": [
- "transformers",
- "CLIPTokenizer"
- ],
- "unet": [
- "diffusers",
- "UNet2DConditionModel"
- ],
- "vae": [
- "diffusers",
- "AutoencoderKL"
- ]
-}
-```
+```
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/other-formats.md b/docs/source/en/using-diffusers/other-formats.md
index 11afbf29d3..b6e333ed77 100644
--- a/docs/source/en/using-diffusers/other-formats.md
+++ b/docs/source/en/using-diffusers/other-formats.md
@@ -10,77 +10,183 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Model files and layouts
-
[[open-in-colab]]
-Diffusion models are saved in various file types and organized in different layouts. Diffusers stores model weights as safetensors files in *Diffusers-multifolder* layout and it also supports loading files (like safetensors and ckpt files) from a *single-file* layout which is commonly used in the diffusion ecosystem.
+# Model formats
-Each layout has its own benefits and use cases, and this guide will show you how to load the different files and layouts, and how to convert them.
+Diffusion models are typically stored in the Diffusers format or single-file format. Model files can be stored in various file types such as safetensors, dduf, or ckpt.
-## Files
+> [!TIP]
+> Format refers to whether the weights are stored in a directory structure and file refers to the file type.
-PyTorch model weights are typically saved with Python's [pickle](https://docs.python.org/3/library/pickle.html) utility as ckpt or bin files. However, pickle is not secure and pickled files may contain malicious code that can be executed. This vulnerability is a serious concern given the popularity of model sharing. To address this security issue, the [Safetensors](https://hf.co/docs/safetensors) library was developed as a secure alternative to pickle, which saves models as safetensors files.
+This guide will show you how to load pipelines and models from these formats and files.
+
+## Diffusers format
+
+The Diffusers format stores each model (UNet, transformer, text encoder) in a separate subfolder. There are several benefits to storing models separately.
+
+- Faster overall pipeline initialization because you can load the individual model you need or load them all in parallel.
+- Reduced memory usage because you don't need to load all the pipeline components if you only need one model. [Reuse](./loading#reusing-models-in-multiple-pipelines) a model that is shared between multiple pipelines.
+- Lower storage requirements because common models shared between multiple pipelines are only downloaded once.
+- Flexibility to use new or improved models in a pipeline.
+
+## Single file format
+
+A single-file format stores *all* the model (UNet, transformer, text encoder) weights in a single file. Benefits of single-file formats include the following.
+
+- Greater compatibility with [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
+- Easier to download and share a single file.
+
+Use [`~loaders.FromSingleFileMixin.from_single_file`] to load a single file.
+
+```py
+import torch
+from diffusers import StableDiffusionXLPipeline
+
+pipeline = StableDiffusionXLPipeline.from_single_file(
+ "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors",
+ torch_dtype=torch.float16,
+ device_map="cuda"
+)
+```
+
+The [`~loaders.FromSingleFileMixin.from_single_file`] method also supports passing new models or schedulers.
+
+```py
+import torch
+from diffusers import FluxPipeline, FluxTransformer2DModel
+
+transformer = FluxTransformer2DModel.from_single_file(
+ "https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=torch.bfloat16
+)
+pipeline = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ transformer=transformer,
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+)
+```
+
+### Configuration options
+
+Diffusers format models have a `config.json` file in their repositories with important attributes such as the number of layers and attention heads. The [`~loaders.FromSingleFileMixin.from_single_file`] method automatically determines the appropriate config to use from `config.json`. This may fail in a few rare instances though, in which case, you should use the `config` argument.
+
+You should also use the `config` argument if the models in a pipeline are different from the original implementation or if it doesn't have the necessary metadata to determine the correct config.
+
+```py
+from diffusers import StableDiffusionXLPipeline
+
+ckpt_path = "https://huggingface.co/segmind/SSD-1B/blob/main/SSD-1B.safetensors"
+
+pipeline = StableDiffusionXLPipeline.from_single_file(ckpt_path, config="segmind/SSD-1B")
+```
+
+Diffusers attempts to infer the pipeline components based on the signature types of the pipeline class when using `original_config` with `local_files_only=True`. It won't download the config files from a Hub repository to avoid backward breaking changes when you can't connect to the internet. This method isn't as reliable as providing a path to a local model with the `config` argument and may lead to errors. You should run the pipeline with `local_files_only=False` to download the config files to the local cache to avoid errors.
+
+Override default configs by passing the arguments directly to [`~loaders.FromSingleFileMixin.from_single_file`]. The examples below demonstrate how to override the configs in a pipeline or model.
+
+```py
+from diffusers import StableDiffusionXLInstructPix2PixPipeline
+
+ckpt_path = "https://huggingface.co/stabilityai/cosxl/blob/main/cosxl_edit.safetensors"
+pipeline = StableDiffusionXLInstructPix2PixPipeline.from_single_file(
+ ckpt_path, config="diffusers/sdxl-instructpix2pix-768", is_cosxl_edit=True
+)
+```
+
+```py
+from diffusers import UNet2DConditionModel
+
+ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
+model = UNet2DConditionModel.from_single_file(ckpt_path, upcast_attention=True)
+```
+
+### Local files
+
+The [`~loaders.FromSingleFileMixin.from_single_file`] method attempts to configure a pipeline or model by inferring the model type from the keys in the checkpoint file. For example, any single file checkpoint based on the Stable Diffusion XL base model is configured from [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).
+
+If you're working with local files, download the config files with the [`~huggingface_hub.snapshot_download`] method and the model checkpoint with [`~huggingface_hub.hf_hub_download`]. These files are downloaded to your [cache directory](https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache), but you can download them to a specific directory with the `local_dir` argument.
+
+```py
+from huggingface_hub import hf_hub_download, snapshot_download
+from diffusers import StableDiffusionXLPipeline
+
+my_local_checkpoint_path = hf_hub_download(
+ repo_id="segmind/SSD-1B",
+ filename="SSD-1B.safetensors"
+)
+
+my_local_config_path = snapshot_download(
+ repo_id="segmind/SSD-1B",
+ allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
+)
+
+pipeline = StableDiffusionXLPipeline.from_single_file(
+ my_local_checkpoint_path, config=my_local_config_path, local_files_only=True
+)
+```
+
+### Symlink
+
+If you're working with a file system that does not support symlinking, download the checkpoint file to a local directory first with the `local_dir` parameter. Using the `local_dir` parameter automatically disables symlinks.
+
+```py
+from huggingface_hub import hf_hub_download, snapshot_download
+from diffusers import StableDiffusionXLPipeline
+
+my_local_checkpoint_path = hf_hub_download(
+ repo_id="segmind/SSD-1B",
+ filename="SSD-1B.safetensors"
+ local_dir="my_local_checkpoints",
+)
+print("My local checkpoint: ", my_local_checkpoint_path)
+
+my_local_config_path = snapshot_download(
+ repo_id="segmind/SSD-1B",
+ allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
+)
+print("My local config: ", my_local_config_path)
+```
+
+Pass these paths to [`~loaders.FromSingleFileMixin.from_single_file`].
+
+```py
+pipeline = StableDiffusionXLPipeline.from_single_file(
+ my_local_checkpoint_path, config=my_local_config_path, local_files_only=True
+)
+```
+
+## File types
+
+Models can be stored in several file types. Safetensors is the most common file type but you may encounter other file types on the Hub or diffusion community.
### safetensors
-> [!TIP]
-> Learn more about the design decisions and why safetensor files are preferred for saving and loading model weights in the [Safetensors audited as really safe and becoming the default](https://blog.eleuther.ai/safetensors-security-audit/) blog post.
+[Safetensors](https://hf.co/docs/safetensors) is a safe and fast file type for securely storing and loading tensors. It restricts the header size to limit certain types of attacks, supports lazy loading (useful for distributed setups), and generally loads faster.
-[Safetensors](https://hf.co/docs/safetensors) is a safe and fast file format for securely storing and loading tensors. Safetensors restricts the header size to limit certain types of attacks, supports lazy loading (useful for distributed setups), and has generally faster loading speeds.
+Diffusers loads safetensors file by default (a required dependency) if they are available and the Safetensors library is installed.
-Make sure you have the [Safetensors](https://hf.co/docs/safetensors) library installed.
-
-```py
-!pip install safetensors
-```
-
-Safetensors stores weights in a safetensors file. Diffusers loads safetensors files by default if they're available and the Safetensors library is installed. There are two ways safetensors files can be organized:
-
-1. Diffusers-multifolder layout: there may be several separate safetensors files, one for each pipeline component (text encoder, UNet, VAE), organized in subfolders (check out the [stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main) repository as an example)
-2. single-file layout: all the model weights may be saved in a single file (check out the [WarriorMama777/OrangeMixs](https://hf.co/WarriorMama777/OrangeMixs/tree/main/Models/AbyssOrangeMix) repository as an example)
-
-
-
-
-Use the [`~DiffusionPipeline.from_pretrained`] method to load a model with safetensors files stored in multiple folders.
+Use [`~DiffusionPipeline.from_pretrained`] or [`~loaders.FromSingleFileMixin.from_single_file`] to load safetensor files.
```py
+import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- use_safetensors=True
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch.dtype=torch.float16,
+ device_map="cuda"
+)
+
+pipeline = DiffusionPipeline.from_single_file(
+ "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors",
+ torch_dtype=torch.float16,
)
```
-
-
+If you're using a checkpoint trained with a Diffusers training script, metadata such as the LoRA configuration, is automatically saved. When the file is loaded, the metadata is parsed to correctly configure the LoRA and avoid missing or incorrect LoRA configs. Inspect the metadata of a safetensors file by clicking on the logo next to the file on the Hub.
-Use the [`~loaders.FromSingleFileMixin.from_single_file`] method to load a model with all the weights stored in a single safetensors file.
-
-```py
-from diffusers import StableDiffusionPipeline
-
-pipeline = StableDiffusionPipeline.from_single_file(
- "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
-)
-```
-
-
-
-
-#### LoRAs
-
-[LoRAs](../tutorials/using_peft_for_inference) are lightweight checkpoints fine-tuned to generate images or video in a specific style. If you are using a checkpoint trained with a Diffusers training script, the LoRA configuration is automatically saved as metadata in a safetensors file. When the safetensors file is loaded, the metadata is parsed to correctly configure the LoRA and avoids missing or incorrect LoRA configurations.
-
-The easiest way to inspect the metadata, if available, is by clicking on the Safetensors logo next to the weights.
-
-
-
-
-
-For LoRAs that aren't trained with Diffusers, you can still save metadata with the `transformer_lora_adapter_metadata` and `text_encoder_lora_adapter_metadata` arguments in [`~loaders.FluxLoraLoaderMixin.save_lora_weights`] as long as it is a safetensors file.
+Save the metadata for LoRAs that aren't trained with Diffusers with either `transformer_lora_adapter_metadata` or `unet_lora_adapter_metadata` depending on your model. For the text encoder, use the `text_encoder_lora_adapter_metadata` and `text_encoder_2_lora_adapter_metadata` arguments in [`~loaders.FluxLoraLoaderMixin.save_lora_weights`]. This is only supported for safetensors files.
```py
import torch
@@ -91,422 +197,88 @@ pipeline = FluxPipeline.from_pretrained(
).to("cuda")
pipeline.load_lora_weights("linoyts/yarn_art_Flux_LoRA")
pipeline.save_lora_weights(
- transformer_lora_adapter_metadata={"r": 16, "lora_alpha": 16},
- text_encoder_lora_adapter_metadata={"r": 8, "lora_alpha": 8}
+ text_encoder_lora_adapter_metadata={"r": 8, "lora_alpha": 8},
+ text_encoder_2_lora_adapter_metadata={"r": 8, "lora_alpha": 8}
)
```
### ckpt
-> [!WARNING]
-> Pickled files may be unsafe because they can be exploited to execute malicious code. It is recommended to use safetensors files instead where possible, or convert the weights to safetensors files.
+Older model weights are commonly saved with Python's [pickle](https://docs.python.org/3/library/pickle.html) utility in a ckpt file.
-PyTorch's [torch.save](https://pytorch.org/docs/stable/generated/torch.save.html) function uses Python's [pickle](https://docs.python.org/3/library/pickle.html) utility to serialize and save models. These files are saved as a ckpt file and they contain the entire model's weights.
+Pickled files may be unsafe because they can be exploited to execute malicious code. It is recommended to use safetensors files or convert the weights to safetensors files.
-Use the [`~loaders.FromSingleFileMixin.from_single_file`] method to directly load a ckpt file.
+Use [`~loaders.FromSingleFileMixin.from_single_file`] to load a ckpt file.
```py
-from diffusers import StableDiffusionPipeline
+from diffusers import DiffusionPipeline
-pipeline = StableDiffusionPipeline.from_single_file(
+pipeline = DiffusionPipeline.from_single_file(
"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned.ckpt"
)
```
-## Storage layout
+### dduf
-There are two ways model files are organized, either in a Diffusers-multifolder layout or in a single-file layout. The Diffusers-multifolder layout is the default, and each component file (text encoder, UNet, VAE) is stored in a separate subfolder. Diffusers also supports loading models from a single-file layout where all the components are bundled together.
+> [!TIP]
+> DDUF is an experimental file type and the API may change. Refer to the DDUF [docs](https://huggingface.co/docs/hub/dduf) to learn more.
-### Diffusers-multifolder
+DDUF is a file type designed to unify different diffusion model distribution methods and weight-saving formats. It is a standardized and flexible method to package all components of a diffusion model into a single file, providing a balance between the Diffusers and single-file formats.
-The Diffusers-multifolder layout is the default storage layout for Diffusers. Each component's (text encoder, UNet, VAE) weights are stored in a separate subfolder. The weights can be stored as safetensors or ckpt files.
-
-
-
-
- multifolder layout
-
-
-
- UNet subfolder
-
-
-
-To load from Diffusers-multifolder layout, use the [`~DiffusionPipeline.from_pretrained`] method.
+Use the `dduf_file` argument in [`~DiffusionPipeline.from_pretrained`] to load a DDUF file. You can also load quantized dduf files as long as they are stored in the Diffusers format.
```py
+import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- torch_dtype=torch.float16,
- variant="fp16",
- use_safetensors=True,
-).to("cuda")
+ "DDUF/FLUX.1-dev-DDUF",
+ dduf_file="FLUX.1-dev.dduf",
+ torch_dtype=torch.bfloat16,
+ device_map="cuda"
+)
```
-Benefits of using the Diffusers-multifolder layout include:
-
-1. Faster to load each component file individually or in parallel.
-2. Reduced memory usage because you only load the components you need. For example, models like [SDXL Turbo](https://hf.co/stabilityai/sdxl-turbo), [SDXL Lightning](https://hf.co/ByteDance/SDXL-Lightning), and [Hyper-SD](https://hf.co/ByteDance/Hyper-SD) have the same components except for the UNet. You can reuse their shared components with the [`~DiffusionPipeline.from_pipe`] method without consuming any additional memory (take a look at the [Reuse a pipeline](./loading#reuse-a-pipeline) guide) and only load the UNet. This way, you don't need to download redundant components and unnecessarily use more memory.
-
- ```py
- import torch
- from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
-
- # download one model
- sdxl_pipeline = StableDiffusionXLPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- torch_dtype=torch.float16,
- variant="fp16",
- use_safetensors=True,
- ).to("cuda")
-
- # switch UNet for another model
- unet = UNet2DConditionModel.from_pretrained(
- "stabilityai/sdxl-turbo",
- subfolder="unet",
- torch_dtype=torch.float16,
- variant="fp16",
- use_safetensors=True
- )
- # reuse all the same components in new model except for the UNet
- turbo_pipeline = StableDiffusionXLPipeline.from_pipe(
- sdxl_pipeline, unet=unet,
- ).to("cuda")
- turbo_pipeline.scheduler = EulerDiscreteScheduler.from_config(
- turbo_pipeline.scheduler.config,
- timestep+spacing="trailing"
- )
- image = turbo_pipeline(
- "an astronaut riding a unicorn on mars",
- num_inference_steps=1,
- guidance_scale=0.0,
- ).images[0]
- image
- ```
-
-3. Reduced storage requirements because if a component, such as the SDXL [VAE](https://hf.co/madebyollin/sdxl-vae-fp16-fix), is shared across multiple models, you only need to download and store a single copy of it instead of downloading and storing it multiple times. For 10 SDXL models, this can save ~3.5GB of storage. The storage savings is even greater for newer models like PixArt Sigma, where the [text encoder](https://hf.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS/tree/main/text_encoder) alone is ~19GB!
-4. Flexibility to replace a component in the model with a newer or better version.
-
- ```py
- from diffusers import DiffusionPipeline, AutoencoderKL
-
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- vae=vae,
- torch_dtype=torch.float16,
- variant="fp16",
- use_safetensors=True,
- ).to("cuda")
- ```
-
-5. More visibility and information about a model's components, which are stored in a [config.json](https://hf.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json) file in each component subfolder.
-
-### Single-file
-
-The single-file layout stores all the model weights in a single file. All the model components (text encoder, UNet, VAE) weights are kept together instead of separately in subfolders. This can be a safetensors or ckpt file.
-
-
-
-
-
-To load from a single-file layout, use the [`~loaders.FromSingleFileMixin.from_single_file`] method.
+To save a pipeline as a dduf file, use the [`~huggingface_hub.export_folder_as_dduf`] utility.
```py
import torch
-from diffusers import StableDiffusionXLPipeline
-
-pipeline = StableDiffusionXLPipeline.from_single_file(
- "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors",
- torch_dtype=torch.float16,
- variant="fp16",
- use_safetensors=True,
-).to("cuda")
-```
-
-Benefits of using a single-file layout include:
-
-1. Easy compatibility with diffusion interfaces such as [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) which commonly use a single-file layout.
-2. Easier to manage (download and share) a single file.
-
-### DDUF
-
-> [!WARNING]
-> DDUF is an experimental file format and APIs related to it can change in the future.
-
-DDUF (**D**DUF **D**iffusion **U**nified **F**ormat) is a file format designed to make storing, distributing, and using diffusion models much easier. Built on the ZIP file format, DDUF offers a standardized, efficient, and flexible way to package all parts of a diffusion model into a single, easy-to-manage file. It provides a balance between Diffusers multi-folder format and the widely popular single-file format.
-
-Learn more details about DDUF on the Hugging Face Hub [documentation](https://huggingface.co/docs/hub/dduf).
-
-Pass a checkpoint to the `dduf_file` parameter to load it in [`DiffusionPipeline`].
-
-```py
from diffusers import DiffusionPipeline
-import torch
-
-pipe = DiffusionPipeline.from_pretrained(
- "DDUF/FLUX.1-dev-DDUF", dduf_file="FLUX.1-dev.dduf", torch_dtype=torch.bfloat16
-).to("cuda")
-image = pipe(
- "photo a cat holding a sign that says Diffusers", num_inference_steps=50, guidance_scale=3.5
-).images[0]
-image.save("cat.png")
-```
-
-To save a pipeline as a `.dduf` checkpoint, use the [`~huggingface_hub.export_folder_as_dduf`] utility, which takes care of all the necessary file-level validations.
-
-```py
from huggingface_hub import export_folder_as_dduf
-from diffusers import DiffusionPipeline
-import torch
-pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
+pipeline = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
save_folder = "flux-dev"
-pipe.save_pretrained("flux-dev")
+pipeline.save_pretrained("flux-dev")
export_folder_as_dduf("flux-dev.dduf", folder_path=save_folder)
+```
-> [!TIP]
-> Packaging and loading quantized checkpoints in the DDUF format is supported as long as they respect the multi-folder structure.
+## Converting formats and files
-## Convert layout and files
+Diffusers provides scripts and methods to convert format and files to enable broader support across the diffusion ecosystem.
-Diffusers provides many scripts and methods to convert storage layouts and file formats to enable broader support across the diffusion ecosystem.
+Take a look at the [diffusers/scripts](https://github.com/huggingface/diffusers/tree/main/scripts) folder to find a conversion script. Scripts with `"to_diffusers` appended at the end converts a model to the Diffusers format. Each script has a specific set of arguments for configuring the conversion. Make sure you check what arguments are available.
-Take a look at the [diffusers/scripts](https://github.com/huggingface/diffusers/tree/main/scripts) collection to find a script that fits your conversion needs.
-
-> [!TIP]
-> Scripts that have "`to_diffusers`" appended at the end mean they convert a model to the Diffusers-multifolder layout. Each script has their own specific set of arguments for configuring the conversion, so make sure you check what arguments are available!
-
-For example, to convert a Stable Diffusion XL model stored in Diffusers-multifolder layout to a single-file layout, run the [convert_diffusers_to_original_sdxl.py](https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_sdxl.py) script. Provide the path to the model to convert, and the path to save the converted model to. You can optionally specify whether you want to save the model as a safetensors file and whether to save the model in half-precision.
+The example below converts a model stored in Diffusers format to a single-file format. Provide the path to the model to convert and where to save the converted model. You can optionally specify what file type and data type to save the model as.
```bash
python convert_diffusers_to_original_sdxl.py --model_path path/to/model/to/convert --checkpoint_path path/to/save/model/to --use_safetensors
```
-You can also save a model to Diffusers-multifolder layout with the [`~DiffusionPipeline.save_pretrained`] method. This creates a directory for you if it doesn't already exist, and it also saves the files as a safetensors file by default.
+The [`~DiffusionPipeline.save_pretrained`] method also saves a model in Diffusers format and takes care of creating subfolders for each model. It saves the files as safetensor files by default.
```py
-from diffusers import StableDiffusionXLPipeline
+from diffusers import DiffusionPipeline
-pipeline = StableDiffusionXLPipeline.from_single_file(
+pipeline = DiffusionPipeline.from_single_file(
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors",
)
pipeline.save_pretrained()
```
-Lastly, there are also Spaces, such as [SD To Diffusers](https://hf.co/spaces/diffusers/sd-to-diffusers) and [SD-XL To Diffusers](https://hf.co/spaces/diffusers/sdxl-to-diffusers), that provide a more user-friendly interface for converting models to Diffusers-multifolder layout. This is the easiest and most convenient option for converting layouts, and it'll open a PR on your model repository with the converted files. However, this option is not as reliable as running a script, and the Space may fail for more complicated models.
+Finally, you can use a Space like [SD To Diffusers](https://hf.co/spaces/diffusers/sd-to-diffusers) or [SD-XL To Diffusers](https://hf.co/spaces/diffusers/sdxl-to-diffusers) to convert models to the Diffusers format. It'll open a PR on your model repository with the converted files. This is the easiest way to convert a model, but it may fail for more complicated models. Using a conversion script is more reliable.
-## Single-file layout usage
+## Resources
-Now that you're familiar with the differences between the Diffusers-multifolder and single-file layout, this section shows you how to load models and pipeline components, customize configuration options for loading, and load local files with the [`~loaders.FromSingleFileMixin.from_single_file`] method.
+- Learn more about the design decisions and why safetensor files are preferred for saving and loading model weights in the [Safetensors audited as really safe and becoming the default](https://blog.eleuther.ai/safetensors-security-audit/) blog post.
-### Load a pipeline or model
-
-Pass the file path of the pipeline or model to the [`~loaders.FromSingleFileMixin.from_single_file`] method to load it.
-
-
-
-
-```py
-from diffusers import StableDiffusionXLPipeline
-
-ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
-pipeline = StableDiffusionXLPipeline.from_single_file(ckpt_path)
-```
-
-
-
-
-```py
-from diffusers import StableCascadeUNet
-
-ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
-model = StableCascadeUNet.from_single_file(ckpt_path)
-```
-
-
-
-
-Customize components in the pipeline by passing them directly to the [`~loaders.FromSingleFileMixin.from_single_file`] method. For example, you can use a different scheduler in a pipeline.
-
-```py
-from diffusers import StableDiffusionXLPipeline, DDIMScheduler
-
-ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
-scheduler = DDIMScheduler()
-pipeline = StableDiffusionXLPipeline.from_single_file(ckpt_path, scheduler=scheduler)
-```
-
-Or you could use a ControlNet model in the pipeline.
-
-```py
-from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
-
-ckpt_path = "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
-controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
-pipeline = StableDiffusionControlNetPipeline.from_single_file(ckpt_path, controlnet=controlnet)
-```
-
-### Customize configuration options
-
-Models have a configuration file that define their attributes like the number of inputs in a UNet. Pipelines configuration options are available in the pipeline's class. For example, if you look at the [`StableDiffusionXLInstructPix2PixPipeline`] class, there is an option to scale the image latents with the `is_cosxl_edit` parameter.
-
-These configuration files can be found in the models Hub repository or another location from which the configuration file originated (for example, a GitHub repository or locally on your device).
-
-
-
-
-> [!TIP]
-> The [`~loaders.FromSingleFileMixin.from_single_file`] method automatically maps the checkpoint to the appropriate model repository, but there are cases where it is useful to use the `config` parameter. For example, if the model components in the checkpoint are different from the original checkpoint or if a checkpoint doesn't have the necessary metadata to correctly determine the configuration to use for the pipeline.
-
-The [`~loaders.FromSingleFileMixin.from_single_file`] method automatically determines the configuration to use from the configuration file in the model repository. You could also explicitly specify the configuration to use by providing the repository id to the `config` parameter.
-
-```py
-from diffusers import StableDiffusionXLPipeline
-
-ckpt_path = "https://huggingface.co/segmind/SSD-1B/blob/main/SSD-1B.safetensors"
-repo_id = "segmind/SSD-1B"
-
-pipeline = StableDiffusionXLPipeline.from_single_file(ckpt_path, config=repo_id)
-```
-
-The model loads the configuration file for the [UNet](https://huggingface.co/segmind/SSD-1B/blob/main/unet/config.json), [VAE](https://huggingface.co/segmind/SSD-1B/blob/main/vae/config.json), and [text encoder](https://huggingface.co/segmind/SSD-1B/blob/main/text_encoder/config.json) from their respective subfolders in the repository.
-
-
-
-
-The [`~loaders.FromSingleFileMixin.from_single_file`] method can also load the original configuration file of a pipeline that is stored elsewhere. Pass a local path or URL of the original configuration file to the `original_config` parameter.
-
-```py
-from diffusers import StableDiffusionXLPipeline
-
-ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
-original_config = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
-
-pipeline = StableDiffusionXLPipeline.from_single_file(ckpt_path, original_config=original_config)
-```
-
-> [!TIP]
-> Diffusers attempts to infer the pipeline components based on the type signatures of the pipeline class when you use `original_config` with `local_files_only=True`, instead of fetching the configuration files from the model repository on the Hub. This prevents backward breaking changes in code that can't connect to the internet to fetch the necessary configuration files.
->
-> This is not as reliable as providing a path to a local model repository with the `config` parameter, and might lead to errors during pipeline configuration. To avoid errors, run the pipeline with `local_files_only=False` once to download the appropriate pipeline configuration files to the local cache.
-
-
-
-
-While the configuration files specify the pipeline or models default parameters, you can override them by providing the parameters directly to the [`~loaders.FromSingleFileMixin.from_single_file`] method. Any parameter supported by the model or pipeline class can be configured in this way.
-
-
-
-
-For example, to scale the image latents in [`StableDiffusionXLInstructPix2PixPipeline`] pass the `is_cosxl_edit` parameter.
-
-```python
-from diffusers import StableDiffusionXLInstructPix2PixPipeline
-
-ckpt_path = "https://huggingface.co/stabilityai/cosxl/blob/main/cosxl_edit.safetensors"
-pipeline = StableDiffusionXLInstructPix2PixPipeline.from_single_file(ckpt_path, config="diffusers/sdxl-instructpix2pix-768", is_cosxl_edit=True)
-```
-
-
-
-
-For example, to upcast the attention dimensions in a [`UNet2DConditionModel`] pass the `upcast_attention` parameter.
-
-```python
-from diffusers import UNet2DConditionModel
-
-ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors"
-model = UNet2DConditionModel.from_single_file(ckpt_path, upcast_attention=True)
-```
-
-
-
-
-### Local files
-
-In Diffusers>=v0.28.0, the [`~loaders.FromSingleFileMixin.from_single_file`] method attempts to configure a pipeline or model by inferring the model type from the keys in the checkpoint file. The inferred model type is used to determine the appropriate model repository on the Hugging Face Hub to configure the model or pipeline.
-
-For example, any single file checkpoint based on the Stable Diffusion XL base model will use the [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) model repository to configure the pipeline.
-
-But if you're working in an environment with restricted internet access, you should download the configuration files with the [`~huggingface_hub.snapshot_download`] function, and the model checkpoint with the [`~huggingface_hub.hf_hub_download`] function. By default, these files are downloaded to the Hugging Face Hub [cache directory](https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache), but you can specify a preferred directory to download the files to with the `local_dir` parameter.
-
-Pass the configuration and checkpoint paths to the [`~loaders.FromSingleFileMixin.from_single_file`] method to load locally.
-
-
-
-
-```python
-from huggingface_hub import hf_hub_download, snapshot_download
-
-my_local_checkpoint_path = hf_hub_download(
- repo_id="segmind/SSD-1B",
- filename="SSD-1B.safetensors"
-)
-
-my_local_config_path = snapshot_download(
- repo_id="segmind/SSD-1B",
- allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
-)
-
-pipeline = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
-```
-
-
-
-
-```python
-from huggingface_hub import hf_hub_download, snapshot_download
-
-my_local_checkpoint_path = hf_hub_download(
- repo_id="segmind/SSD-1B",
- filename="SSD-1B.safetensors"
- local_dir="my_local_checkpoints"
-)
-
-my_local_config_path = snapshot_download(
- repo_id="segmind/SSD-1B",
- allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
- local_dir="my_local_config"
-)
-
-pipeline = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
-```
-
-
-
-
-#### Local files without symlink
-
-> [!TIP]
-> In huggingface_hub>=v0.23.0, the `local_dir_use_symlinks` argument isn't necessary for the [`~huggingface_hub.hf_hub_download`] and [`~huggingface_hub.snapshot_download`] functions.
-
-The [`~loaders.FromSingleFileMixin.from_single_file`] method relies on the [huggingface_hub](https://hf.co/docs/huggingface_hub/index) caching mechanism to fetch and store checkpoints and configuration files for models and pipelines. If you're working with a file system that does not support symlinking, you should download the checkpoint file to a local directory first, and disable symlinking with the `local_dir_use_symlink=False` parameter in the [`~huggingface_hub.hf_hub_download`] function and [`~huggingface_hub.snapshot_download`] functions.
-
-```python
-from huggingface_hub import hf_hub_download, snapshot_download
-
-my_local_checkpoint_path = hf_hub_download(
- repo_id="segmind/SSD-1B",
- filename="SSD-1B.safetensors"
- local_dir="my_local_checkpoints",
- local_dir_use_symlinks=False
-)
-print("My local checkpoint: ", my_local_checkpoint_path)
-
-my_local_config_path = snapshot_download(
- repo_id="segmind/SSD-1B",
- allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
- local_dir_use_symlinks=False,
-)
-print("My local config: ", my_local_config_path)
-```
-
-Then you can pass the local paths to the `pretrained_model_link_or_path` and `config` parameters.
-
-```python
-pipeline = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
-```
diff --git a/docs/source/en/using-diffusers/pag.md b/docs/source/en/using-diffusers/pag.md
index 46d716bcf8..c11a5dc379 100644
--- a/docs/source/en/using-diffusers/pag.md
+++ b/docs/source/en/using-diffusers/pag.md
@@ -219,11 +219,8 @@ pipeline = AutoPipelineForText2Image.from_pretrained(
pipeline.enable_model_cpu_offload()
```
-
-
-If you already have a controlnet pipeline and want to enable PAG, you can use the `from_pipe` API: `AutoPipelineForText2Image.from_pipe(pipeline_controlnet, enable_pag=True)`
-
-
+> [!TIP]
+> If you already have a controlnet pipeline and want to enable PAG, you can use the `from_pipe` API: `AutoPipelineForText2Image.from_pipe(pipeline_controlnet, enable_pag=True)`
You can use the pipeline in the same way you normally use ControlNet pipelines, with the added option to specify a `pag_scale` parameter. Note that PAG works well for unconditional generation. In this example, we will generate an image without a prompt.
diff --git a/docs/source/en/using-diffusers/push_to_hub.md b/docs/source/en/using-diffusers/push_to_hub.md
index c77ce27656..4319f620a9 100644
--- a/docs/source/en/using-diffusers/push_to_hub.md
+++ b/docs/source/en/using-diffusers/push_to_hub.md
@@ -10,19 +10,22 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Push files to the Hub
-
[[open-in-colab]]
-🤗 Diffusers provides a [`~diffusers.utils.PushToHubMixin`] for uploading your model, scheduler, or pipeline to the Hub. It is an easy way to store your files on the Hub, and also allows you to share your work with others. Under the hood, the [`~diffusers.utils.PushToHubMixin`]:
+# Sharing pipelines and models
+
+Share your pipeline or models and schedulers on the Hub with the [`~diffusers.utils.PushToHubMixin`] class. This class:
1. creates a repository on the Hub
2. saves your model, scheduler, or pipeline files so they can be reloaded later
3. uploads folder containing these files to the Hub
-This guide will show you how to use the [`~diffusers.utils.PushToHubMixin`] to upload your files to the Hub.
+This guide will show you how to upload your files to the Hub with the [`~diffusers.utils.PushToHubMixin`] class.
-You'll need to log in to your Hub account with your access [token](https://huggingface.co/settings/tokens) first:
+Log in to your Hugging Face account with your access [token](https://huggingface.co/settings/tokens).
+
+
+
```py
from huggingface_hub import notebook_login
@@ -30,9 +33,19 @@ from huggingface_hub import notebook_login
notebook_login()
```
+
+
+
+```bash
+hf auth login
+```
+
+
+
+
## Models
-To push a model to the Hub, call [`~diffusers.utils.PushToHubMixin.push_to_hub`] and specify the repository id of the model to be stored on the Hub:
+To push a model to the Hub, call [`~diffusers.utils.PushToHubMixin.push_to_hub`] and specify the repository id of the model.
```py
from diffusers import ControlNetModel
@@ -48,15 +61,9 @@ controlnet = ControlNetModel(
controlnet.push_to_hub("my-controlnet-model")
```
-For models, you can also specify the [*variant*](loading#checkpoint-variants) of the weights to push to the Hub. For example, to push `fp16` weights:
+The [`~diffusers.utils.PushToHubMixin.push_to_hub`] method saves the model's `config.json` file and the weights are automatically saved as safetensors files.
-```py
-controlnet.push_to_hub("my-controlnet-model", variant="fp16")
-```
-
-The [`~diffusers.utils.PushToHubMixin.push_to_hub`] function saves the model's `config.json` file and the weights are automatically saved in the `safetensors` format.
-
-Now you can reload the model from your repository on the Hub:
+Load the model again with [`~DiffusionPipeline.from_pretrained`].
```py
model = ControlNetModel.from_pretrained("your-namespace/my-controlnet-model")
@@ -64,7 +71,7 @@ model = ControlNetModel.from_pretrained("your-namespace/my-controlnet-model")
## Scheduler
-To push a scheduler to the Hub, call [`~diffusers.utils.PushToHubMixin.push_to_hub`] and specify the repository id of the scheduler to be stored on the Hub:
+To push a scheduler to the Hub, call [`~diffusers.utils.PushToHubMixin.push_to_hub`] and specify the repository id of the scheduler.
```py
from diffusers import DDIMScheduler
@@ -81,7 +88,7 @@ scheduler.push_to_hub("my-controlnet-scheduler")
The [`~diffusers.utils.PushToHubMixin.push_to_hub`] function saves the scheduler's `scheduler_config.json` file to the specified repository.
-Now you can reload the scheduler from your repository on the Hub:
+Load the scheduler again with [`~SchedulerMixin.from_pretrained`].
```py
scheduler = DDIMScheduler.from_pretrained("your-namepsace/my-controlnet-scheduler")
@@ -89,7 +96,7 @@ scheduler = DDIMScheduler.from_pretrained("your-namepsace/my-controlnet-schedule
## Pipeline
-You can also push an entire pipeline with all it's components to the Hub. For example, initialize the components of a [`StableDiffusionPipeline`] with the parameters you want:
+To push a pipeline to the Hub, initialize the pipeline components with your desired parameters.
```py
from diffusers import (
@@ -143,7 +150,7 @@ text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
```
-Pass all of the components to the [`StableDiffusionPipeline`] and call [`~diffusers.utils.PushToHubMixin.push_to_hub`] to push the pipeline to the Hub:
+Pass all components to the pipeline and call [`~diffusers.utils.PushToHubMixin.push_to_hub`].
```py
components = {
@@ -160,7 +167,7 @@ pipeline = StableDiffusionPipeline(**components)
pipeline.push_to_hub("my-pipeline")
```
-The [`~diffusers.utils.PushToHubMixin.push_to_hub`] function saves each component to a subfolder in the repository. Now you can reload the pipeline from your repository on the Hub:
+The [`~diffusers.utils.PushToHubMixin.push_to_hub`] method saves each component to a subfolder in the repository. Load the pipeline again with [`~DiffusionPipeline.from_pretrained`].
```py
pipeline = StableDiffusionPipeline.from_pretrained("your-namespace/my-pipeline")
@@ -168,10 +175,10 @@ pipeline = StableDiffusionPipeline.from_pretrained("your-namespace/my-pipeline")
## Privacy
-Set `private=True` in the [`~diffusers.utils.PushToHubMixin.push_to_hub`] function to keep your model, scheduler, or pipeline files private:
+Set `private=True` in [`~diffusers.utils.PushToHubMixin.push_to_hub`] to keep a model, scheduler, or pipeline files private.
```py
controlnet.push_to_hub("my-controlnet-model-private", private=True)
```
-Private repositories are only visible to you, and other users won't be able to clone the repository and your repository won't appear in search results. Even if a user has the URL to your private repository, they'll receive a `404 - Sorry, we can't find the page you are looking for`. You must be [logged in](https://huggingface.co/docs/huggingface_hub/quick-start#login) to load a model from a private repository.
\ No newline at end of file
+Private repositories are only visible to you. Other users won't be able to clone the repository and it won't appear in search results. Even if a user has the URL to your private repository, they'll receive a `404 - Sorry, we can't find the page you are looking for`. You must be [logged in](https://huggingface.co/docs/huggingface_hub/quick-start#login) to load a model from a private repository.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/reusing_seeds.md b/docs/source/en/using-diffusers/reusing_seeds.md
index ac9350f24c..b4aed0aa63 100644
--- a/docs/source/en/using-diffusers/reusing_seeds.md
+++ b/docs/source/en/using-diffusers/reusing_seeds.md
@@ -10,129 +10,86 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Reproducible pipelines
+# Reproducibility
-Diffusion models are inherently random which is what allows it to generate different outputs every time it is run. But there are certain times when you want to generate the same output every time, like when you're testing, replicating results, and even [improving image quality](#deterministic-batch-generation). While you can't expect to get identical results across platforms, you can expect reproducible results across releases and platforms within a certain tolerance range (though even this may vary).
+Diffusion is a random process that generates a different output every time. For certain situations like testing and replicating results, you want to generate the same result each time, across releases and platforms within a certain tolerance range.
-This guide will show you how to control randomness for deterministic generation on a CPU and GPU.
+This guide will show you how to control sources of randomness and enable deterministic algorithms.
+
+## Generator
+
+Pipelines rely on [torch.randn](https://pytorch.org/docs/stable/generated/torch.randn.html), which uses a different random seed each time, to create the initial noisy tensors. To generate the same output on a CPU or GPU, use a [Generator](https://docs.pytorch.org/docs/stable/generated/torch.Generator.html) to manage how random values are generated.
> [!TIP]
-> We strongly recommend reading PyTorch's [statement about reproducibility](https://pytorch.org/docs/stable/notes/randomness.html):
->
-> "Completely reproducible results are not guaranteed across PyTorch releases, individual commits, or different platforms. Furthermore, results may not be reproducible between CPU and GPU executions, even when using identical seeds."
+> If reproducibility is important to your use case, we recommend always using a CPU `Generator`. The performance loss is often negligible and you'll generate more similar values.
-## Control randomness
+
+
-During inference, pipelines rely heavily on random sampling operations which include creating the
-Gaussian noise tensors to denoise and adding noise to the scheduling step.
+The GPU uses a different random number generator than the CPU. Diffusers solves this issue with the [`~utils.torch_utils.randn_tensor`] function to create the random tensor on a CPU and then moving it to the GPU. This function is used everywhere inside the pipeline and you don't need to explicitly call it.
-Take a look at the tensor values in the [`DDIMPipeline`] after two inference steps.
+Use [manual_seed](https://docs.pytorch.org/docs/stable/generated/torch.manual_seed.html) as shown below to set a seed.
-```python
-from diffusers import DDIMPipeline
-import numpy as np
-
-ddim = DDIMPipeline.from_pretrained( "google/ddpm-cifar10-32", use_safetensors=True)
-image = ddim(num_inference_steps=2, output_type="np").images
-print(np.abs(image).sum())
-```
-
-Running the code above prints one value, but if you run it again you get a different value.
-
-Each time the pipeline is run, [torch.randn](https://pytorch.org/docs/stable/generated/torch.randn.html) uses a different random seed to create the Gaussian noise tensors. This leads to a different result each time it is run and enables the diffusion pipeline to generate a different random image each time.
-
-But if you need to reliably generate the same image, that depends on whether you're running the pipeline on a CPU or GPU.
-
-> [!TIP]
-> It might seem unintuitive to pass `Generator` objects to a pipeline instead of the integer value representing the seed. However, this is the recommended design when working with probabilistic models in PyTorch because a `Generator` is a *random state* that can be passed to multiple pipelines in a sequence. As soon as the `Generator` is consumed, the *state* is changed in place which means even if you passed the same `Generator` to a different pipeline, it won't produce the same result because the state is already changed.
-
-
-
-
-To generate reproducible results on a CPU, you'll need to use a PyTorch [Generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) and set a seed. Now when you run the code, it always prints a value of `1491.1711` because the `Generator` object with the seed is passed to all the random functions in the pipeline. You should get a similar, if not the same, result on whatever hardware and PyTorch version you're using.
-
-```python
+```py
import torch
import numpy as np
from diffusers import DDIMPipeline
-ddim = DDIMPipeline.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True)
+ddim = DDIMPipeline.from_pretrained("google/ddpm-cifar10-32", device_map="cuda")
+generator = torch.manual_seed(0)
+image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
+print(np.abs(image).sum())
+```
+
+
+
+
+Set `device="cpu"` in the `Generator` and use [manual_seed](https://docs.pytorch.org/docs/stable/generated/torch.manual_seed.html) to set a seed for generating random numbers.
+
+```py
+import torch
+import numpy as np
+from diffusers import DDIMPipeline
+
+ddim = DDIMPipeline.from_pretrained("google/ddpm-cifar10-32")
generator = torch.Generator(device="cpu").manual_seed(0)
image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
print(np.abs(image).sum())
```
-
-
-
-Writing a reproducible pipeline on a GPU is a bit trickier, and full reproducibility across different hardware is not guaranteed because matrix multiplication - which diffusion pipelines require a lot of - is less deterministic on a GPU than a CPU. For example, if you run the same code example from the CPU example, you'll get a different result even though the seed is identical. This is because the GPU uses a different random number generator than the CPU.
-
-```python
-import torch
-import numpy as np
-from diffusers import DDIMPipeline
-
-ddim = DDIMPipeline.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True)
-ddim.to("cuda")
-generator = torch.Generator(device="cuda").manual_seed(0)
-image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
-print(np.abs(image).sum())
-```
-
-To avoid this issue, Diffusers has a [`~utils.torch_utils.randn_tensor`] function for creating random noise on the CPU, and then moving the tensor to a GPU if necessary. The [`~utils.torch_utils.randn_tensor`] function is used everywhere inside the pipeline. Now you can call [torch.manual_seed](https://pytorch.org/docs/stable/generated/torch.manual_seed.html) which automatically creates a CPU `Generator` that can be passed to the pipeline even if it is being run on a GPU.
-
-```python
-import torch
-import numpy as np
-from diffusers import DDIMPipeline
-
-ddim = DDIMPipeline.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True)
-ddim.to("cuda")
-generator = torch.manual_seed(0)
-image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
-print(np.abs(image).sum())
-```
-
-> [!TIP]
-> If reproducibility is important to your use case, we recommend always passing a CPU `Generator`. The performance loss is often negligible and you'll generate more similar values than if the pipeline had been run on a GPU.
-
-Finally, more complex pipelines such as [`UnCLIPPipeline`], are often extremely
-susceptible to precision error propagation. You'll need to use
-exactly the same hardware and PyTorch version for full reproducibility.
-
+The `Generator` object should be passed to the pipeline instead of an integer seed. `Generator` maintains a *random state* that is consumed and modified when used. Once consumed, the same `Generator` object produces different results in subsequent calls, even across different pipelines, because it's *state* has changed.
+
+```py
+generator = torch.manual_seed(0)
+
+for _ in range(5):
+- image = pipeline(prompt, generator=generator)
++ image = pipeline(prompt, generator=torch.manual_seed(0))
+```
+
## Deterministic algorithms
-You can also configure PyTorch to use deterministic algorithms to create a reproducible pipeline. The downside is that deterministic algorithms may be slower than non-deterministic ones and you may observe a decrease in performance.
+PyTorch supports [deterministic algorithms](https://docs.pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms) - where available - for certain operations so they produce the same results. Deterministic algorithms may be slower and decrease performance.
-Non-deterministic behavior occurs when operations are launched in more than one CUDA stream. To avoid this, set the environment variable [CUBLAS_WORKSPACE_CONFIG](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility) to `:16:8` to only use one buffer size during runtime.
-
-PyTorch typically benchmarks multiple algorithms to select the fastest one, but if you want reproducibility, you should disable this feature because the benchmark may select different algorithms each time. Set Diffusers [enable_full_determinism](https://github.com/huggingface/diffusers/blob/142f353e1c638ff1d20bd798402b68f72c1ebbdd/src/diffusers/utils/testing_utils.py#L861) to enable deterministic algorithms.
-
-```py
-enable_full_determinism()
-```
-
-Now when you run the same pipeline twice, you'll get identical results.
+Use Diffusers' [enable_full_determinism](https://github.com/huggingface/diffusers/blob/142f353e1c638ff1d20bd798402b68f72c1ebbdd/src/diffusers/utils/testing_utils.py#L861) function to enable deterministic algorithms.
```py
import torch
-from diffusers import DDIMScheduler, StableDiffusionPipeline
+from diffusers_utils import enable_full_determinism
-pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", use_safetensors=True).to("cuda")
-pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
-g = torch.Generator(device="cuda")
-
-prompt = "A bear is playing a guitar on Times Square"
-
-g.manual_seed(0)
-result1 = pipe(prompt=prompt, num_inference_steps=50, generator=g, output_type="latent").images
-
-g.manual_seed(0)
-result2 = pipe(prompt=prompt, num_inference_steps=50, generator=g, output_type="latent").images
-
-print("L_inf dist =", abs(result1 - result2).max())
-"L_inf dist = tensor(0., device='cuda:0')"
+enable_full_determinism()
```
+
+Under the hood, `enable_full_determinism` works by:
+
+- Setting the environment variable [CUBLAS_WORKSPACE_CONFIG](https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility) to `:16:8` to only use one buffer size during rntime. Non-deterministic behavior occurs when operations are used in more than one CUDA stream.
+- Disabling benchmarking to find the fastest convolution operation by setting `torch.backends.cudnn.benchmark=False`. Non-deterministic behavior occurs because the benchmark may select different algorithms each time depending on hardware or benchmarking noise.
+- Disabling TensorFloat32 (TF32) operations in favor of more precise and consistent full-precision operations.
+
+
+## Resources
+
+We strongly recommend reading PyTorch's developer notes about [Reproducibility](https://docs.pytorch.org/docs/stable/notes/randomness.html). You can try to limit randomness, but it is not *guaranteed* even with an identical seed.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/scheduler_features.md b/docs/source/en/using-diffusers/scheduler_features.md
deleted file mode 100644
index f7977d53d5..0000000000
--- a/docs/source/en/using-diffusers/scheduler_features.md
+++ /dev/null
@@ -1,235 +0,0 @@
-
-
-# Scheduler features
-
-The scheduler is an important component of any diffusion model because it controls the entire denoising (or sampling) process. There are many types of schedulers, some are optimized for speed and some for quality. With Diffusers, you can modify the scheduler configuration to use custom noise schedules, sigmas, and rescale the noise schedule. Changing these parameters can have profound effects on inference quality and speed.
-
-This guide will demonstrate how to use these features to improve inference quality.
-
-> [!TIP]
-> Diffusers currently only supports the `timesteps` and `sigmas` parameters for a select list of schedulers and pipelines. Feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you want to extend these parameters to a scheduler and pipeline that does not currently support it!
-
-## Timestep schedules
-
-The timestep or noise schedule determines the amount of noise at each sampling step. The scheduler uses this to generate an image with the corresponding amount of noise at each step. The timestep schedule is generated from the scheduler's default configuration, but you can customize the scheduler to use new and optimized sampling schedules that aren't in Diffusers yet.
-
-For example, [Align Your Steps (AYS)](https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/) is a method for optimizing a sampling schedule to generate a high-quality image in as little as 10 steps. The optimal [10-step schedule](https://github.com/huggingface/diffusers/blob/a7bf77fc284810483f1e60afe34d1d27ad91ce2e/src/diffusers/schedulers/scheduling_utils.py#L51) for Stable Diffusion XL is:
-
-```py
-from diffusers.schedulers import AysSchedules
-
-sampling_schedule = AysSchedules["StableDiffusionXLTimesteps"]
-print(sampling_schedule)
-"[999, 845, 730, 587, 443, 310, 193, 116, 53, 13]"
-```
-
-You can use the AYS sampling schedule in a pipeline by passing it to the `timesteps` parameter.
-
-```py
-pipeline = StableDiffusionXLPipeline.from_pretrained(
- "SG161222/RealVisXL_V4.0",
- torch_dtype=torch.float16,
- variant="fp16",
-).to("cuda")
-pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, algorithm_type="sde-dpmsolver++")
-
-prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
-generator = torch.Generator(device="cpu").manual_seed(2487854446)
-image = pipeline(
- prompt=prompt,
- negative_prompt="",
- generator=generator,
- timesteps=sampling_schedule,
-).images[0]
-```
-
-
-
-
- AYS timestep schedule 10 steps
-
-
-
- Linearly-spaced timestep schedule 10 steps
-
-
-
- Linearly-spaced timestep schedule 25 steps
-
-
-
-## Timestep spacing
-
-The way sample steps are selected in the schedule can affect the quality of the generated image, especially with respect to [rescaling the noise schedule](#rescale-noise-schedule), which can enable a model to generate much brighter or darker images. Diffusers provides three timestep spacing methods:
-
-- `leading` creates evenly spaced steps
-- `linspace` includes the first and last steps and evenly selects the remaining intermediate steps
-- `trailing` only includes the last step and evenly selects the remaining intermediate steps starting from the end
-
-It is recommended to use the `trailing` spacing method because it generates higher quality images with more details when there are fewer sample steps. But the difference in quality is not as obvious for more standard sample step values.
-
-```py
-import torch
-from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
-
-pipeline = StableDiffusionXLPipeline.from_pretrained(
- "SG161222/RealVisXL_V4.0",
- torch_dtype=torch.float16,
- variant="fp16",
-).to("cuda")
-pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, timestep_spacing="trailing")
-
-prompt = "A cinematic shot of a cute little black cat sitting on a pumpkin at night"
-generator = torch.Generator(device="cpu").manual_seed(2487854446)
-image = pipeline(
- prompt=prompt,
- negative_prompt="",
- generator=generator,
- num_inference_steps=5,
-).images[0]
-image
-```
-
-
-
-
- trailing spacing after 5 steps
-
-
-
- leading spacing after 5 steps
-
-
-
-## Sigmas
-
-The `sigmas` parameter is the amount of noise added at each timestep according to the timestep schedule. Like the `timesteps` parameter, you can customize the `sigmas` parameter to control how much noise is added at each step. When you use a custom `sigmas` value, the `timesteps` are calculated from the custom `sigmas` value and the default scheduler configuration is ignored.
-
-For example, you can manually pass the [sigmas](https://github.com/huggingface/diffusers/blob/6529ee67ec02fcf58d2fd9242164ea002b351d75/src/diffusers/schedulers/scheduling_utils.py#L55) for something like the 10-step AYS schedule from before to the pipeline.
-
-```py
-import torch
-
-from diffusers import DiffusionPipeline, EulerDiscreteScheduler
-
-model_id = "stabilityai/stable-diffusion-xl-base-1.0"
-pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0",
- torch_dtype=torch.float16,
- variant="fp16",
-).to("cuda")
-pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
-
-sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0]
-prompt = "anthropomorphic capybara wearing a suit and working with a computer"
-generator = torch.Generator(device='cuda').manual_seed(123)
-image = pipeline(
- prompt=prompt,
- num_inference_steps=10,
- sigmas=sigmas,
- generator=generator
-).images[0]
-```
-
-When you take a look at the scheduler's `timesteps` parameter, you'll see that it is the same as the AYS timestep schedule because the `timestep` schedule is calculated from the `sigmas`.
-
-```py
-print(f" timesteps: {pipe.scheduler.timesteps}")
-"timesteps: tensor([999., 845., 730., 587., 443., 310., 193., 116., 53., 13.], device='cuda:0')"
-```
-
-### Karras sigmas
-
-> [!TIP]
-> Refer to the scheduler API [overview](../api/schedulers/overview) for a list of schedulers that support Karras sigmas.
->
-> Karras sigmas should not be used for models that weren't trained with them. For example, the base Stable Diffusion XL model shouldn't use Karras sigmas but the [DreamShaperXL](https://hf.co/Lykon/dreamshaper-xl-1-0) model can since they are trained with Karras sigmas.
-
-Karras scheduler's use the timestep schedule and sigmas from the [Elucidating the Design Space of Diffusion-Based Generative Models](https://hf.co/papers/2206.00364) paper. This scheduler variant applies a smaller amount of noise per step as it approaches the end of the sampling process compared to other schedulers, and can increase the level of details in the generated image.
-
-Enable Karras sigmas by setting `use_karras_sigmas=True` in the scheduler.
-
-```py
-import torch
-from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
-
-pipeline = StableDiffusionXLPipeline.from_pretrained(
- "SG161222/RealVisXL_V4.0",
- torch_dtype=torch.float16,
- variant="fp16",
-).to("cuda")
-pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, algorithm_type="sde-dpmsolver++", use_karras_sigmas=True)
-
-prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
-generator = torch.Generator(device="cpu").manual_seed(2487854446)
-image = pipeline(
- prompt=prompt,
- negative_prompt="",
- generator=generator,
-).images[0]
-```
-
-
-
-
- Karras sigmas enabled
-
-
-
- Karras sigmas disabled
-
-
-
-## Rescale noise schedule
-
-In the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://hf.co/papers/2305.08891) paper, the authors discovered that common noise schedules allowed some signal to leak into the last timestep. This signal leakage at inference can cause models to only generate images with medium brightness. By enforcing a zero signal-to-noise ratio (SNR) for the timstep schedule and sampling from the last timestep, the model can be improved to generate very bright or dark images.
-
-> [!TIP]
-> For inference, you need a model that has been trained with *v_prediction*. To train your own model with *v_prediction*, add the following flag to the [train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) or [train_text_to_image_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) scripts.
->
-> ```bash
-> --prediction_type="v_prediction"
-> ```
-
-For example, load the [ptx0/pseudo-journey-v2](https://hf.co/ptx0/pseudo-journey-v2) checkpoint which was trained with `v_prediction` and the [`DDIMScheduler`]. Configure the following parameters in the [`DDIMScheduler`]:
-
-* `rescale_betas_zero_snr=True` to rescale the noise schedule to zero SNR
-* `timestep_spacing="trailing"` to start sampling from the last timestep
-
-Set `guidance_rescale` in the pipeline to prevent over-exposure. A lower value increases brightness but some of the details may appear washed out.
-
-```py
-from diffusers import DiffusionPipeline, DDIMScheduler
-
-pipeline = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", use_safetensors=True)
-
-pipeline.scheduler = DDIMScheduler.from_config(
- pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
-)
-pipeline.to("cuda")
-prompt = "cinematic photo of a snowy mountain at night with the northern lights aurora borealis overhead, 35mm photograph, film, professional, 4k, highly detailed"
-generator = torch.Generator(device="cpu").manual_seed(23)
-image = pipeline(prompt, guidance_rescale=0.7, generator=generator).images[0]
-image
-```
-
-
-
-
- default Stable Diffusion v2-1 image
-
-
-
- image with zero SNR and trailing timestep spacing enabled
-
-
diff --git a/docs/source/en/using-diffusers/schedulers.md b/docs/source/en/using-diffusers/schedulers.md
index aabb9dd31c..0e236e4e3e 100644
--- a/docs/source/en/using-diffusers/schedulers.md
+++ b/docs/source/en/using-diffusers/schedulers.md
@@ -10,247 +10,273 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Load schedulers and models
-
[[open-in-colab]]
-Diffusion pipelines are a collection of interchangeable schedulers and models that can be mixed and matched to tailor a pipeline to a specific use case. The scheduler encapsulates the entire denoising process such as the number of denoising steps and the algorithm for finding the denoised sample. A scheduler is not parameterized or trained so they don't take very much memory. The model is usually only concerned with the forward pass of going from a noisy input to a less noisy sample.
+# Schedulers
-This guide will show you how to load schedulers and models to customize a pipeline. You'll use the [stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint throughout this guide, so let's load it first.
+A scheduler is an algorithm that provides instructions to the denoising process such as how much noise to remove at a certain step. It takes the model prediction from step *t* and applies an update for how to compute the next sample at step *t-1*. Different schedulers produce different results; some are faster while others are more accurate.
+
+Diffusers supports many schedulers and allows you to modify their timestep schedules, timestep spacing, and more, to generate high-quality images in fewer steps.
+
+This guide will show you how to load and customize schedulers.
+
+## Loading schedulers
+
+Schedulers don't have any parameters and are defined in a configuration file. Access the `.scheduler` attribute of a pipeline to view the configuration.
```py
import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
-).to("cuda")
-```
-
-You can see what scheduler this pipeline uses with the `pipeline.scheduler` attribute.
-
-```py
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, device_map="cuda"
+)
pipeline.scheduler
-PNDMScheduler {
- "_class_name": "PNDMScheduler",
- "_diffusers_version": "0.21.4",
- "beta_end": 0.012,
- "beta_schedule": "scaled_linear",
- "beta_start": 0.00085,
- "clip_sample": false,
- "num_train_timesteps": 1000,
- "set_alpha_to_one": false,
- "skip_prk_steps": true,
- "steps_offset": 1,
- "timestep_spacing": "leading",
- "trained_betas": null
-}
```
-## Load a scheduler
-
-Schedulers are defined by a configuration file that can be used by a variety of schedulers. Load a scheduler with the [`SchedulerMixin.from_pretrained`] method, and specify the `subfolder` parameter to load the configuration file into the correct subfolder of the pipeline repository.
-
-For example, to load the [`DDIMScheduler`]:
-
-```py
-from diffusers import DDIMScheduler, DiffusionPipeline
-
-ddim = DDIMScheduler.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="scheduler")
-```
-
-Then you can pass the newly loaded scheduler to the pipeline.
-
-```python
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", scheduler=ddim, torch_dtype=torch.float16, use_safetensors=True
-).to("cuda")
-```
-
-## Compare schedulers
-
-Schedulers have their own unique strengths and weaknesses, making it difficult to quantitatively compare which scheduler works best for a pipeline. You typically have to make a trade-off between denoising speed and denoising quality. We recommend trying out different schedulers to find one that works best for your use case. Call the `pipeline.scheduler.compatibles` attribute to see what schedulers are compatible with a pipeline.
-
-Let's compare the [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], and the [`DPMSolverMultistepScheduler`] on the following prompt and seed.
-
-```py
-import torch
-from diffusers import DiffusionPipeline
-
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
-).to("cuda")
-
-prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."
-generator = torch.Generator(device="cuda").manual_seed(8)
-```
-
-To change the pipelines scheduler, use the [`~ConfigMixin.from_config`] method to load a different scheduler's `pipeline.scheduler.config` into the pipeline.
-
-
-
-
-[`LMSDiscreteScheduler`] typically generates higher quality images than the default scheduler.
-
-```py
-from diffusers import LMSDiscreteScheduler
-
-pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
-image = pipeline(prompt, generator=generator).images[0]
-image
-```
-
-
-
-
-[`EulerDiscreteScheduler`] can generate higher quality images in just 30 steps.
-
-```py
-from diffusers import EulerDiscreteScheduler
-
-pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
-image = pipeline(prompt, generator=generator).images[0]
-image
-```
-
-
-
-
-[`EulerAncestralDiscreteScheduler`] can generate higher quality images in just 30 steps.
-
-```py
-from diffusers import EulerAncestralDiscreteScheduler
-
-pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)
-image = pipeline(prompt, generator=generator).images[0]
-image
-```
-
-
-
-
-[`DPMSolverMultistepScheduler`] provides a balance between speed and quality and can generate higher quality images in just 20 steps.
+Load a different scheduler with [`~SchedulerMixin.from_pretrained`] and specify the `subfolder` argument to load the configuration file into the correct subfolder of the pipeline repository. Pass the new scheduler to the existing pipeline.
```py
from diffusers import DPMSolverMultistepScheduler
-pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
-image = pipeline(prompt, generator=generator).images[0]
+dpm = DPMSolverMultistepScheduler.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler"
+)
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ scheduler=dpm,
+ torch_dtype=torch.float16,
+ device_map="cuda"
+)
+pipeline.scheduler
+```
+
+## Timestep schedules
+
+Timestep or noise schedule decides how noise is distributed over the denoising process. The schedule can be linear or more concentrated toward the beginning or end. It is a precomputed sequence of noise levels generated from the scheduler's default configuration, but it can be customized to use other schedules.
+
+> [!TIP]
+> The `timesteps` argument is only supported for a select list of schedulers and pipelines. Feel free to open a feature request if you want to extend these parameters to a scheduler and pipeline that does not currently support it!
+
+The example below uses the [Align Your Steps (AYS)](https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/) schedule which can generate a high-quality image in 10 steps, significantly speeding up generation and reducing computation time.
+
+Import the schedule and pass it to the `timesteps` argument in the pipeline.
+
+```py
+import torch
+from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
+from diffusers.schedulers import AysSchedules
+
+sampling_schedule = AysSchedules["StableDiffusionXLTimesteps"]
+print(sampling_schedule)
+"[999, 845, 730, 587, 443, 310, 193, 116, 53, 13]"
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "SG161222/RealVisXL_V4.0",
+ torch_dtype=torch.float16,
+ device_map="cuda"
+)
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
+ pipeline.scheduler.config, algorithm_type="sde-dpmsolver++"
+)
+
+prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
+image = pipeline(
+ prompt=prompt,
+ negative_prompt="",
+ timesteps=sampling_schedule,
+).images[0]
+```
+
+
+
+
+ AYS timestep schedule 10 steps
+
+
+
+ Linearly-spaced timestep schedule 10 steps
+
+
+
+ Linearly-spaced timestep schedule 25 steps
+
+
+
+### Rescaling schedules
+
+Denoising should begin with pure noise and the signal-to-noise (SNR) ration should be zero. However, some models don't actually start from pure noise which makes it difficult to generate images at brightness extremes.
+
+> [!TIP]
+> Train your own model with `v_prediction` by adding the `--prediction_type="v_prediction"` flag to your training script. You can also [search](https://huggingface.co/search/full-text?q=v_prediction&type=model) for existing models trained with `v_prediction`.
+
+To fix this, a model must be trained with `v_prediction`. If a model is trained with `v_prediction`, then enable the following arguments in the scheduler.
+
+- Set `rescale_betas_zero_snr=True` to rescale the noise schedule to the very last timestep with exactly zero SNR
+- Set `timestep_spacing="trailing"` to force sampling from the last timestep with pure noise
+
+```py
+from diffusers import DiffusionPipeline, DDIMScheduler
+
+pipeline = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", device_map="cuda")
+
+pipeline.scheduler = DDIMScheduler.from_config(
+ pipeline.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
+)
+```
+
+Set `guidance_rescale` in the pipeline to avoid overexposed images. A lower value increases brightness, but some details may appear washed out.
+
+```py
+prompt = """
+cinematic photo of a snowy mountain at night with the northern lights aurora borealis
+overhead, 35mm photograph, film, professional, 4k, highly detailed
+"""
+image = pipeline(prompt, guidance_rescale=0.7).images[0]
+```
+
+
+
+
+ default Stable Diffusion v2-1 image
+
+
+
+ image with zero SNR and trailing timestep spacing enabled
+
+
+
+## Timestep spacing
+
+Timestep spacing refers to the specific steps *t* to sample from from the schedule. Diffusers provides three spacing types as shown below.
+
+| spacing strategy | spacing calculation | example timesteps |
+|---|---|---|
+| `leading` | evenly spaced steps | `[900, 800, 700, ..., 100, 0]` |
+| `linspace` | include first and last steps and evenly divide remaining intermediate steps | `[1000, 888.89, 777.78, ..., 111.11, 0]` |
+| `trailing` | include last step and evenly divide remaining intermediate steps beginning from the end | `[999, 899, 799, 699, 599, 499, 399, 299, 199, 99]` |
+
+Pass the spacing strategy to the `timestep_spacing` argument in the scheduler.
+
+> [!TIP]
+> The `trailing` strategy typically produces higher quality images with more details with fewer steps, but the difference in quality is not as obvious for more standard step values.
+
+```py
+import torch
+from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "SG161222/RealVisXL_V4.0",
+ torch_dtype=torch.float16,
+ device_map="cuda"
+)
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
+ pipeline.scheduler.config, timestep_spacing="trailing"
+)
+
+prompt = "A cinematic shot of a cute little black cat sitting on a pumpkin at night"
+image = pipeline(
+ prompt=prompt,
+ negative_prompt="",
+ num_inference_steps=5,
+).images[0]
image
```
-
-
-
-
- LMSDiscreteScheduler
+
+ trailing spacing after 5 steps
-
- EulerDiscreteScheduler
-
-
-
-
-
- EulerAncestralDiscreteScheduler
-
-
-
- DPMSolverMultistepScheduler
+
+ leading spacing after 5 steps
-Most images look very similar and are comparable in quality. Again, it often comes down to your specific use case so a good approach is to run multiple different schedulers and compare the results.
+## Sigmas
-### Flax schedulers
+Sigmas is a measure of how noisy a sample is at a certain step as defined by the schedule. When using custom `sigmas`, the `timesteps` are calculated from these values instead of the default scheduler configuration.
-To compare Flax schedulers, you need to additionally load the scheduler state into the model parameters. For example, let's change the default scheduler in [`FlaxStableDiffusionPipeline`] to use the super fast [`FlaxDPMSolverMultistepScheduler`].
+> [!TIP]
+> The `sigmas` argument is only supported for a select list of schedulers and pipelines. Feel free to open a feature request if you want to extend these parameters to a scheduler and pipeline that does not currently support it!
-> [!WARNING]
-> The [`FlaxLMSDiscreteScheduler`] and [`FlaxDDPMScheduler`] are not compatible with the [`FlaxStableDiffusionPipeline`] yet.
+Pass the custom sigmas to the `sigmas` argument in the pipeline. The example below uses the [sigmas](https://github.com/huggingface/diffusers/blob/6529ee67ec02fcf58d2fd9242164ea002b351d75/src/diffusers/schedulers/scheduling_utils.py#L55) from the 10-step AYS schedule.
```py
-import jax
-import numpy as np
-from flax.jax_utils import replicate
-from flax.training.common_utils import shard
-from diffusers import FlaxStableDiffusionPipeline, FlaxDPMSolverMultistepScheduler
+import torch
+from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
-scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- subfolder="scheduler"
+pipeline = DiffusionPipeline.from_pretrained(
+ "SG161222/RealVisXL_V4.0",
+ torch_dtype=torch.float16,
+ device_map="cuda"
)
-pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- scheduler=scheduler,
- variant="bf16",
- dtype=jax.numpy.bfloat16,
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
+ pipeline.scheduler.config, algorithm_type="sde-dpmsolver++"
)
-params["scheduler"] = scheduler_state
+
+sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.0]
+prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
+image = pipeline(
+ prompt=prompt,
+ negative_prompt="",
+ sigmas=sigmas,
+).images[0]
```
-Then you can take advantage of Flax's compatibility with TPUs to generate a number of images in parallel. You'll need to make a copy of the model parameters for each available device and then split the inputs across them to generate your desired number of images.
+### Karras sigmas
+
+[Karras sigmas](https://huggingface.co/papers/2206.00364) resamples the noise schedule for more efficient sampling by clustering sigmas more densely in the middle of the sequence where structure reconstruction is critical, while using fewer sigmas at the beginning and end where noise changes have less impact. This can increase the level of details in a generated image.
+
+Set `use_karras_sigmas=True` in the scheduler to enable it.
```py
-# Generate 1 image per parallel device (8 on TPUv2-8 or TPUv3-8)
-prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition."
-num_samples = jax.device_count()
-prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
+import torch
+from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
-prng_seed = jax.random.PRNGKey(0)
-num_inference_steps = 25
-
-# shard inputs and rng
-params = replicate(params)
-prng_seed = jax.random.split(prng_seed, jax.device_count())
-prompt_ids = shard(prompt_ids)
-
-images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
-images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
-```
-
-## Models
-
-Models are loaded from the [`ModelMixin.from_pretrained`] method, which downloads and caches the latest version of the model weights and configurations. If the latest files are available in the local cache, [`~ModelMixin.from_pretrained`] reuses files in the cache instead of re-downloading them.
-
-Models can be loaded from a subfolder with the `subfolder` argument. For example, the model weights for [stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5) are stored in the [unet](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main/unet) subfolder.
-
-```python
-from diffusers import UNet2DConditionModel
-
-unet = UNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", use_safetensors=True)
-```
-
-They can also be directly loaded from a [repository](https://huggingface.co/google/ddpm-cifar10-32/tree/main).
-
-```python
-from diffusers import UNet2DModel
-
-unet = UNet2DModel.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True)
-```
-
-To load and save model variants, specify the `variant` argument in [`ModelMixin.from_pretrained`] and [`ModelMixin.save_pretrained`].
-
-```python
-from diffusers import UNet2DConditionModel
-
-unet = UNet2DConditionModel.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", variant="non_ema", use_safetensors=True
+pipeline = DiffusionPipeline.from_pretrained(
+ "SG161222/RealVisXL_V4.0",
+ torch_dtype=torch.float16,
+ device_map="cuda"
)
-unet.save_pretrained("./local-unet", variant="non_ema")
-```
-
-Use the `torch_dtype` argument in [`~ModelMixin.from_pretrained`] to specify the dtype to load a model in.
-
-```py
-from diffusers import AutoModel
-
-unet = AutoModel.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.float16
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
+ pipeline.scheduler.config,
+ algorithm_type="sde-dpmsolver++",
+ use_karras_sigmas=True,
)
+
+prompt = "A cinematic shot of a cute little rabbit wearing a jacket and doing a thumbs up"
+image = pipeline(
+ prompt=prompt,
+ negative_prompt="",
+ sigmas=sigmas,
+).images[0]
```
-You can also use the [torch.Tensor.to](https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html) method to convert to the specified dtype on the fly. It converts *all* weights unlike the `torch_dtype` argument that respects the `_keep_in_fp32_modules`. This is important for models whose layers must remain in fp32 for numerical stability and best generation quality (see example [here](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374)).
+
+
+
+ Karras sigmas enabled
+
+
+
+ Karras sigmas disabled
+
+
+
+Refer to the scheduler API [overview](../api/schedulers/overview) for a list of schedulers that support Karras sigmas. It should only be used for models trained with Karras sigmas.
+
+## Choosing a scheduler
+
+It's important to try different schedulers to find the best one for your use case. Here are a few recommendations to help you get started.
+
+- DPM++ 2M SDE Karras is generally a good all-purpose option.
+- [`TCDScheduler`] works well for distilled models.
+- [`FlowMatchEulerDiscreteScheduler`] and [`FlowMatchHeunDiscreteScheduler`] for FlowMatch models.
+- [`EulerDiscreteScheduler`] or [`EulerAncestralDiscreteScheduler`] for generating anime style images.
+- DPM++ 2M paired with [`LCMScheduler`] on SDXL for generating realistic images.
+
+## Resources
+
+- Read the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) paper for more details about rescaling the noise schedule to enforce zero SNR.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/sdxl.md b/docs/source/en/using-diffusers/sdxl.md
index 106005c338..275394a03c 100644
--- a/docs/source/en/using-diffusers/sdxl.md
+++ b/docs/source/en/using-diffusers/sdxl.md
@@ -29,15 +29,12 @@ Before you begin, make sure you have the following libraries installed:
#!pip install -q diffusers transformers accelerate invisible-watermark>=0.2.0
```
-
-
-We recommend installing the [invisible-watermark](https://pypi.org/project/invisible-watermark/) library to help identify images that are generated. If the invisible-watermark library is installed, it is used by default. To disable the watermarker:
-
-```py
-pipeline = StableDiffusionXLPipeline.from_pretrained(..., add_watermarker=False)
-```
-
-
+> [!WARNING]
+> We recommend installing the [invisible-watermark](https://pypi.org/project/invisible-watermark/) library to help identify images that are generated. If the invisible-watermark library is installed, it is used by default. To disable the watermarker:
+>
+> ```py
+> pipeline = StableDiffusionXLPipeline.from_pretrained(..., add_watermarker=False)
+> ```
## Load model checkpoints
@@ -174,11 +171,8 @@ refiner = DiffusionPipeline.from_pretrained(
To use this approach, you need to define the number of timesteps for each model to run through their respective stages. For the base model, this is controlled by the [`denoising_end`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline.__call__.denoising_end) parameter and for the refiner model, it is controlled by the [`denoising_start`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLImg2ImgPipeline.__call__.denoising_start) parameter.
-
-
-The `denoising_end` and `denoising_start` parameters should be a float between 0 and 1. These parameters are represented as a proportion of discrete timesteps as defined by the scheduler. If you're also using the `strength` parameter, it'll be ignored because the number of denoising steps is determined by the discrete timesteps the model is trained on and the declared fractional cutoff.
-
-
+> [!TIP]
+> The `denoising_end` and `denoising_start` parameters should be a float between 0 and 1. These parameters are represented as a proportion of discrete timesteps as defined by the scheduler. If you're also using the `strength` parameter, it'll be ignored because the number of denoising steps is determined by the discrete timesteps the model is trained on and the declared fractional cutoff.
Let's set `denoising_end=0.8` so the base model performs the first 80% of denoising the **high-noise** timesteps and set `denoising_start=0.8` so the refiner model performs the last 20% of denoising the **low-noise** timesteps. The base model output should be in **latent** space instead of a PIL image.
@@ -285,11 +279,8 @@ refiner = DiffusionPipeline.from_pretrained(
).to("cuda")
```
-
-
-You can use SDXL refiner with a different base model. For example, you can use the [Hunyuan-DiT](../../api/pipelines/hunyuandit) or [PixArt-Sigma](../../api/pipelines/pixart_sigma) pipelines to generate images with better prompt adherence. Once you have generated an image, you can pass it to the SDXL refiner model to enhance final generation quality.
-
-
+> [!WARNING]
+> You can use SDXL refiner with a different base model. For example, you can use the [Hunyuan-DiT](../api/pipelines/hunyuandit) or [PixArt-Sigma](../api/pipelines/pixart_sigma) pipelines to generate images with better prompt adherence. Once you have generated an image, you can pass it to the SDXL refiner model to enhance final generation quality.
Generate an image from the base model, and set the model output to **latent** space:
@@ -322,11 +313,8 @@ For inpainting, load the base and the refiner model in the [`StableDiffusionXLIn
SDXL training involves several additional conditioning techniques, which are referred to as *micro-conditioning*. These include original image size, target image size, and cropping parameters. The micro-conditionings can be used at inference time to create high-quality, centered images.
-
-
-You can use both micro-conditioning and negative micro-conditioning parameters thanks to classifier-free guidance. They are available in the [`StableDiffusionXLPipeline`], [`StableDiffusionXLImg2ImgPipeline`], [`StableDiffusionXLInpaintPipeline`], and [`StableDiffusionXLControlNetPipeline`].
-
-
+> [!TIP]
+> You can use both micro-conditioning and negative micro-conditioning parameters thanks to classifier-free guidance. They are available in the [`StableDiffusionXLPipeline`], [`StableDiffusionXLImg2ImgPipeline`], [`StableDiffusionXLInpaintPipeline`], and [`StableDiffusionXLControlNetPipeline`].
### Size conditioning
diff --git a/docs/source/en/using-diffusers/shap-e.md b/docs/source/en/using-diffusers/shap-e.md
index 51f0f53b02..8cd62b3ffd 100644
--- a/docs/source/en/using-diffusers/shap-e.md
+++ b/docs/source/en/using-diffusers/shap-e.md
@@ -151,11 +151,8 @@ images = pipe(prompt, guidance_scale=guidance_scale, num_inference_steps=64, fra
Use the [`~utils.export_to_ply`] function to save the mesh output as a `ply` file:
-
-
-You can optionally save the mesh output as an `obj` file with the [`~utils.export_to_obj`] function. The ability to save the mesh output in a variety of formats makes it more flexible for downstream usage!
-
-
+> [!TIP]
+> You can optionally save the mesh output as an `obj` file with the [`~utils.export_to_obj`] function. The ability to save the mesh output in a variety of formats makes it more flexible for downstream usage!
```py
from diffusers.utils import export_to_ply
diff --git a/docs/source/en/using-diffusers/stable_diffusion_jax_how_to.md b/docs/source/en/using-diffusers/stable_diffusion_jax_how_to.md
deleted file mode 100644
index ac9ffe0dfc..0000000000
--- a/docs/source/en/using-diffusers/stable_diffusion_jax_how_to.md
+++ /dev/null
@@ -1,225 +0,0 @@
-
-
-# JAX/Flax
-
-[[open-in-colab]]
-
-🤗 Diffusers supports Flax for super fast inference on Google TPUs, such as those available in Colab, Kaggle or Google Cloud Platform. This guide shows you how to run inference with Stable Diffusion using JAX/Flax.
-
-Before you begin, make sure you have the necessary libraries installed:
-
-```py
-# uncomment to install the necessary libraries in Colab
-#!pip install -q jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy
-#!pip install -q diffusers
-```
-
-You should also make sure you're using a TPU backend. While JAX does not run exclusively on TPUs, you'll get the best performance on a TPU because each server has 8 TPU accelerators working in parallel.
-
-If you are running this guide in Colab, select *Runtime* in the menu above, select the option *Change runtime type*, and then select *TPU* under the *Hardware accelerator* setting. Import JAX and quickly check whether you're using a TPU:
-
-```python
-import jax
-import jax.tools.colab_tpu
-jax.tools.colab_tpu.setup_tpu()
-
-num_devices = jax.device_count()
-device_type = jax.devices()[0].device_kind
-
-print(f"Found {num_devices} JAX devices of type {device_type}.")
-assert (
- "TPU" in device_type,
- "Available device is not a TPU, please select TPU from Runtime > Change runtime type > Hardware accelerator"
-)
-# Found 8 JAX devices of type Cloud TPU.
-```
-
-Great, now you can import the rest of the dependencies you'll need:
-
-```python
-import jax.numpy as jnp
-from jax import pmap
-from flax.jax_utils import replicate
-from flax.training.common_utils import shard
-
-from diffusers import FlaxStableDiffusionPipeline
-```
-
-## Load a model
-
-Flax is a functional framework, so models are stateless and parameters are stored outside of them. Loading a pretrained Flax pipeline returns *both* the pipeline and the model weights (or parameters). In this guide, you'll use `bfloat16`, a more efficient half-float type that is supported by TPUs (you can also use `float32` for full precision if you want).
-
-```python
-dtype = jnp.bfloat16
-pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4",
- variant="bf16",
- dtype=dtype,
-)
-```
-
-## Inference
-
-TPUs usually have 8 devices working in parallel, so let's use the same prompt for each device. This means you can perform inference on 8 devices at once, with each device generating one image. As a result, you'll get 8 images in the same amount of time it takes for one chip to generate a single image!
-
-
-
-Learn more details in the [How does parallelization work?](#how-does-parallelization-work) section.
-
-
-
-After replicating the prompt, get the tokenized text ids by calling the `prepare_inputs` function on the pipeline. The length of the tokenized text is set to 77 tokens as required by the configuration of the underlying CLIP text model.
-
-```python
-prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
-prompt = [prompt] * jax.device_count()
-prompt_ids = pipeline.prepare_inputs(prompt)
-prompt_ids.shape
-# (8, 77)
-```
-
-Model parameters and inputs have to be replicated across the 8 parallel devices. The parameters dictionary is replicated with [`flax.jax_utils.replicate`](https://flax.readthedocs.io/en/latest/api_reference/flax.jax_utils.html#flax.jax_utils.replicate) which traverses the dictionary and changes the shape of the weights so they are repeated 8 times. Arrays are replicated using `shard`.
-
-```python
-# parameters
-p_params = replicate(params)
-
-# arrays
-prompt_ids = shard(prompt_ids)
-prompt_ids.shape
-# (8, 1, 77)
-```
-
-This shape means each one of the 8 devices receives as an input a `jnp` array with shape `(1, 77)`, where `1` is the batch size per device. On TPUs with sufficient memory, you could have a batch size larger than `1` if you want to generate multiple images (per chip) at once.
-
-Next, create a random number generator to pass to the generation function. This is standard procedure in Flax, which is very serious and opinionated about random numbers. All functions that deal with random numbers are expected to receive a generator to ensure reproducibility, even when you're training across multiple distributed devices.
-
-The helper function below uses a seed to initialize a random number generator. As long as you use the same seed, you'll get the exact same results. Feel free to use different seeds when exploring results later in the guide.
-
-```python
-def create_key(seed=0):
- return jax.random.PRNGKey(seed)
-```
-
-The helper function, or `rng`, is split 8 times so each device receives a different generator and generates a different image.
-
-```python
-rng = create_key(0)
-rng = jax.random.split(rng, jax.device_count())
-```
-
-To take advantage of JAX's optimized speed on a TPU, pass `jit=True` to the pipeline to compile the JAX code into an efficient representation and to ensure the model runs in parallel across the 8 devices.
-
-
-
-You need to ensure all your inputs have the same shape in subsequent calls, otherwise JAX will need to recompile the code which is slower.
-
-
-
-The first inference run takes more time because it needs to compile the code, but subsequent calls (even with different inputs) are much faster. For example, it took more than a minute to compile on a TPU v2-8, but then it takes about **7s** on a future inference run!
-
-```py
-%%time
-images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
-
-# CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
-# Wall time: 1min 29s
-```
-
-The returned array has shape `(8, 1, 512, 512, 3)` which should be reshaped to remove the second dimension and get 8 images of `512 × 512 × 3`. Then you can use the [`~utils.numpy_to_pil`] function to convert the arrays into images.
-
-```python
-from diffusers.utils import make_image_grid
-
-images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
-images = pipeline.numpy_to_pil(images)
-make_image_grid(images, rows=2, cols=4)
-```
-
-
-
-## Using different prompts
-
-You don't necessarily have to use the same prompt on all devices. For example, to generate 8 different prompts:
-
-```python
-prompts = [
- "Labrador in the style of Hokusai",
- "Painting of a squirrel skating in New York",
- "HAL-9000 in the style of Van Gogh",
- "Times Square under water, with fish and a dolphin swimming around",
- "Ancient Roman fresco showing a man working on his laptop",
- "Close-up photograph of young black woman against urban background, high quality, bokeh",
- "Armchair in the shape of an avocado",
- "Clown astronaut in space, with Earth in the background",
-]
-
-prompt_ids = pipeline.prepare_inputs(prompts)
-prompt_ids = shard(prompt_ids)
-
-images = pipeline(prompt_ids, p_params, rng, jit=True).images
-images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
-images = pipeline.numpy_to_pil(images)
-
-make_image_grid(images, 2, 4)
-```
-
-
-
-## How does parallelization work?
-
-The Flax pipeline in 🤗 Diffusers automatically compiles the model and runs it in parallel on all available devices. Let's take a closer look at how that process works.
-
-JAX parallelization can be done in multiple ways. The easiest one revolves around using the [`jax.pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) function to achieve single-program multiple-data (SPMD) parallelization. It means running several copies of the same code, each on different data inputs. More sophisticated approaches are possible, and you can go over to the JAX [documentation](https://jax.readthedocs.io/en/latest/index.html) to explore this topic in more detail if you are interested!
-
-`jax.pmap` does two things:
-
-1. Compiles (or "`jit`s") the code which is similar to `jax.jit()`. This does not happen when you call `pmap`, and only the first time the `pmap`ped function is called.
-2. Ensures the compiled code runs in parallel on all available devices.
-
-To demonstrate, call `pmap` on the pipeline's `_generate` method (this is a private method that generates images and may be renamed or removed in future releases of 🤗 Diffusers):
-
-```python
-p_generate = pmap(pipeline._generate)
-```
-
-After calling `pmap`, the prepared function `p_generate` will:
-
-1. Make a copy of the underlying function, `pipeline._generate`, on each device.
-2. Send each device a different portion of the input arguments (this is why it's necessary to call the *shard* function). In this case, `prompt_ids` has shape `(8, 1, 77, 768)` so the array is split into 8 and each copy of `_generate` receives an input with shape `(1, 77, 768)`.
-
-The most important thing to pay attention to here is the batch size (1 in this example), and the input dimensions that make sense for your code. You don't have to change anything else to make the code work in parallel.
-
-The first time you call the pipeline takes more time, but the calls afterward are much faster. The `block_until_ready` function is used to correctly measure inference time because JAX uses asynchronous dispatch and returns control to the Python loop as soon as it can. You don't need to use that in your code; blocking occurs automatically when you want to use the result of a computation that has not yet been materialized.
-
-```py
-%%time
-images = p_generate(prompt_ids, p_params, rng)
-images = images.block_until_ready()
-
-# CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
-# Wall time: 1min 15s
-```
-
-Check your image dimensions to see if they're correct:
-
-```python
-images.shape
-# (8, 1, 512, 512, 3)
-```
-
-## Resources
-
-To learn more about how JAX works with Stable Diffusion, you may be interested in reading:
-
-* [Accelerating Stable Diffusion XL Inference with JAX on Cloud TPU v5e](https://hf.co/blog/sdxl_jax)
diff --git a/docs/source/en/using-diffusers/text-img2vid.md b/docs/source/en/using-diffusers/text-img2vid.md
index 67d1fd118e..9b69a2fded 100644
--- a/docs/source/en/using-diffusers/text-img2vid.md
+++ b/docs/source/en/using-diffusers/text-img2vid.md
@@ -98,7 +98,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16
},
- components_to_quantize=["transformer"]
+ components_to_quantize="transformer"
)
pipeline = HunyuanVideoPipeline.from_pretrained(
@@ -287,7 +287,7 @@ export_to_video(output, "output.mp4", fps=16)
## Reduce memory usage
-Recent video models like [`HunyuanVideoPipeline`] and [`WanPipeline`], which have 10B+ parameters, require a lot of memory and it often exceeds the memory availabe on consumer hardware. Diffusers offers several techniques for reducing the memory requirements of these large models.
+Recent video models like [`HunyuanVideoPipeline`] and [`WanPipeline`], which have 10B+ parameters, require a lot of memory and it often exceeds the memory available on consumer hardware. Diffusers offers several techniques for reducing the memory requirements of these large models.
> [!TIP]
> Refer to the [Reduce memory usage](../optimization/memory) guide for more details about other memory saving techniques.
diff --git a/docs/source/en/using-diffusers/unconditional_image_generation.md b/docs/source/en/using-diffusers/unconditional_image_generation.md
index 0208d715d4..0add5bab67 100644
--- a/docs/source/en/using-diffusers/unconditional_image_generation.md
+++ b/docs/source/en/using-diffusers/unconditional_image_generation.md
@@ -26,11 +26,8 @@ image = generator().images[0]
image
```
-
-
-Want to generate images of something else? Take a look at the training [guide](../training/unconditional_training) to learn how to train a model to generate your own images.
-
-
+> [!TIP]
+> Want to generate images of something else? Take a look at the training [guide](../training/unconditional_training) to learn how to train a model to generate your own images.
The output image is a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object that can be saved:
diff --git a/docs/source/en/using-diffusers/weighted_prompts.md b/docs/source/en/using-diffusers/weighted_prompts.md
index 2ebf92d0eb..f89ebfe4a2 100644
--- a/docs/source/en/using-diffusers/weighted_prompts.md
+++ b/docs/source/en/using-diffusers/weighted_prompts.md
@@ -10,426 +10,96 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Prompt techniques
-
[[open-in-colab]]
-Prompts are important because they describe what you want a diffusion model to generate. The best prompts are detailed, specific, and well-structured to help the model realize your vision. But crafting a great prompt takes time and effort and sometimes it may not be enough because language and words can be imprecise. This is where you need to boost your prompt with other techniques, such as prompt enhancing and prompt weighting, to get the results you want.
+# Prompting
-This guide will show you how you can use these prompt techniques to generate high-quality images with lower effort and adjust the weight of certain keywords in a prompt.
+Prompts describes what a model should generate. Good prompts are detailed, specific, and structured and they generate better images and videos.
-## Prompt engineering
+This guide shows you how to write effective prompts and introduces techniques that make them stronger.
-> [!TIP]
-> This is not an exhaustive guide on prompt engineering, but it will help you understand the necessary parts of a good prompt. We encourage you to continue experimenting with different prompts and combine them in new ways to see what works best. As you write more prompts, you'll develop an intuition for what works and what doesn't!
+## Writing good prompts
-New diffusion models do a pretty good job of generating high-quality images from a basic prompt, but it is still important to create a well-written prompt to get the best results. Here are a few tips for writing a good prompt:
+Every effective prompt needs three core elements.
-1. What is the image *medium*? Is it a photo, a painting, a 3D illustration, or something else?
-2. What is the image *subject*? Is it a person, animal, object, or scene?
-3. What *details* would you like to see in the image? This is where you can get really creative and have a lot of fun experimenting with different words to bring your image to life. For example, what is the lighting like? What is the vibe and aesthetic? What kind of art or illustration style are you looking for? The more specific and precise words you use, the better the model will understand what you want to generate.
+1. Subject - what you want to generate. Start your prompt here.
+2. Style - the medium or aesthetic. How should it look?
+3. Context - details about actions, setting, and mood.
+
+Use these elements as a structured narrative, not a keyword list. Modern models understand language better than keyword matching. Start simple, then add details.
+
+Context is especially important for creating better prompts. Try adding lighting, artistic details, and mood.
-
-
- "A photo of a banana-shaped couch in a living room"
+
+
+ A cute catlounges on a leaf in a pool during a peaceful summer afternoon, in lofi art style, illustration.
-
-
- "A vibrant yellow banana-shaped couch sits in a cozy living room, its curve cradling a pile of colorful cushions. on the wooden floor, a patterned rug adds a touch of eclectic charm, and a potted plant sits in the corner, reaching towards the sunlight filtering through the windows"
+
+
+ A cute cat lounges on a floating leaf in a sparkling pool during a peaceful summer afternoon. Clear reflections ripple across the water, with sunlight casting soft, smooth highlights. The illustration is detailed and polished, with elegant lines and harmonious colors, evoking a relaxing, serene, and whimsical lofi mood, anime-inspired and visually comforting.
-## Prompt enhancing with GPT2
-
-Prompt enhancing is a technique for quickly improving prompt quality without spending too much effort constructing one. It uses a model like GPT2 pretrained on Stable Diffusion text prompts to automatically enrich a prompt with additional important keywords to generate high-quality images.
-
-The technique works by curating a list of specific keywords and forcing the model to generate those words to enhance the original prompt. This way, your prompt can be "a cat" and GPT2 can enhance the prompt to "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain quality sharp focus beautiful detailed intricate stunning amazing epic".
+Be specific and add context. Use photography terms like lens type, focal length, camera angles, and depth of field.
> [!TIP]
-> You should also use a [*offset noise*](https://www.crosslabs.org//blog/diffusion-with-offset-noise) LoRA to improve the contrast in bright and dark images and create better lighting overall. This [LoRA](https://hf.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_offset_example-lora_1.0.safetensors) is available from [stabilityai/stable-diffusion-xl-base-1.0](https://hf.co/stabilityai/stable-diffusion-xl-base-1.0).
-
-Start by defining certain styles and a list of words (you can check out a more comprehensive list of [words](https://hf.co/LykosAI/GPT-Prompt-Expansion-Fooocus-v2/blob/main/positive.txt) and [styles](https://github.com/lllyasviel/Fooocus/tree/main/sdxl_styles) used by Fooocus) to enhance a prompt with.
-
-```py
-import torch
-from transformers import GenerationConfig, GPT2LMHeadModel, GPT2Tokenizer, LogitsProcessor, LogitsProcessorList
-from diffusers import StableDiffusionXLPipeline
-
-styles = {
- "cinematic": "cinematic film still of {prompt}, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
- "anime": "anime artwork of {prompt}, anime style, key visual, vibrant, studio anime, highly detailed",
- "photographic": "cinematic photo of {prompt}, 35mm photograph, film, professional, 4k, highly detailed",
- "comic": "comic of {prompt}, graphic illustration, comic art, graphic novel art, vibrant, highly detailed",
- "lineart": "line art drawing {prompt}, professional, sleek, modern, minimalist, graphic, line art, vector graphics",
- "pixelart": " pixel-art {prompt}, low-res, blocky, pixel art style, 8-bit graphics",
-}
-
-words = [
- "aesthetic", "astonishing", "beautiful", "breathtaking", "composition", "contrasted", "epic", "moody", "enhanced",
- "exceptional", "fascinating", "flawless", "glamorous", "glorious", "illumination", "impressive", "improved",
- "inspirational", "magnificent", "majestic", "hyperrealistic", "smooth", "sharp", "focus", "stunning", "detailed",
- "intricate", "dramatic", "high", "quality", "perfect", "light", "ultra", "highly", "radiant", "satisfying",
- "soothing", "sophisticated", "stylish", "sublime", "terrific", "touching", "timeless", "wonderful", "unbelievable",
- "elegant", "awesome", "amazing", "dynamic", "trendy",
-]
-```
-
-You may have noticed in the `words` list, there are certain words that can be paired together to create something more meaningful. For example, the words "high" and "quality" can be combined to create "high quality". Let's pair these words together and remove the words that can't be paired.
-
-```py
-word_pairs = ["highly detailed", "high quality", "enhanced quality", "perfect composition", "dynamic light"]
-
-def find_and_order_pairs(s, pairs):
- words = s.split()
- found_pairs = []
- for pair in pairs:
- pair_words = pair.split()
- if pair_words[0] in words and pair_words[1] in words:
- found_pairs.append(pair)
- words.remove(pair_words[0])
- words.remove(pair_words[1])
-
- for word in words[:]:
- for pair in pairs:
- if word in pair.split():
- words.remove(word)
- break
- ordered_pairs = ", ".join(found_pairs)
- remaining_s = ", ".join(words)
- return ordered_pairs, remaining_s
-```
-
-Next, implement a custom [`~transformers.LogitsProcessor`] class that assigns tokens in the `words` list a value of 0 and assigns tokens not in the `words` list a negative value so they aren't picked during generation. This way, generation is biased towards words in the `words` list. After a word from the list is used, it is also assigned a negative value so it isn't picked again.
-
-```py
-class CustomLogitsProcessor(LogitsProcessor):
- def __init__(self, bias):
- super().__init__()
- self.bias = bias
-
- def __call__(self, input_ids, scores):
- if len(input_ids.shape) == 2:
- last_token_id = input_ids[0, -1]
- self.bias[last_token_id] = -1e10
- return scores + self.bias
-
-word_ids = [tokenizer.encode(word, add_prefix_space=True)[0] for word in words]
-bias = torch.full((tokenizer.vocab_size,), -float("Inf")).to("cuda")
-bias[word_ids] = 0
-processor = CustomLogitsProcessor(bias)
-processor_list = LogitsProcessorList([processor])
-```
-
-Combine the prompt and the `cinematic` style prompt defined in the `styles` dictionary earlier.
-
-```py
-prompt = "a cat basking in the sun on a roof in Turkey"
-style = "cinematic"
-
-prompt = styles[style].format(prompt=prompt)
-prompt
-"cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain"
-```
-
-Load a GPT2 tokenizer and model from the [Gustavosta/MagicPrompt-Stable-Diffusion](https://huggingface.co/Gustavosta/MagicPrompt-Stable-Diffusion) checkpoint (this specific checkpoint is trained to generate prompts) to enhance the prompt.
-
-```py
-tokenizer = GPT2Tokenizer.from_pretrained("Gustavosta/MagicPrompt-Stable-Diffusion")
-model = GPT2LMHeadModel.from_pretrained("Gustavosta/MagicPrompt-Stable-Diffusion", torch_dtype=torch.float16).to(
- "cuda"
-)
-model.eval()
-
-inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
-token_count = inputs["input_ids"].shape[1]
-max_new_tokens = 50 - token_count
-
-generation_config = GenerationConfig(
- penalty_alpha=0.7,
- top_k=50,
- eos_token_id=model.config.eos_token_id,
- pad_token_id=model.config.eos_token_id,
- pad_token=model.config.pad_token_id,
- do_sample=True,
-)
-
-with torch.no_grad():
- generated_ids = model.generate(
- input_ids=inputs["input_ids"],
- attention_mask=inputs["attention_mask"],
- max_new_tokens=max_new_tokens,
- generation_config=generation_config,
- logits_processor=proccesor_list,
- )
-```
-
-Then you can combine the input prompt and the generated prompt. Feel free to take a look at what the generated prompt (`generated_part`) is, the word pairs that were found (`pairs`), and the remaining words (`words`). This is all packed together in the `enhanced_prompt`.
-
-```py
-output_tokens = [tokenizer.decode(generated_id, skip_special_tokens=True) for generated_id in generated_ids]
-input_part, generated_part = output_tokens[0][: len(prompt)], output_tokens[0][len(prompt) :]
-pairs, words = find_and_order_pairs(generated_part, word_pairs)
-formatted_generated_part = pairs + ", " + words
-enhanced_prompt = input_part + ", " + formatted_generated_part
-enhanced_prompt
-["cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain quality sharp focus beautiful detailed intricate stunning amazing epic"]
-```
-
-Finally, load a pipeline and the offset noise LoRA with a *low weight* to generate an image with the enhanced prompt.
-
-```py
-pipeline = StableDiffusionXLPipeline.from_pretrained(
- "RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.float16, variant="fp16"
-).to("cuda")
-
-pipeline.load_lora_weights(
- "stabilityai/stable-diffusion-xl-base-1.0",
- weight_name="sd_xl_offset_example-lora_1.0.safetensors",
- adapter_name="offset",
-)
-pipeline.set_adapters(["offset"], adapter_weights=[0.2])
-
-image = pipeline(
- enhanced_prompt,
- width=1152,
- height=896,
- guidance_scale=7.5,
- num_inference_steps=25,
-).images[0]
-image
-```
-
-
-
-
- "a cat basking in the sun on a roof in Turkey"
-
-
-
- "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain"
-
-
+> Try a [prompt enhancer](https://huggingface.co/models?sort=downloads&search=prompt+enhancer) to help improve your prompt structure.
## Prompt weighting
-Prompt weighting provides a way to emphasize or de-emphasize certain parts of a prompt, allowing for more control over the generated image. A prompt can include several concepts, which gets turned into contextualized text embeddings. The embeddings are used by the model to condition its cross-attention layers to generate an image (read the Stable Diffusion [blog post](https://huggingface.co/blog/stable_diffusion) to learn more about how it works).
+Prompt weighting makes some words stronger and others weaker. It scales attention scores so you control how much influence each concept has.
-Prompt weighting works by increasing or decreasing the scale of the text embedding vector that corresponds to its concept in the prompt because you may not necessarily want the model to focus on all concepts equally. The easiest way to prepare the prompt embeddings is to use [Stable Diffusion Long Prompt Weighted Embedding](https://github.com/xhinker/sd_embed) (sd_embed). Once you have the prompt-weighted embeddings, you can pass them to any pipeline that has a [prompt_embeds](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.prompt_embeds) (and optionally [negative_prompt_embeds](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.negative_prompt_embeds)) parameter, such as [`StableDiffusionPipeline`], [`StableDiffusionControlNetPipeline`], and [`StableDiffusionXLPipeline`].
+Diffusers handles this through `prompt_embeds` and `pooled_prompt_embeds` arguments which take scaled text embedding vectors. Use the [sd_embed](https://github.com/xhinker/sd_embed) library to generate these embeddings. It also supports longer prompts.
-
-
-If your favorite pipeline doesn't have a `prompt_embeds` parameter, please open an [issue](https://github.com/huggingface/diffusers/issues/new/choose) so we can add it!
-
-
-
-This guide will show you how to weight your prompts with sd_embed.
-
-Before you begin, make sure you have the latest version of sd_embed installed:
-
-```bash
-pip install git+https://github.com/xhinker/sd_embed.git@main
-```
-
-For this example, let's use [`StableDiffusionXLPipeline`].
+> [!NOTE]
+> The sd_embed library only supports Stable Diffusion, Stable Diffusion XL, Stable Diffusion 3, Stable Cascade, and Flux. Prompt weighting doesn't necessarily help for newer models like Flux which already has very good prompt adherence.
```py
-from diffusers import StableDiffusionXLPipeline, UniPCMultistepScheduler
-import torch
-
-pipe = StableDiffusionXLPipeline.from_pretrained("Lykon/dreamshaper-xl-1-0", torch_dtype=torch.float16)
-pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
-pipe.to("cuda")
+!uv pip install git+https://github.com/xhinker/sd_embed.git@main
```
-To upweight or downweight a concept, surround the text with parentheses. More parentheses applies a heavier weight on the text. You can also append a numerical multiplier to the text to indicate how much you want to increase or decrease its weights by.
+Format weighted text with numerical multipliers or parentheses. More parentheses mean stronger weighting.
| format | multiplier |
|---|---|
-| `(hippo)` | increase by 1.1x |
-| `((hippo))` | increase by 1.21x |
-| `(hippo:1.5)` | increase by 1.5x |
-| `(hippo:0.5)` | decrease by 4x |
+| `(cat)` | increase by 1.1x |
+| `((cat))` | increase by 1.21x |
+| `(cat:1.5)` | increase by 1.5x |
+| `(cat:0.5)` | decrease by 4x |
-Create a prompt and use a combination of parentheses and numerical multipliers to upweight various text.
+Create a weighted prompt and pass it to [get_weighted_text_embeddings_sdxl](https://github.com/xhinker/sd_embed/blob/4a47f71150a22942fa606fb741a1c971d95ba56f/src/sd_embed/embedding_funcs.py#L405) to generate embeddings.
+
+> [!TIP]
+> You could also pass negative prompts to `negative_prompt_embeds` and `negative_pooled_prompt_embeds`.
```py
+import torch
+from diffusers import DiffusionPipeline
from sd_embed.embedding_funcs import get_weighted_text_embeddings_sdxl
-prompt = """A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus.
-This imaginative creature features the distinctive, bulky body of a hippo,
-but with a texture and appearance resembling a golden-brown, crispy waffle.
-The creature might have elements like waffle squares across its skin and a syrup-like sheen.
-It's set in a surreal environment that playfully combines a natural water habitat of a hippo with elements of a breakfast table setting,
-possibly including oversized utensils or plates in the background.
-The image should evoke a sense of playful absurdity and culinary fantasy.
-"""
-
-neg_prompt = """\
-skin spots,acnes,skin blemishes,age spot,(ugly:1.2),(duplicate:1.2),(morbid:1.21),(mutilated:1.2),\
-(tranny:1.2),mutated hands,(poorly drawn hands:1.5),blurry,(bad anatomy:1.2),(bad proportions:1.3),\
-extra limbs,(disfigured:1.2),(missing arms:1.2),(extra legs:1.2),(fused fingers:1.5),\
-(too many fingers:1.5),(unclear eyes:1.2),lowers,bad hands,missing fingers,extra digit,\
-bad hands,missing fingers,(extra arms and legs),(worst quality:2),(low quality:2),\
-(normal quality:2),lowres,((monochrome)),((grayscale))
-"""
-```
-
-Use the `get_weighted_text_embeddings_sdxl` function to generate the prompt embeddings and the negative prompt embeddings. It'll also generated the pooled and negative pooled prompt embeddings since you're using the SDXL model.
-
-> [!TIP]
-> You can safely ignore the error message below about the token index length exceeding the models maximum sequence length. All your tokens will be used in the embedding process.
->
-> ```
-> Token indices sequence length is longer than the specified maximum sequence length for this model
-> ```
-
-```py
-(
- prompt_embeds,
- prompt_neg_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds
-) = get_weighted_text_embeddings_sdxl(
- pipe,
- prompt=prompt,
- neg_prompt=neg_prompt
+pipeline = DiffusionPipeline.from_pretrained(
+ "Lykon/dreamshaper-xl-1-0", torch_dtype=torch.bfloat16, device_map="cuda"
)
-image = pipe(
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=prompt_neg_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- num_inference_steps=30,
- height=1024,
- width=1024 + 512,
- guidance_scale=4.0,
- generator=torch.Generator("cuda").manual_seed(2)
-).images[0]
-image
+prompt = """
+A (cute cat:1.4) lounges on a (floating leaf:1.2) in a (sparkling pool:1.1) during a peaceful summer afternoon.
+Gentle ripples reflect pastel skies, while (sunlight:1.1) casts soft highlights. The illustration is smooth and polished
+with elegant, sketchy lines and subtle gradients, evoking a ((whimsical, nostalgic, dreamy lofi atmosphere:2.0)),
+(anime-inspired:1.6), calming, comforting, and visually serene.
+"""
+
+prompt_embeds, _, pooled_prompt_embeds, *_ = get_weighted_text_embeddings_sdxl(pipeline, prompt=prompt)
+```
+
+Pass the embeddings to `prompt_embeds` and `pooled_prompt_embeds` to generate your image.
+
+```py
+image = pipeline(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds).images[0]
```
-
+
-> [!TIP]
-> Refer to the [sd_embed](https://github.com/xhinker/sd_embed) repository for additional details about long prompt weighting for FLUX.1, Stable Cascade, and Stable Diffusion 1.5.
-
-### Textual inversion
-
-[Textual inversion](../training/text_inversion) is a technique for learning a specific concept from some images which you can use to generate new images conditioned on that concept.
-
-Create a pipeline and use the [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] function to load the textual inversion embeddings (feel free to browse the [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer) for 100+ trained concepts):
-
-```py
-import torch
-from diffusers import StableDiffusionPipeline
-
-pipe = StableDiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- torch_dtype=torch.float16,
-).to("cuda")
-pipe.load_textual_inversion("sd-concepts-library/midjourney-style")
-```
-
-Add the `` text to the prompt to trigger the textual inversion.
-
-```py
-from sd_embed.embedding_funcs import get_weighted_text_embeddings_sd15
-
-prompt = """ A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus.
-This imaginative creature features the distinctive, bulky body of a hippo,
-but with a texture and appearance resembling a golden-brown, crispy waffle.
-The creature might have elements like waffle squares across its skin and a syrup-like sheen.
-It's set in a surreal environment that playfully combines a natural water habitat of a hippo with elements of a breakfast table setting,
-possibly including oversized utensils or plates in the background.
-The image should evoke a sense of playful absurdity and culinary fantasy.
-"""
-
-neg_prompt = """\
-skin spots,acnes,skin blemishes,age spot,(ugly:1.2),(duplicate:1.2),(morbid:1.21),(mutilated:1.2),\
-(tranny:1.2),mutated hands,(poorly drawn hands:1.5),blurry,(bad anatomy:1.2),(bad proportions:1.3),\
-extra limbs,(disfigured:1.2),(missing arms:1.2),(extra legs:1.2),(fused fingers:1.5),\
-(too many fingers:1.5),(unclear eyes:1.2),lowers,bad hands,missing fingers,extra digit,\
-bad hands,missing fingers,(extra arms and legs),(worst quality:2),(low quality:2),\
-(normal quality:2),lowres,((monochrome)),((grayscale))
-"""
-```
-
-Use the `get_weighted_text_embeddings_sd15` function to generate the prompt embeddings and the negative prompt embeddings.
-
-```py
-(
- prompt_embeds,
- prompt_neg_embeds,
-) = get_weighted_text_embeddings_sd15(
- pipe,
- prompt=prompt,
- neg_prompt=neg_prompt
-)
-
-image = pipe(
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=prompt_neg_embeds,
- height=768,
- width=896,
- guidance_scale=4.0,
- generator=torch.Generator("cuda").manual_seed(2)
-).images[0]
-image
-```
-
-
-
-
-
-### DreamBooth
-
-[DreamBooth](../training/dreambooth) is a technique for generating contextualized images of a subject given just a few images of the subject to train on. It is similar to textual inversion, but DreamBooth trains the full model whereas textual inversion only fine-tunes the text embeddings. This means you should use [`~DiffusionPipeline.from_pretrained`] to load the DreamBooth model (feel free to browse the [Stable Diffusion Dreambooth Concepts Library](https://huggingface.co/sd-dreambooth-library) for 100+ trained models):
-
-```py
-import torch
-from diffusers import DiffusionPipeline, UniPCMultistepScheduler
-
-pipe = DiffusionPipeline.from_pretrained("sd-dreambooth-library/dndcoverart-v1", torch_dtype=torch.float16).to("cuda")
-pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
-```
-
-Depending on the model you use, you'll need to incorporate the model's unique identifier into your prompt. For example, the `dndcoverart-v1` model uses the identifier `dndcoverart`:
-
-```py
-from sd_embed.embedding_funcs import get_weighted_text_embeddings_sd15
-
-prompt = """dndcoverart of A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus.
-This imaginative creature features the distinctive, bulky body of a hippo,
-but with a texture and appearance resembling a golden-brown, crispy waffle.
-The creature might have elements like waffle squares across its skin and a syrup-like sheen.
-It's set in a surreal environment that playfully combines a natural water habitat of a hippo with elements of a breakfast table setting,
-possibly including oversized utensils or plates in the background.
-The image should evoke a sense of playful absurdity and culinary fantasy.
-"""
-
-neg_prompt = """\
-skin spots,acnes,skin blemishes,age spot,(ugly:1.2),(duplicate:1.2),(morbid:1.21),(mutilated:1.2),\
-(tranny:1.2),mutated hands,(poorly drawn hands:1.5),blurry,(bad anatomy:1.2),(bad proportions:1.3),\
-extra limbs,(disfigured:1.2),(missing arms:1.2),(extra legs:1.2),(fused fingers:1.5),\
-(too many fingers:1.5),(unclear eyes:1.2),lowers,bad hands,missing fingers,extra digit,\
-bad hands,missing fingers,(extra arms and legs),(worst quality:2),(low quality:2),\
-(normal quality:2),lowres,((monochrome)),((grayscale))
-"""
-
-(
- prompt_embeds
- , prompt_neg_embeds
-) = get_weighted_text_embeddings_sd15(
- pipe
- , prompt = prompt
- , neg_prompt = neg_prompt
-)
-```
-
-
-
-
+Prompt weighting works with [Textual inversion](./textual_inversion_inference) and [DreamBooth](./dreambooth) adapters too.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/write_own_pipeline.md b/docs/source/en/using-diffusers/write_own_pipeline.md
index 15a7e8dc7c..e34727b5da 100644
--- a/docs/source/en/using-diffusers/write_own_pipeline.md
+++ b/docs/source/en/using-diffusers/write_own_pipeline.md
@@ -110,11 +110,8 @@ Stable Diffusion is a text-to-image *latent diffusion* model. It is called a lat
As you can see, this is already more complex than the DDPM pipeline which only contains a UNet model. The Stable Diffusion model has three separate pretrained models.
-
-
-💡 Read the [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) blog for more details about how the VAE, UNet, and text encoder models work.
-
-
+> [!TIP]
+> 💡 Read the [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) blog for more details about how the VAE, UNet, and text encoder models work.
Now that you know what you need for the Stable Diffusion pipeline, load all these components with the [`~ModelMixin.from_pretrained`] method. You can find them in the pretrained [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint, and each component is stored in a separate subfolder:
@@ -155,11 +152,8 @@ To speed up inference, move the models to a GPU since, unlike the scheduler, the
The next step is to tokenize the text to generate embeddings. The text is used to condition the UNet model and steer the diffusion process towards something that resembles the input prompt.
-
-
-💡 The `guidance_scale` parameter determines how much weight should be given to the prompt when generating an image.
-
-
+> [!TIP]
+> 💡 The `guidance_scale` parameter determines how much weight should be given to the prompt when generating an image.
Feel free to choose any prompt you like if you want to generate something else!
@@ -202,15 +196,12 @@ Let's concatenate the conditional and unconditional embeddings into a batch to a
Next, generate some initial random noise as a starting point for the diffusion process. This is the latent representation of the image, and it'll be gradually denoised. At this point, the `latent` image is smaller than the final image size but that's okay though because the model will transform it into the final 512x512 image dimensions later.
-
-
-💡 The height and width are divided by 8 because the `vae` model has 3 down-sampling layers. You can check by running the following:
-
-```py
-2 ** (len(vae.config.block_out_channels) - 1) == 8
-```
-
-
+> [!TIP]
+> 💡 The height and width are divided by 8 because the `vae` model has 3 down-sampling layers. You can check by running the following:
+>
+> ```py
+> 2 ** (len(vae.config.block_out_channels) - 1) == 8
+> ```
```py
>>> latents = torch.randn(
@@ -289,5 +280,5 @@ This is really what 🧨 Diffusers is designed for: to make it intuitive and eas
For your next steps, feel free to:
-* Learn how to [build and contribute a pipeline](../using-diffusers/contribute_pipeline) to 🧨 Diffusers. We can't wait and see what you'll come up with!
+* Learn how to [build and contribute a pipeline](../conceptual/contribution) to 🧨 Diffusers. We can't wait and see what you'll come up with!
* Explore [existing pipelines](../api/pipelines/overview) in the library, and see if you can deconstruct and build a pipeline from scratch using the models and schedulers separately.
diff --git a/docs/source/ja/installation.md b/docs/source/ja/installation.md
index 97d60528c4..fd6f4eda0f 100644
--- a/docs/source/ja/installation.md
+++ b/docs/source/ja/installation.md
@@ -108,11 +108,8 @@ pip install -e ".[flax]"
Python は通常のライブラリパスに加えて、クローンしたフォルダの中を探すようになります。
例えば、Python パッケージが通常 `~/anaconda3/envs/main/lib/python3.10/site-packages/` にインストールされている場合、Python はクローンした `~/diffusers/` フォルダも同様に参照します。
-
-
-ライブラリを使い続けたい場合は、`diffusers`フォルダを残しておく必要があります。
-
-
+> [!WARNING]
+> ライブラリを使い続けたい場合は、`diffusers`フォルダを残しておく必要があります。
これで、以下のコマンドで簡単にクローンを最新版の🤗 Diffusersにアップデートできます:
diff --git a/docs/source/ja/quicktour.md b/docs/source/ja/quicktour.md
index 03b340b352..ce88aaf7b5 100644
--- a/docs/source/ja/quicktour.md
+++ b/docs/source/ja/quicktour.md
@@ -24,11 +24,8 @@ specific language governing permissions and limitations under the License.
この案内では、[`DiffusionPipeline`]を生成に使用する方法を紹介し、モデルとスケジューラを組み合わせて[`DiffusionPipeline`]の内部で起こっていることを再現する方法を説明します。
-
-
-この案内は🧨 Diffusers [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb)を簡略化したもので、すぐに使い始めることができます。Diffusers 🧨のゴール、設計哲学、コアAPIの詳細についてもっと知りたい方は、ノートブックをご覧ください!
-
-
+> [!TIP]
+> この案内は🧨 Diffusers [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb)を簡略化したもので、すぐに使い始めることができます。Diffusers 🧨のゴール、設計哲学、コアAPIの詳細についてもっと知りたい方は、ノートブックをご覧ください!
始める前に必要なライブラリーがすべてインストールされていることを確認してください:
@@ -56,11 +53,8 @@ specific language governing permissions and limitations under the License.
この[`DiffusionPipeline`]はHugging Face Hubに保存されている任意の[チェックポイント](https://huggingface.co/models?library=diffusers&sort=downloads)を使用することができます。
この案内では、[`stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)チェックポイントでテキストから画像へ生成します。
-
-
-[Stable Diffusion]モデルについては、モデルを実行する前にまず[ライセンス](https://huggingface.co/spaces/CompVis/stable-diffusion-license)を注意深くお読みください。🧨 Diffusers は、攻撃的または有害なコンテンツを防ぐために [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) を実装していますが、モデルの改良された画像生成機能により、潜在的に有害なコンテンツが生成される可能性があります。
-
-
+> [!WARNING]
+> [Stable Diffusion]モデルについては、モデルを実行する前にまず[ライセンス](https://huggingface.co/spaces/CompVis/stable-diffusion-license)を注意深くお読みください。🧨 Diffusers は、攻撃的または有害なコンテンツを防ぐために [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) を実装していますが、モデルの改良された画像生成機能により、潜在的に有害なコンテンツが生成される可能性があります。
モデルを[`~DiffusionPipeline.from_pretrained`]メソッドでロードします:
@@ -204,11 +198,8 @@ torch.Size([1, 3, 256, 256])
スケジューラは、モデルの出力(この場合は `noisy_residual` )が与えられたときに、ノイズの多いサンプルからノイズの少ないサンプルへの移行を管理します。
-
-
-🧨 Diffusersは拡散システムを構築するためのツールボックスです。[`DiffusionPipeline`]は事前に構築された拡散システムを使い始めるのに便利な方法ですが、独自のモデルとスケジューラコンポーネントを個別に選択してカスタム拡散システムを構築することもできます。
-
-
+> [!TIP]
+> 🧨 Diffusersは拡散システムを構築するためのツールボックスです。[`DiffusionPipeline`]は事前に構築された拡散システムを使い始めるのに便利な方法ですが、独自のモデルとスケジューラコンポーネントを個別に選択してカスタム拡散システムを構築することもできます。
この案内では、[`DDPMScheduler`]を[`~diffusers.ConfigMixin.from_config`]メソッドでインスタンス化します:
@@ -232,11 +223,8 @@ DDPMScheduler {
}
```
-
-
-💡 スケジューラがどのようにコンフィギュレーションからインスタンス化されるかに注目してください。モデルとは異なり、スケジューラは学習可能な重みを持たず、パラメーターを持ちません!
-
-
+> [!TIP]
+> 💡 スケジューラがどのようにコンフィギュレーションからインスタンス化されるかに注目してください。モデルとは異なり、スケジューラは学習可能な重みを持たず、パラメーターを持ちません!
最も重要なパラメータは以下の通りです:
diff --git a/docs/source/ja/stable_diffusion.md b/docs/source/ja/stable_diffusion.md
index 85f2b38a7d..79abfa005d 100644
--- a/docs/source/ja/stable_diffusion.md
+++ b/docs/source/ja/stable_diffusion.md
@@ -37,11 +37,8 @@ prompt = "portrait photo of a old warrior chief"
## Speed
-
-
-💡 GPUを利用できない場合は、[Colab](https://colab.research.google.com/)のようなGPUプロバイダーから無料で利用できます!
-
-
+> [!TIP]
+> 💡 GPUを利用できない場合は、[Colab](https://colab.research.google.com/)のようなGPUプロバイダーから無料で利用できます!
画像生成を高速化する最も簡単な方法の1つは、PyTorchモジュールと同じようにGPU上にパイプラインを配置することです:
@@ -88,11 +85,8 @@ image
今回、画像生成にかかった時間はわずか11秒で、以前より3倍近く速くなりました!
-
-
-💡 パイプラインは常に `float16` で実行することを強くお勧めします。
-
-
+> [!TIP]
+> 💡 パイプラインは常に `float16` で実行することを強くお勧めします。
生成ステップ数を減らすという方法もあります。より効率的なスケジューラを選択することで、出力品質を犠牲にすることなくステップ数を減らすことができます。`compatibles`メソッドを呼び出すことで、[`DiffusionPipeline`]の現在のモデルと互換性のあるスケジューラを見つけることができます:
diff --git a/docs/source/ja/tutorials/autopipeline.md b/docs/source/ja/tutorials/autopipeline.md
index a9a780186a..7dc678da90 100644
--- a/docs/source/ja/tutorials/autopipeline.md
+++ b/docs/source/ja/tutorials/autopipeline.md
@@ -16,11 +16,8 @@ Diffusersは様々なタスクをこなすことができ、テキストから
`AutoPipeline` クラスは、🤗 Diffusers の様々なパイプラインをよりシンプルするために設計されています。この汎用的でタスク重視のパイプラインによってタスクそのものに集中することができます。`AutoPipeline` は、使用するべき正しいパイプラインクラスを自動的に検出するため、特定のパイプラインクラス名を知らなくても、タスクのチェックポイントを簡単にロードできます。
-
-
-どのタスクがサポートされているかは、[AutoPipeline](../api/pipelines/auto_pipeline) のリファレンスをご覧ください。現在、text-to-image、image-to-image、inpaintingをサポートしています。
-
-
+> [!TIP]
+> どのタスクがサポートされているかは、[AutoPipeline](../api/pipelines/auto_pipeline) のリファレンスをご覧ください。現在、text-to-image、image-to-image、inpaintingをサポートしています。
このチュートリアルでは、`AutoPipeline` を使用して、事前に学習された重みが与えられたときに、特定のタスクを読み込むためのパイプラインクラスを自動的に推測する方法を示します。
diff --git a/docs/source/ko/api/pipelines/stable_diffusion/stable_diffusion_xl.md b/docs/source/ko/api/pipelines/stable_diffusion/stable_diffusion_xl.md
index 34a00d63fe..ba85b4a855 100644
--- a/docs/source/ko/api/pipelines/stable_diffusion/stable_diffusion_xl.md
+++ b/docs/source/ko/api/pipelines/stable_diffusion/stable_diffusion_xl.md
@@ -207,11 +207,8 @@ image = refiner(
동일한 40 단계에서 base 모델을 실행한다면, 이미지의 디테일(예: 사자의 눈과 코)이 떨어졌을 것입니다:
-
-
-앙상블 방식은 사용 가능한 모든 스케줄러에서 잘 작동합니다!
-
-
+> [!TIP]
+> 앙상블 방식은 사용 가능한 모든 스케줄러에서 잘 작동합니다!
#### 2.) 노이즈가 완전히 제거된 기본 이미지에서 이미지 출력을 정제하기
@@ -248,11 +245,8 @@ image = refiner(prompt=prompt, image=image[None, :]).images[0]
|---|---|
|  |  |
-
-
-refiner는 또한 인페인팅 설정에 잘 사용될 수 있습니다. 아래에 보여지듯이 [`StableDiffusionXLInpaintPipeline`] 클래스를 사용해서 만들어보세요.
-
-
+> [!TIP]
+> refiner는 또한 인페인팅 설정에 잘 사용될 수 있습니다. 아래에 보여지듯이 [`StableDiffusionXLInpaintPipeline`] 클래스를 사용해서 만들어보세요.
Denoiser 앙상블 설정에서 인페인팅에 refiner를 사용하려면 다음을 수행하면 됩니다:
diff --git a/docs/source/ko/conceptual/ethical_guidelines.md b/docs/source/ko/conceptual/ethical_guidelines.md
index b8c55048bf..63fc4a7741 100644
--- a/docs/source/ko/conceptual/ethical_guidelines.md
+++ b/docs/source/ko/conceptual/ethical_guidelines.md
@@ -14,51 +14,47 @@ specific language governing permissions and limitations under the License.
## 서문 [[preamble]]
-[Diffusers](https://huggingface.co/docs/diffusers/index)는 사전 훈련된 diffusion 모델을 제공하며 추론 및 훈련을 위한 모듈식 툴박스로 사용됩니다.
+[Diffusers](https://huggingface.co/docs/diffusers/index)는 사전 훈련된 diffusion 모델을 제공하며, 추론과 훈련을 위한 모듈형 툴박스로 활용됩니다.
-이 기술의 실제 적용과 사회에 미칠 수 있는 부정적인 영향을 고려하여 Diffusers 라이브러리의 개발, 사용자 기여 및 사용에 윤리 지침을 제공하는 것이 중요하다고 생각합니다.
-
-이이 기술을 사용함에 따른 위험은 여전히 검토 중이지만, 몇 가지 예를 들면: 예술가들에 대한 저작권 문제; 딥 페이크의 악용; 부적절한 맥락에서의 성적 콘텐츠 생성; 동의 없는 사칭; 소수자 집단의 억압을 영속화하는 유해한 사회적 편견 등이 있습니다.
-
-우리는 위험을 지속적으로 추적하고 커뮤니티의 응답과 소중한 피드백에 따라 다음 지침을 조정할 것입니다.
+이 기술의 실제 적용 사례와 사회에 미칠 수 있는 잠재적 부정적 영향을 고려할 때, Diffusers 라이브러리의 개발, 사용자 기여, 사용에 윤리 지침을 제공하는 것이 중요하다고 생각합니다.
+이 기술 사용과 관련된 위험은 여전히 검토 중이지만, 예를 들면: 예술가의 저작권 문제, 딥페이크 악용, 부적절한 맥락에서의 성적 콘텐츠 생성, 비동의 사칭, 소수자 집단 억압을 영속화하는 유해한 사회적 편견 등이 있습니다.
+우리는 이러한 위험을 지속적으로 추적하고, 커뮤니티의 반응과 소중한 피드백에 따라 아래 지침을 조정할 것입니다.
## 범위 [[scope]]
-Diffusers 커뮤니티는 프로젝트의 개발에 다음과 같은 윤리 지침을 적용하며, 특히 윤리적 문제와 관련된 민감한 주제에 대한 커뮤니티의 기여를 조정하는 데 도움을 줄 것입니다.
-
+Diffusers 커뮤니티는 프로젝트 개발에 다음 윤리 지침을 적용하며, 특히 윤리적 문제와 관련된 민감한 주제에 대해 커뮤니티의 기여를 조정하는 데 도움을 줄 것입니다.
## 윤리 지침 [[ethical-guidelines]]
-다음 윤리 지침은 일반적으로 적용되지만, 민감한 윤리적 문제와 관련하여 기술적 선택을 할 때 이를 우선적으로 적용할 것입니다. 나아가, 해당 기술의 최신 동향과 관련된 새로운 위험이 발생함에 따라 이러한 윤리 원칙을 조정할 것을 약속드립니다.
+다음 윤리 지침은 일반적으로 적용되지만, 윤리적으로 민감한 문제와 관련된 기술적 선택을 할 때 우선적으로 적용됩니다. 또한, 해당 기술의 최신 동향과 관련된 새로운 위험이 발생함에 따라 이러한 윤리 원칙을 지속적으로 조정할 것을 약속합니다.
-- **투명성**: 우리는 PR을 관리하고, 사용자에게 우리의 선택을 설명하며, 기술적 의사결정을 내릴 때 투명성을 유지할 것을 약속합니다.
+- **투명성**: 우리는 PR 관리, 사용자에게 선택의 이유 설명, 기술적 의사결정 과정에서 투명성을 유지할 것을 약속합니다.
-- **일관성**: 우리는 프로젝트 관리에서 사용자들에게 동일한 수준의 관심을 보장하고 기술적으로 안정되고 일관된 상태를 유지할 것을 약속합니다.
+- **일관성**: 프로젝트 관리에서 모든 사용자에게 동일한 수준의 관심을 보장하고, 기술적으로 안정적이고 일관된 상태를 유지할 것을 약속합니다.
-- **간결성**: Diffusers 라이브러리를 사용하고 활용하기 쉽게 만들기 위해, 프로젝트의 목표를 간결하고 일관성 있게 유지할 것을 약속합니다.
+- **간결성**: Diffusers 라이브러리를 쉽게 사용하고 활용할 수 있도록, 프로젝트의 목표를 간결하고 일관성 있게 유지할 것을 약속합니다.
-- **접근성**: Diffusers 프로젝트는 기술적 전문 지식 없어도 프로젝트 운영에 참여할 수 있는 기여자의 진입장벽을 낮춥니다. 이를 통해 연구 결과물이 커뮤니티에 더 잘 접근할 수 있게 됩니다.
+- **접근성**: Diffusers 프로젝트는 기술적 전문지식이 없어도 기여할 수 있도록 진입장벽을 낮춥니다. 이를 통해 연구 결과물이 커뮤니티에 더 잘 접근될 수 있습니다.
-- **재현성**: 우리는 Diffusers 라이브러리를 통해 제공되는 업스트림(upstream) 코드, 모델 및 데이터셋의 재현성에 대해 투명하게 공개할 것을 목표로 합니다.
-
-- **책임**: 우리는 커뮤니티와 팀워크를 통해, 이 기술의 잠재적인 위험과 위험을 예측하고 완화하는 데 대한 공동 책임을 가지고 있습니다.
+- **재현성**: 우리는 Diffusers 라이브러리를 통해 제공되는 업스트림 코드, 모델, 데이터셋의 재현성에 대해 투명하게 공개하는 것을 목표로 합니다.
+- **책임**: 커뮤니티와 팀워크를 통해, 이 기술의 잠재적 위험을 예측하고 완화하는 데 공동 책임을 집니다.
## 구현 사례: 안전 기능과 메커니즘 [[examples-of-implementations-safety-features-and-mechanisms]]
-팀은 diffusion 기술과 관련된 잠재적인 윤리 및 사회적 위험에 대처하기 위한 기술적 및 비기술적 도구를 제공하고자 하고 있습니다. 또한, 커뮤니티의 참여는 이러한 기능의 구현하고 우리와 함께 인식을 높이는 데 매우 중요합니다.
+팀은 diffusion 기술과 관련된 잠재적 윤리 및 사회적 위험에 대응하기 위해 기술적·비기술적 도구를 제공하고자 노력하고 있습니다. 또한, 커뮤니티의 참여는 이러한 기능 구현과 인식 제고에 매우 중요합니다.
-- [**커뮤니티 탭**](https://huggingface.co/docs/hub/repositories-pull-requests-discussions): 이를 통해 커뮤니티는 프로젝트에 대해 토론하고 더 나은 협력을 할 수 있습니다.
+- [**커뮤니티 탭**](https://huggingface.co/docs/hub/repositories-pull-requests-discussions): 커뮤니티가 프로젝트에 대해 토론하고 더 나은 협업을 할 수 있도록 지원합니다.
-- **편향 탐색 및 평가**: Hugging Face 팀은 Stable Diffusion 모델의 편향성을 대화형으로 보여주는 [space](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer)을 제공합니다. 이런 의미에서, 우리는 편향 탐색 및 평가를 지원하고 장려합니다.
+- **편향 탐색 및 평가**: Hugging Face 팀은 Stable Diffusion 모델의 편향성을 대화형으로 보여주는 [space](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer)를 제공합니다. 우리는 이러한 편향 탐색과 평가를 지원하고 장려합니다.
- **배포에서의 안전 유도**
- - [**안전한 Stable Diffusion**](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_safe): 이는 필터되지 않은 웹 크롤링 데이터셋으로 훈련된 Stable Diffusion과 같은 모델이 부적절한 변질에 취약한 문제를 완화합니다. 관련 논문: [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://huggingface.co/papers/2211.05105).
+ - [**안전한 Stable Diffusion**](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_safe): 필터링되지 않은 웹 크롤링 데이터셋으로 훈련된 Stable Diffusion과 같은 모델이 부적절하게 변질되는 문제를 완화합니다. 관련 논문: [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://huggingface.co/papers/2211.05105).
- - [**안전 검사기**](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py): 이미지가 생성된 후에 이미자가 임베딩 공간에서 일련의 하드코딩된 유해 개념의 클래스일 확률을 확인하고 비교합니다. 유해 개념은 역공학을 방지하기 위해 의도적으로 숨겨져 있습니다.
+ - [**안전 검사기**](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py): 생성된 이미지가 임베딩 공간에서 하드코딩된 유해 개념 클래스와 일치할 확률을 확인하고 비교합니다. 유해 개념은 역공학을 방지하기 위해 의도적으로 숨겨져 있습니다.
-- **Hub에서의 단계적인 배포**: 특히 민감한 상황에서는 일부 리포지토리에 대한 접근을 제한해야 합니다. 이 단계적인 배포는 중간 단계로, 리포지토리 작성자가 사용에 대한 더 많은 통제력을 갖게 합니다.
+- **Hub에서의 단계적 배포**: 특히 민감한 상황에서는 일부 리포지토리에 대한 접근을 제한할 수 있습니다. 단계적 배포는 리포지토리 작성자가 사용에 대해 더 많은 통제권을 갖도록 하는 중간 단계입니다.
-- **라이선싱**: [OpenRAILs](https://huggingface.co/blog/open_rail)와 같은 새로운 유형의 라이선싱을 통해 자유로운 접근을 보장하면서도 더 책임 있는 사용을 위한 일련의 제한을 둘 수 있습니다.
+- **라이선싱**: [OpenRAILs](https://huggingface.co/blog/open_rail)와 같은 새로운 유형의 라이선스를 통해 자유로운 접근을 보장하면서도 보다 책임 있는 사용을 위한 일련의 제한을 둘 수 있습니다.
diff --git a/docs/source/ko/conceptual/evaluation.md b/docs/source/ko/conceptual/evaluation.md
index 2d296420bc..731b511485 100644
--- a/docs/source/ko/conceptual/evaluation.md
+++ b/docs/source/ko/conceptual/evaluation.md
@@ -95,11 +95,8 @@ images = sd_pipeline(sample_prompts, num_images_per_prompt=1, generator=generato
다양한 모델을 사용하여 모든 프롬프트에서 생성된 여러 이미지들이 생성되면 (평가 과정에서) 이러한 결과물들은 사람 평가자들에게 점수를 매기기 위해 제시됩니다. DrawBench와 PartiPrompts 벤치마크에 대한 자세한 내용은 각각의 논문을 참조하십시오.
-
-
-모델이 훈련 중일 때 추론 샘플을 살펴보는 것은 훈련 진행 상황을 측정하는 데 유용합니다. [훈련 스크립트](https://github.com/huggingface/diffusers/tree/main/examples/)에서는 TensorBoard와 Weights & Biases에 대한 추가 지원과 함께 이 유틸리티를 지원합니다.
-
-
+> [!TIP]
+> 모델이 훈련 중일 때 추론 샘플을 살펴보는 것은 훈련 진행 상황을 측정하는 데 유용합니다. [훈련 스크립트](https://github.com/huggingface/diffusers/tree/main/examples/)에서는 TensorBoard와 Weights & Biases에 대한 추가 지원과 함께 이 유틸리티를 지원합니다.
## 정량적 평가[[quantitative-evaluation]]
@@ -193,11 +190,8 @@ print(f"CLIP Score with v-1-5: {sd_clip_score_1_5}")
[v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) 체크포인트가 이전 버전보다 더 나은 성능을 보이는 것 같습니다. 그러나 CLIP 점수를 계산하기 위해 사용한 프롬프트의 수가 상당히 적습니다. 보다 실용적인 평가를 위해서는 이 수를 훨씬 높게 설정하고, 프롬프트를 다양하게 사용해야 합니다.
-
-
-이 점수에는 몇 가지 제한 사항이 있습니다. 훈련 데이터셋의 캡션은 웹에서 크롤링되어 이미지와 관련된 `alt` 및 유사한 태그에서 추출되었습니다. 이들은 인간이 이미지를 설명하는 데 사용할 수 있는 것과 일치하지 않을 수 있습니다. 따라서 여기서는 몇 가지 프롬프트를 "엔지니어링"해야 했습니다.
-
-
+> [!WARNING]
+> 이 점수에는 몇 가지 제한 사항이 있습니다. 훈련 데이터셋의 캡션은 웹에서 크롤링되어 이미지와 관련된 `alt` 및 유사한 태그에서 추출되었습니다. 이들은 인간이 이미지를 설명하는 데 사용할 수 있는 것과 일치하지 않을 수 있습니다. 따라서 여기서는 몇 가지 프롬프트를 "엔지니어링"해야 했습니다.
### 이미지 조건화된 텍스트-이미지 생성[[image-conditioned-text-to-image-generation]]
@@ -405,11 +399,8 @@ CLIP 점수와 마찬가지로, CLIP 방향 유사성이 높을수록 좋습니
[`StableDiffusionPix2PixZeroPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/pix2pix_zero#diffusers.StableDiffusionPix2PixZeroPipeline)와 같은 유사한 파이프라인에도 이러한 메트릭을 사용할 수 있습니다.
-
-
-CLIP 점수와 CLIP 방향 유사성 모두 CLIP 모델에 의존하기 때문에 평가가 편향될 수 있습니다
-
-
+> [!TIP]
+> CLIP 점수와 CLIP 방향 유사성 모두 CLIP 모델에 의존하기 때문에 평가가 편향될 수 있습니다
***IS, FID (나중에 설명할 예정), 또는 KID와 같은 메트릭을 확장하는 것은 어려울 수 있습니다***. 평가 중인 모델이 대규모 이미지 캡셔닝 데이터셋 (예: [LAION-5B 데이터셋](https://laion.ai/blog/laion-5b/))에서 사전 훈련되었을 때 이는 문제가 될 수 있습니다. 왜냐하면 이러한 메트릭의 기반에는 중간 이미지 특징을 추출하기 위해 ImageNet-1k 데이터셋에서 사전 훈련된 InceptionNet이 사용되기 때문입니다. Stable Diffusion의 사전 훈련 데이터셋은 InceptionNet의 사전 훈련 데이터셋과 겹치는 부분이 제한적일 수 있으므로 따라서 여기에는 좋은 후보가 아닙니다.
@@ -532,19 +523,16 @@ FID는 낮을수록 좋습니다. 여러 가지 요소가 FID에 영향을 줄
마지막 두 가지 요소에 대해서는, 다른 시드와 추론 단계에서 평가를 실행하고 평균 결과를 보고하는 것은 좋은 실천 방법입니다
-
-
-FID 결과는 많은 요소에 의존하기 때문에 취약할 수 있습니다:
-
-* 계산 중 사용되는 특정 Inception 모델.
-* 계산의 구현 정확도.
-* 이미지 형식 (PNG 또는 JPG에서 시작하는 경우가 다릅니다).
-
-이러한 사항을 염두에 두면, FID는 유사한 실행을 비교할 때 가장 유용하지만, 저자가 FID 측정 코드를 주의 깊게 공개하지 않는 한 논문 결과를 재현하기는 어렵습니다.
-
-이러한 사항은 KID 및 IS와 같은 다른 관련 메트릭에도 적용됩니다.
-
-
+> [!WARNING]
+> FID 결과는 많은 요소에 의존하기 때문에 취약할 수 있습니다:
+>
+> * 계산 중 사용되는 특정 Inception 모델.
+> * 계산의 구현 정확도.
+> * 이미지 형식 (PNG 또는 JPG에서 시작하는 경우가 다릅니다).
+>
+> 이러한 사항을 염두에 두면, FID는 유사한 실행을 비교할 때 가장 유용하지만, 저자가 FID 측정 코드를 주의 깊게 공개하지 않는 한 논문 결과를 재현하기는 어렵습니다.
+>
+> 이러한 사항은 KID 및 IS와 같은 다른 관련 메트릭에도 적용됩니다.
마지막 단계로, `fake_images`를 시각적으로 검사해 봅시다.
diff --git a/docs/source/ko/installation.md b/docs/source/ko/installation.md
index c03b464290..198ca4b7c7 100644
--- a/docs/source/ko/installation.md
+++ b/docs/source/ko/installation.md
@@ -107,11 +107,8 @@ pip install -e ".[flax]"
Python은 이제 일반 라이브러리 경로에 더하여 복제한 폴더 내부를 살펴봅니다.
예를들어 Python 패키지가 `~/anaconda3/envs/main/lib/python3.10/site-packages/`에 설치되어 있는 경우 Python은 복제한 폴더인 `~/diffusers/`도 검색합니다.
-
-
-라이브러리를 계속 사용하려면 `diffusers` 폴더를 유지해야 합니다.
-
-
+> [!WARNING]
+> 라이브러리를 계속 사용하려면 `diffusers` 폴더를 유지해야 합니다.
이제 다음 명령어를 사용하여 최신 버전의 🤗 Diffusers로 쉽게 업데이트할 수 있습니다:
diff --git a/docs/source/ko/optimization/coreml.md b/docs/source/ko/optimization/coreml.md
index 60f19fd2c3..73ca851177 100644
--- a/docs/source/ko/optimization/coreml.md
+++ b/docs/source/ko/optimization/coreml.md
@@ -16,11 +16,8 @@ specific language governing permissions and limitations under the License.
Core ML 모델은 Apple 기기에서 사용할 수 있는 모든 컴퓨팅 엔진들, 즉 CPU, GPU, Apple Neural Engine(또는 Apple Silicon Mac 및 최신 iPhone/iPad에서 사용할 수 있는 텐서 최적화 가속기인 ANE)을 활용할 수 있습니다. 모델과 실행 중인 기기에 따라 Core ML은 컴퓨팅 엔진도 혼합하여 사용할 수 있으므로, 예를 들어 모델의 일부가 CPU에서 실행되는 반면 다른 부분은 GPU에서 실행될 수 있습니다.
-
-
-PyTorch에 내장된 `mps` 가속기를 사용하여 Apple Silicon Macs에서 `diffusers` Python 코드베이스를 실행할 수도 있습니다. 이 방법은 [mps 가이드]에 자세히 설명되어 있지만 네이티브 앱과 호환되지 않습니다.
-
-
+> [!TIP]
+> PyTorch에 내장된 `mps` 가속기를 사용하여 Apple Silicon Macs에서 `diffusers` Python 코드베이스를 실행할 수도 있습니다. 이 방법은 [mps 가이드]에 자세히 설명되어 있지만 네이티브 앱과 호환되지 않습니다.
## Stable Diffusion Core ML 체크포인트
diff --git a/docs/source/ko/optimization/fp16.md b/docs/source/ko/optimization/fp16.md
index db0370875e..56f1330c40 100644
--- a/docs/source/ko/optimization/fp16.md
+++ b/docs/source/ko/optimization/fp16.md
@@ -74,18 +74,16 @@ prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
```
-
- 어떤 파이프라인에서도 [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast) 를 사용하는 것은 검은색 이미지를 생성할 수 있고, 순수한 float16 정밀도를 사용하는 것보다 항상 느리기 때문에 사용하지 않는 것이 좋습니다.
-
+> [!WARNING]
+> 어떤 파이프라인에서도 [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast) 를 사용하는 것은 검은색 이미지를 생성할 수 있고, 순수한 float16 정밀도를 사용하는 것보다 항상 느리기 때문에 사용하지 않는 것이 좋습니다.
## 추가 메모리 절약을 위한 슬라이스 어텐션
추가 메모리 절약을 위해, 한 번에 모두 계산하는 대신 단계적으로 계산을 수행하는 슬라이스 버전의 어텐션(attention)을 사용할 수 있습니다.
-
- Attention slicing은 모델이 하나 이상의 어텐션 헤드를 사용하는 한, 배치 크기가 1인 경우에도 유용합니다.
- 하나 이상의 어텐션 헤드가 있는 경우 *QK^T* 어텐션 매트릭스는 상당한 양의 메모리를 절약할 수 있는 각 헤드에 대해 순차적으로 계산될 수 있습니다.
-
+> [!TIP]
+> Attention slicing은 모델이 하나 이상의 어텐션 헤드를 사용하는 한, 배치 크기가 1인 경우에도 유용합니다.
+> 하나 이상의 어텐션 헤드가 있는 경우 *QK^T* 어텐션 매트릭스는 상당한 양의 메모리를 절약할 수 있는 각 헤드에 대해 순차적으로 계산될 수 있습니다.
각 헤드에 대해 순차적으로 어텐션 계산을 수행하려면, 다음과 같이 추론 전에 파이프라인에서 [`~StableDiffusionPipeline.enable_attention_slicing`]를 호출하면 됩니다:
@@ -161,9 +159,8 @@ image = pipe(prompt).images[0]
참고로 이 방법은 전체 모델이 아닌 서브모듈 수준에서 작동합니다. 이는 메모리 소비를 최소화하는 가장 좋은 방법이지만 프로세스의 반복적 특성으로 인해 추론 속도가 훨씬 느립니다. 파이프라인의 UNet 구성 요소는 여러 번 실행됩니다('num_inference_steps' 만큼). 매번 UNet의 서로 다른 서브모듈이 순차적으로 온로드된 다음 필요에 따라 오프로드되므로 메모리 이동 횟수가 많습니다.
-
-또 다른 최적화 방법인 모델 오프로딩을 사용하는 것을 고려하십시오. 이는 훨씬 빠르지만 메모리 절약이 크지는 않습니다.
-
+> [!TIP]
+> 또 다른 최적화 방법인 모델 오프로딩을 사용하는 것을 고려하십시오. 이는 훨씬 빠르지만 메모리 절약이 크지는 않습니다.
또한 ttention slicing과 연결해서 최소 메모리(< 2GB)로도 동작할 수 있습니다.
@@ -231,9 +228,8 @@ pipe.enable_attention_slicing(1)
image = pipe(prompt).images[0]
```
-
-이 기능을 사용하려면 'accelerate' 버전 0.17.0 이상이 필요합니다.
-
+> [!TIP]
+> 이 기능을 사용하려면 'accelerate' 버전 0.17.0 이상이 필요합니다.
## Channels Last 메모리 형식 사용하기
diff --git a/docs/source/ko/optimization/mps.md b/docs/source/ko/optimization/mps.md
index 4daeaf5dba..004374c4af 100644
--- a/docs/source/ko/optimization/mps.md
+++ b/docs/source/ko/optimization/mps.md
@@ -27,11 +27,8 @@ Diffusers는 Stable Diffusion 추론을 위해 PyTorch `mps`를 사용해 Apple
아래 코도는 익숙한 `to()` 인터페이스를 사용하여 `mps` 백엔드로 Stable Diffusion 파이프라인을 M1 또는 M2 장치로 이동하는 방법을 보여줍니다.
-
-
-**PyTorch 1.13을 사용 중일 때 ** 추가 일회성 전달을 사용하여 파이프라인을 "프라이밍"하는 것을 추천합니다. 이것은 발견한 이상한 문제에 대한 임시 해결 방법입니다. 첫 번째 추론 전달은 후속 전달와 약간 다른 결과를 생성합니다. 이 전달은 한 번만 수행하면 되며 추론 단계를 한 번만 사용하고 결과를 폐기해도 됩니다.
-
-
+> [!WARNING]
+> **PyTorch 1.13을 사용 중일 때 ** 추가 일회성 전달을 사용하여 파이프라인을 "프라이밍"하는 것을 추천합니다. 이것은 발견한 이상한 문제에 대한 임시 해결 방법입니다. 첫 번째 추론 전달은 후속 전달와 약간 다른 결과를 생성합니다. 이 전달은 한 번만 수행하면 되며 추론 단계를 한 번만 사용하고 결과를 폐기해도 됩니다.
이전 팁에서 설명한 것들을 포함한 여러 문제를 해결하므로 PyTorch 2 이상을 사용하는 것이 좋습니다.
diff --git a/docs/source/ko/optimization/xformers.md b/docs/source/ko/optimization/xformers.md
index 3e4d107c0a..96fab34acf 100644
--- a/docs/source/ko/optimization/xformers.md
+++ b/docs/source/ko/optimization/xformers.md
@@ -21,16 +21,10 @@ specific language governing permissions and limitations under the License.
pip install xformers
```
-
-
-xFormers PIP 패키지에는 최신 버전의 PyTorch(xFormers 0.0.16에 1.13.1)가 필요합니다. 이전 버전의 PyTorch를 사용해야 하는 경우 [프로젝트 지침](https://github.com/facebookresearch/xformers#installing-xformers)의 소스를 사용해 xFormers를 설치하는 것이 좋습니다.
-
-
+> [!TIP]
+> xFormers PIP 패키지에는 최신 버전의 PyTorch(xFormers 0.0.16에 1.13.1)가 필요합니다. 이전 버전의 PyTorch를 사용해야 하는 경우 [프로젝트 지침](https://github.com/facebookresearch/xformers#installing-xformers)의 소스를 사용해 xFormers를 설치하는 것이 좋습니다.
xFormers를 설치하면, [여기](fp16#memory-efficient-attention)서 설명한 것처럼 'enable_xformers_memory_efficient_attention()'을 사용하여 추론 속도를 높이고 메모리 소비를 줄일 수 있습니다.
-
-
-[이 이슈](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212)에 따르면 xFormers `v0.0.16`에서 GPU를 사용한 학습(파인 튜닝 또는 Dreambooth)을 할 수 없습니다. 해당 문제가 발견되면. 해당 코멘트를 참고해 development 버전을 설치하세요.
-
-
+> [!WARNING]
+> [이 이슈](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212)에 따르면 xFormers `v0.0.16`에서 GPU를 사용한 학습(파인 튜닝 또는 Dreambooth)을 할 수 없습니다. 해당 문제가 발견되면. 해당 코멘트를 참고해 development 버전을 설치하세요.
diff --git a/docs/source/ko/quicktour.md b/docs/source/ko/quicktour.md
index 58ebb8960f..0a3cd0f7c4 100644
--- a/docs/source/ko/quicktour.md
+++ b/docs/source/ko/quicktour.md
@@ -23,11 +23,8 @@ Diffusion 모델은 이미지나 오디오와 같은 관심 샘플들을 생성
훑어보기에서는 추론을 위해 [`DiffusionPipeline`]을 사용하는 방법을 보여준 다음, 모델과 스케줄러를 결합하여 [`DiffusionPipeline`] 내부에서 일어나는 일을 복제하는 방법을 안내합니다.
-
-
-훑어보기는 간결한 버전의 🧨 Diffusers 소개로서 [노트북](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb) 빠르게 시작할 수 있도록 도와드립니다. 디퓨저의 목표, 디자인 철학, 핵심 API에 대한 추가 세부 정보를 자세히 알아보려면 노트북을 확인하세요!
-
-
+> [!TIP]
+> 훑어보기는 간결한 버전의 🧨 Diffusers 소개로서 [노트북](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb) 빠르게 시작할 수 있도록 도와드립니다. 디퓨저의 목표, 디자인 철학, 핵심 API에 대한 추가 세부 정보를 자세히 알아보려면 노트북을 확인하세요!
시작하기 전에 필요한 라이브러리가 모두 설치되어 있는지 확인하세요:
@@ -55,11 +52,8 @@ Diffusion 모델은 이미지나 오디오와 같은 관심 샘플들을 생성
허깅페이스 허브에 저장된 모든 [checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads)에 대해 [`DiffusionPipeline`]을 사용할 수 있습니다.
이 훑어보기에서는 text-to-image 생성을 위한 [`stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) 체크포인트를 로드합니다.
-
-
-[Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) 모델의 경우, 모델을 실행하기 전에 [라이선스](https://huggingface.co/spaces/CompVis/stable-diffusion-license)를 먼저 주의 깊게 읽어주세요. 🧨 Diffusers는 불쾌하거나 유해한 콘텐츠를 방지하기 위해 [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py)를 구현하고 있지만, 모델의 향상된 이미지 생성 기능으로 인해 여전히 잠재적으로 유해한 콘텐츠가 생성될 수 있습니다.
-
-
+> [!WARNING]
+> [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) 모델의 경우, 모델을 실행하기 전에 [라이선스](https://huggingface.co/spaces/CompVis/stable-diffusion-license)를 먼저 주의 깊게 읽어주세요. 🧨 Diffusers는 불쾌하거나 유해한 콘텐츠를 방지하기 위해 [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py)를 구현하고 있지만, 모델의 향상된 이미지 생성 기능으로 인해 여전히 잠재적으로 유해한 콘텐츠가 생성될 수 있습니다.
[`~DiffusionPipeline.from_pretrained`] 방법으로 모델 로드하기:
@@ -203,11 +197,8 @@ torch.Size([1, 3, 256, 256])
스케줄러는 모델 출력이 주어졌을 때 노이즈가 많은 샘플에서 노이즈가 적은 샘플로 전환하는 것을 관리합니다 - 이 경우 'noisy_residual'.
-
-
-🧨 Diffusers는 Diffusion 시스템을 구축하기 위한 툴박스입니다. [`DiffusionPipeline`]을 사용하면 미리 만들어진 Diffusion 시스템을 편리하게 시작할 수 있지만, 모델과 스케줄러 구성 요소를 개별적으로 선택하여 사용자 지정 Diffusion 시스템을 구축할 수도 있습니다.
-
-
+> [!TIP]
+> 🧨 Diffusers는 Diffusion 시스템을 구축하기 위한 툴박스입니다. [`DiffusionPipeline`]을 사용하면 미리 만들어진 Diffusion 시스템을 편리하게 시작할 수 있지만, 모델과 스케줄러 구성 요소를 개별적으로 선택하여 사용자 지정 Diffusion 시스템을 구축할 수도 있습니다.
훑어보기의 경우, [`~diffusers.ConfigMixin.from_config`] 메서드를 사용하여 [`DDPMScheduler`]를 인스턴스화합니다:
@@ -231,11 +222,8 @@ DDPMScheduler {
}
```
-
-
-💡 스케줄러가 구성에서 어떻게 인스턴스화되는지 주목하세요. 모델과 달리 스케줄러에는 학습 가능한 가중치가 없으며 매개변수도 없습니다!
-
-
+> [!TIP]
+> 💡 스케줄러가 구성에서 어떻게 인스턴스화되는지 주목하세요. 모델과 달리 스케줄러에는 학습 가능한 가중치가 없으며 매개변수도 없습니다!
가장 중요한 매개변수는 다음과 같습니다:
diff --git a/docs/source/ko/stable_diffusion.md b/docs/source/ko/stable_diffusion.md
index 794bdf9c66..0f61e16d2a 100644
--- a/docs/source/ko/stable_diffusion.md
+++ b/docs/source/ko/stable_diffusion.md
@@ -37,11 +37,8 @@ prompt = "portrait photo of a old warrior chief"
## 속도
-
-
-💡 GPU에 액세스할 수 없는 경우 다음과 같은 GPU 제공업체에서 무료로 사용할 수 있습니다!. [Colab](https://colab.research.google.com/)
-
-
+> [!TIP]
+> 💡 GPU에 액세스할 수 없는 경우 다음과 같은 GPU 제공업체에서 무료로 사용할 수 있습니다!. [Colab](https://colab.research.google.com/)
추론 속도를 높이는 가장 간단한 방법 중 하나는 Pytorch 모듈을 사용할 때와 같은 방식으로 GPU에 파이프라인을 배치하는 것입니다:
@@ -89,11 +86,8 @@ image
이번에는 이미지를 생성하는 데 약 11초밖에 걸리지 않아 이전보다 3배 가까이 빨라졌습니다!
-
-
-💡 파이프라인은 항상 `float16`에서 실행할 것을 강력히 권장하며, 지금까지 출력 품질이 저하되는 경우는 거의 없었습니다.
-
-
+> [!TIP]
+> 💡 파이프라인은 항상 `float16`에서 실행할 것을 강력히 권장하며, 지금까지 출력 품질이 저하되는 경우는 거의 없었습니다.
또 다른 옵션은 추론 단계의 수를 줄이는 것입니다. 보다 효율적인 스케줄러를 선택하면 출력 품질 저하 없이 단계 수를 줄이는 데 도움이 될 수 있습니다. 현재 모델과 호환되는 스케줄러는 `compatibles` 메서드를 호출하여 [`DiffusionPipeline`]에서 찾을 수 있습니다:
diff --git a/docs/source/ko/training/controlnet.md b/docs/source/ko/training/controlnet.md
index 434ca959bd..e868b57c55 100644
--- a/docs/source/ko/training/controlnet.md
+++ b/docs/source/ko/training/controlnet.md
@@ -20,11 +20,8 @@ specific language governing permissions and limitations under the License.
아래의 스크립트를 실행하기 전에, 라이브러리의 학습 의존성을 설치해야 합니다.
-
-
-가장 최신 버전의 예시 스크립트를 성공적으로 실행하기 위해서는, 소스에서 설치하고 최신 버전의 설치를 유지하는 것을 강력하게 추천합니다. 우리는 예시 스크립트들을 자주 업데이트하고 예시에 맞춘 특정한 요구사항을 설치합니다.
-
-
+> [!WARNING]
+> 가장 최신 버전의 예시 스크립트를 성공적으로 실행하기 위해서는, 소스에서 설치하고 최신 버전의 설치를 유지하는 것을 강력하게 추천합니다. 우리는 예시 스크립트들을 자주 업데이트하고 예시에 맞춘 특정한 요구사항을 설치합니다.
위 사항을 만족시키기 위해서, 새로운 가상환경에서 다음 일련의 스텝을 실행하세요:
diff --git a/docs/source/ko/training/create_dataset.md b/docs/source/ko/training/create_dataset.md
index a869cd09f0..c459a9d6a1 100644
--- a/docs/source/ko/training/create_dataset.md
+++ b/docs/source/ko/training/create_dataset.md
@@ -11,11 +11,8 @@
- 이미지 폴더를 `--train_data_dir` 인수에 제공합니다.
- 데이터셋을 Hub에 업로드하고 데이터셋 리포지토리 id를 `--dataset_name` 인수에 전달합니다.
-
-
-💡 학습에 사용할 이미지 데이터셋을 만드는 방법에 대한 자세한 내용은 [이미지 데이터셋 만들기](https://huggingface.co/docs/datasets/image_dataset) 가이드를 참고하세요.
-
-
+> [!TIP]
+> 💡 학습에 사용할 이미지 데이터셋을 만드는 방법에 대한 자세한 내용은 [이미지 데이터셋 만들기](https://huggingface.co/docs/datasets/image_dataset) 가이드를 참고하세요.
## 폴더 형태로 데이터셋 구축하기
@@ -40,11 +37,8 @@ accelerate launch train_unconditional.py \
## Hub에 데이터 올리기
-
-
-💡 데이터셋을 만들고 Hub에 업로드하는 것에 대한 자세한 내용은 [🤗 Datasets을 사용한 이미지 검색](https://huggingface.co/blog/image-search-datasets) 게시물을 참고하세요.
-
-
+> [!TIP]
+> 💡 데이터셋을 만들고 Hub에 업로드하는 것에 대한 자세한 내용은 [🤗 Datasets을 사용한 이미지 검색](https://huggingface.co/blog/image-search-datasets) 게시물을 참고하세요.
PIL 인코딩된 이미지가 포함된 `이미지` 열을 생성하는 [이미지 폴더](https://huggingface.co/docs/datasets/image_load#imagefolder) 기능을 사용하여 데이터셋 생성을 시작합니다.
diff --git a/docs/source/ko/training/distributed_inference.md b/docs/source/ko/training/distributed_inference.md
index c4d6400d97..e63764f5eb 100644
--- a/docs/source/ko/training/distributed_inference.md
+++ b/docs/source/ko/training/distributed_inference.md
@@ -32,9 +32,8 @@ Use the `--num_processes` argument to specify the number of GPUs to use, and cal
accelerate launch run_distributed.py --num_processes=2
```
-자세한 내용은 [🤗 Accelerate를 사용한 분산 추론](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) 가이드를 참조하세요.
-
-
+> [!TIP]
+> 자세한 내용은 [🤗 Accelerate를 사용한 분산 추론](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) 가이드를 참조하세요.
## Pytoerch 분산
diff --git a/docs/source/ko/training/dreambooth.md b/docs/source/ko/training/dreambooth.md
index 8e62f8edab..3e5a17d5f6 100644
--- a/docs/source/ko/training/dreambooth.md
+++ b/docs/source/ko/training/dreambooth.md
@@ -51,11 +51,8 @@ write_basic_config()
## 파인튜닝
-
-
-DreamBooth 파인튜닝은 하이퍼파라미터에 매우 민감하고 과적합되기 쉽습니다. 적절한 하이퍼파라미터를 선택하는 데 도움이 되도록 다양한 권장 설정이 포함된 [심층 분석](https://huggingface.co/blog/dreambooth)을 살펴보는 것이 좋습니다.
-
-
+> [!WARNING]
+> DreamBooth 파인튜닝은 하이퍼파라미터에 매우 민감하고 과적합되기 쉽습니다. 적절한 하이퍼파라미터를 선택하는 데 도움이 되도록 다양한 권장 설정이 포함된 [심층 분석](https://huggingface.co/blog/dreambooth)을 살펴보는 것이 좋습니다.
@@ -176,11 +173,8 @@ python train_dreambooth_flax.py \
해당 스크립트를 사용하면 `unet`과 함께 `text_encoder`를 파인튜닝할 수 있습니다. 실험에서(자세한 내용은 [🧨 Diffusers를 사용해 DreamBooth로 Stable Diffusion 학습하기](https://huggingface.co/blog/dreambooth) 게시물을 확인하세요), 특히 얼굴 이미지를 생성할 때 훨씬 더 나은 결과를 얻을 수 있습니다.
-
-
-텍스트 인코더를 학습시키려면 추가 메모리가 필요해 16GB GPU로는 동작하지 않습니다. 이 옵션을 사용하려면 최소 24GB VRAM이 필요합니다.
-
-
+> [!WARNING]
+> 텍스트 인코더를 학습시키려면 추가 메모리가 필요해 16GB GPU로는 동작하지 않습니다. 이 옵션을 사용하려면 최소 24GB VRAM이 필요합니다.
`--train_text_encoder` 인수를 학습 스크립트에 전달하여 `text_encoder` 및 `unet`을 파인튜닝할 수 있습니다:
diff --git a/docs/source/ko/training/lora.md b/docs/source/ko/training/lora.md
index 5bcef27143..515e3fd65e 100644
--- a/docs/source/ko/training/lora.md
+++ b/docs/source/ko/training/lora.md
@@ -14,11 +14,8 @@ specific language governing permissions and limitations under the License.
[[open-in-colab]]
-
-
-현재 LoRA는 [`UNet2DConditionalModel`]의 어텐션 레이어에서만 지원됩니다.
-
-
+> [!WARNING]
+> 현재 LoRA는 [`UNet2DConditionalModel`]의 어텐션 레이어에서만 지원됩니다.
[LoRA(Low-Rank Adaptation of Large Language Models)](https://huggingface.co/papers/2106.09685)는 메모리를 적게 사용하면서 대규모 모델의 학습을 가속화하는 학습 방법입니다. 이는 rank-decomposition weight 행렬 쌍(**업데이트 행렬**이라고 함)을 추가하고 새로 추가된 가중치**만** 학습합니다. 여기에는 몇 가지 장점이 있습니다.
@@ -28,11 +25,8 @@ specific language governing permissions and limitations under the License.
- 메모리 효율성이 향상되어 Tesla T4, RTX 3080 또는 RTX 2080 Ti와 같은 소비자용 GPU에서 파인튜닝을 실행할 수 있습니다! T4와 같은 GPU는 무료이며 Kaggle 또는 Google Colab 노트북에서 쉽게 액세스할 수 있습니다.
-
-
-💡 LoRA는 어텐션 레이어에만 한정되지는 않습니다. 저자는 언어 모델의 어텐션 레이어를 수정하는 것이 매우 효율적으로 죻은 성능을 얻기에 충분하다는 것을 발견했습니다. 이것이 LoRA 가중치를 모델의 어텐션 레이어에 추가하는 것이 일반적인 이유입니다. LoRA 작동 방식에 대한 자세한 내용은 [Using LoRA for effective Stable Diffusion fine-tuning](https://huggingface.co/blog/lora) 블로그를 확인하세요!
-
-
+> [!TIP]
+> 💡 LoRA는 어텐션 레이어에만 한정되지는 않습니다. 저자는 언어 모델의 어텐션 레이어를 수정하는 것이 매우 효율적으로 죻은 성능을 얻기에 충분하다는 것을 발견했습니다. 이것이 LoRA 가중치를 모델의 어텐션 레이어에 추가하는 것이 일반적인 이유입니다. LoRA 작동 방식에 대한 자세한 내용은 [Using LoRA for effective Stable Diffusion fine-tuning](https://huggingface.co/blog/lora) 블로그를 확인하세요!
[cloneofsimo](https://github.com/cloneofsimo)는 인기 있는 [lora](https://github.com/cloneofsimo/lora) GitHub 리포지토리에서 Stable Diffusion을 위한 LoRA 학습을 최초로 시도했습니다. 🧨 Diffusers는 [text-to-image 생성](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#training-with-lora) 및 [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#training-with-low-rank-adaptation-of-large-language-models-lora)을 지원합니다. 이 가이드는 두 가지를 모두 수행하는 방법을 보여줍니다.
@@ -104,11 +98,8 @@ accelerate launch train_dreambooth_lora.py \
*기본 모델의 가중치 위에* 파인튜닝된 DreamBooth 모델에서 LoRA 가중치를 불러온 다음, 더 빠른 추론을 위해 파이프라인을 GPU로 이동합니다. LoRA 가중치를 프리징된 사전 훈련된 모델 가중치와 병합할 때, 선택적으로 'scale' 매개변수로 어느 정도의 가중치를 병합할 지 조절할 수 있습니다:
-
-
-💡 `0`의 `scale` 값은 LoRA 가중치를 사용하지 않아 원래 모델의 가중치만 사용한 것과 같고, `1`의 `scale` 값은 파인튜닝된 LoRA 가중치만 사용함을 의미합니다. 0과 1 사이의 값들은 두 결과들 사이로 보간됩니다.
-
-
+> [!TIP]
+> 💡 `0`의 `scale` 값은 LoRA 가중치를 사용하지 않아 원래 모델의 가중치만 사용한 것과 같고, `1`의 `scale` 값은 파인튜닝된 LoRA 가중치만 사용함을 의미합니다. 0과 1 사이의 값들은 두 결과들 사이로 보간됩니다.
```py
>>> pipe.unet.load_attn_procs(model_path)
diff --git a/docs/source/ko/training/text2image.md b/docs/source/ko/training/text2image.md
index 4283f73ed9..b26603bf1b 100644
--- a/docs/source/ko/training/text2image.md
+++ b/docs/source/ko/training/text2image.md
@@ -13,11 +13,8 @@ specific language governing permissions and limitations under the License.
# Text-to-image
-
-
-text-to-image 파인튜닝 스크립트는 experimental 상태입니다. 과적합하기 쉽고 치명적인 망각과 같은 문제에 부딪히기 쉽습니다. 자체 데이터셋에서 최상의 결과를 얻으려면 다양한 하이퍼파라미터를 탐색하는 것이 좋습니다.
-
-
+> [!WARNING]
+> text-to-image 파인튜닝 스크립트는 experimental 상태입니다. 과적합하기 쉽고 치명적인 망각과 같은 문제에 부딪히기 쉽습니다. 자체 데이터셋에서 최상의 결과를 얻으려면 다양한 하이퍼파라미터를 탐색하는 것이 좋습니다.
Stable Diffusion과 같은 text-to-image 모델은 텍스트 프롬프트에서 이미지를 생성합니다. 이 가이드는 PyTorch 및 Flax를 사용하여 자체 데이터셋에서 [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) 모델로 파인튜닝하는 방법을 보여줍니다. 이 가이드에 사용된 text-to-image 파인튜닝을 위한 모든 학습 스크립트에 관심이 있는 경우 이 [리포지토리](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image)에서 자세히 찾을 수 있습니다.
diff --git a/docs/source/ko/training/text_inversion.md b/docs/source/ko/training/text_inversion.md
index b27bed7d14..d8b44930e3 100644
--- a/docs/source/ko/training/text_inversion.md
+++ b/docs/source/ko/training/text_inversion.md
@@ -23,11 +23,8 @@ specific language governing permissions and limitations under the License.
이 가이드에서는 textual-inversion으로 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) 모델을 학습하는 방법을 설명합니다. 이 가이드에서 사용된 모든 textual-inversion 학습 스크립트는 [여기](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion)에서 확인할 수 있습니다. 내부적으로 어떻게 작동하는지 자세히 살펴보고 싶으시다면 해당 링크를 참조해주시기 바랍니다.
-
-
-[Stable Diffusion Textual Inversion Concepts Library](https://huggingface.co/sd-concepts-library)에는 커뮤니티에서 제작한 학습된 textual-inversion 모델들이 있습니다. 시간이 지남에 따라 더 많은 콘셉트들이 추가되어 유용한 리소스로 성장할 것입니다!
-
-
+> [!TIP]
+> [Stable Diffusion Textual Inversion Concepts Library](https://huggingface.co/sd-concepts-library)에는 커뮤니티에서 제작한 학습된 textual-inversion 모델들이 있습니다. 시간이 지남에 따라 더 많은 콘셉트들이 추가되어 유용한 리소스로 성장할 것입니다!
시작하기 전에 학습을 위한 의존성 라이브러리들을 설치해야 합니다:
@@ -100,11 +97,8 @@ snapshot_download(
- `token_identifier.txt`
- `type_of_concept.txt`.
-
-
-💡V100 GPU 1개를 기준으로 전체 학습에는 최대 1시간이 걸립니다. 학습이 완료되기를 기다리는 동안 궁금한 점이 있으면 아래 섹션에서 [textual-inversion이 어떻게 작동하는지](https://huggingface.co/docs/diffusers/training/text_inversion#how-it-works) 자유롭게 확인하세요 !
-
-
+> [!TIP]
+> 💡V100 GPU 1개를 기준으로 전체 학습에는 최대 1시간이 걸립니다. 학습이 완료되기를 기다리는 동안 궁금한 점이 있으면 아래 섹션에서 [textual-inversion이 어떻게 작동하는지](https://huggingface.co/docs/diffusers/training/text_inversion#how-it-works) 자유롭게 확인하세요 !
@@ -128,15 +122,12 @@ accelerate launch textual_inversion.py \
--push_to_hub
```
-
-
-💡학습 성능을 올리기 위해, 플레이스홀더 토큰(``)을 (단일한 임베딩 벡터가 아닌) 복수의 임베딩 벡터로 표현하는 것 역시 고려할 있습니다. 이러한 트릭이 모델이 보다 복잡한 이미지의 스타일(앞서 말한 콘셉트)을 더 잘 캡처하는 데 도움이 될 수 있습니다. 복수의 임베딩 벡터 학습을 활성화하려면 다음 옵션을 전달하십시오.
-
-```bash
---num_vectors=5
-```
-
-
+> [!TIP]
+> 💡학습 성능을 올리기 위해, 플레이스홀더 토큰(``)을 (단일한 임베딩 벡터가 아닌) 복수의 임베딩 벡터로 표현하는 것 역시 고려할 있습니다. 이러한 트릭이 모델이 보다 복잡한 이미지의 스타일(앞서 말한 콘셉트)을 더 잘 캡처하는 데 도움이 될 수 있습니다. 복수의 임베딩 벡터 학습을 활성화하려면 다음 옵션을 전달하십시오.
+>
+> ```bash
+> --num_vectors=5
+> ```
@@ -193,11 +184,8 @@ textual-inversion 스크립트는 기본적으로 textual-inversion을 통해
-
-
-💡 커뮤니티는 [sd-concepts-library](https://huggingface.co/sd-concepts-library) 라는 대규모의 textual-inversion 임베딩 벡터 라이브러리를 만들었습니다. textual-inversion 임베딩을 밑바닥부터 학습하는 대신, 해당 라이브러리에 본인이 찾는 textual-inversion 임베딩이 이미 추가되어 있지 않은지를 확인하는 것도 좋은 방법이 될 것 같습니다.
-
-
+> [!TIP]
+> 💡 커뮤니티는 [sd-concepts-library](https://huggingface.co/sd-concepts-library) 라는 대규모의 textual-inversion 임베딩 벡터 라이브러리를 만들었습니다. textual-inversion 임베딩을 밑바닥부터 학습하는 대신, 해당 라이브러리에 본인이 찾는 textual-inversion 임베딩이 이미 추가되어 있지 않은지를 확인하는 것도 좋은 방법이 될 것 같습니다.
textual-inversion 임베딩 벡터을 불러오기 위해서는, 먼저 해당 임베딩 벡터를 학습할 때 사용한 모델을 불러와야 합니다. 여기서는 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/docs/diffusers/training/stable-diffusion-v1-5/stable-diffusion-v1-5) 모델이 사용되었다고 가정하고 불러오겠습니다.
diff --git a/docs/source/ko/training/unconditional_training.md b/docs/source/ko/training/unconditional_training.md
index c8c463da6b..04a9a6c7ea 100644
--- a/docs/source/ko/training/unconditional_training.md
+++ b/docs/source/ko/training/unconditional_training.md
@@ -78,11 +78,8 @@ write_basic_config()
학습 스크립트는 `diffusion_pytorch_model.bin` 파일을 생성하고, 그것을 당신의 리포지토리에 저장합니다.
-
-
-💡 전체 학습은 V100 GPU 4개를 사용할 경우, 2시간이 소요됩니다.
-
-
+> [!TIP]
+> 💡 전체 학습은 V100 GPU 4개를 사용할 경우, 2시간이 소요됩니다.
예를 들어, [Oxford Flowers](https://huggingface.co/datasets/huggan/flowers-102-categories) 데이터셋을 사용해 파인튜닝할 경우:
diff --git a/docs/source/ko/tutorials/basic_training.md b/docs/source/ko/tutorials/basic_training.md
index 2c4c89edd1..05ce1037b5 100644
--- a/docs/source/ko/tutorials/basic_training.md
+++ b/docs/source/ko/tutorials/basic_training.md
@@ -19,11 +19,8 @@ Unconditional 이미지 생성은 학습에 사용된 데이터셋과 유사한
이 튜토리얼은 나만의 🦋 나비 🦋를 생성하기 위해 [Smithsonian Butterflies](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) 데이터셋의 하위 집합에서 [`UNet2DModel`] 모델을 학습하는 방법을 가르쳐줄 것입니다.
-
-
-💡 이 학습 튜토리얼은 [Training with 🧨 Diffusers](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) 노트북 기반으로 합니다. Diffusion 모델의 작동 방식 및 자세한 내용은 노트북을 확인하세요!
-
-
+> [!TIP]
+> 💡 이 학습 튜토리얼은 [Training with 🧨 Diffusers](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) 노트북 기반으로 합니다. Diffusion 모델의 작동 방식 및 자세한 내용은 노트북을 확인하세요!
시작 전에, 🤗 Datasets을 불러오고 전처리하기 위해 데이터셋이 설치되어 있는지 다수 GPU에서 학습을 간소화하기 위해 🤗 Accelerate 가 설치되어 있는지 확인하세요. 그 후 학습 메트릭을 시각화하기 위해 [TensorBoard](https://www.tensorflow.org/tensorboard)를 또한 설치하세요. (또한 학습 추적을 위해 [Weights & Biases](https://docs.wandb.ai/)를 사용할 수 있습니다.)
diff --git a/docs/source/ko/using-diffusers/controlling_generation.md b/docs/source/ko/using-diffusers/controlling_generation.md
index 1b9a8b5df5..db22fe042d 100644
--- a/docs/source/ko/using-diffusers/controlling_generation.md
+++ b/docs/source/ko/using-diffusers/controlling_generation.md
@@ -85,12 +85,9 @@ Pix2Pix Zero는 합성 이미지와 실제 이미지를 편집하는 데 모두
다음으로 편집할 컨셉과 새로운 타겟 컨셉에 대한 이미지 캡션을 생성합니다. 이를 위해 [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)와 같은 모델을 사용할 수 있습니다. 그런 다음 텍스트 인코더를 통해 소스 개념과 대상 개념 모두에 대한 "평균" 프롬프트 임베딩을 생성합니다. 마지막으로, 합성 이미지를 편집하기 위해 pix2pix-zero 알고리즘을 사용합니다.
- 실제 이미지를 편집하려면 먼저 [BLIP](https://huggingface.co/docs/transformers/model_doc/blip)과 같은 모델을 사용하여 이미지 캡션을 생성합니다. 그런 다음 프롬프트와 이미지에 ddim 반전을 적용하여 "역(inverse)" latents을 생성합니다. 이전과 마찬가지로 소스 및 대상 개념 모두에 대한 "평균(mean)" 프롬프트 임베딩이 생성되고 마지막으로 "역(inverse)" latents와 결합된 pix2pix-zero 알고리즘이 이미지를 편집하는 데 사용됩니다.
-
-
-Pix2Pix Zero는 '제로 샷(zero-shot)' 이미지 편집이 가능한 최초의 모델입니다.
-즉, 이 모델은 다음과 같이 일반 소비자용 GPU에서 1분 이내에 이미지를 편집할 수 있습니다(../api/pipelines/stable_diffusion/pix2pix_zero#usage-example).
-
-
+> [!TIP]
+> Pix2Pix Zero는 '제로 샷(zero-shot)' 이미지 편집이 가능한 최초의 모델입니다.
+> 즉, 이 모델은 다음과 같이 일반 소비자용 GPU에서 1분 이내에 이미지를 편집할 수 있습니다(../api/pipelines/stable_diffusion/pix2pix_zero#usage-example).
위에서 언급했듯이 Pix2Pix Zero에는 특정 개념으로 세대를 유도하기 위해 (UNet, VAE 또는 텍스트 인코더가 아닌) latents을 최적화하는 기능이 포함되어 있습니다.즉, 전체 파이프라인에 표준 [StableDiffusionPipeline](../api/pipelines/stable_diffusion/text2img)보다 더 많은 메모리가 필요할 수 있습니다.
@@ -140,13 +137,10 @@ SAG는 고빈도 세부 정보를 기반으로 하지 않은 예측에서 완전
사용 방법에 대한 자세한 내용은 [여기](../api/pipelines/stable_diffusion_2#depthtoimage)를 참조하세요.
-
-
-InstructPix2Pix와 Pix2Pix Zero와 같은 방법의 중요한 차이점은 전자의 경우
-는 사전 학습된 가중치를 미세 조정하는 반면, 후자는 그렇지 않다는 것입니다. 즉, 다음을 수행할 수 있습니다.
-사용 가능한 모든 안정적 확산 모델에 Pix2Pix Zero를 적용할 수 있습니다.
-
-
+> [!TIP]
+> InstructPix2Pix와 Pix2Pix Zero와 같은 방법의 중요한 차이점은 전자의 경우
+> 는 사전 학습된 가중치를 미세 조정하는 반면, 후자는 그렇지 않다는 것입니다. 즉, 다음을 수행할 수 있습니다.
+> 사용 가능한 모든 안정적 확산 모델에 Pix2Pix Zero를 적용할 수 있습니다.
## MultiDiffusion Panorama
diff --git a/docs/source/ko/using-diffusers/custom_pipeline_overview.md b/docs/source/ko/using-diffusers/custom_pipeline_overview.md
index b143bf8ab0..caeeca8cef 100644
--- a/docs/source/ko/using-diffusers/custom_pipeline_overview.md
+++ b/docs/source/ko/using-diffusers/custom_pipeline_overview.md
@@ -20,11 +20,8 @@ specific language governing permissions and limitations under the License.
허브에서 커뮤니티 파이프라인을 로드하려면, 커뮤니티 파이프라인의 리포지토리 ID와 (파이프라인 가중치 및 구성 요소를 로드하려는) 모델의 리포지토리 ID를 인자로 전달해야 합니다. 예를 들어, 아래 예시에서는 `hf-internal-testing/diffusers-dummy-pipeline`에서 더미 파이프라인을 불러오고, `google/ddpm-cifar10-32`에서 파이프라인의 가중치와 컴포넌트들을 로드합니다.
-
-
-🔒 허깅 페이스 허브에서 커뮤니티 파이프라인을 불러오는 것은 곧 해당 코드가 안전하다고 신뢰하는 것입니다. 코드를 자동으로 불러오고 실행하기 앞서 반드시 온라인으로 해당 코드의 신뢰성을 검사하세요!
-
-
+> [!WARNING]
+> 🔒 허깅 페이스 허브에서 커뮤니티 파이프라인을 불러오는 것은 곧 해당 코드가 안전하다고 신뢰하는 것입니다. 코드를 자동으로 불러오고 실행하기 앞서 반드시 온라인으로 해당 코드의 신뢰성을 검사하세요!
```py
from diffusers import DiffusionPipeline
diff --git a/docs/source/ko/using-diffusers/diffedit.md b/docs/source/ko/using-diffusers/diffedit.md
index 74b9e97831..edf23f0214 100644
--- a/docs/source/ko/using-diffusers/diffedit.md
+++ b/docs/source/ko/using-diffusers/diffedit.md
@@ -156,11 +156,8 @@ print(source_prompts)
print(target_prompts)
```
-
-
-다양한 품질의 텍스트를 생성하는 전략에 대해 자세히 알아보려면 [생성 전략](https://huggingface.co/docs/transformers/main/en/generation_strategies) 가이드를 참조하세요.
-
-
+> [!TIP]
+> 다양한 품질의 텍스트를 생성하는 전략에 대해 자세히 알아보려면 [생성 전략](https://huggingface.co/docs/transformers/main/en/generation_strategies) 가이드를 참조하세요.
텍스트 인코딩을 위해 [`StableDiffusionDiffEditPipeline`]에서 사용하는 텍스트 인코더 모델을 불러옵니다. 텍스트 인코더를 사용하여 텍스트 임베딩을 계산합니다:
diff --git a/docs/source/ko/using-diffusers/img2img.md b/docs/source/ko/using-diffusers/img2img.md
index 8da840f748..3901fb755f 100644
--- a/docs/source/ko/using-diffusers/img2img.md
+++ b/docs/source/ko/using-diffusers/img2img.md
@@ -53,11 +53,8 @@ init_image
-
-
-💡 `strength`는 입력 이미지에 추가되는 노이즈의 양을 제어하는 0.0에서 1.0 사이의 값입니다. 1.0에 가까운 값은 다양한 변형을 허용하지만 입력 이미지와 의미적으로 일치하지 않는 이미지를 생성합니다.
-
-
+> [!TIP]
+> 💡 `strength`는 입력 이미지에 추가되는 노이즈의 양을 제어하는 0.0에서 1.0 사이의 값입니다. 1.0에 가까운 값은 다양한 변형을 허용하지만 입력 이미지와 의미적으로 일치하지 않는 이미지를 생성합니다.
프롬프트를 정의하고(지브리 스타일(Ghibli-style)에 맞게 조정된 이 체크포인트의 경우 프롬프트 앞에 `ghibli style` 토큰을 붙여야 합니다) 파이프라인을 실행합니다:
diff --git a/docs/source/ko/using-diffusers/inpaint.md b/docs/source/ko/using-diffusers/inpaint.md
index adf1251176..cefb892186 100644
--- a/docs/source/ko/using-diffusers/inpaint.md
+++ b/docs/source/ko/using-diffusers/inpaint.md
@@ -59,11 +59,8 @@ image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
:-------------------------:|:-------------------------:|:-------------------------:|-------------------------:|
| | ***Face of a yellow cat, high resolution, sitting on a park bench*** | |
-
-
-이전의 실험적인 인페인팅 구현에서는 품질이 낮은 다른 프로세스를 사용했습니다. 이전 버전과의 호환성을 보장하기 위해 새 모델이 포함되지 않은 사전학습된 파이프라인을 불러오면 이전 인페인팅 방법이 계속 적용됩니다.
-
-
+> [!WARNING]
+> 이전의 실험적인 인페인팅 구현에서는 품질이 낮은 다른 프로세스를 사용했습니다. 이전 버전과의 호환성을 보장하기 위해 새 모델이 포함되지 않은 사전학습된 파이프라인을 불러오면 이전 인페인팅 방법이 계속 적용됩니다.
아래 Space에서 이미지 인페인팅을 직접 해보세요!
diff --git a/docs/source/ko/using-diffusers/kandinsky.md b/docs/source/ko/using-diffusers/kandinsky.md
index cc554c67f9..8eff8f5629 100644
--- a/docs/source/ko/using-diffusers/kandinsky.md
+++ b/docs/source/ko/using-diffusers/kandinsky.md
@@ -31,15 +31,12 @@ Kandinsky 모델은 일련의 다국어 text-to-image 생성 모델입니다. Ka
#!pip install -q diffusers transformers accelerate
```
-
-
-Kandinsky 2.1과 2.2의 사용법은 매우 유사합니다! 유일한 차이점은 Kandinsky 2.2는 latents를 디코딩할 때 `프롬프트`를 입력으로 받지 않는다는 것입니다. 대신, Kandinsky 2.2는 디코딩 중에는 `image_embeds`만 받아들입니다.
-
-
-
-Kandinsky 3는 더 간결한 아키텍처를 가지고 있으며 prior 모델이 필요하지 않습니다. 즉, [Stable Diffusion XL](sdxl)과 같은 다른 diffusion 모델과 사용법이 동일합니다.
-
-
+> [!WARNING]
+> Kandinsky 2.1과 2.2의 사용법은 매우 유사합니다! 유일한 차이점은 Kandinsky 2.2는 latents를 디코딩할 때 `프롬프트`를 입력으로 받지 않는다는 것입니다. 대신, Kandinsky 2.2는 디코딩 중에는 `image_embeds`만 받아들입니다.
+>
+>
+>
+> Kandinsky 3는 더 간결한 아키텍처를 가지고 있으며 prior 모델이 필요하지 않습니다. 즉, [Stable Diffusion XL](sdxl)과 같은 다른 diffusion 모델과 사용법이 동일합니다.
## Text-to-image
@@ -321,20 +318,17 @@ make_image_grid([original_image.resize((512, 512)), image.resize((512, 512))], r
## Inpainting
-
-
-⚠️ Kandinsky 모델은 이제 검은색 픽셀 대신 ⬜️ **흰색 픽셀**을 사용하여 마스크 영역을 표현합니다. 프로덕션에서 [`KandinskyInpaintPipeline`]을 사용하는 경우 흰색 픽셀을 사용하도록 마스크를 변경해야 합니다:
-
-```py
-# PIL 입력에 대해
-import PIL.ImageOps
-mask = PIL.ImageOps.invert(mask)
-
-# PyTorch와 NumPy 입력에 대해
-mask = 1 - mask
-```
-
-
+> [!WARNING]
+> ⚠️ Kandinsky 모델은 이제 검은색 픽셀 대신 ⬜️ **흰색 픽셀**을 사용하여 마스크 영역을 표현합니다. 프로덕션에서 [`KandinskyInpaintPipeline`]을 사용하는 경우 흰색 픽셀을 사용하도록 마스크를 변경해야 합니다:
+>
+> ```py
+> # PIL 입력에 대해
+> import PIL.ImageOps
+> mask = PIL.ImageOps.invert(mask)
+>
+> # PyTorch와 NumPy 입력에 대해
+> mask = 1 - mask
+> ```
인페인팅에서는 원본 이미지, 원본 이미지에서 대체할 영역의 마스크, 인페인팅할 내용에 대한 텍스트 프롬프트가 필요합니다. Prior 파이프라인을 불러옵니다:
@@ -565,11 +559,8 @@ image
## ControlNet
-
-
-⚠️ ControlNet은 Kandinsky 2.2에서만 지원됩니다!
-
-
+> [!WARNING]
+> ⚠️ ControlNet은 Kandinsky 2.2에서만 지원됩니다!
ControlNet을 사용하면 depth map이나 edge detection와 같은 추가 입력을 통해 사전학습된 large diffusion 모델을 conditioning할 수 있습니다. 예를 들어, 모델이 depth map의 구조를 이해하고 보존할 수 있도록 깊이 맵으로 Kandinsky 2.2를 conditioning할 수 있습니다.
diff --git a/docs/source/ko/using-diffusers/loading.md b/docs/source/ko/using-diffusers/loading.md
index 3d6b7634b4..2160acacc2 100644
--- a/docs/source/ko/using-diffusers/loading.md
+++ b/docs/source/ko/using-diffusers/loading.md
@@ -30,11 +30,8 @@ diffusion 모델의 훈련과 추론에 필요한 모든 것은 [`DiffusionPipel
## Diffusion 파이프라인
-
-
-💡 [`DiffusionPipeline`] 클래스가 동작하는 방식에 보다 자세한 내용이 궁금하다면, [DiffusionPipeline explained](#diffusionpipeline에-대해-알아보기) 섹션을 확인해보세요.
-
-
+> [!TIP]
+> 💡 [`DiffusionPipeline`] 클래스가 동작하는 방식에 보다 자세한 내용이 궁금하다면, [DiffusionPipeline explained](#diffusionpipeline에-대해-알아보기) 섹션을 확인해보세요.
[`DiffusionPipeline`] 클래스는 diffusion 모델을 [허브](https://huggingface.co/models?library=diffusers)로부터 불러오는 가장 심플하면서 보편적인 방식입니다. [`DiffusionPipeline.from_pretrained`] 메서드는 적합한 파이프라인 클래스를 자동으로 탐지하고, 필요한 구성요소(configuration)와 가중치(weight) 파일들을 다운로드하고 캐싱한 다음, 해당 파이프라인 인스턴스를 반환합니다.
@@ -175,11 +172,8 @@ Variant란 일반적으로 다음과 같은 체크포인트들을 의미합니
- `torch.float16`과 같이 정밀도는 더 낮지만, 용량 역시 더 작은 부동소수점 타입의 가중치를 사용하는 체크포인트. *(다만 이와 같은 variant의 경우, 추가적인 훈련과 CPU환경에서의 구동이 불가능합니다.)*
- Non-EMA 가중치를 사용하는 체크포인트. *(Non-EMA 가중치의 경우, 파인 튜닝 단계에서 사용하는 것이 권장되는데, 추론 단계에선 사용하지 않는 것이 권장됩니다.)*
-
-
-💡 모델 구조는 동일하지만 서로 다른 학습 환경에서 서로 다른 데이터셋으로 학습된 체크포인트들이 있을 경우, 해당 체크포인트들은 variant 단계가 아닌 리포지토리 단계에서 분리되어 관리되어야 합니다. (즉, 해당 체크포인트들은 서로 다른 리포지토리에서 따로 관리되어야 합니다. 예시: [`stable-diffusion-v1-4`], [`stable-diffusion-v1-5`]).
-
-
+> [!TIP]
+> 💡 모델 구조는 동일하지만 서로 다른 학습 환경에서 서로 다른 데이터셋으로 학습된 체크포인트들이 있을 경우, 해당 체크포인트들은 variant 단계가 아닌 리포지토리 단계에서 분리되어 관리되어야 합니다. (즉, 해당 체크포인트들은 서로 다른 리포지토리에서 따로 관리되어야 합니다. 예시: [`stable-diffusion-v1-4`], [`stable-diffusion-v1-5`]).
| **checkpoint type** | **weight name** | **argument for loading weights** |
| ------------------- | ----------------------------------- | -------------------------------- |
diff --git a/docs/source/ko/using-diffusers/loading_adapters.md b/docs/source/ko/using-diffusers/loading_adapters.md
index f0d085bc6a..e7ae116575 100644
--- a/docs/source/ko/using-diffusers/loading_adapters.md
+++ b/docs/source/ko/using-diffusers/loading_adapters.md
@@ -18,11 +18,8 @@ specific language governing permissions and limitations under the License.
이 가이드에서는 DreamBooth, textual inversion 및 LoRA 가중치를 불러오는 방법을 설명합니다.
-
-
-사용할 체크포인트와 임베딩은 [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer), [LoRA the Explorer](https://huggingface.co/spaces/multimodalart/LoraTheExplorer), [Diffusers Models Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery)에서 찾아보시기 바랍니다.
-
-
+> [!TIP]
+> 사용할 체크포인트와 임베딩은 [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer), [LoRA the Explorer](https://huggingface.co/spaces/multimodalart/LoraTheExplorer), [Diffusers Models Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery)에서 찾아보시기 바랍니다.
## DreamBooth
@@ -101,11 +98,8 @@ image
[Low-Rank Adaptation (LoRA)](https://huggingface.co/papers/2106.09685)은 속도가 빠르고 파일 크기가 (수백 MB로) 작기 때문에 널리 사용되는 학습 기법입니다. 이 가이드의 다른 방법과 마찬가지로, LoRA는 몇 장의 이미지만으로 새로운 스타일을 학습하도록 모델을 학습시킬 수 있습니다. 이는 diffusion 모델에 새로운 가중치를 삽입한 다음 전체 모델 대신 새로운 가중치만 학습시키는 방식으로 작동합니다. 따라서 LoRA를 더 빠르게 학습시키고 더 쉽게 저장할 수 있습니다.
-
-
-LoRA는 다른 학습 방법과 함께 사용할 수 있는 매우 일반적인 학습 기법입니다. 예를 들어, DreamBooth와 LoRA로 모델을 학습하는 것이 일반적입니다. 또한 새롭고 고유한 이미지를 생성하기 위해 여러 개의 LoRA를 불러오고 병합하는 것이 점점 더 일반화되고 있습니다. 병합은 이 불러오기 가이드의 범위를 벗어나므로 자세한 내용은 심층적인 [LoRA 병합](merge_loras) 가이드에서 확인할 수 있습니다.
-
-
+> [!TIP]
+> LoRA는 다른 학습 방법과 함께 사용할 수 있는 매우 일반적인 학습 기법입니다. 예를 들어, DreamBooth와 LoRA로 모델을 학습하는 것이 일반적입니다. 또한 새롭고 고유한 이미지를 생성하기 위해 여러 개의 LoRA를 불러오고 병합하는 것이 점점 더 일반화되고 있습니다. 병합은 이 불러오기 가이드의 범위를 벗어나므로 자세한 내용은 심층적인 [LoRA 병합](merge_loras) 가이드에서 확인할 수 있습니다.
LoRA는 다른 모델과 함께 사용해야 합니다:
@@ -184,11 +178,8 @@ pipe.set_adapters("my_adapter", scales)
이는 여러 어댑터에서도 작동합니다. 방법은 [이 가이드](https://huggingface.co/docs/diffusers/tutorials/using_peft_for_inference#customize-adapters-strength)를 참조하세요.
-
-
-현재 [`~loaders.LoraLoaderMixin.set_adapters`]는 어텐션 가중치의 스케일링만 지원합니다. LoRA에 다른 부분(예: resnets or down-/upsamplers)이 있는 경우 1.0의 스케일을 유지합니다.
-
-
+> [!WARNING]
+> 현재 [`~loaders.LoraLoaderMixin.set_adapters`]는 어텐션 가중치의 스케일링만 지원합니다. LoRA에 다른 부분(예: resnets or down-/upsamplers)이 있는 경우 1.0의 스케일을 유지합니다.
### Kohya와 TheLastBen
@@ -222,14 +213,11 @@ image = pipeline(prompt).images[0]
image
```
-
-
-Kohya LoRA를 🤗 Diffusers와 함께 사용할 때 몇 가지 제한 사항이 있습니다:
-
-- [여기](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736)에 설명된 여러 가지 이유로 인해 이미지가 ComfyUI와 같은 UI에서 생성된 이미지와 다르게 보일 수 있습니다.
-- [LyCORIS 체크포인트](https://github.com/KohakuBlueleaf/LyCORIS)가 완전히 지원되지 않습니다. [`~loaders.LoraLoaderMixin.load_lora_weights`] 메서드는 LoRA 및 LoCon 모듈로 LyCORIS 체크포인트를 불러올 수 있지만, Hada 및 LoKR은 지원되지 않습니다.
-
-
+> [!WARNING]
+> Kohya LoRA를 🤗 Diffusers와 함께 사용할 때 몇 가지 제한 사항이 있습니다:
+>
+> - [여기](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736)에 설명된 여러 가지 이유로 인해 이미지가 ComfyUI와 같은 UI에서 생성된 이미지와 다르게 보일 수 있습니다.
+> - [LyCORIS 체크포인트](https://github.com/KohakuBlueleaf/LyCORIS)가 완전히 지원되지 않습니다. [`~loaders.LoraLoaderMixin.load_lora_weights`] 메서드는 LoRA 및 LoCon 모듈로 LyCORIS 체크포인트를 불러올 수 있지만, Hada 및 LoKR은 지원되지 않습니다.
@@ -326,9 +314,8 @@ pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=
IP-Adapter FaceID 모델은 CLIP 이미지 임베딩 대신 `insightface`에서 생성한 이미지 임베딩을 사용하는 실험적인 IP Adapter입니다. 이러한 모델 중 일부는 LoRA를 사용하여 ID 일관성을 개선하기도 합니다.
이러한 모델을 사용하려면 `insightface`와 해당 요구 사항을 모두 설치해야 합니다.
-
-InsightFace 사전학습된 모델은 비상업적 연구 목적으로만 사용할 수 있으므로, IP-Adapter-FaceID 모델은 연구 목적으로만 릴리즈되었으며 상업적 용도로는 사용할 수 없습니다.
-
+> [!WARNING]
+> InsightFace 사전학습된 모델은 비상업적 연구 목적으로만 사용할 수 있으므로, IP-Adapter-FaceID 모델은 연구 목적으로만 릴리즈되었으며 상업적 용도로는 사용할 수 없습니다.
```py
pipeline = AutoPipelineForText2Image.from_pretrained(
diff --git a/docs/source/ko/using-diffusers/other-formats.md b/docs/source/ko/using-diffusers/other-formats.md
index 3034551f48..f5a71f56eb 100644
--- a/docs/source/ko/using-diffusers/other-formats.md
+++ b/docs/source/ko/using-diffusers/other-formats.md
@@ -14,11 +14,8 @@ specific language governing permissions and limitations under the License.
Stable Diffusion 모델들은 학습 및 저장된 프레임워크와 다운로드 위치에 따라 다양한 형식으로 제공됩니다. 이러한 형식을 🤗 Diffusers에서 사용할 수 있도록 변환하면 추론을 위한 [다양한 스케줄러 사용](schedulers), 사용자 지정 파이프라인 구축, 추론 속도 최적화를 위한 다양한 기법과 방법 등 라이브러리에서 지원하는 모든 기능을 사용할 수 있습니다.
-
-
-우리는 `.safetensors` 형식을 추천합니다. 왜냐하면 기존의 pickled 파일은 취약하고 머신에서 코드를 실행할 때 악용될 수 있는 것에 비해 훨씬 더 안전합니다. (safetensors 불러오기 가이드에서 자세히 알아보세요.)
-
-
+> [!TIP]
+> 우리는 `.safetensors` 형식을 추천합니다. 왜냐하면 기존의 pickled 파일은 취약하고 머신에서 코드를 실행할 때 악용될 수 있는 것에 비해 훨씬 더 안전합니다. (safetensors 불러오기 가이드에서 자세히 알아보세요.)
이 가이드에서는 다른 Stable Diffusion 형식을 🤗 Diffusers와 호환되도록 변환하는 방법을 설명합니다.
diff --git a/docs/source/ko/using-diffusers/schedulers.md b/docs/source/ko/using-diffusers/schedulers.md
index 55424c9982..b12c08b8c8 100644
--- a/docs/source/ko/using-diffusers/schedulers.md
+++ b/docs/source/ko/using-diffusers/schedulers.md
@@ -318,12 +318,9 @@ images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
```
-
-
-다음 Flax 스케줄러는 *아직* Flax Stable Diffusion 파이프라인과 호환되지 않습니다.
-
-- `FlaxLMSDiscreteScheduler`
-- `FlaxDDPMScheduler`
-
-
+> [!WARNING]
+> 다음 Flax 스케줄러는 *아직* Flax Stable Diffusion 파이프라인과 호환되지 않습니다.
+>
+> - `FlaxLMSDiscreteScheduler`
+> - `FlaxDDPMScheduler`
diff --git a/docs/source/ko/using-diffusers/shap-e.md b/docs/source/ko/using-diffusers/shap-e.md
index abf5a182b3..4c9d7fb7d1 100644
--- a/docs/source/ko/using-diffusers/shap-e.md
+++ b/docs/source/ko/using-diffusers/shap-e.md
@@ -151,11 +151,8 @@ images = pipe(prompt, guidance_scale=guidance_scale, num_inference_steps=64, fra
메시 출력을 `ply` 파일로 저장하려면 [`~utils.export_to_ply`] 함수를 사용합니다:
-
-
-선택적으로 [`~utils.export_to_obj`] 함수를 사용하여 메시 출력을 `obj` 파일로 저장할 수 있습니다. 다양한 형식으로 메시 출력을 저장할 수 있어 다운스트림에서 더욱 유연하게 사용할 수 있습니다!
-
-
+> [!TIP]
+> 선택적으로 [`~utils.export_to_obj`] 함수를 사용하여 메시 출력을 `obj` 파일로 저장할 수 있습니다. 다양한 형식으로 메시 출력을 저장할 수 있어 다운스트림에서 더욱 유연하게 사용할 수 있습니다!
```py
from diffusers.utils import export_to_ply
diff --git a/docs/source/ko/using-diffusers/unconditional_image_generation.md b/docs/source/ko/using-diffusers/unconditional_image_generation.md
index c3eaac4b03..b8fe800578 100644
--- a/docs/source/ko/using-diffusers/unconditional_image_generation.md
+++ b/docs/source/ko/using-diffusers/unconditional_image_generation.md
@@ -20,11 +20,8 @@ Unconditional 이미지 생성은 비교적 간단한 작업입니다. 모델이
먼저 ['DiffusionPipeline']의 인스턴스를 생성하고 다운로드할 파이프라인의 [체크포인트](https://huggingface.co/models?library=diffusers&sort=downloads)를 지정합니다. 허브의 🧨 diffusion 체크포인트 중 하나를 사용할 수 있습니다(사용할 체크포인트는 나비 이미지를 생성합니다).
-
-
-💡 나만의 unconditional 이미지 생성 모델을 학습시키고 싶으신가요? 학습 가이드를 살펴보고 나만의 이미지를 생성하는 방법을 알아보세요.
-
-
+> [!TIP]
+> 💡 나만의 unconditional 이미지 생성 모델을 학습시키고 싶으신가요? 학습 가이드를 살펴보고 나만의 이미지를 생성하는 방법을 알아보세요.
이 가이드에서는 unconditional 이미지 생성에 ['DiffusionPipeline']과 [DDPM](https://huggingface.co/papers/2006.11239)을 사용합니다:
diff --git a/docs/source/ko/using-diffusers/write_own_pipeline.md b/docs/source/ko/using-diffusers/write_own_pipeline.md
index 45678763cc..ae6ce238ac 100644
--- a/docs/source/ko/using-diffusers/write_own_pipeline.md
+++ b/docs/source/ko/using-diffusers/write_own_pipeline.md
@@ -110,11 +110,8 @@ Stable Diffusion 은 text-to-image *latent diffusion* 모델입니다. latent di
보시다시피, 이것은 UNet 모델만 포함된 DDPM 파이프라인보다 더 복잡합니다. Stable Diffusion 모델에는 세 개의 개별 사전학습된 모델이 있습니다.
-
-
-💡 VAE, UNet 및 텍스트 인코더 모델의 작동방식에 대한 자세한 내용은 [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) 블로그를 참조하세요.
-
-
+> [!TIP]
+> 💡 VAE, UNet 및 텍스트 인코더 모델의 작동방식에 대한 자세한 내용은 [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) 블로그를 참조하세요.
이제 Stable Diffusion 파이프라인에 필요한 구성요소들이 무엇인지 알았으니, [`~ModelMixin.from_pretrained`] 메서드를 사용해 모든 구성요소를 불러옵니다. 사전학습된 체크포인트 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)에서 찾을 수 있으며, 각 구성요소들은 별도의 하위 폴더에 저장되어 있습니다:
@@ -151,11 +148,8 @@ Stable Diffusion 은 text-to-image *latent diffusion* 모델입니다. latent di
다음 단계는 임베딩을 생성하기 위해 텍스트를 토큰화하는 것입니다. 이 텍스트는 UNet 모델에서 condition으로 사용되고 입력 프롬프트와 유사한 방향으로 diffusion 프로세스를 조정하는 데 사용됩니다.
-
-
-💡 `guidance_scale` 매개변수는 이미지를 생성할 때 프롬프트에 얼마나 많은 가중치를 부여할지 결정합니다.
-
-
+> [!TIP]
+> 💡 `guidance_scale` 매개변수는 이미지를 생성할 때 프롬프트에 얼마나 많은 가중치를 부여할지 결정합니다.
다른 프롬프트를 생성하고 싶다면 원하는 프롬프트를 자유롭게 선택하세요!
@@ -198,15 +192,12 @@ Stable Diffusion 은 text-to-image *latent diffusion* 모델입니다. latent di
그다음 diffusion 프로세스의 시작점으로 초기 랜덤 노이즈를 생성합니다. 이것이 이미지의 잠재적 표현이며 점차적으로 노이즈가 제거됩니다. 이 시점에서 `latent` 이미지는 최종 이미지 크기보다 작지만 나중에 모델이 이를 512x512 이미지 크기로 변환하므로 괜찮습니다.
-
-
-💡 `vae` 모델에는 3개의 다운 샘플링 레이어가 있기 때문에 높이와 너비가 8로 나뉩니다. 다음을 실행하여 확인할 수 있습니다:
-
-```py
-2 ** (len(vae.config.block_out_channels) - 1) == 8
-```
-
-
+> [!TIP]
+> 💡 `vae` 모델에는 3개의 다운 샘플링 레이어가 있기 때문에 높이와 너비가 8로 나뉩니다. 다음을 실행하여 확인할 수 있습니다:
+>
+> ```py
+> 2 ** (len(vae.config.block_out_channels) - 1) == 8
+> ```
```py
>>> latents = torch.randn(
diff --git a/docs/source/pt/installation.md b/docs/source/pt/installation.md
index 1e83e36ca1..acc767110c 100644
--- a/docs/source/pt/installation.md
+++ b/docs/source/pt/installation.md
@@ -104,11 +104,8 @@ Esses comandos irá linkar a pasta que você clonou o repositório e os caminhos
Python então irá procurar dentro da pasta que você clonou além dos caminhos normais das bibliotecas.
Por exemplo, se o pacote python for tipicamente instalado no `~/anaconda3/envs/main/lib/python3.10/site-packages/`, o Python também irá procurar na pasta `~/diffusers/` que você clonou.
-
-
-Você deve deixar a pasta `diffusers` se você quiser continuar usando a biblioteca.
-
-
+> [!WARNING]
+> Você deve deixar a pasta `diffusers` se você quiser continuar usando a biblioteca.
Agora você pode facilmente atualizar seu clone para a última versão do 🤗 Diffusers com o seguinte comando:
diff --git a/docs/source/pt/quicktour.md b/docs/source/pt/quicktour.md
index 109f7e2712..5996b65a9c 100644
--- a/docs/source/pt/quicktour.md
+++ b/docs/source/pt/quicktour.md
@@ -24,11 +24,8 @@ Seja você um desenvolvedor ou um usuário, esse tour rápido irá introduzir vo
Esse tour rápido mostrará como usar o [`DiffusionPipeline`] para inferência, e então mostrará como combinar um modelo e um agendador para replicar o que está acontecendo dentro do [`DiffusionPipeline`].
-
-
-Esse tour rápido é uma versão simplificada da introdução 🧨 Diffusers [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb) para ajudar você a começar rápido. Se você quer aprender mais sobre o objetivo do 🧨 Diffusers, filosofia de design, e detalhes adicionais sobre a API principal, veja o notebook!
-
-
+> [!TIP]
+> Esse tour rápido é uma versão simplificada da introdução 🧨 Diffusers [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb) para ajudar você a começar rápido. Se você quer aprender mais sobre o objetivo do 🧨 Diffusers, filosofia de design, e detalhes adicionais sobre a API principal, veja o notebook!
Antes de começar, certifique-se de ter todas as bibliotecas necessárias instaladas:
@@ -56,11 +53,8 @@ Comece criando uma instância do [`DiffusionPipeline`] e especifique qual checkp
Você pode usar o [`DiffusionPipeline`] para qualquer [checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads) armazenado no Hugging Face Hub.
Nesse quicktour, você carregará o checkpoint [`stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) para geração de texto para imagem.
-
-
-Para os modelos de [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion), por favor leia cuidadosamente a [licença](https://huggingface.co/spaces/CompVis/stable-diffusion-license) primeiro antes de rodar o modelo. 🧨 Diffusers implementa uma verificação de segurança: [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) para prevenir conteúdo ofensivo ou nocivo, mas as capacidades de geração de imagem aprimorada do modelo podem ainda produzir conteúdo potencialmente nocivo.
-
-
+> [!WARNING]
+> Para os modelos de [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion), por favor leia cuidadosamente a [licença](https://huggingface.co/spaces/CompVis/stable-diffusion-license) primeiro antes de rodar o modelo. 🧨 Diffusers implementa uma verificação de segurança: [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) para prevenir conteúdo ofensivo ou nocivo, mas as capacidades de geração de imagem aprimorada do modelo podem ainda produzir conteúdo potencialmente nocivo.
Para carregar o modelo com o método [`~DiffusionPipeline.from_pretrained`]:
@@ -204,11 +198,8 @@ Para geração de exemplos reais, você precisará de um agendador para guiar o
Agendadores gerenciam a retirada do ruído de uma amostra ruidosa para uma amostra menos ruidosa dado a saída do modelo - nesse caso, é o `noisy_residual`.
-
-
-🧨 Diffusers é uma caixa de ferramentas para construir sistemas de difusão. Enquanto o [`DiffusionPipeline`] é uma forma conveniente de começar com um sistema de difusão pré-construído, você também pode escolher seus próprios modelos e agendadores separadamente para construir um sistema de difusão personalizado.
-
-
+> [!TIP]
+> 🧨 Diffusers é uma caixa de ferramentas para construir sistemas de difusão. Enquanto o [`DiffusionPipeline`] é uma forma conveniente de começar com um sistema de difusão pré-construído, você também pode escolher seus próprios modelos e agendadores separadamente para construir um sistema de difusão personalizado.
Para o tour rápido, você irá instanciar o [`DDPMScheduler`] com o método [`~diffusers.ConfigMixin.from_config`]:
@@ -232,11 +223,8 @@ DDPMScheduler {
}
```
-
-
-💡 Perceba como o agendador é instanciado de uma configuração. Diferentemente de um modelo, um agendador não tem pesos treináveis e é livre de parâmetros!
-
-
+> [!TIP]
+> 💡 Perceba como o agendador é instanciado de uma configuração. Diferentemente de um modelo, um agendador não tem pesos treináveis e é livre de parâmetros!
Um dos parâmetros mais importante são:
diff --git a/docs/source/zh/_toctree.yml b/docs/source/zh/_toctree.yml
index 6416c468a8..337d010fc7 100644
--- a/docs/source/zh/_toctree.yml
+++ b/docs/source/zh/_toctree.yml
@@ -1,12 +1,150 @@
-- sections:
+- title: 开始Diffusers
+ sections:
- local: index
- title: 🧨 Diffusers
+ title: Diffusers
+ - local: installation
+ title: 安装
- local: quicktour
title: 快速入门
- local: stable_diffusion
title: 有效和高效的扩散
- - local: consisid
- title: 身份保持的文本到视频生成
- - local: installation
- title: 安装
- title: 开始
+
+- title: DiffusionPipeline
+ isExpanded: false
+ sections:
+ - local: using-diffusers/schedulers
+ title: Load schedulers and models
+
+- title: Inference
+ isExpanded: false
+ sections:
+ - local: training/distributed_inference
+ title: Distributed inference
+
+- title: Inference optimization
+ isExpanded: false
+ sections:
+ - local: optimization/fp16
+ title: Accelerate inference
+ - local: optimization/cache
+ title: Caching
+ - local: optimization/memory
+ title: Reduce memory usage
+ - local: optimization/speed-memory-optims
+ title: Compile and offloading quantized models
+ - title: Community optimizations
+ sections:
+ - local: optimization/pruna
+ title: Pruna
+ - local: optimization/xformers
+ title: xFormers
+ - local: optimization/tome
+ title: Token merging
+ - local: optimization/deepcache
+ title: DeepCache
+ - local: optimization/tgate
+ title: TGATE
+ - local: optimization/xdit
+ title: xDiT
+ - local: optimization/para_attn
+ title: ParaAttention
+
+- title: Hybrid Inference
+ isExpanded: false
+ sections:
+ - local: hybrid_inference/overview
+ title: Overview
+ - local: hybrid_inference/vae_encode
+ title: VAE Encode
+ - local: hybrid_inference/api_reference
+ title: API Reference
+
+- title: Modular Diffusers
+ isExpanded: false
+ sections:
+ - local: modular_diffusers/overview
+ title: Overview
+ - local: modular_diffusers/quickstart
+ title: Quickstart
+ - local: modular_diffusers/modular_diffusers_states
+ title: States
+ - local: modular_diffusers/pipeline_block
+ title: ModularPipelineBlocks
+ - local: modular_diffusers/sequential_pipeline_blocks
+ title: SequentialPipelineBlocks
+ - local: modular_diffusers/loop_sequential_pipeline_blocks
+ title: LoopSequentialPipelineBlocks
+ - local: modular_diffusers/auto_pipeline_blocks
+ title: AutoPipelineBlocks
+ - local: modular_diffusers/modular_pipeline
+ title: ModularPipeline
+ - local: modular_diffusers/components_manager
+ title: ComponentsManager
+ - local: modular_diffusers/guiders
+ title: Guiders
+
+- title: Training
+ isExpanded: false
+ sections:
+ - local: training/overview
+ title: Overview
+ - local: training/adapt_a_model
+ title: Adapt a model to a new task
+ - title: Models
+ sections:
+ - local: training/text2image
+ title: Text-to-image
+ - local: training/kandinsky
+ title: Kandinsky 2.2
+ - local: training/wuerstchen
+ title: Wuerstchen
+ - local: training/controlnet
+ title: ControlNet
+ - local: training/instructpix2pix
+ title: InstructPix2Pix
+ - title: Methods
+ sections:
+ - local: training/text_inversion
+ title: Textual Inversion
+ - local: training/dreambooth
+ title: DreamBooth
+ - local: training/lora
+ title: LoRA
+
+- title: Model accelerators and hardware
+ isExpanded: false
+ sections:
+ - local: optimization/onnx
+ title: ONNX
+ - local: optimization/open_vino
+ title: OpenVINO
+ - local: optimization/coreml
+ title: Core ML
+ - local: optimization/mps
+ title: Metal Performance Shaders (MPS)
+ - local: optimization/habana
+ title: Intel Gaudi
+ - local: optimization/neuron
+ title: AWS Neuron
+
+- title: Specific pipeline examples
+ isExpanded: false
+ sections:
+ - local: using-diffusers/consisid
+ title: ConsisID
+
+- title: Resources
+ isExpanded: false
+ sections:
+ - title: Task recipes
+ sections:
+ - local: community_projects
+ title: Projects built with Diffusers
+ - local: conceptual/philosophy
+ title: Philosophy
+ - local: conceptual/contribution
+ title: How to contribute?
+ - local: conceptual/ethical_guidelines
+ title: Diffusers' Ethical Guidelines
+ - local: conceptual/evaluation
+ title: Evaluating Diffusion Models
diff --git a/docs/source/zh/community_projects.md b/docs/source/zh/community_projects.md
new file mode 100644
index 0000000000..0440142452
--- /dev/null
+++ b/docs/source/zh/community_projects.md
@@ -0,0 +1,89 @@
+
+
+# 社区项目
+
+欢迎来到社区项目。这个空间致力于展示我们充满活力的社区使用`diffusers`库创建的令人难以置信的工作和创新应用。
+
+本节旨在:
+
+- 突出使用`diffusers`构建的多样化和鼓舞人心的项目
+- 促进我们社区内的知识共享
+- 提供如何利用`diffusers`的实际例子
+
+探索愉快,感谢您成为Diffusers社区的一部分!
+
+
-
-非常的令人印象深刻! Let's tweak the second image - 把 `Generator` 的种子设置为 `1` - 添加一些关于年龄的主题文本:
-
-```python
-prompts = [
- "portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
- "portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
- "portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
- "portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
-]
-
-generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))]
-images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images
-make_image_grid(images, 2, 2)
-```
-
-
+
+非常的令人印象深刻! Let's tweak the second image - 把 `Generator` 的种子设置为 `1` - 添加一些关于年龄的主题文本:
+
+```python
+prompts = [
+ "portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+]
+
+generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))]
+images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images
+make_image_grid(images, 2, 2)
+```
+
+
+
+多数生成图像质量相近,实际选择需根据具体场景测试多种调度器进行比较。
+
+### Flax调度器
+
+对比Flax调度器时,需额外将调度器状态加载到模型参数中。例如将[`FlaxStableDiffusionPipeline`]的默认调度器切换为超高效的[`FlaxDPMSolverMultistepScheduler`]:
+
+> [!警告]
+> [`FlaxLMSDiscreteScheduler`]和[`FlaxDDPMScheduler`]目前暂不兼容[`FlaxStableDiffusionPipeline`]。
+
+```python
+import jax
+import numpy as np
+from flax.jax_utils import replicate
+from flax.training.common_utils import shard
+from diffusers import FlaxStableDiffusionPipeline, FlaxDPMSolverMultistepScheduler
+
+scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ subfolder="scheduler"
+)
+pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ scheduler=scheduler,
+ variant="bf16",
+ dtype=jax.numpy.bfloat16,
+)
+params["scheduler"] = scheduler_state
+```
+
+利用Flax对TPU的兼容性实现并行图像生成。需为每个设备复制模型参数,并分配输入数据:
+
+```python
+# 每个并行设备生成1张图像(TPUv2-8/TPUv3-8支持8设备并行)
+prompt = "一张宇航员在火星上骑马的高清照片,高分辨率,高画质。"
+num_samples = jax.device_count()
+prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
+
+prng_seed = jax.random.PRNGKey(0)
+num_inference_steps = 25
+
+# 分配输入和随机种子
+params = replicate(params)
+prng_seed = jax.random.split(prng_seed, jax.device_count())
+prompt_ids = shard(prompt_ids)
+
+images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
+images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
+```
+
+## 模型加载
+
+通过[`ModelMixin.from_pretrained`]方法加载模型,该方法会下载并缓存模型权重和配置的最新版本。若本地缓存已存在最新文件,则直接复用缓存而非重复下载。
+
+通过`subfolder`参数可从子目录加载模型。例如[stable-diffusion-v1-5/stable-diffusion-v1-5](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5)的模型权重存储在[unet](https://hf.co/stable-diffusion-v1-5/stable-diffusion-v1-5/tree/main/unet)子目录中:
+
+```python
+from diffusers import UNet2DConditionModel
+
+unet = UNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", use_safetensors=True)
+```
+
+也可直接从[仓库](https://huggingface.co/google/ddpm-cifar10-32/tree/main)加载:
+
+```python
+from diffusers import UNet2DModel
+
+unet = UNet2DModel.from_pretrained("google/ddpm-cifar10-32", use_safetensors=True)
+```
+
+加载和保存模型变体时,需在[`ModelMixin.from_pretrained`]和[`ModelMixin.save_pretrained`]中指定`variant`参数:
+
+```python
+from diffusers import UNet2DConditionModel
+
+unet = UNet2DConditionModel.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", variant="non_ema", use_safetensors=True
+)
+unet.save_pretrained("./local-unet", variant="non_ema")
+```
+
+使用[`~ModelMixin.from_pretrained`]的`torch_dtype`参数指定模型加载精度:
+
+```python
+from diffusers import AutoModel
+
+unet = AutoModel.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.float16
+)
+```
+
+也可使用[torch.Tensor.to](https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html)方法即时转换精度,但会转换所有权重(不同于`torch_dtype`参数会保留`_keep_in_fp32_modules`中的层)。这对某些必须保持fp32精度的层尤为重要(参见[示例](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374))。
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
index a30624e35a..5aa33190d4 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
# /// script
# dependencies = [
@@ -24,6 +25,10 @@
# "Jinja2",
# "peft>=0.11.1",
# "sentencepiece",
+# "torchvision",
+# "datasets",
+# "bitsandbytes",
+# "prodigyopt",
# ]
# ///
@@ -89,7 +94,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
@@ -1398,6 +1403,7 @@ def main(args):
torch_dtype = torch.float16
elif args.prior_generation_precision == "bf16":
torch_dtype = torch.bfloat16
+
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
@@ -1418,7 +1424,8 @@ def main(args):
for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
):
- images = pipeline(example["prompt"]).images
+ with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
+ images = pipeline(prompt=example["prompt"]).images
for i, image in enumerate(images):
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
index 17c5150eb1..924323753b 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
# /// script
# dependencies = [
@@ -87,7 +88,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
index 65e280801c..3aad6b7b49 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
# /// script
# dependencies = [
@@ -94,7 +95,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py
index 1ebc58b494..b4440e807e 100644
--- a/examples/cogvideo/train_cogvideox_image_to_video_lora.py
+++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py
index f6903fde0a..9a1e5fd45c 100644
--- a/examples/cogvideo/train_cogvideox_lora.py
+++ b/examples/cogvideo/train_cogvideox_lora.py
@@ -52,7 +52,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py
index 93b33a189e..ae12012a4c 100644
--- a/examples/cogview4-control/train_control_cogview4.py
+++ b/examples/cogview4-control/train_control_cogview4.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import copy
@@ -59,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/community/README.md b/examples/community/README.md
index e4fbd79366..e314463077 100644
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -88,6 +88,8 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://huggingface.co/papers/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
| Stable Diffusion 3 InstructPix2Pix Pipeline | Implementation of Stable Diffusion 3 InstructPix2Pix Pipeline | [Stable Diffusion 3 InstructPix2Pix Pipeline](#stable-diffusion-3-instructpix2pix-pipeline) | [](https://huggingface.co/BleachNick/SD3_UltraEdit_freeform) [](https://huggingface.co/CaptainZZZ/sd3-instructpix2pix) | [Jiayu Zhang](https://github.com/xduzhangjiayu) and [Haozhe Zhao](https://github.com/HaozheZhao)|
| Flux Kontext multiple images | A modified version of the `FluxKontextPipeline` that supports calling Flux Kontext with multiple reference images.| [Flux Kontext multiple input Pipeline](#flux-kontext-multiple-images) | - | [Net-Mist](https://github.com/Net-Mist) |
+
+
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
```py
diff --git a/examples/community/composable_stable_diffusion.py b/examples/community/composable_stable_diffusion.py
index ec653bcdb4..a7c540ceb9 100644
--- a/examples/community/composable_stable_diffusion.py
+++ b/examples/community/composable_stable_diffusion.py
@@ -398,7 +398,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin)
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/examples/community/imagic_stable_diffusion.py b/examples/community/imagic_stable_diffusion.py
index a2561c9198..091d0fbf8d 100644
--- a/examples/community/imagic_stable_diffusion.py
+++ b/examples/community/imagic_stable_diffusion.py
@@ -147,7 +147,7 @@ class ImagicStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
diff --git a/examples/community/img2img_inpainting.py b/examples/community/img2img_inpainting.py
index 7b9bd043d0..499230b1e2 100644
--- a/examples/community/img2img_inpainting.py
+++ b/examples/community/img2img_inpainting.py
@@ -197,7 +197,7 @@ class ImageToImageInpaintingPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/examples/community/interpolate_stable_diffusion.py b/examples/community/interpolate_stable_diffusion.py
index 460bb464f3..5b96c14d63 100644
--- a/examples/community/interpolate_stable_diffusion.py
+++ b/examples/community/interpolate_stable_diffusion.py
@@ -173,7 +173,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline, StableDiffusionMixin):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py
index ccb17a51e6..cb017c0bbe 100644
--- a/examples/community/lpw_stable_diffusion.py
+++ b/examples/community/lpw_stable_diffusion.py
@@ -888,7 +888,7 @@ class StableDiffusionLongPromptWeightingPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -1131,7 +1131,7 @@ class StableDiffusionLongPromptWeightingPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/lpw_stable_diffusion_onnx.py b/examples/community/lpw_stable_diffusion_onnx.py
index ab1462b81b..92effc1933 100644
--- a/examples/community/lpw_stable_diffusion_onnx.py
+++ b/examples/community/lpw_stable_diffusion_onnx.py
@@ -721,7 +721,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
latents (`np.ndarray`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
output_type (`str`, *optional*, defaults to `"pil"`):
@@ -918,7 +918,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
latents (`np.ndarray`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
output_type (`str`, *optional*, defaults to `"pil"`):
diff --git a/examples/community/lpw_stable_diffusion_xl.py b/examples/community/lpw_stable_diffusion_xl.py
index ea67738ab7..272c5d5652 100644
--- a/examples/community/lpw_stable_diffusion_xl.py
+++ b/examples/community/lpw_stable_diffusion_xl.py
@@ -1519,7 +1519,7 @@ class SDXLLongPromptWeightingPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
prompt_embeds (`torch.Tensor`, *optional*):
diff --git a/examples/community/marigold_depth_estimation.py b/examples/community/marigold_depth_estimation.py
index 8be773c138..3bdaef7981 100644
--- a/examples/community/marigold_depth_estimation.py
+++ b/examples/community/marigold_depth_estimation.py
@@ -43,7 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
class MarigoldDepthOutput(BaseOutput):
diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py
index 274851e2ac..3871552672 100644
--- a/examples/community/matryoshka.py
+++ b/examples/community/matryoshka.py
@@ -1475,11 +1475,8 @@ class MatryoshkaFusedAttnProcessor2_0:
fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is currently 🧪 experimental in nature and can change in future.
-
-
+ > [!WARNING]
+ > This API is currently 🧪 experimental in nature and can change in future.
"""
def __init__(self):
@@ -2696,11 +2693,8 @@ class MatryoshkaUNet2DConditionModel(
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING]
+ > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -2719,11 +2713,8 @@ class MatryoshkaUNet2DConditionModel(
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING]
+ > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/examples/community/multilingual_stable_diffusion.py b/examples/community/multilingual_stable_diffusion.py
index 5e7453ed12..afef4e9e97 100644
--- a/examples/community/multilingual_stable_diffusion.py
+++ b/examples/community/multilingual_stable_diffusion.py
@@ -187,7 +187,7 @@ class MultilingualStableDiffusion(DiffusionPipeline, StableDiffusionMixin):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/examples/community/pipeline_controlnet_xl_kolors.py b/examples/community/pipeline_controlnet_xl_kolors.py
index af5586990e..dc90aacdbc 100644
--- a/examples/community/pipeline_controlnet_xl_kolors.py
+++ b/examples/community/pipeline_controlnet_xl_kolors.py
@@ -888,7 +888,7 @@ class KolorsControlNetPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_controlnet_xl_kolors_img2img.py b/examples/community/pipeline_controlnet_xl_kolors_img2img.py
index c0831945ed..189d031214 100644
--- a/examples/community/pipeline_controlnet_xl_kolors_img2img.py
+++ b/examples/community/pipeline_controlnet_xl_kolors_img2img.py
@@ -1066,7 +1066,7 @@ class KolorsControlNetImg2ImgPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_controlnet_xl_kolors_inpaint.py b/examples/community/pipeline_controlnet_xl_kolors_inpaint.py
index db15d99ac3..4b6123cc1f 100644
--- a/examples/community/pipeline_controlnet_xl_kolors_inpaint.py
+++ b/examples/community/pipeline_controlnet_xl_kolors_inpaint.py
@@ -1298,7 +1298,7 @@ class KolorsControlNetInpaintPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/examples/community/pipeline_demofusion_sdxl.py b/examples/community/pipeline_demofusion_sdxl.py
index c9b57a6ece..119b39cefe 100644
--- a/examples/community/pipeline_demofusion_sdxl.py
+++ b/examples/community/pipeline_demofusion_sdxl.py
@@ -724,7 +724,7 @@ class DemoFusionSDXLPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_faithdiff_stable_diffusion_xl.py b/examples/community/pipeline_faithdiff_stable_diffusion_xl.py
index 43ef55d32c..a8fdc133d0 100644
--- a/examples/community/pipeline_faithdiff_stable_diffusion_xl.py
+++ b/examples/community/pipeline_faithdiff_stable_diffusion_xl.py
@@ -1705,6 +1705,12 @@ class FaithDiffStableDiffusionXLPipeline(
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
self.unet.denoise_encoder.enable_tiling()
@@ -1713,6 +1719,12 @@ class FaithDiffStableDiffusionXLPipeline(
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
self.unet.denoise_encoder.disable_tiling()
@@ -1906,7 +1918,7 @@ class FaithDiffStableDiffusionXLPipeline(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_flux_differential_img2img.py b/examples/community/pipeline_flux_differential_img2img.py
index 7d6358cb32..3677e73136 100644
--- a/examples/community/pipeline_flux_differential_img2img.py
+++ b/examples/community/pipeline_flux_differential_img2img.py
@@ -730,7 +730,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
1)`, or `(H, W)`.
mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
`Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
- latents tensor will ge generated by `mask_image`.
+ latents tensor will be generated by `mask_image`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -769,7 +769,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_flux_kontext_multiple_images.py b/examples/community/pipeline_flux_kontext_multiple_images.py
index ef0c643a40..9e6ae427db 100644
--- a/examples/community/pipeline_flux_kontext_multiple_images.py
+++ b/examples/community/pipeline_flux_kontext_multiple_images.py
@@ -35,6 +35,7 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -643,6 +644,12 @@ class FluxKontextPipeline(
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
@@ -651,6 +658,12 @@ class FluxKontextPipeline(
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def preprocess_image(self, image: PipelineImageInput, _auto_resize: bool, multiple_of: int) -> torch.Tensor:
@@ -885,7 +898,7 @@ class FluxKontextPipeline(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py
index 631d04b762..2cd6eb088c 100644
--- a/examples/community/pipeline_flux_rf_inversion.py
+++ b/examples/community/pipeline_flux_rf_inversion.py
@@ -30,6 +30,7 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -526,6 +527,12 @@ class RFInversionFluxPipeline(
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -533,6 +540,12 @@ class RFInversionFluxPipeline(
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -541,6 +554,12 @@ class RFInversionFluxPipeline(
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -548,6 +567,12 @@ class RFInversionFluxPipeline(
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def prepare_latents_inversion(
@@ -711,7 +736,7 @@ class RFInversionFluxPipeline(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_flux_semantic_guidance.py b/examples/community/pipeline_flux_semantic_guidance.py
index 93bcd3af75..74cd5c6981 100644
--- a/examples/community/pipeline_flux_semantic_guidance.py
+++ b/examples/community/pipeline_flux_semantic_guidance.py
@@ -35,6 +35,7 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -702,6 +703,12 @@ class FluxSemanticGuidancePipeline(
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
@@ -710,6 +717,12 @@ class FluxSemanticGuidancePipeline(
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
@@ -853,7 +866,7 @@ class FluxSemanticGuidancePipeline(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_flux_with_cfg.py b/examples/community/pipeline_flux_with_cfg.py
index 1b8dc9ecb8..5bc13f7e5e 100644
--- a/examples/community/pipeline_flux_with_cfg.py
+++ b/examples/community/pipeline_flux_with_cfg.py
@@ -28,6 +28,7 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -503,6 +504,12 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -510,6 +517,12 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -518,6 +531,12 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -525,6 +544,12 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def prepare_latents(
@@ -639,7 +664,7 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_kolors_differential_img2img.py b/examples/community/pipeline_kolors_differential_img2img.py
index 9491447409..d299c83981 100644
--- a/examples/community/pipeline_kolors_differential_img2img.py
+++ b/examples/community/pipeline_kolors_differential_img2img.py
@@ -904,7 +904,7 @@ class KolorsDifferentialImg2ImgPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_kolors_inpainting.py b/examples/community/pipeline_kolors_inpainting.py
index cce9f10ded..3cab8ecac0 100644
--- a/examples/community/pipeline_kolors_inpainting.py
+++ b/examples/community/pipeline_kolors_inpainting.py
@@ -1246,7 +1246,7 @@ class KolorsInpaintPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/examples/community/pipeline_prompt2prompt.py b/examples/community/pipeline_prompt2prompt.py
index 065edc0cfb..8d94dc9248 100644
--- a/examples/community/pipeline_prompt2prompt.py
+++ b/examples/community/pipeline_prompt2prompt.py
@@ -611,7 +611,7 @@ class Prompt2PromptPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/examples/community/pipeline_sdxl_style_aligned.py b/examples/community/pipeline_sdxl_style_aligned.py
index ea168036c1..10438af365 100644
--- a/examples/community/pipeline_sdxl_style_aligned.py
+++ b/examples/community/pipeline_sdxl_style_aligned.py
@@ -1480,7 +1480,7 @@ class StyleAlignedSDXLPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_stable_diffusion_3_differential_img2img.py b/examples/community/pipeline_stable_diffusion_3_differential_img2img.py
index 693485d175..1803cf60cc 100644
--- a/examples/community/pipeline_stable_diffusion_3_differential_img2img.py
+++ b/examples/community/pipeline_stable_diffusion_3_differential_img2img.py
@@ -29,11 +29,7 @@ from diffusers.models.transformers import SD3Transformer2DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
-from diffusers.utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
@@ -748,7 +744,7 @@ class StableDiffusion3DifferentialImg2ImgPipeline(DiffusionPipeline):
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_stable_diffusion_3_instruct_pix2pix.py b/examples/community/pipeline_stable_diffusion_3_instruct_pix2pix.py
index 6923db23a6..d9cee800e8 100644
--- a/examples/community/pipeline_stable_diffusion_3_instruct_pix2pix.py
+++ b/examples/community/pipeline_stable_diffusion_3_instruct_pix2pix.py
@@ -945,7 +945,7 @@ class StableDiffusion3InstructPix2PixPipeline(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_stable_diffusion_boxdiff.py b/examples/community/pipeline_stable_diffusion_boxdiff.py
index ebca3017c3..07e29b9c05 100644
--- a/examples/community/pipeline_stable_diffusion_boxdiff.py
+++ b/examples/community/pipeline_stable_diffusion_boxdiff.py
@@ -504,6 +504,12 @@ class StableDiffusionBoxDiffPipeline(
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -511,6 +517,12 @@ class StableDiffusionBoxDiffPipeline(
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -519,6 +531,12 @@ class StableDiffusionBoxDiffPipeline(
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -526,6 +544,12 @@ class StableDiffusionBoxDiffPipeline(
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def _encode_prompt(
@@ -924,11 +948,8 @@ class StableDiffusionBoxDiffPipeline(
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING]
+ > This API is 🧪 experimental.
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
@@ -954,11 +975,8 @@ class StableDiffusionBoxDiffPipeline(
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING]
+ > This API is 🧪 experimental.
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
diff --git a/examples/community/pipeline_stable_diffusion_pag.py b/examples/community/pipeline_stable_diffusion_pag.py
index 69a0059d98..6b62b610af 100644
--- a/examples/community/pipeline_stable_diffusion_pag.py
+++ b/examples/community/pipeline_stable_diffusion_pag.py
@@ -471,6 +471,12 @@ class StableDiffusionPAGPipeline(
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -478,6 +484,12 @@ class StableDiffusionPAGPipeline(
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -486,6 +498,12 @@ class StableDiffusionPAGPipeline(
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -493,6 +511,12 @@ class StableDiffusionPAGPipeline(
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def _encode_prompt(
@@ -916,9 +940,8 @@ class StableDiffusionPAGPipeline(
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
-
- This API is 🧪 experimental.
-
+ > [!WARNING]
+ > This API is 🧪 experimental.
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
@@ -942,9 +965,8 @@ class StableDiffusionPAGPipeline(
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
-
- This API is 🧪 experimental.
-
+ > [!WARNING]
+ > This API is 🧪 experimental.
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
diff --git a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
index ab8064c6e3..a881814c2a 100644
--- a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
+++ b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
@@ -1786,7 +1786,7 @@ class StableDiffusionXL_AE_Pipeline(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py
index ccf1098c61..564a19e923 100644
--- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py
+++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py
@@ -973,7 +973,7 @@ class StableDiffusionXLControlNetAdapterPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
index 38db19148d..c73433b20f 100644
--- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
+++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
@@ -1329,7 +1329,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py b/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py
index b9f00cb82d..89388e10cb 100644
--- a/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py
+++ b/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py
@@ -1053,7 +1053,7 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_stable_diffusion_xl_ipex.py b/examples/community/pipeline_stable_diffusion_xl_ipex.py
index eda6089f59..aa2b24f396 100644
--- a/examples/community/pipeline_stable_diffusion_xl_ipex.py
+++ b/examples/community/pipeline_stable_diffusion_xl_ipex.py
@@ -832,7 +832,7 @@ class StableDiffusionXLPipelineIpex(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_stg_cogvideox.py b/examples/community/pipeline_stg_cogvideox.py
index 1c98ae0f6d..bdb6aecc30 100644
--- a/examples/community/pipeline_stg_cogvideox.py
+++ b/examples/community/pipeline_stg_cogvideox.py
@@ -632,7 +632,7 @@ class CogVideoXSTGPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_stg_hunyuan_video.py b/examples/community/pipeline_stg_hunyuan_video.py
index a2cb9aa1b7..028d54d047 100644
--- a/examples/community/pipeline_stg_hunyuan_video.py
+++ b/examples/community/pipeline_stg_hunyuan_video.py
@@ -26,7 +26,7 @@ from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3
from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
-from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
+from diffusers.utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
@@ -481,6 +481,12 @@ class HunyuanVideoSTGPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -488,6 +494,12 @@ class HunyuanVideoSTGPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -496,6 +508,12 @@ class HunyuanVideoSTGPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -503,6 +521,12 @@ class HunyuanVideoSTGPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
@property
diff --git a/examples/community/pipeline_stg_ltx.py b/examples/community/pipeline_stg_ltx.py
index f7ccf99e96..70069a33f5 100644
--- a/examples/community/pipeline_stg_ltx.py
+++ b/examples/community/pipeline_stg_ltx.py
@@ -620,7 +620,7 @@ class LTXSTGPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderM
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_stg_ltx_image2video.py b/examples/community/pipeline_stg_ltx_image2video.py
index 3b3d233380..c32805e141 100644
--- a/examples/community/pipeline_stg_ltx_image2video.py
+++ b/examples/community/pipeline_stg_ltx_image2video.py
@@ -682,7 +682,7 @@ class LTXImageToVideoSTGPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVide
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_stg_mochi.py b/examples/community/pipeline_stg_mochi.py
index b6ab1b192c..ad9317f6bc 100644
--- a/examples/community/pipeline_stg_mochi.py
+++ b/examples/community/pipeline_stg_mochi.py
@@ -26,11 +26,7 @@ from diffusers.models import AutoencoderKLMochi, MochiTransformer3DModel
from diffusers.pipelines.mochi.pipeline_output import MochiPipelineOutput
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
-from diffusers.utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from diffusers.utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
@@ -458,6 +454,12 @@ class MochiSTGPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -465,6 +467,12 @@ class MochiSTGPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -473,6 +481,12 @@ class MochiSTGPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -480,6 +494,12 @@ class MochiSTGPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def prepare_latents(
@@ -603,7 +623,7 @@ class MochiSTGPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/pipeline_zero1to3.py b/examples/community/pipeline_zero1to3.py
index 0db543b169..9e29566978 100644
--- a/examples/community/pipeline_zero1to3.py
+++ b/examples/community/pipeline_zero1to3.py
@@ -657,7 +657,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/rerender_a_video.py b/examples/community/rerender_a_video.py
index 133c232943..78a15a03b0 100644
--- a/examples/community/rerender_a_video.py
+++ b/examples/community/rerender_a_video.py
@@ -656,7 +656,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/run_onnx_controlnet.py b/examples/community/run_onnx_controlnet.py
index 2221fc09db..f0ab2a2b96 100644
--- a/examples/community/run_onnx_controlnet.py
+++ b/examples/community/run_onnx_controlnet.py
@@ -591,7 +591,7 @@ class OnnxStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/run_tensorrt_controlnet.py b/examples/community/run_tensorrt_controlnet.py
index b9e71724c0..e4f1abc83b 100644
--- a/examples/community/run_tensorrt_controlnet.py
+++ b/examples/community/run_tensorrt_controlnet.py
@@ -695,7 +695,7 @@ class TensorRTStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/sd_text2img_k_diffusion.py b/examples/community/sd_text2img_k_diffusion.py
index ab6cf2d9cd..4d5cea497f 100755
--- a/examples/community/sd_text2img_k_diffusion.py
+++ b/examples/community/sd_text2img_k_diffusion.py
@@ -326,7 +326,7 @@ class StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/examples/community/seed_resize_stable_diffusion.py b/examples/community/seed_resize_stable_diffusion.py
index 3c823012c1..eafe7572aa 100644
--- a/examples/community/seed_resize_stable_diffusion.py
+++ b/examples/community/seed_resize_stable_diffusion.py
@@ -122,7 +122,7 @@ class SeedResizeStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin)
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/examples/community/stable_diffusion_comparison.py b/examples/community/stable_diffusion_comparison.py
index 36e7dba2de..22f3b3e0c3 100644
--- a/examples/community/stable_diffusion_comparison.py
+++ b/examples/community/stable_diffusion_comparison.py
@@ -279,7 +279,7 @@ class StableDiffusionComparisonPipeline(DiffusionPipeline, StableDiffusionMixin)
latents (`torch.Tensor`, optional):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, optional, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/examples/community/stable_diffusion_controlnet_img2img.py b/examples/community/stable_diffusion_controlnet_img2img.py
index 877464454a..6d8038cfd4 100644
--- a/examples/community/stable_diffusion_controlnet_img2img.py
+++ b/examples/community/stable_diffusion_controlnet_img2img.py
@@ -670,7 +670,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, StableDiffusio
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/stable_diffusion_controlnet_inpaint.py b/examples/community/stable_diffusion_controlnet_inpaint.py
index 175c47d015..fe7b808b6b 100644
--- a/examples/community/stable_diffusion_controlnet_inpaint.py
+++ b/examples/community/stable_diffusion_controlnet_inpaint.py
@@ -810,7 +810,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, StableDiffusio
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
index 51e7ac38dd..2b5dc77fe5 100644
--- a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
+++ b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
@@ -804,7 +804,7 @@ class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline, StableD
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/stable_diffusion_controlnet_reference.py b/examples/community/stable_diffusion_controlnet_reference.py
index aa9ab1b242..e5dd249e04 100644
--- a/examples/community/stable_diffusion_controlnet_reference.py
+++ b/examples/community/stable_diffusion_controlnet_reference.py
@@ -179,7 +179,7 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/stable_diffusion_ipex.py b/examples/community/stable_diffusion_ipex.py
index 18d5e8feaa..7d1cd4f5d0 100644
--- a/examples/community/stable_diffusion_ipex.py
+++ b/examples/community/stable_diffusion_ipex.py
@@ -615,7 +615,7 @@ class StableDiffusionIPEXPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/stable_diffusion_reference.py b/examples/community/stable_diffusion_reference.py
index 69fa0722cf..6f7dce9823 100644
--- a/examples/community/stable_diffusion_reference.py
+++ b/examples/community/stable_diffusion_reference.py
@@ -885,7 +885,7 @@ class StableDiffusionReferencePipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/stable_diffusion_repaint.py b/examples/community/stable_diffusion_repaint.py
index 9f6172f3b8..94b9f8b01b 100644
--- a/examples/community/stable_diffusion_repaint.py
+++ b/examples/community/stable_diffusion_repaint.py
@@ -678,7 +678,7 @@ class StableDiffusionRepaintPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/stable_diffusion_xl_reference.py b/examples/community/stable_diffusion_xl_reference.py
index 11926a5d9a..eb05557496 100644
--- a/examples/community/stable_diffusion_xl_reference.py
+++ b/examples/community/stable_diffusion_xl_reference.py
@@ -380,7 +380,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/community/text_inpainting.py b/examples/community/text_inpainting.py
index 2908388029..f262cf2cac 100644
--- a/examples/community/text_inpainting.py
+++ b/examples/community/text_inpainting.py
@@ -180,7 +180,7 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/examples/community/tiled_upscaling.py b/examples/community/tiled_upscaling.py
index 56eb3e89b5..7a5e77155c 100644
--- a/examples/community/tiled_upscaling.py
+++ b/examples/community/tiled_upscaling.py
@@ -231,7 +231,7 @@ class StableDiffusionTiledUpscalePipeline(StableDiffusionUpscalePipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
tile_size (`int`, *optional*):
The size of the tiles. Too big can result in an OOM-error.
tile_border (`int`, *optional*):
diff --git a/examples/community/wildcard_stable_diffusion.py b/examples/community/wildcard_stable_diffusion.py
index c750610ca3..d40221e5b1 100644
--- a/examples/community/wildcard_stable_diffusion.py
+++ b/examples/community/wildcard_stable_diffusion.py
@@ -209,7 +209,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/examples/conftest.py b/examples/conftest.py
index 9b8996430f..ff7543ba82 100644
--- a/examples/conftest.py
+++ b/examples/conftest.py
@@ -25,6 +25,11 @@ from os.path import abspath, dirname, join
git_repo_path = abspath(join(dirname(dirname(dirname(__file__))), "src"))
sys.path.insert(1, git_repo_path)
+# Add parent directory to path so we can import from tests
+repo_root = abspath(dirname(dirname(__file__)))
+if repo_root not in sys.path:
+ sys.path.insert(0, repo_root)
+
# silence FutureWarning warnings in tests since often we can't act on them until
# they become normal warnings - i.e. the tests still need to test the current functionality
@@ -32,13 +37,13 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
def pytest_addoption(parser):
- from diffusers.utils.testing_utils import pytest_addoption_shared
+ from tests.testing_utils import pytest_addoption_shared
pytest_addoption_shared(parser)
def pytest_terminal_summary(terminalreporter):
- from diffusers.utils.testing_utils import pytest_terminal_summary_main
+ from tests.testing_utils import pytest_terminal_summary_main
make_reports = terminalreporter.config.getoption("--make-reports")
if make_reports:
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
index 5822967d05..fb3ad01183 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import functools
@@ -73,7 +74,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
index e7f64ef14d..bb35649b51 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import copy
@@ -66,7 +67,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
index 4b79a59134..99ad07d240 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import copy
@@ -79,7 +80,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py
index 057b86eaaa..9f38b8c9b6 100644
--- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import functools
@@ -72,7 +73,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
index 09982f0546..3c51dd25c2 100644
--- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import copy
@@ -78,7 +79,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py
index c9be7a7f92..7d85878e66 100644
--- a/examples/controlnet/train_controlnet.py
+++ b/examples/controlnet/train_controlnet.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import contextlib
@@ -60,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py
index 2c08ffc49a..d1e1c8efd8 100644
--- a/examples/controlnet/train_controlnet_flax.py
+++ b/examples/controlnet/train_controlnet_flax.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import logging
@@ -60,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = logging.getLogger(__name__)
diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py
index d281668e11..6d786f6320 100644
--- a/examples/controlnet/train_controlnet_flux.py
+++ b/examples/controlnet/train_controlnet_flux.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import copy
@@ -65,7 +66,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py
index 033c9d7f26..1d6fc57640 100644
--- a/examples/controlnet/train_controlnet_sd3.py
+++ b/examples/controlnet/train_controlnet_sd3.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import contextlib
@@ -23,6 +24,8 @@ import math
import os
import random
import shutil
+
+# Add repo root to path to import from tests
from pathlib import Path
import accelerate
@@ -53,15 +56,14 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
-from diffusers.utils.testing_utils import backend_empty_cache
-from diffusers.utils.torch_utils import is_compiled_module
+from diffusers.utils.torch_utils import backend_empty_cache, is_compiled_module
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py
index 3d182f8f4c..d9e2a712c4 100644
--- a/examples/controlnet/train_controlnet_sdxl.py
+++ b/examples/controlnet/train_controlnet_sdxl.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import functools
@@ -61,7 +62,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py
index ce4fec0a12..c105a3786e 100644
--- a/examples/custom_diffusion/train_custom_diffusion.py
+++ b/examples/custom_diffusion/train_custom_diffusion.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import itertools
@@ -63,7 +64,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md
index c6c119ff97..006e583e9f 100644
--- a/examples/dreambooth/README.md
+++ b/examples/dreambooth/README.md
@@ -19,8 +19,9 @@ cd diffusers
pip install -e .
```
-Then cd in the example folder and run
+Install the requirements in the `examples/dreambooth` folder as shown below.
```bash
+cd examples/dreambooth
pip install -r requirements.txt
```
diff --git a/examples/dreambooth/README_qwen.md b/examples/dreambooth/README_qwen.md
index ed4a4f5ac5..68c546a25d 100644
--- a/examples/dreambooth/README_qwen.md
+++ b/examples/dreambooth/README_qwen.md
@@ -75,9 +75,9 @@ Now, we can launch training using:
```bash
export MODEL_NAME="Qwen/Qwen-Image"
export INSTANCE_DIR="dog"
-export OUTPUT_DIR="trained-sana-lora"
+export OUTPUT_DIR="trained-qwenimage-lora"
-accelerate launch train_dreambooth_lora_sana.py \
+accelerate launch train_dreambooth_lora_qwen_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py
index 1807e9bd80..503e2ae1d4 100644
--- a/examples/dreambooth/train_dreambooth.py
+++ b/examples/dreambooth/train_dreambooth.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import copy
@@ -63,7 +64,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py
index ccf4626cf8..6c09f0a84c 100644
--- a/examples/dreambooth/train_dreambooth_flax.py
+++ b/examples/dreambooth/train_dreambooth_flax.py
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py
index b3e7560251..c24d16c600 100644
--- a/examples/dreambooth/train_dreambooth_flux.py
+++ b/examples/dreambooth/train_dreambooth_flux.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
# /// script
# dependencies = [
@@ -79,7 +80,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
@@ -641,6 +642,7 @@ def parse_args(input_args=None):
],
help="The image interpolation method to use for resizing images.",
)
+ parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
if input_args is not None:
args = parser.parse_args(input_args)
@@ -1181,6 +1183,13 @@ def main(args):
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
+ if args.enable_npu_flash_attention:
+ if is_torch_npu_available():
+ logger.info("npu flash attention enabled.")
+ transformer.set_attention_backend("_native_npu")
+ else:
+ raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
+
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py
index aaf61f9813..b105aa5536 100644
--- a/examples/dreambooth/train_dreambooth_lora.py
+++ b/examples/dreambooth/train_dreambooth_lora.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import copy
@@ -74,7 +75,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py
index 6ec532e630..3b6ab814f2 100644
--- a/examples/dreambooth/train_dreambooth_lora_flux.py
+++ b/examples/dreambooth/train_dreambooth_lora_flux.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
# /// script
# dependencies = [
@@ -24,6 +25,10 @@
# "Jinja2",
# "peft>=0.11.1",
# "sentencepiece",
+# "torchvision",
+# "datasets",
+# "bitsandbytes",
+# "prodigyopt",
# ]
# ///
@@ -79,6 +84,7 @@ from diffusers.utils import (
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -86,7 +92,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
@@ -685,6 +691,7 @@ def parse_args(input_args=None):
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
if input_args is not None:
args = parser.parse_args(input_args)
@@ -1128,6 +1135,7 @@ def main(args):
torch_dtype = torch.float16
elif args.prior_generation_precision == "bf16":
torch_dtype = torch.bfloat16
+
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
@@ -1148,7 +1156,8 @@ def main(args):
for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
):
- images = pipeline(example["prompt"]).images
+ with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
+ images = pipeline(prompt=example["prompt"]).images
for i, image in enumerate(images):
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
@@ -1156,8 +1165,7 @@ def main(args):
image.save(image_filename)
del pipeline
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
+ free_memory()
# Handle the repository creation
if accelerator.is_main_process:
@@ -1212,6 +1220,13 @@ def main(args):
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
+ if args.enable_npu_flash_attention:
+ if is_torch_npu_available():
+ logger.info("npu flash attention enabled.")
+ transformer.set_attention_backend("_native_npu")
+ else:
+ raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
+
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
@@ -1718,6 +1733,10 @@ def main(args):
device=accelerator.device,
prompt=args.instance_prompt,
)
+ else:
+ prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
+ prompts, text_encoders, tokenizers
+ )
# Convert images to latent space
if args.cache_latents:
diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py
index 38896728fa..fc6df87768 100644
--- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py
+++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py
@@ -12,6 +12,25 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# /// script
+# dependencies = [
+# "diffusers @ git+https://github.com/huggingface/diffusers.git",
+# "torch>=2.0.0",
+# "accelerate>=0.31.0",
+# "transformers>=4.41.2",
+# "ftfy",
+# "tensorboard",
+# "Jinja2",
+# "peft>=0.11.1",
+# "sentencepiece",
+# "torchvision",
+# "datasets",
+# "bitsandbytes",
+# "prodigyopt",
+# ]
+# ///
import argparse
import copy
@@ -28,8 +47,9 @@ from pathlib import Path
import numpy as np
import torch
import transformers
-from accelerate import Accelerator
+from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from huggingface_hub.utils import insecure_hashlib
@@ -72,7 +92,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
@@ -705,6 +725,7 @@ def parse_args(input_args=None):
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
if input_args is not None:
args = parser.parse_args(input_args)
@@ -1220,6 +1241,9 @@ def main(args):
kwargs_handlers=[kwargs],
)
+ if accelerator.distributed_type == DistributedType.DEEPSPEED:
+ AcceleratorState().deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
+
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
@@ -1268,6 +1292,7 @@ def main(args):
subfolder="transformer",
revision=args.revision,
variant=args.variant,
+ torch_dtype=torch_dtype,
)
pipeline = FluxKontextPipeline.from_pretrained(
args.pretrained_model_name_or_path,
@@ -1290,7 +1315,8 @@ def main(args):
for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
):
- images = pipeline(example["prompt"]).images
+ with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
+ images = pipeline(prompt=example["prompt"]).images
for i, image in enumerate(images):
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
@@ -1353,6 +1379,13 @@ def main(args):
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
+ if args.enable_npu_flash_attention:
+ if is_torch_npu_available():
+ logger.info("npu flash attention enabled.")
+ transformer.set_attention_backend("_native_npu")
+ else:
+ raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
+
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
@@ -1427,17 +1460,20 @@ def main(args):
text_encoder_one_lora_layers_to_save = None
modules_to_save = {}
for model in models:
- if isinstance(model, type(unwrap_model(transformer))):
+ if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
+ model = unwrap_model(model)
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
- elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
+ model = unwrap_model(model)
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["text_encoder"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
- weights.pop()
+ if weights:
+ weights.pop()
FluxKontextPipeline.save_lora_weights(
output_dir,
@@ -1450,15 +1486,25 @@ def main(args):
transformer_ = None
text_encoder_one_ = None
- while len(models) > 0:
- model = models.pop()
+ if not accelerator.distributed_type == DistributedType.DEEPSPEED:
+ while len(models) > 0:
+ model = models.pop()
- if isinstance(model, type(unwrap_model(transformer))):
- transformer_ = model
- elif isinstance(model, type(unwrap_model(text_encoder_one))):
- text_encoder_one_ = model
- else:
- raise ValueError(f"unexpected save model: {model.__class__}")
+ if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
+ transformer_ = unwrap_model(model)
+ elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
+ text_encoder_one_ = unwrap_model(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ else:
+ transformer_ = FluxTransformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="transformer"
+ )
+ transformer_.add_adapter(transformer_lora_config)
+ text_encoder_one_ = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder"
+ )
lora_state_dict = FluxKontextPipeline.lora_state_dict(input_dir)
@@ -1890,6 +1936,10 @@ def main(args):
device=accelerator.device,
prompt=args.instance_prompt,
)
+ else:
+ prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
+ prompts, text_encoders, tokenizers
+ )
# Convert images to latent space
if args.cache_latents:
@@ -2054,7 +2104,7 @@ def main(args):
progress_bar.update(1)
global_step += 1
- if accelerator.is_main_process:
+ if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py
index 199a8a68ea..8cbc3a43fd 100644
--- a/examples/dreambooth/train_dreambooth_lora_hidream.py
+++ b/examples/dreambooth/train_dreambooth_lora_hidream.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import copy
@@ -74,7 +75,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth_lora_lumina2.py b/examples/dreambooth/train_dreambooth_lora_lumina2.py
index ee84de66d2..8bf4895863 100644
--- a/examples/dreambooth/train_dreambooth_lora_lumina2.py
+++ b/examples/dreambooth/train_dreambooth_lora_lumina2.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import copy
@@ -72,7 +73,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth_lora_qwen_image.py b/examples/dreambooth/train_dreambooth_lora_qwen_image.py
index 231aff8bfe..56de160d6f 100644
--- a/examples/dreambooth/train_dreambooth_lora_qwen_image.py
+++ b/examples/dreambooth/train_dreambooth_lora_qwen_image.py
@@ -13,6 +13,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# /// script
+# dependencies = [
+# "diffusers @ git+https://github.com/huggingface/diffusers.git",
+# "torch>=2.0.0",
+# "accelerate>=0.31.0",
+# "transformers>=4.41.2",
+# "ftfy",
+# "tensorboard",
+# "Jinja2",
+# "peft>=0.11.1",
+# "sentencepiece",
+# "torchvision",
+# "datasets",
+# "bitsandbytes",
+# "prodigyopt",
+# ]
+# ///
+
import argparse
import copy
import itertools
@@ -75,7 +93,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
@@ -1320,7 +1338,7 @@ def main(args):
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=vae.dtype
)
- latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
if train_dataset.custom_instance_prompts:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, prompt_embeds_mask = compute_text_embeddings(
diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py
index 2c4e63fd95..2b0c1ee669 100644
--- a/examples/dreambooth/train_dreambooth_lora_sana.py
+++ b/examples/dreambooth/train_dreambooth_lora_sana.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
# /// script
# dependencies = [
@@ -24,6 +25,10 @@
# "Jinja2",
# "peft>=0.14.0",
# "sentencepiece",
+# "torchvision",
+# "datasets",
+# "bitsandbytes",
+# "prodigyopt",
# ]
# ///
@@ -86,7 +91,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py
index 5ab21df518..eef732c531 100644
--- a/examples/dreambooth/train_dreambooth_lora_sd3.py
+++ b/examples/dreambooth/train_dreambooth_lora_sd3.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import copy
@@ -72,7 +73,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index 5758db8508..1ffb73cee4 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import gc
@@ -79,7 +80,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py
index b130b9ff21..d345ebb391 100644
--- a/examples/dreambooth/train_dreambooth_sd3.py
+++ b/examples/dreambooth/train_dreambooth_sd3.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import copy
@@ -63,7 +64,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/flux-control/train_control_flux.py b/examples/flux-control/train_control_flux.py
index 63cb770ccd..fe47e07441 100644
--- a/examples/flux-control/train_control_flux.py
+++ b/examples/flux-control/train_control_flux.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import copy
@@ -54,7 +55,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py
index 2990d5701a..36320449bd 100644
--- a/examples/flux-control/train_control_lora_flux.py
+++ b/examples/flux-control/train_control_lora_flux.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import copy
@@ -57,7 +58,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py
index b6b29fce27..85b85aa2fa 100644
--- a/examples/instruct_pix2pix/train_instruct_pix2pix.py
+++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py
@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
index ef55321f58..acf5d8dff0 100644
--- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
+++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
index 56a8136ab2..a30e255953 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import logging
@@ -52,7 +53,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
index 7461f5b742..57c92f3ae5 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
index 64fd8ba3cb..2a0ef7d6fb 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
index fd4694d862..df7cffef9b 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import logging
@@ -51,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/model_search/pipeline_easy.py b/examples/model_search/pipeline_easy.py
index fcce297c37..ee5dced817 100644
--- a/examples/model_search/pipeline_easy.py
+++ b/examples/model_search/pipeline_easy.py
@@ -1246,12 +1246,9 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `hf auth login`.
-
-
+ > [!TIP]
+ > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ > `hf auth login`.
Examples:
@@ -1355,12 +1352,9 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image):
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
below for more information.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `hf auth login`.
-
-
+ > [!TIP]
+ > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ > `hf auth login`.
Examples:
@@ -1504,12 +1498,9 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `hf auth login`.
-
-
+ > [!TIP]
+ > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ > `hf auth login`.
Examples:
@@ -1614,12 +1605,9 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
below for more information.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `hf auth login`.
-
-
+ > [!TIP]
+ > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ > `hf auth login`.
Examples:
@@ -1763,12 +1751,9 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `hf auth login
-
-
+ > [!TIP]
+ > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ > `hf auth login
Examples:
@@ -1872,12 +1857,9 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting):
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
below for more information.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `hf auth login
-
-
+ > [!TIP]
+ > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ > `hf auth login
Examples:
diff --git a/examples/research_projects/autoencoderkl/train_autoencoderkl.py b/examples/research_projects/autoencoderkl/train_autoencoderkl.py
index dfb9e42ef1..b217f58d6d 100644
--- a/examples/research_projects/autoencoderkl/train_autoencoderkl.py
+++ b/examples/research_projects/autoencoderkl/train_autoencoderkl.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import contextlib
diff --git a/examples/research_projects/controlnet/train_controlnet_webdataset.py b/examples/research_projects/controlnet/train_controlnet_webdataset.py
index f33a65c756..c1ddb4eae1 100644
--- a/examples/research_projects/controlnet/train_controlnet_webdataset.py
+++ b/examples/research_projects/controlnet/train_controlnet_webdataset.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import functools
diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
index fda2a15809..a65767d084 100644
--- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
+++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import contextlib
diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
index aa39b0b517..756b20bb8d 100644
--- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
+++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import contextlib
diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
index 46045d330b..5a1b26f886 100644
--- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
+++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import contextlib
diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
index 93418bf910..f1bfaa2fb5 100644
--- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
+++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import contextlib
diff --git a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py
index 572c69fddf..65811ae57c 100644
--- a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py
+++ b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import copy
diff --git a/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb b/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb
index a39bcc5eea..3d5b8adfba 100644
--- a/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb
+++ b/examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb
@@ -1760,7 +1760,7 @@
"clip_local = None\n",
"clip_pos = None\n",
"\n",
- "# constands for data handling\n",
+ "# constants for data handling\n",
"save_traj = False\n",
"save_data = False\n",
"output_dir = \"/content/\""
diff --git a/examples/research_projects/multi_subject_dreambooth_inpainting/README.md b/examples/research_projects/multi_subject_dreambooth_inpainting/README.md
index 32c375efea..8ddef1b83c 100644
--- a/examples/research_projects/multi_subject_dreambooth_inpainting/README.md
+++ b/examples/research_projects/multi_subject_dreambooth_inpainting/README.md
@@ -2,7 +2,7 @@
Please note that this project is not actively maintained. However, you can open an issue and tag @gzguevara.
-[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. This project consists of **two parts**. Training Stable Diffusion for inpainting requieres prompt-image-mask pairs. The Unet of inpainiting models have 5 additional input channels (4 for the encoded masked-image and 1 for the mask itself).
+[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. This project consists of **two parts**. Training Stable Diffusion for inpainting requires prompt-image-mask pairs. The Unet of inpainiting models have 5 additional input channels (4 for the encoded masked-image and 1 for the mask itself).
**The first part**, the `multi_inpaint_dataset.ipynb` notebook, demonstrates how make a 🤗 dataset of prompt-image-mask pairs. You can, however, skip the first part and move straight to the second part with the example datasets in this project. ([cat toy dataset masked](https://huggingface.co/datasets/gzguevara/cat_toy_masked), [mr. potato head dataset masked](https://huggingface.co/datasets/gzguevara/mr_potato_head_masked))
diff --git a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
index ffcc8a75c8..3d000c8c66 100644
--- a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
+++ b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import logging
diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
index dd4c341ca8..1af05e8b22 100644
--- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
+++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import logging
diff --git a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
index 28bf029af4..6044607c14 100644
--- a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
+++ b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import logging
diff --git a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py
index 148b2e7f31..89228983d4 100644
--- a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py
+++ b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py
@@ -860,7 +860,7 @@ class PixArtAlphaControlnetPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
index 7dfbc8b3e5..1bd9c0161f 100644
--- a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
+++ b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
@@ -263,6 +263,12 @@ class PromptDiffusionPipeline(
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
@@ -271,6 +277,12 @@ class PromptDiffusionPipeline(
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
diff --git a/examples/research_projects/rdm/pipeline_rdm.py b/examples/research_projects/rdm/pipeline_rdm.py
index 7e2095b724..9b696874c5 100644
--- a/examples/research_projects/rdm/pipeline_rdm.py
+++ b/examples/research_projects/rdm/pipeline_rdm.py
@@ -202,7 +202,7 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/examples/research_projects/sana/train_sana_sprint_diffusers.py b/examples/research_projects/sana/train_sana_sprint_diffusers.py
index 51db15f194..d127fee5fd 100644
--- a/examples/research_projects/sana/train_sana_sprint_diffusers.py
+++ b/examples/research_projects/sana/train_sana_sprint_diffusers.py
@@ -13,6 +13,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import io
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
index 50ab487bfe..c504056369 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import copy
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
index 5ce510861a..88f6ca0f4d 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import copy
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
index 554aaedd7b..64914f5204 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import contextlib
diff --git a/examples/research_projects/vae/vae_roundtrip.py b/examples/research_projects/vae/vae_roundtrip.py
index 8388a352b2..cdc3a54fdf 100644
--- a/examples/research_projects/vae/vae_roundtrip.py
+++ b/examples/research_projects/vae/vae_roundtrip.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import typing
diff --git a/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
index 12586b5f57..fbf73a070e 100644
--- a/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
+++ b/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
@@ -10,6 +10,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import logging
diff --git a/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_prior.py
index e72152b45c..737c70665b 100644
--- a/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_prior.py
+++ b/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_prior.py
@@ -10,6 +10,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import logging
diff --git a/examples/server-async/Pipelines.py b/examples/server-async/Pipelines.py
new file mode 100644
index 0000000000..f89cac6a7e
--- /dev/null
+++ b/examples/server-async/Pipelines.py
@@ -0,0 +1,91 @@
+import logging
+import os
+from dataclasses import dataclass, field
+from typing import List
+
+import torch
+from pydantic import BaseModel
+
+from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
+
+
+logger = logging.getLogger(__name__)
+
+
+class TextToImageInput(BaseModel):
+ model: str
+ prompt: str
+ size: str | None = None
+ n: int | None = None
+
+
+@dataclass
+class PresetModels:
+ SD3: List[str] = field(default_factory=lambda: ["stabilityai/stable-diffusion-3-medium"])
+ SD3_5: List[str] = field(
+ default_factory=lambda: [
+ "stabilityai/stable-diffusion-3.5-large",
+ "stabilityai/stable-diffusion-3.5-large-turbo",
+ "stabilityai/stable-diffusion-3.5-medium",
+ ]
+ )
+
+
+class TextToImagePipelineSD3:
+ def __init__(self, model_path: str | None = None):
+ self.model_path = model_path or os.getenv("MODEL_PATH")
+ self.pipeline: StableDiffusion3Pipeline | None = None
+ self.device: str | None = None
+
+ def start(self):
+ if torch.cuda.is_available():
+ model_path = self.model_path or "stabilityai/stable-diffusion-3.5-large"
+ logger.info("Loading CUDA")
+ self.device = "cuda"
+ self.pipeline = StableDiffusion3Pipeline.from_pretrained(
+ model_path,
+ torch_dtype=torch.float16,
+ ).to(device=self.device)
+ elif torch.backends.mps.is_available():
+ model_path = self.model_path or "stabilityai/stable-diffusion-3.5-medium"
+ logger.info("Loading MPS for Mac M Series")
+ self.device = "mps"
+ self.pipeline = StableDiffusion3Pipeline.from_pretrained(
+ model_path,
+ torch_dtype=torch.bfloat16,
+ ).to(device=self.device)
+ else:
+ raise Exception("No CUDA or MPS device available")
+
+
+class ModelPipelineInitializer:
+ def __init__(self, model: str = "", type_models: str = "t2im"):
+ self.model = model
+ self.type_models = type_models
+ self.pipeline = None
+ self.device = "cuda" if torch.cuda.is_available() else "mps"
+ self.model_type = None
+
+ def initialize_pipeline(self):
+ if not self.model:
+ raise ValueError("Model name not provided")
+
+ # Check if model exists in PresetModels
+ preset_models = PresetModels()
+
+ # Determine which model type we're dealing with
+ if self.model in preset_models.SD3:
+ self.model_type = "SD3"
+ elif self.model in preset_models.SD3_5:
+ self.model_type = "SD3_5"
+
+ # Create appropriate pipeline based on model type and type_models
+ if self.type_models == "t2im":
+ if self.model_type in ["SD3", "SD3_5"]:
+ self.pipeline = TextToImagePipelineSD3(self.model)
+ else:
+ raise ValueError(f"Model type {self.model_type} not supported for text-to-image")
+ elif self.type_models == "t2v":
+ raise ValueError(f"Unsupported type_models: {self.type_models}")
+
+ return self.pipeline
diff --git a/examples/server-async/README.md b/examples/server-async/README.md
new file mode 100644
index 0000000000..a47ab7c7f2
--- /dev/null
+++ b/examples/server-async/README.md
@@ -0,0 +1,171 @@
+# Asynchronous server and parallel execution of models
+
+> Example/demo server that keeps a single model in memory while safely running parallel inference requests by creating per-request lightweight views and cloning only small, stateful components (schedulers, RNG state, small mutable attrs). Works with StableDiffusion3 pipelines.
+> We recommend running 10 to 50 inferences in parallel for optimal performance, averaging between 25 and 30 seconds to 1 minute and 1 minute and 30 seconds. (This is only recommended if you have a GPU with 35GB of VRAM or more; otherwise, keep it to one or two inferences in parallel to avoid decoding or saving errors due to memory shortages.)
+
+## ⚠️ IMPORTANT
+
+* The example demonstrates how to run pipelines like `StableDiffusion3-3.5` concurrently while keeping a single copy of the heavy model parameters on GPU.
+
+## Necessary components
+
+All the components needed to create the inference server are in the current directory:
+
+```
+server-async/
+├── utils/
+├─────── __init__.py
+├─────── scheduler.py # BaseAsyncScheduler wrapper and async_retrieve_timesteps for secure inferences
+├─────── requestscopedpipeline.py # RequestScoped Pipeline for inference with a single in-memory model
+├─────── utils.py # Image/video saving utilities and service configuration
+├── Pipelines.py # pipeline loader classes (SD3)
+├── serverasync.py # FastAPI app with lifespan management and async inference endpoints
+├── test.py # Client test script for inference requests
+├── requirements.txt # Dependencies
+└── README.md # This documentation
+```
+
+## What `diffusers-async` adds / Why we needed it
+
+Core problem: a naive server that calls `pipe.__call__` concurrently can hit **race conditions** (e.g., `scheduler.set_timesteps` mutates shared state) or explode memory by deep-copying the whole pipeline per-request.
+
+`diffusers-async` / this example addresses that by:
+
+* **Request-scoped views**: `RequestScopedPipeline` creates a shallow copy of the pipeline per request so heavy weights (UNet, VAE, text encoder) remain shared and *are not duplicated*.
+* **Per-request mutable state**: stateful small objects (scheduler, RNG state, small lists/dicts, callbacks) are cloned per request. The system uses `BaseAsyncScheduler.clone_for_request(...)` for scheduler cloning, with fallback to safe `deepcopy` or other heuristics.
+* **Tokenizer concurrency safety**: `RequestScopedPipeline` now manages an internal tokenizer lock with automatic tokenizer detection and wrapping. This ensures that Rust tokenizers are safe to use under concurrency — race condition errors like `Already borrowed` no longer occur.
+* **`async_retrieve_timesteps(..., return_scheduler=True)`**: fully retro-compatible helper that returns `(timesteps, num_inference_steps, scheduler)` without mutating the shared scheduler. For users not using `return_scheduler=True`, the behavior is identical to the original API.
+* **Robust attribute handling**: wrapper avoids writing to read-only properties (e.g., `components`) and auto-detects small mutable attributes to clone while avoiding duplication of large tensors. Configurable tensor size threshold prevents cloning of large tensors.
+* **Enhanced scheduler wrapping**: `BaseAsyncScheduler` automatically wraps schedulers with improved `__getattr__`, `__setattr__`, and debugging methods (`__repr__`, `__str__`).
+
+## How the server works (high-level flow)
+
+1. **Single model instance** is loaded into memory (GPU/MPS) when the server starts.
+2. On each HTTP inference request:
+
+ * The server uses `RequestScopedPipeline.generate(...)` which:
+
+ * automatically wraps the base scheduler in `BaseAsyncScheduler` (if not already wrapped),
+ * obtains a *local scheduler* (via `clone_for_request()` or `deepcopy`),
+ * does `local_pipe = copy.copy(base_pipe)` (shallow copy),
+ * sets `local_pipe.scheduler = local_scheduler` (if possible),
+ * clones only small mutable attributes (callbacks, rng, small latents) with auto-detection,
+ * wraps tokenizers with thread-safe locks to prevent race conditions,
+ * optionally enters a `model_cpu_offload_context()` for memory offload hooks,
+ * calls the pipeline on the local view (`local_pipe(...)`).
+3. **Result**: inference completes, images are moved to CPU & saved (if requested), internal buffers freed (GC + `torch.cuda.empty_cache()`).
+4. Multiple requests can run in parallel while sharing heavy weights and isolating mutable state.
+
+## How to set up and run the server
+
+### 1) Install dependencies
+
+Recommended: create a virtualenv / conda environment.
+
+```bash
+pip install diffusers
+pip install -r requirements.txt
+```
+
+### 2) Start the server
+
+Using the `serverasync.py` file that already has everything you need:
+
+```bash
+python serverasync.py
+```
+
+The server will start on `http://localhost:8500` by default with the following features:
+- FastAPI application with async lifespan management
+- Automatic model loading and pipeline initialization
+- Request counting and active inference tracking
+- Memory cleanup after each inference
+- CORS middleware for cross-origin requests
+
+### 3) Test the server
+
+Use the included test script:
+
+```bash
+python test.py
+```
+
+Or send a manual request:
+
+`POST /api/diffusers/inference` with JSON body:
+
+```json
+{
+ "prompt": "A futuristic cityscape, vibrant colors",
+ "num_inference_steps": 30,
+ "num_images_per_prompt": 1
+}
+```
+
+Response example:
+
+```json
+{
+ "response": ["http://localhost:8500/images/img123.png"]
+}
+```
+
+### 4) Server endpoints
+
+- `GET /` - Welcome message
+- `POST /api/diffusers/inference` - Main inference endpoint
+- `GET /images/{filename}` - Serve generated images
+- `GET /api/status` - Server status and memory info
+
+## Advanced Configuration
+
+### RequestScopedPipeline Parameters
+
+```python
+RequestScopedPipeline(
+ pipeline, # Base pipeline to wrap
+ mutable_attrs=None, # Custom list of attributes to clone
+ auto_detect_mutables=True, # Enable automatic detection of mutable attributes
+ tensor_numel_threshold=1_000_000, # Tensor size threshold for cloning
+ tokenizer_lock=None, # Custom threading lock for tokenizers
+ wrap_scheduler=True # Auto-wrap scheduler in BaseAsyncScheduler
+)
+```
+
+### BaseAsyncScheduler Features
+
+* Transparent proxy to the original scheduler with `__getattr__` and `__setattr__`
+* `clone_for_request()` method for safe per-request scheduler cloning
+* Enhanced debugging with `__repr__` and `__str__` methods
+* Full compatibility with existing scheduler APIs
+
+### Server Configuration
+
+The server configuration can be modified in `serverasync.py` through the `ServerConfigModels` dataclass:
+
+```python
+@dataclass
+class ServerConfigModels:
+ model: str = 'stabilityai/stable-diffusion-3.5-medium'
+ type_models: str = 't2im'
+ host: str = '0.0.0.0'
+ port: int = 8500
+```
+
+## Troubleshooting (quick)
+
+* `Already borrowed` — previously a Rust tokenizer concurrency error.
+ ✅ This is now fixed: `RequestScopedPipeline` automatically detects and wraps tokenizers with thread locks, so race conditions no longer happen.
+
+* `can't set attribute 'components'` — pipeline exposes read-only `components`.
+ ✅ The RequestScopedPipeline now detects read-only properties and skips setting them automatically.
+
+* Scheduler issues:
+ * If the scheduler doesn't implement `clone_for_request` and `deepcopy` fails, we log and fallback — but prefer `async_retrieve_timesteps(..., return_scheduler=True)` to avoid mutating the shared scheduler.
+ ✅ Note: `async_retrieve_timesteps` is fully retro-compatible — if you don't pass `return_scheduler=True`, the behavior is unchanged.
+
+* Memory issues with large tensors:
+ ✅ The system now has configurable `tensor_numel_threshold` to prevent cloning of large tensors while still cloning small mutable ones.
+
+* Automatic tokenizer detection:
+ ✅ The system automatically identifies tokenizer components by checking for tokenizer methods, class names, and attributes, then applies thread-safe wrappers.
\ No newline at end of file
diff --git a/examples/server-async/requirements.txt b/examples/server-async/requirements.txt
new file mode 100644
index 0000000000..aafa93b702
--- /dev/null
+++ b/examples/server-async/requirements.txt
@@ -0,0 +1,10 @@
+torch
+torchvision
+transformers
+sentencepiece
+fastapi
+uvicorn
+ftfy
+accelerate
+xformers
+protobuf
\ No newline at end of file
diff --git a/examples/server-async/serverasync.py b/examples/server-async/serverasync.py
new file mode 100644
index 0000000000..b279b36f9a
--- /dev/null
+++ b/examples/server-async/serverasync.py
@@ -0,0 +1,230 @@
+import asyncio
+import gc
+import logging
+import os
+import random
+import threading
+from contextlib import asynccontextmanager
+from dataclasses import dataclass
+from typing import Any, Dict, Optional, Type
+
+import torch
+from fastapi import FastAPI, HTTPException, Request
+from fastapi.concurrency import run_in_threadpool
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import FileResponse
+from Pipelines import ModelPipelineInitializer
+from pydantic import BaseModel
+
+from utils import RequestScopedPipeline, Utils
+
+
+@dataclass
+class ServerConfigModels:
+ model: str = "stabilityai/stable-diffusion-3.5-medium"
+ type_models: str = "t2im"
+ constructor_pipeline: Optional[Type] = None
+ custom_pipeline: Optional[Type] = None
+ components: Optional[Dict[str, Any]] = None
+ torch_dtype: Optional[torch.dtype] = None
+ host: str = "0.0.0.0"
+ port: int = 8500
+
+
+server_config = ServerConfigModels()
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ logging.basicConfig(level=logging.INFO)
+ app.state.logger = logging.getLogger("diffusers-server")
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
+
+ app.state.total_requests = 0
+ app.state.active_inferences = 0
+ app.state.metrics_lock = asyncio.Lock()
+ app.state.metrics_task = None
+
+ app.state.utils_app = Utils(
+ host=server_config.host,
+ port=server_config.port,
+ )
+
+ async def metrics_loop():
+ try:
+ while True:
+ async with app.state.metrics_lock:
+ total = app.state.total_requests
+ active = app.state.active_inferences
+ app.state.logger.info(f"[METRICS] total_requests={total} active_inferences={active}")
+ await asyncio.sleep(5)
+ except asyncio.CancelledError:
+ app.state.logger.info("Metrics loop cancelled")
+ raise
+
+ app.state.metrics_task = asyncio.create_task(metrics_loop())
+
+ try:
+ yield
+ finally:
+ task = app.state.metrics_task
+ if task:
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+
+ try:
+ stop_fn = getattr(model_pipeline, "stop", None) or getattr(model_pipeline, "close", None)
+ if callable(stop_fn):
+ await run_in_threadpool(stop_fn)
+ except Exception as e:
+ app.state.logger.warning(f"Error during pipeline shutdown: {e}")
+
+ app.state.logger.info("Lifespan shutdown complete")
+
+
+app = FastAPI(lifespan=lifespan)
+
+logger = logging.getLogger("DiffusersServer.Pipelines")
+
+
+initializer = ModelPipelineInitializer(
+ model=server_config.model,
+ type_models=server_config.type_models,
+)
+model_pipeline = initializer.initialize_pipeline()
+model_pipeline.start()
+
+request_pipe = RequestScopedPipeline(model_pipeline.pipeline)
+pipeline_lock = threading.Lock()
+
+logger.info(f"Pipeline initialized and ready to receive requests (model ={server_config.model})")
+
+app.state.MODEL_INITIALIZER = initializer
+app.state.MODEL_PIPELINE = model_pipeline
+app.state.REQUEST_PIPE = request_pipe
+app.state.PIPELINE_LOCK = pipeline_lock
+
+
+class JSONBodyQueryAPI(BaseModel):
+ model: str | None = None
+ prompt: str
+ negative_prompt: str | None = None
+ num_inference_steps: int = 28
+ num_images_per_prompt: int = 1
+
+
+@app.middleware("http")
+async def count_requests_middleware(request: Request, call_next):
+ async with app.state.metrics_lock:
+ app.state.total_requests += 1
+ response = await call_next(request)
+ return response
+
+
+@app.get("/")
+async def root():
+ return {"message": "Welcome to the Diffusers Server"}
+
+
+@app.post("/api/diffusers/inference")
+async def api(json: JSONBodyQueryAPI):
+ prompt = json.prompt
+ negative_prompt = json.negative_prompt or ""
+ num_steps = json.num_inference_steps
+ num_images_per_prompt = json.num_images_per_prompt
+
+ wrapper = app.state.MODEL_PIPELINE
+ initializer = app.state.MODEL_INITIALIZER
+
+ utils_app = app.state.utils_app
+
+ if not wrapper or not wrapper.pipeline:
+ raise HTTPException(500, "Model not initialized correctly")
+ if not prompt.strip():
+ raise HTTPException(400, "No prompt provided")
+
+ def make_generator():
+ g = torch.Generator(device=initializer.device)
+ return g.manual_seed(random.randint(0, 10_000_000))
+
+ req_pipe = app.state.REQUEST_PIPE
+
+ def infer():
+ gen = make_generator()
+ return req_pipe.generate(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ generator=gen,
+ num_inference_steps=num_steps,
+ num_images_per_prompt=num_images_per_prompt,
+ device=initializer.device,
+ output_type="pil",
+ )
+
+ try:
+ async with app.state.metrics_lock:
+ app.state.active_inferences += 1
+
+ output = await run_in_threadpool(infer)
+
+ async with app.state.metrics_lock:
+ app.state.active_inferences = max(0, app.state.active_inferences - 1)
+
+ urls = [utils_app.save_image(img) for img in output.images]
+ return {"response": urls}
+
+ except Exception as e:
+ async with app.state.metrics_lock:
+ app.state.active_inferences = max(0, app.state.active_inferences - 1)
+ logger.error(f"Error during inference: {e}")
+ raise HTTPException(500, f"Error in processing: {e}")
+
+ finally:
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+ torch.cuda.ipc_collect()
+ gc.collect()
+
+
+@app.get("/images/{filename}")
+async def serve_image(filename: str):
+ utils_app = app.state.utils_app
+ file_path = os.path.join(utils_app.image_dir, filename)
+ if not os.path.isfile(file_path):
+ raise HTTPException(status_code=404, detail="Image not found")
+ return FileResponse(file_path, media_type="image/png")
+
+
+@app.get("/api/status")
+async def get_status():
+ memory_info = {}
+ if torch.cuda.is_available():
+ memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
+ memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB
+ memory_info = {
+ "memory_allocated_gb": round(memory_allocated, 2),
+ "memory_reserved_gb": round(memory_reserved, 2),
+ "device": torch.cuda.get_device_name(0),
+ }
+
+ return {"current_model": server_config.model, "type_models": server_config.type_models, "memory": memory_info}
+
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+if __name__ == "__main__":
+ import uvicorn
+
+ uvicorn.run(app, host=server_config.host, port=server_config.port)
diff --git a/examples/server-async/test.py b/examples/server-async/test.py
new file mode 100644
index 0000000000..e67317ea8f
--- /dev/null
+++ b/examples/server-async/test.py
@@ -0,0 +1,65 @@
+import os
+import time
+import urllib.parse
+
+import requests
+
+
+SERVER_URL = "http://localhost:8500/api/diffusers/inference"
+BASE_URL = "http://localhost:8500"
+DOWNLOAD_FOLDER = "generated_images"
+WAIT_BEFORE_DOWNLOAD = 2 # seconds
+
+os.makedirs(DOWNLOAD_FOLDER, exist_ok=True)
+
+
+def save_from_url(url: str) -> str:
+ """Download the given URL (relative or absolute) and save it locally."""
+ if url.startswith("/"):
+ direct = BASE_URL.rstrip("/") + url
+ else:
+ direct = url
+ resp = requests.get(direct, timeout=60)
+ resp.raise_for_status()
+ filename = os.path.basename(urllib.parse.urlparse(direct).path) or f"img_{int(time.time())}.png"
+ path = os.path.join(DOWNLOAD_FOLDER, filename)
+ with open(path, "wb") as f:
+ f.write(resp.content)
+ return path
+
+
+def main():
+ payload = {
+ "prompt": "The T-800 Terminator Robot Returning From The Future, Anime Style",
+ "num_inference_steps": 30,
+ "num_images_per_prompt": 1,
+ }
+
+ print("Sending request...")
+ try:
+ r = requests.post(SERVER_URL, json=payload, timeout=480)
+ r.raise_for_status()
+ except Exception as e:
+ print(f"Request failed: {e}")
+ return
+
+ body = r.json().get("response", [])
+ # Normalize to a list
+ urls = body if isinstance(body, list) else [body] if body else []
+ if not urls:
+ print("No URLs found in the response. Check the server output.")
+ return
+
+ print(f"Received {len(urls)} URL(s). Waiting {WAIT_BEFORE_DOWNLOAD}s before downloading...")
+ time.sleep(WAIT_BEFORE_DOWNLOAD)
+
+ for u in urls:
+ try:
+ path = save_from_url(u)
+ print(f"Image saved to: {path}")
+ except Exception as e:
+ print(f"Error downloading {u}: {e}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/server-async/utils/__init__.py b/examples/server-async/utils/__init__.py
new file mode 100644
index 0000000000..731cfe491a
--- /dev/null
+++ b/examples/server-async/utils/__init__.py
@@ -0,0 +1,2 @@
+from .requestscopedpipeline import RequestScopedPipeline
+from .utils import Utils
diff --git a/examples/server-async/utils/requestscopedpipeline.py b/examples/server-async/utils/requestscopedpipeline.py
new file mode 100644
index 0000000000..57d1e25671
--- /dev/null
+++ b/examples/server-async/utils/requestscopedpipeline.py
@@ -0,0 +1,296 @@
+import copy
+import threading
+from typing import Any, Iterable, List, Optional
+
+import torch
+
+from diffusers.utils import logging
+
+from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
+
+
+logger = logging.get_logger(__name__)
+
+
+def safe_tokenize(tokenizer, *args, lock, **kwargs):
+ with lock:
+ return tokenizer(*args, **kwargs)
+
+
+class RequestScopedPipeline:
+ DEFAULT_MUTABLE_ATTRS = [
+ "_all_hooks",
+ "_offload_device",
+ "_progress_bar_config",
+ "_progress_bar",
+ "_rng_state",
+ "_last_seed",
+ "latents",
+ ]
+
+ def __init__(
+ self,
+ pipeline: Any,
+ mutable_attrs: Optional[Iterable[str]] = None,
+ auto_detect_mutables: bool = True,
+ tensor_numel_threshold: int = 1_000_000,
+ tokenizer_lock: Optional[threading.Lock] = None,
+ wrap_scheduler: bool = True,
+ ):
+ self._base = pipeline
+ self.unet = getattr(pipeline, "unet", None)
+ self.vae = getattr(pipeline, "vae", None)
+ self.text_encoder = getattr(pipeline, "text_encoder", None)
+ self.components = getattr(pipeline, "components", None)
+
+ if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
+ if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
+ pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
+
+ self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
+ self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
+
+ self._auto_detect_mutables = bool(auto_detect_mutables)
+ self._tensor_numel_threshold = int(tensor_numel_threshold)
+
+ self._auto_detected_attrs: List[str] = []
+
+ def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
+ base_sched = getattr(self._base, "scheduler", None)
+ if base_sched is None:
+ return None
+
+ if not isinstance(base_sched, BaseAsyncScheduler):
+ wrapped_scheduler = BaseAsyncScheduler(base_sched)
+ else:
+ wrapped_scheduler = base_sched
+
+ try:
+ return wrapped_scheduler.clone_for_request(
+ num_inference_steps=num_inference_steps, device=device, **clone_kwargs
+ )
+ except Exception as e:
+ logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
+ try:
+ return copy.deepcopy(wrapped_scheduler)
+ except Exception as e:
+ logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
+ return wrapped_scheduler
+
+ def _autodetect_mutables(self, max_attrs: int = 40):
+ if not self._auto_detect_mutables:
+ return []
+
+ if self._auto_detected_attrs:
+ return self._auto_detected_attrs
+
+ candidates: List[str] = []
+ seen = set()
+ for name in dir(self._base):
+ if name.startswith("__"):
+ continue
+ if name in self._mutable_attrs:
+ continue
+ if name in ("to", "save_pretrained", "from_pretrained"):
+ continue
+ try:
+ val = getattr(self._base, name)
+ except Exception:
+ continue
+
+ import types
+
+ # skip callables and modules
+ if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
+ continue
+
+ # containers -> candidate
+ if isinstance(val, (dict, list, set, tuple, bytearray)):
+ candidates.append(name)
+ seen.add(name)
+ else:
+ # try Tensor detection
+ try:
+ if isinstance(val, torch.Tensor):
+ if val.numel() <= self._tensor_numel_threshold:
+ candidates.append(name)
+ seen.add(name)
+ else:
+ logger.debug(f"Ignoring large tensor attr '{name}', numel={val.numel()}")
+ except Exception:
+ continue
+
+ if len(candidates) >= max_attrs:
+ break
+
+ self._auto_detected_attrs = candidates
+ logger.debug(f"Autodetected mutable attrs to clone: {self._auto_detected_attrs}")
+ return self._auto_detected_attrs
+
+ def _is_readonly_property(self, base_obj, attr_name: str) -> bool:
+ try:
+ cls = type(base_obj)
+ descriptor = getattr(cls, attr_name, None)
+ if isinstance(descriptor, property):
+ return descriptor.fset is None
+ if hasattr(descriptor, "__set__") is False and descriptor is not None:
+ return False
+ except Exception:
+ pass
+ return False
+
+ def _clone_mutable_attrs(self, base, local):
+ attrs_to_clone = list(self._mutable_attrs)
+ attrs_to_clone.extend(self._autodetect_mutables())
+
+ EXCLUDE_ATTRS = {
+ "components",
+ }
+
+ for attr in attrs_to_clone:
+ if attr in EXCLUDE_ATTRS:
+ logger.debug(f"Skipping excluded attr '{attr}'")
+ continue
+ if not hasattr(base, attr):
+ continue
+ if self._is_readonly_property(base, attr):
+ logger.debug(f"Skipping read-only property '{attr}'")
+ continue
+
+ try:
+ val = getattr(base, attr)
+ except Exception as e:
+ logger.debug(f"Could not getattr('{attr}') on base pipeline: {e}")
+ continue
+
+ try:
+ if isinstance(val, dict):
+ setattr(local, attr, dict(val))
+ elif isinstance(val, (list, tuple, set)):
+ setattr(local, attr, list(val))
+ elif isinstance(val, bytearray):
+ setattr(local, attr, bytearray(val))
+ else:
+ # small tensors or atomic values
+ if isinstance(val, torch.Tensor):
+ if val.numel() <= self._tensor_numel_threshold:
+ setattr(local, attr, val.clone())
+ else:
+ # don't clone big tensors, keep reference
+ setattr(local, attr, val)
+ else:
+ try:
+ setattr(local, attr, copy.copy(val))
+ except Exception:
+ setattr(local, attr, val)
+ except (AttributeError, TypeError) as e:
+ logger.debug(f"Skipping cloning attribute '{attr}' because it is not settable: {e}")
+ continue
+ except Exception as e:
+ logger.debug(f"Unexpected error cloning attribute '{attr}': {e}")
+ continue
+
+ def _is_tokenizer_component(self, component) -> bool:
+ if component is None:
+ return False
+
+ tokenizer_methods = ["encode", "decode", "tokenize", "__call__"]
+ has_tokenizer_methods = any(hasattr(component, method) for method in tokenizer_methods)
+
+ class_name = component.__class__.__name__.lower()
+ has_tokenizer_in_name = "tokenizer" in class_name
+
+ tokenizer_attrs = ["vocab_size", "pad_token", "eos_token", "bos_token"]
+ has_tokenizer_attrs = any(hasattr(component, attr) for attr in tokenizer_attrs)
+
+ return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)
+
+ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
+ local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)
+
+ try:
+ local_pipe = copy.copy(self._base)
+ except Exception as e:
+ logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
+ local_pipe = copy.deepcopy(self._base)
+
+ if local_scheduler is not None:
+ try:
+ timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
+ local_scheduler.scheduler,
+ num_inference_steps=num_inference_steps,
+ device=device,
+ return_scheduler=True,
+ **{k: v for k, v in kwargs.items() if k in ["timesteps", "sigmas"]},
+ )
+
+ final_scheduler = BaseAsyncScheduler(configured_scheduler)
+ setattr(local_pipe, "scheduler", final_scheduler)
+ except Exception:
+ logger.warning("Could not set scheduler on local pipe; proceeding without replacing scheduler.")
+
+ self._clone_mutable_attrs(self._base, local_pipe)
+
+ # 4) wrap tokenizers on the local pipe with the lock wrapper
+ tokenizer_wrappers = {} # name -> original_tokenizer
+ try:
+ # a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
+ for name in dir(local_pipe):
+ if "tokenizer" in name and not name.startswith("_"):
+ tok = getattr(local_pipe, name, None)
+ if tok is not None and self._is_tokenizer_component(tok):
+ tokenizer_wrappers[name] = tok
+ setattr(
+ local_pipe,
+ name,
+ lambda *args, tok=tok, **kwargs: safe_tokenize(
+ tok, *args, lock=self._tokenizer_lock, **kwargs
+ ),
+ )
+
+ # b) wrap tokenizers in components dict
+ if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
+ for key, val in local_pipe.components.items():
+ if val is None:
+ continue
+
+ if self._is_tokenizer_component(val):
+ tokenizer_wrappers[f"components[{key}]"] = val
+ local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
+ tokenizer, *args, lock=self._tokenizer_lock, **kwargs
+ )
+
+ except Exception as e:
+ logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
+
+ result = None
+ cm = getattr(local_pipe, "model_cpu_offload_context", None)
+ try:
+ if callable(cm):
+ try:
+ with cm():
+ result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
+ except TypeError:
+ # cm might be a context manager instance rather than callable
+ try:
+ with cm:
+ result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
+ except Exception as e:
+ logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
+ result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
+ else:
+ # no offload context available — call directly
+ result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
+
+ return result
+
+ finally:
+ try:
+ for name, tok in tokenizer_wrappers.items():
+ if name.startswith("components["):
+ key = name[len("components[") : -1]
+ local_pipe.components[key] = tok
+ else:
+ setattr(local_pipe, name, tok)
+ except Exception as e:
+ logger.debug(f"Error restoring wrapped tokenizers: {e}")
diff --git a/examples/server-async/utils/scheduler.py b/examples/server-async/utils/scheduler.py
new file mode 100644
index 0000000000..86d47cac61
--- /dev/null
+++ b/examples/server-async/utils/scheduler.py
@@ -0,0 +1,141 @@
+import copy
+import inspect
+from typing import Any, List, Optional, Union
+
+import torch
+
+
+class BaseAsyncScheduler:
+ def __init__(self, scheduler: Any):
+ self.scheduler = scheduler
+
+ def __getattr__(self, name: str):
+ if hasattr(self.scheduler, name):
+ return getattr(self.scheduler, name)
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
+
+ def __setattr__(self, name: str, value):
+ if name == "scheduler":
+ super().__setattr__(name, value)
+ else:
+ if hasattr(self, "scheduler") and hasattr(self.scheduler, name):
+ setattr(self.scheduler, name, value)
+ else:
+ super().__setattr__(name, value)
+
+ def clone_for_request(self, num_inference_steps: int, device: Union[str, torch.device, None] = None, **kwargs):
+ local = copy.deepcopy(self.scheduler)
+ local.set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs)
+ cloned = self.__class__(local)
+ return cloned
+
+ def __repr__(self):
+ return f"BaseAsyncScheduler({repr(self.scheduler)})"
+
+ def __str__(self):
+ return f"BaseAsyncScheduler wrapping: {str(self.scheduler)}"
+
+
+def async_retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call.
+ Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Backwards compatible: by default the function behaves exactly as before and returns
+ (timesteps_tensor, num_inference_steps)
+
+ If the caller passes `return_scheduler=True` in kwargs, the function will **not** mutate the passed
+ scheduler. Instead it will use a cloned scheduler if available (via `scheduler.clone_for_request`)
+ or a deepcopy fallback, call `set_timesteps` on that cloned scheduler, and return:
+ (timesteps_tensor, num_inference_steps, scheduler_in_use)
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Optional kwargs:
+ return_scheduler (bool, default False): if True, return (timesteps, num_inference_steps, scheduler_in_use)
+ where `scheduler_in_use` is a scheduler instance that already has timesteps set.
+ This mode will prefer `scheduler.clone_for_request(...)` if available, to avoid mutating the original scheduler.
+
+ Returns:
+ `(timesteps_tensor, num_inference_steps)` by default (backwards compatible), or
+ `(timesteps_tensor, num_inference_steps, scheduler_in_use)` if `return_scheduler=True`.
+ """
+ # pop our optional control kwarg (keeps compatibility)
+ return_scheduler = bool(kwargs.pop("return_scheduler", False))
+
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+
+ # choose scheduler to call set_timesteps on
+ scheduler_in_use = scheduler
+ if return_scheduler:
+ # Do not mutate the provided scheduler: prefer to clone if possible
+ if hasattr(scheduler, "clone_for_request"):
+ try:
+ # clone_for_request may accept num_inference_steps or other kwargs; be permissive
+ scheduler_in_use = scheduler.clone_for_request(
+ num_inference_steps=num_inference_steps or 0, device=device
+ )
+ except Exception:
+ scheduler_in_use = copy.deepcopy(scheduler)
+ else:
+ # fallback deepcopy (scheduler tends to be smallish - acceptable)
+ scheduler_in_use = copy.deepcopy(scheduler)
+
+ # helper to test if set_timesteps supports a particular kwarg
+ def _accepts(param_name: str) -> bool:
+ try:
+ return param_name in set(inspect.signature(scheduler_in_use.set_timesteps).parameters.keys())
+ except (ValueError, TypeError):
+ # if signature introspection fails, be permissive and attempt the call later
+ return False
+
+ # now call set_timesteps on the chosen scheduler_in_use (may be original or clone)
+ if timesteps is not None:
+ accepts_timesteps = _accepts("timesteps")
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler_in_use.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps_out = scheduler_in_use.timesteps
+ num_inference_steps = len(timesteps_out)
+ elif sigmas is not None:
+ accept_sigmas = _accepts("sigmas")
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler_in_use.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps_out = scheduler_in_use.timesteps
+ num_inference_steps = len(timesteps_out)
+ else:
+ # default path
+ scheduler_in_use.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps_out = scheduler_in_use.timesteps
+
+ if return_scheduler:
+ return timesteps_out, num_inference_steps, scheduler_in_use
+ return timesteps_out, num_inference_steps
diff --git a/examples/server-async/utils/utils.py b/examples/server-async/utils/utils.py
new file mode 100644
index 0000000000..9f94330512
--- /dev/null
+++ b/examples/server-async/utils/utils.py
@@ -0,0 +1,48 @@
+import gc
+import logging
+import os
+import tempfile
+import uuid
+
+import torch
+
+
+logger = logging.getLogger(__name__)
+
+
+class Utils:
+ def __init__(self, host: str = "0.0.0.0", port: int = 8500):
+ self.service_url = f"http://{host}:{port}"
+ self.image_dir = os.path.join(tempfile.gettempdir(), "images")
+ if not os.path.exists(self.image_dir):
+ os.makedirs(self.image_dir)
+
+ self.video_dir = os.path.join(tempfile.gettempdir(), "videos")
+ if not os.path.exists(self.video_dir):
+ os.makedirs(self.video_dir)
+
+ def save_image(self, image):
+ if hasattr(image, "to"):
+ try:
+ image = image.to("cpu")
+ except Exception:
+ pass
+
+ if isinstance(image, torch.Tensor):
+ from torchvision import transforms
+
+ to_pil = transforms.ToPILImage()
+ image = to_pil(image.squeeze(0).clamp(0, 1))
+
+ filename = "img" + str(uuid.uuid4()).split("-")[0] + ".png"
+ image_path = os.path.join(self.image_dir, filename)
+ logger.info(f"Saving image to {image_path}")
+
+ image.save(image_path, format="PNG", optimize=True)
+
+ del image
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return os.path.join(self.service_url, "images", filename)
diff --git a/examples/server/README.md b/examples/server/README.md
index 8ad0ed3cbe..f8cd58fc1c 100644
--- a/examples/server/README.md
+++ b/examples/server/README.md
@@ -9,8 +9,8 @@ This guide will show you how to use the [`StableDiffusion3Pipeline`] in a server
Start by navigating to the `examples/server` folder and installing all of the dependencies.
```py
-pip install .
-pip install -f requirements.txt
+pip install diffusers
+pip install -r requirements.txt
```
Launch the server with the following command.
diff --git a/examples/server/requirements.in b/examples/server/requirements.in
index a469569a10..f8c35d48cd 100644
--- a/examples/server/requirements.in
+++ b/examples/server/requirements.in
@@ -6,4 +6,5 @@ py-consul
prometheus_client >= 0.18.0
prometheus-fastapi-instrumentator >= 7.0.0
fastapi
-uvicorn
\ No newline at end of file
+uvicorn
+accelerate
diff --git a/examples/server/requirements.txt b/examples/server/requirements.txt
index b91a8861a0..688a4ee94f 100644
--- a/examples/server/requirements.txt
+++ b/examples/server/requirements.txt
@@ -39,7 +39,7 @@ fsspec==2024.10.0
# torch
h11==0.14.0
# via uvicorn
-huggingface-hub==0.26.1
+huggingface-hub==0.35.0
# via
# tokenizers
# transformers
diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py
index acbee19fa5..989ac6e0c4 100644
--- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py
+++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import functools
@@ -60,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/text_to_image/requirements.txt b/examples/text_to_image/requirements.txt
index c3ffa42f0e..be05fe3fcd 100644
--- a/examples/text_to_image/requirements.txt
+++ b/examples/text_to_image/requirements.txt
@@ -5,4 +5,4 @@ datasets>=2.19.1
ftfy
tensorboard
Jinja2
-peft==0.7.0
+peft>=0.17.0
diff --git a/examples/text_to_image/requirements_sdxl.txt b/examples/text_to_image/requirements_sdxl.txt
index 64cbc9205f..4dacc26ce4 100644
--- a/examples/text_to_image/requirements_sdxl.txt
+++ b/examples/text_to_image/requirements_sdxl.txt
@@ -5,4 +5,4 @@ ftfy
tensorboard
Jinja2
datasets
-peft==0.7.0
\ No newline at end of file
+peft>=0.17.0
\ No newline at end of file
diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py
index bbd8fc062e..7ebf7b5465 100644
--- a/examples/text_to_image/train_text_to_image.py
+++ b/examples/text_to_image/train_text_to_image.py
@@ -57,7 +57,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py
index 74423dcf27..c4f36879f3 100644
--- a/examples/text_to_image/train_text_to_image_flax.py
+++ b/examples/text_to_image/train_text_to_image_flax.py
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = logging.getLogger(__name__)
diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py
index 19968c2547..663d6f6b08 100644
--- a/examples/text_to_image/train_text_to_image_lora.py
+++ b/examples/text_to_image/train_text_to_image_lora.py
@@ -56,7 +56,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py
index 88be919727..5fb1825f37 100644
--- a/examples/text_to_image/train_text_to_image_lora_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py
@@ -68,7 +68,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py
index dec202fbbf..c26cb44841 100644
--- a/examples/text_to_image/train_text_to_image_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_sdxl.py
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py
index e31ba9bd0c..caa77e4bba 100644
--- a/examples/textual_inversion/textual_inversion.py
+++ b/examples/textual_inversion/textual_inversion.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import logging
@@ -81,7 +82,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py
index f5863d94b0..4a03d9bf6b 100644
--- a/examples/textual_inversion/textual_inversion_flax.py
+++ b/examples/textual_inversion/textual_inversion_flax.py
@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = logging.getLogger(__name__)
diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py
index 1752bfd3b1..51de29a71a 100644
--- a/examples/textual_inversion/textual_inversion_sdxl.py
+++ b/examples/textual_inversion/textual_inversion_sdxl.py
@@ -12,6 +12,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# limitations under the License.
import argparse
import logging
@@ -76,7 +77,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py
index 892c674575..3ffeef1364 100644
--- a/examples/unconditional_image_generation/train_unconditional.py
+++ b/examples/unconditional_image_generation/train_unconditional.py
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/vqgan/test_vqgan.py b/examples/vqgan/test_vqgan.py
index d13e102e78..a3c8ee1e84 100644
--- a/examples/vqgan/test_vqgan.py
+++ b/examples/vqgan/test_vqgan.py
@@ -24,12 +24,18 @@ import tempfile
import torch
from diffusers import VQModel
-from diffusers.utils.testing_utils import require_timm
+# Add parent directories to path to import from tests
sys.path.append("..")
+repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
+if repo_root not in sys.path:
+ sys.path.insert(0, repo_root)
+
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+from tests.testing_utils import require_timm # noqa
+
logging.basicConfig(level=logging.DEBUG)
diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py
index 5ba1678d44..eeb592a3f7 100644
--- a/examples/vqgan/train_vqgan.py
+++ b/examples/vqgan/train_vqgan.py
@@ -50,7 +50,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.35.0.dev0")
+check_min_version("0.36.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py
index 599c90be57..39a364b07d 100644
--- a/scripts/convert_wan_to_diffusers.py
+++ b/scripts/convert_wan_to_diffusers.py
@@ -278,6 +278,29 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
}
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
+ elif model_type == "Wan2.2-VACE-Fun-14B":
+ config = {
+ "model_id": "alibaba-pai/Wan2.2-VACE-Fun-A14B",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 40,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ "vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
+ "vace_in_channels": 96,
+ },
+ }
+ RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
+ SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "Wan2.2-I2V-14B-720p":
config = {
"model_id": "Wan-AI/Wan2.2-I2V-A14B",
@@ -975,7 +998,17 @@ if __name__ == "__main__":
image_encoder=image_encoder,
image_processor=image_processor,
)
- elif "VACE" in args.model_type:
+ elif "Wan2.2-VACE" in args.model_type:
+ pipe = WanVACEPipeline(
+ transformer=transformer,
+ transformer_2=transformer_2,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ boundary_ratio=0.875,
+ )
+ elif "Wan-VACE" in args.model_type:
pipe = WanVACEPipeline(
transformer=transformer,
text_encoder=text_encoder,
diff --git a/setup.py b/setup.py
index 799150fd03..8d346ddfec 100644
--- a/setup.py
+++ b/setup.py
@@ -102,7 +102,8 @@ _deps = [
"filelock",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
- "huggingface-hub>=0.34.0",
+ "httpx<1.0.0",
+ "huggingface-hub>=0.34.0,<2.0",
"requests-mock==1.10.0",
"importlib_metadata",
"invisible-watermark>=0.2.0",
@@ -116,7 +117,7 @@ _deps = [
"librosa",
"numpy",
"parameterized",
- "peft>=0.15.0",
+ "peft>=0.17.0",
"protobuf>=3.20.3,<4",
"pytest",
"pytest-timeout",
@@ -132,6 +133,7 @@ _deps = [
"gguf>=0.10.0",
"torchao>=0.7.0",
"bitsandbytes>=0.43.3",
+ "nvidia_modelopt[hf]>=0.33.1",
"regex!=2019.12.17",
"requests",
"tensorboard",
@@ -143,6 +145,7 @@ _deps = [
"black",
"phonemizer",
"opencv-python",
+ "timm",
]
# this is a lookup table with items like:
@@ -216,7 +219,7 @@ class DepsTableUpdateCommand(Command):
extras = {}
extras["quality"] = deps_list("urllib3", "isort", "ruff", "hf-doc-builder")
extras["docs"] = deps_list("hf-doc-builder")
-extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2", "peft")
+extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2", "peft", "timm")
extras["test"] = deps_list(
"compel",
"GitPython",
@@ -244,6 +247,7 @@ extras["bitsandbytes"] = deps_list("bitsandbytes", "accelerate")
extras["gguf"] = deps_list("gguf", "accelerate")
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
extras["torchao"] = deps_list("torchao", "accelerate")
+extras["nvidia_modelopt"] = deps_list("nvidia_modelopt[hf]")
if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
@@ -257,6 +261,7 @@ extras["dev"] = (
install_requires = [
deps["importlib_metadata"],
deps["filelock"],
+ deps["httpx"],
deps["huggingface-hub"],
deps["numpy"],
deps["regex"],
@@ -269,7 +274,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup(
name="diffusers",
- version="0.35.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
+ version="0.36.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 1c25a65f50..95d559ff75 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "0.35.0.dev0"
+__version__ = "0.36.0.dev0"
from typing import TYPE_CHECKING
@@ -13,6 +13,7 @@ from .utils import (
is_k_diffusion_available,
is_librosa_available,
is_note_seq_available,
+ is_nvidia_modelopt_available,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,
@@ -111,6 +112,18 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["quantizers.quantization_config"].append("QuantoConfig")
+try:
+ if not is_torch_available() and not is_accelerate_available() and not is_nvidia_modelopt_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from .utils import dummy_nvidia_modelopt_objects
+
+ _import_structure["utils.dummy_nvidia_modelopt_objects"] = [
+ name for name in dir(dummy_nvidia_modelopt_objects) if not name.startswith("_")
+ ]
+else:
+ _import_structure["quantizers.quantization_config"].append("NVIDIAModelOptConfig")
+
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -139,6 +152,7 @@ else:
"AutoGuidance",
"ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance",
+ "FrequencyDecoupledGuidance",
"PerturbedAttentionGuidance",
"SkipLayerGuidance",
"SmoothedEnergyGuidance",
@@ -180,6 +194,7 @@ else:
"AutoencoderOobleck",
"AutoencoderTiny",
"AutoModel",
+ "BriaTransformer2DModel",
"CacheMixin",
"ChromaTransformer2DModel",
"CogVideoXTransformer3DModel",
@@ -187,6 +202,7 @@ else:
"CogView4Transformer2DModel",
"ConsisIDTransformer3DModel",
"ConsistencyDecoderVAE",
+ "ContextParallelConfig",
"ControlNetModel",
"ControlNetUnionModel",
"ControlNetXSAdapter",
@@ -214,8 +230,11 @@ else:
"MultiAdapter",
"MultiControlNetModel",
"OmniGenTransformer2DModel",
+ "ParallelConfig",
"PixArtTransformer2DModel",
"PriorTransformer",
+ "QwenImageControlNetModel",
+ "QwenImageMultiControlNetModel",
"QwenImageTransformer2DModel",
"SanaControlNetModel",
"SanaTransformer2DModel",
@@ -367,7 +386,15 @@ else:
_import_structure["modular_pipelines"].extend(
[
"FluxAutoBlocks",
+ "FluxKontextAutoBlocks",
+ "FluxKontextModularPipeline",
"FluxModularPipeline",
+ "QwenImageAutoBlocks",
+ "QwenImageEditAutoBlocks",
+ "QwenImageEditModularPipeline",
+ "QwenImageEditPlusAutoBlocks",
+ "QwenImageEditPlusModularPipeline",
+ "QwenImageModularPipeline",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
"WanAutoBlocks",
@@ -396,6 +423,7 @@ else:
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
+ "BriaPipeline",
"ChromaImg2ImgPipeline",
"ChromaPipeline",
"CLIPImageProjection",
@@ -473,6 +501,7 @@ else:
"LTXImageToVideoPipeline",
"LTXLatentUpsamplePipeline",
"LTXPipeline",
+ "LucyEditPipeline",
"Lumina2Pipeline",
"Lumina2Text2ImgPipeline",
"LuminaPipeline",
@@ -488,6 +517,13 @@ else:
"PixArtAlphaPipeline",
"PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline",
+ "QwenImageControlNetInpaintPipeline",
+ "QwenImageControlNetPipeline",
+ "QwenImageEditInpaintPipeline",
+ "QwenImageEditPipeline",
+ "QwenImageEditPlusPipeline",
+ "QwenImageImg2ImgPipeline",
+ "QwenImageInpaintPipeline",
"QwenImagePipeline",
"ReduxImageEncoder",
"SanaControlNetPipeline",
@@ -785,6 +821,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .quantizers.quantization_config import QuantoConfig
+ try:
+ if not is_nvidia_modelopt_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_nvidia_modelopt_objects import *
+ else:
+ from .quantizers.quantization_config import NVIDIAModelOptConfig
+
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -804,6 +848,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
+ FrequencyDecoupledGuidance,
PerturbedAttentionGuidance,
SkipLayerGuidance,
SmoothedEnergyGuidance,
@@ -841,6 +886,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderOobleck,
AutoencoderTiny,
AutoModel,
+ BriaTransformer2DModel,
CacheMixin,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel,
@@ -848,6 +894,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView4Transformer2DModel,
ConsisIDTransformer3DModel,
ConsistencyDecoderVAE,
+ ContextParallelConfig,
ControlNetModel,
ControlNetUnionModel,
ControlNetXSAdapter,
@@ -875,8 +922,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MultiAdapter,
MultiControlNetModel,
OmniGenTransformer2DModel,
+ ParallelConfig,
PixArtTransformer2DModel,
PriorTransformer,
+ QwenImageControlNetModel,
+ QwenImageMultiControlNetModel,
QwenImageTransformer2DModel,
SanaControlNetModel,
SanaTransformer2DModel,
@@ -903,12 +953,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
WanVACETransformer3DModel,
attention_backend,
)
- from .modular_pipelines import (
- ComponentsManager,
- ComponentSpec,
- ModularPipeline,
- ModularPipelineBlocks,
- )
+ from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
@@ -1007,7 +1052,15 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .modular_pipelines import (
FluxAutoBlocks,
+ FluxKontextAutoBlocks,
+ FluxKontextModularPipeline,
FluxModularPipeline,
+ QwenImageAutoBlocks,
+ QwenImageEditAutoBlocks,
+ QwenImageEditModularPipeline,
+ QwenImageEditPlusAutoBlocks,
+ QwenImageEditPlusModularPipeline,
+ QwenImageModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
WanAutoBlocks,
@@ -1032,6 +1085,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
+ BriaPipeline,
ChromaImg2ImgPipeline,
ChromaPipeline,
CLIPImageProjection,
@@ -1109,6 +1163,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,
LTXPipeline,
+ LucyEditPipeline,
Lumina2Pipeline,
Lumina2Text2ImgPipeline,
LuminaPipeline,
@@ -1124,6 +1179,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PixArtAlphaPipeline,
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
+ QwenImageControlNetInpaintPipeline,
+ QwenImageControlNetPipeline,
+ QwenImageEditInpaintPipeline,
+ QwenImageEditPipeline,
+ QwenImageEditPlusPipeline,
+ QwenImageImg2ImgPipeline,
+ QwenImageInpaintPipeline,
QwenImagePipeline,
ReduxImageEncoder,
SanaControlNetPipeline,
diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py
index 540aab0307..1c4ee33acb 100644
--- a/src/diffusers/configuration_utils.py
+++ b/src/diffusers/configuration_utils.py
@@ -30,11 +30,11 @@ import numpy as np
from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
from huggingface_hub.utils import (
EntryNotFoundError,
+ HfHubHTTPError,
RepositoryNotFoundError,
RevisionNotFoundError,
validate_hf_hub_args,
)
-from requests import HTTPError
from typing_extensions import Self
from . import __version__
@@ -419,7 +419,7 @@ class ConfigMixin:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
)
- except HTTPError as err:
+ except HfHubHTTPError as err:
raise EnvironmentError(
"There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}"
diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py
index 3d14a8b3e0..6e5ac630ab 100644
--- a/src/diffusers/dependency_versions_table.py
+++ b/src/diffusers/dependency_versions_table.py
@@ -9,7 +9,8 @@ deps = {
"filelock": "filelock",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
- "huggingface-hub": "huggingface-hub>=0.34.0",
+ "httpx": "httpx<1.0.0",
+ "huggingface-hub": "huggingface-hub>=0.34.0,<2.0",
"requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata",
"invisible-watermark": "invisible-watermark>=0.2.0",
@@ -23,7 +24,7 @@ deps = {
"librosa": "librosa",
"numpy": "numpy",
"parameterized": "parameterized",
- "peft": "peft>=0.15.0",
+ "peft": "peft>=0.17.0",
"protobuf": "protobuf>=3.20.3,<4",
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",
@@ -39,6 +40,7 @@ deps = {
"gguf": "gguf>=0.10.0",
"torchao": "torchao>=0.7.0",
"bitsandbytes": "bitsandbytes>=0.43.3",
+ "nvidia_modelopt[hf]": "nvidia_modelopt[hf]>=0.33.1",
"regex": "regex!=2019.12.17",
"requests": "requests",
"tensorboard": "tensorboard",
@@ -50,4 +52,5 @@ deps = {
"black": "black",
"phonemizer": "phonemizer",
"opencv-python": "opencv-python",
+ "timm": "timm",
}
diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py
index 1c288f00f0..23cb7a0a71 100644
--- a/src/diffusers/guiders/__init__.py
+++ b/src/diffusers/guiders/__init__.py
@@ -22,6 +22,7 @@ if is_torch_available():
from .auto_guidance import AutoGuidance
from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
+ from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance
from .smoothed_energy_guidance import SmoothedEnergyGuidance
@@ -32,6 +33,7 @@ if is_torch_available():
AutoGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
+ FrequencyDecoupledGuidance,
PerturbedAttentionGuidance,
SkipLayerGuidance,
SmoothedEnergyGuidance,
diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py
index 81137db106..92b1fd5a1c 100644
--- a/src/diffusers/guiders/adaptive_projected_guidance.py
+++ b/src/diffusers/guiders/adaptive_projected_guidance.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
-from .guider_utils import BaseGuidance, rescale_noise_cfg
+from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING:
@@ -92,7 +92,7 @@ class AdaptiveProjectedGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
- def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
if not self._is_apg_enabled():
@@ -111,7 +111,7 @@ class AdaptiveProjectedGuidance(BaseGuidance):
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
- return pred, {}
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property
def is_conditional(self) -> bool:
diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py
index e1642211d3..5271a530ea 100644
--- a/src/diffusers/guiders/auto_guidance.py
+++ b/src/diffusers/guiders/auto_guidance.py
@@ -20,7 +20,7 @@ import torch
from ..configuration_utils import register_to_config
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
-from .guider_utils import BaseGuidance, rescale_noise_cfg
+from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING:
@@ -82,15 +82,15 @@ class AutoGuidance(BaseGuidance):
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
- if auto_guidance_layers is None and auto_guidance_config is None:
+ is_layer_or_config_provided = auto_guidance_layers is not None or auto_guidance_config is not None
+ is_layer_and_config_provided = auto_guidance_layers is not None and auto_guidance_config is not None
+ if not is_layer_or_config_provided:
raise ValueError(
- "Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable Skip Layer Guidance."
+ "Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable AutoGuidance."
)
- if auto_guidance_layers is not None and auto_guidance_config is not None:
+ if is_layer_and_config_provided:
raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.")
- if (dropout is None and auto_guidance_layers is not None) or (
- dropout is not None and auto_guidance_layers is None
- ):
+ if auto_guidance_config is None and dropout is None:
raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.")
if auto_guidance_layers is not None:
@@ -145,7 +145,7 @@ class AutoGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
- def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
if not self._is_ag_enabled():
@@ -158,7 +158,7 @@ class AutoGuidance(BaseGuidance):
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
- return pred, {}
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property
def is_conditional(self) -> bool:
diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py
index 7e72b92fce..050590336f 100644
--- a/src/diffusers/guiders/classifier_free_guidance.py
+++ b/src/diffusers/guiders/classifier_free_guidance.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
-from .guider_utils import BaseGuidance, rescale_noise_cfg
+from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING:
@@ -96,7 +96,7 @@ class ClassifierFreeGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
- def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
if not self._is_cfg_enabled():
@@ -109,7 +109,7 @@ class ClassifierFreeGuidance(BaseGuidance):
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
- return pred, {}
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property
def is_conditional(self) -> bool:
diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py
index 85d5cc62d4..b64e356331 100644
--- a/src/diffusers/guiders/classifier_free_zero_star_guidance.py
+++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
-from .guider_utils import BaseGuidance, rescale_noise_cfg
+from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING:
@@ -89,7 +89,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
- def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
if self._step < self.zero_init_steps:
@@ -109,7 +109,7 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
- return pred, {}
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property
def is_conditional(self) -> bool:
diff --git a/src/diffusers/guiders/frequency_decoupled_guidance.py b/src/diffusers/guiders/frequency_decoupled_guidance.py
new file mode 100644
index 0000000000..93822a180e
--- /dev/null
+++ b/src/diffusers/guiders/frequency_decoupled_guidance.py
@@ -0,0 +1,327 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import register_to_config
+from ..utils import is_kornia_available
+from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+_CAN_USE_KORNIA = is_kornia_available()
+
+
+if _CAN_USE_KORNIA:
+ from kornia.geometry import pyrup as upsample_and_blur_func
+ from kornia.geometry.transform import build_laplacian_pyramid as build_laplacian_pyramid_func
+else:
+ upsample_and_blur_func = None
+ build_laplacian_pyramid_func = None
+
+
+def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper
+ (Algorithm 2).
+ """
+ # v0 shape: [B, ...]
+ # v1 shape: [B, ...]
+ # Assume first dim is a batch dim and all other dims are channel or "spatial" dims
+ all_dims_but_first = list(range(1, len(v0.shape)))
+ if upcast_to_double:
+ dtype = v0.dtype
+ v0, v1 = v0.double(), v1.double()
+ v1 = torch.nn.functional.normalize(v1, dim=all_dims_but_first)
+ v0_parallel = (v0 * v1).sum(dim=all_dims_but_first, keepdim=True) * v1
+ v0_orthogonal = v0 - v0_parallel
+ if upcast_to_double:
+ v0_parallel = v0_parallel.to(dtype)
+ v0_orthogonal = v0_orthogonal.to(dtype)
+ return v0_parallel, v0_orthogonal
+
+
+def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor:
+ """
+ Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper
+ (Algorithm 2).
+ """
+ # pyramid shapes: [[B, C, H, W], [B, C, H/2, W/2], ...]
+ img = pyramid[-1]
+ for i in range(len(pyramid) - 2, -1, -1):
+ img = upsample_and_blur_func(img) + pyramid[i]
+ return img
+
+
+class FrequencyDecoupledGuidance(BaseGuidance):
+ """
+ Frequency-Decoupled Guidance (FDG): https://huggingface.co/papers/2506.19713
+
+ FDG is a technique similar to (and based on) classifier-free guidance (CFG) which is used to improve generation
+ quality and condition-following in diffusion models. Like CFG, during training we jointly train the model on both
+ conditional and unconditional data, and use a combination of the two during inference. (If you want more details on
+ how CFG works, you can check out the CFG guider.)
+
+ FDG differs from CFG in that the normal CFG prediction is instead decoupled into low- and high-frequency components
+ using a frequency transform (such as a Laplacian pyramid). The CFG update is then performed in frequency space
+ separately for the low- and high-frequency components with different guidance scales. Finally, the inverse
+ frequency transform is used to map the CFG frequency predictions back to data space (e.g. pixel space for images)
+ to form the final FDG prediction.
+
+ For images, the FDG authors found that using low guidance scales for the low-frequency components retains sample
+ diversity and realistic color composition, while using high guidance scales for high-frequency components enhances
+ sample quality (such as better visual details). Therefore, they recommend using low guidance scales (low w_low) for
+ the low-frequency components and high guidance scales (high w_high) for the high-frequency components. As an
+ example, they suggest w_low = 5.0 and w_high = 10.0 for Stable Diffusion XL (see Table 8 in the paper).
+
+ As with CFG, Diffusers implements the scaling and shifting on the unconditional prediction based on the [Imagen
+ paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original CFG paper proposed in
+ theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
+
+ The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
+ paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
+
+ Args:
+ guidance_scales (`List[float]`, defaults to `[10.0, 5.0]`):
+ The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest
+ frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower
+ values allow for more freedom in generation. Higher values may lead to saturation and deterioration of
+ image quality. The FDG authors recommend using higher guidance scales for higher frequency components and
+ lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in
+ descending order).
+ guidance_rescale (`float` or `List[float]`, defaults to `0.0`):
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as
+ `guidance_scales`.
+ parallel_weights (`float` or `List[float]`, *optional*):
+ Optional weights for the parallel component of each frequency component of the projected CFG shift. If not
+ set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift
+ (that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is
+ recommended. If a list is supplied, it should be the same length as `guidance_scales`.
+ use_original_formulation (`bool`, defaults to `False`):
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ start (`float` or `List[float]`, defaults to `0.0`):
+ The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it
+ should be the same length as `guidance_scales`.
+ stop (`float` or `List[float]`, defaults to `1.0`):
+ The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it
+ should be the same length as `guidance_scales`.
+ guidance_rescale_space (`str`, defaults to `"data"`):
+ Whether to performance guidance rescaling in `"data"` space (after the full FDG update in data space) or in
+ `"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is
+ speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value
+ will be used; otherwise, per-frequency-level guidance rescale values will be used if available.
+ upcast_to_double (`bool`, defaults to `True`):
+ Whether to upcast certain operations, such as the projection operation when using `parallel_weights`, to
+ float64 when performing guidance. This may result in better performance at the cost of increased runtime.
+ """
+
+ _input_predictions = ["pred_cond", "pred_uncond"]
+
+ @register_to_config
+ def __init__(
+ self,
+ guidance_scales: Union[List[float], Tuple[float]] = [10.0, 5.0],
+ guidance_rescale: Union[float, List[float], Tuple[float]] = 0.0,
+ parallel_weights: Optional[Union[float, List[float], Tuple[float]]] = None,
+ use_original_formulation: bool = False,
+ start: Union[float, List[float], Tuple[float]] = 0.0,
+ stop: Union[float, List[float], Tuple[float]] = 1.0,
+ guidance_rescale_space: str = "data",
+ upcast_to_double: bool = True,
+ ):
+ if not _CAN_USE_KORNIA:
+ raise ImportError(
+ "The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which "
+ "it depends is not available in the current environment. You can install `kornia` with `pip install "
+ "kornia`."
+ )
+
+ # Set start to earliest start for any freq component and stop to latest stop for any freq component
+ min_start = start if isinstance(start, float) else min(start)
+ max_stop = stop if isinstance(stop, float) else max(stop)
+ super().__init__(min_start, max_stop)
+
+ self.guidance_scales = guidance_scales
+ self.levels = len(guidance_scales)
+
+ if isinstance(guidance_rescale, float):
+ self.guidance_rescale = [guidance_rescale] * self.levels
+ elif len(guidance_rescale) == self.levels:
+ self.guidance_rescale = guidance_rescale
+ else:
+ raise ValueError(
+ f"`guidance_rescale` has length {len(guidance_rescale)} but should have the same length as "
+ f"`guidance_scales` ({len(self.guidance_scales)})"
+ )
+ # Whether to perform guidance rescaling in frequency space (right after the CFG update) or data space (after
+ # transforming from frequency space back to data space)
+ if guidance_rescale_space not in ["data", "freq"]:
+ raise ValueError(
+ f"Guidance rescale space is {guidance_rescale_space} but must be one of `data` or `freq`."
+ )
+ self.guidance_rescale_space = guidance_rescale_space
+
+ if parallel_weights is None:
+ # Use normal CFG shift (equal weights for parallel and orthogonal components)
+ self.parallel_weights = [1.0] * self.levels
+ elif isinstance(parallel_weights, float):
+ self.parallel_weights = [parallel_weights] * self.levels
+ elif len(parallel_weights) == self.levels:
+ self.parallel_weights = parallel_weights
+ else:
+ raise ValueError(
+ f"`parallel_weights` has length {len(parallel_weights)} but should have the same length as "
+ f"`guidance_scales` ({len(self.guidance_scales)})"
+ )
+
+ self.use_original_formulation = use_original_formulation
+ self.upcast_to_double = upcast_to_double
+
+ if isinstance(start, float):
+ self.guidance_start = [start] * self.levels
+ elif len(start) == self.levels:
+ self.guidance_start = start
+ else:
+ raise ValueError(
+ f"`start` has length {len(start)} but should have the same length as `guidance_scales` "
+ f"({len(self.guidance_scales)})"
+ )
+ if isinstance(stop, float):
+ self.guidance_stop = [stop] * self.levels
+ elif len(stop) == self.levels:
+ self.guidance_stop = stop
+ else:
+ raise ValueError(
+ f"`stop` has length {len(stop)} but should have the same length as `guidance_scales` "
+ f"({len(self.guidance_scales)})"
+ )
+
+ def prepare_inputs(
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
+ ) -> List["BlockState"]:
+ if input_fields is None:
+ input_fields = self._input_fields
+
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for i in range(self.num_conditions):
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
+ data_batches.append(data_batch)
+ return data_batches
+
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
+ pred = None
+
+ if not self._is_fdg_enabled():
+ pred = pred_cond
+ else:
+ # Apply the frequency transform (e.g. Laplacian pyramid) to the conditional and unconditional predictions.
+ pred_cond_pyramid = build_laplacian_pyramid_func(pred_cond, self.levels)
+ pred_uncond_pyramid = build_laplacian_pyramid_func(pred_uncond, self.levels)
+
+ # From high frequencies to low frequencies, following the paper implementation
+ pred_guided_pyramid = []
+ parameters = zip(self.guidance_scales, self.parallel_weights, self.guidance_rescale)
+ for level, (guidance_scale, parallel_weight, guidance_rescale) in enumerate(parameters):
+ if self._is_fdg_enabled_for_level(level):
+ # Get the cond/uncond preds (in freq space) at the current frequency level
+ pred_cond_freq = pred_cond_pyramid[level]
+ pred_uncond_freq = pred_uncond_pyramid[level]
+
+ shift = pred_cond_freq - pred_uncond_freq
+
+ # Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift)
+ if not math.isclose(parallel_weight, 1.0):
+ shift_parallel, shift_orthogonal = project(shift, pred_cond_freq, self.upcast_to_double)
+ shift = parallel_weight * shift_parallel + shift_orthogonal
+
+ # Apply CFG update for the current frequency level
+ pred = pred_cond_freq if self.use_original_formulation else pred_uncond_freq
+ pred = pred + guidance_scale * shift
+
+ if self.guidance_rescale_space == "freq" and guidance_rescale > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond_freq, guidance_rescale)
+
+ # Add the current FDG guided level to the FDG prediction pyramid
+ pred_guided_pyramid.append(pred)
+ else:
+ # Add the current pred_cond_pyramid level as the "non-FDG" prediction
+ pred_guided_pyramid.append(pred_cond_freq)
+
+ # Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform
+ pred = build_image_from_pyramid(pred_guided_pyramid)
+
+ # If rescaling in data space, use the first elem of self.guidance_rescale as the "global" rescale value
+ # across all freq levels
+ if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0])
+
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
+
+ @property
+ def is_conditional(self) -> bool:
+ return self._count_prepared == 1
+
+ @property
+ def num_conditions(self) -> int:
+ num_conditions = 1
+ if self._is_fdg_enabled():
+ num_conditions += 1
+ return num_conditions
+
+ def _is_fdg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self._start * self._num_inference_steps)
+ skip_stop_step = int(self._stop * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = all(math.isclose(guidance_scale, 0.0) for guidance_scale in self.guidance_scales)
+ else:
+ is_close = all(math.isclose(guidance_scale, 1.0) for guidance_scale in self.guidance_scales)
+
+ return is_within_range and not is_close
+
+ def _is_fdg_enabled_for_level(self, level: int) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self.guidance_start[level] * self._num_inference_steps)
+ skip_stop_step = int(self.guidance_stop[level] * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.guidance_scales[level], 0.0)
+ else:
+ is_close = math.isclose(self.guidance_scales[level], 1.0)
+
+ return is_within_range and not is_close
diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py
index 9dc83a7f1d..7524b5a3ea 100644
--- a/src/diffusers/guiders/guider_utils.py
+++ b/src/diffusers/guiders/guider_utils.py
@@ -20,7 +20,7 @@ from huggingface_hub.utils import validate_hf_hub_args
from typing_extensions import Self
from ..configuration_utils import ConfigMixin
-from ..utils import PushToHubMixin, get_logger
+from ..utils import BaseOutput, PushToHubMixin, get_logger
if TYPE_CHECKING:
@@ -247,15 +247,11 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
-
-
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
- auth login`. You can also activate the special
- ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
+ > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
+ with `hf > auth login`. You can also activate the special >
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a >
firewalled environment.
-
-
"""
config, kwargs, commit_hash = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path,
@@ -284,6 +280,12 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
+class GuiderOutput(BaseOutput):
+ pred: torch.Tensor
+ pred_cond: Optional[torch.Tensor]
+ pred_uncond: Optional[torch.Tensor]
+
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
diff --git a/src/diffusers/guiders/perturbed_attention_guidance.py b/src/diffusers/guiders/perturbed_attention_guidance.py
index 1b2256732f..e294e8d0db 100644
--- a/src/diffusers/guiders/perturbed_attention_guidance.py
+++ b/src/diffusers/guiders/perturbed_attention_guidance.py
@@ -21,7 +21,7 @@ from ..configuration_utils import register_to_config
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from ..utils import get_logger
-from .guider_utils import BaseGuidance, rescale_noise_cfg
+from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING:
@@ -197,7 +197,7 @@ class PerturbedAttentionGuidance(BaseGuidance):
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_skip: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
+ ) -> GuiderOutput:
pred = None
if not self._is_cfg_enabled() and not self._is_slg_enabled():
@@ -219,7 +219,7 @@ class PerturbedAttentionGuidance(BaseGuidance):
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
- return pred, {}
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py
index 68a657960a..3530df8b0a 100644
--- a/src/diffusers/guiders/skip_layer_guidance.py
+++ b/src/diffusers/guiders/skip_layer_guidance.py
@@ -20,7 +20,7 @@ import torch
from ..configuration_utils import register_to_config
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
-from .guider_utils import BaseGuidance, rescale_noise_cfg
+from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING:
@@ -192,7 +192,7 @@ class SkipLayerGuidance(BaseGuidance):
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_skip: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
+ ) -> GuiderOutput:
pred = None
if not self._is_cfg_enabled() and not self._is_slg_enabled():
@@ -214,7 +214,7 @@ class SkipLayerGuidance(BaseGuidance):
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
- return pred, {}
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property
def is_conditional(self) -> bool:
diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py
index d8e8a3cf2f..767d20b62f 100644
--- a/src/diffusers/guiders/smoothed_energy_guidance.py
+++ b/src/diffusers/guiders/smoothed_energy_guidance.py
@@ -20,7 +20,7 @@ import torch
from ..configuration_utils import register_to_config
from ..hooks import HookRegistry
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
-from .guider_utils import BaseGuidance, rescale_noise_cfg
+from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING:
@@ -181,7 +181,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_seg: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
+ ) -> GuiderOutput:
pred = None
if not self._is_cfg_enabled() and not self._is_seg_enabled():
@@ -203,7 +203,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
- return pred, {}
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property
def is_conditional(self) -> bool:
diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py
index b3187e5263..df1e69fe71 100644
--- a/src/diffusers/guiders/tangential_classifier_free_guidance.py
+++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
-from .guider_utils import BaseGuidance, rescale_noise_cfg
+from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING:
@@ -78,7 +78,7 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
- def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
if not self._is_tcfg_enabled():
@@ -89,7 +89,7 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
- return pred, {}
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property
def is_conditional(self) -> bool:
diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py
index 525a0747da..524a92ea99 100644
--- a/src/diffusers/hooks/__init__.py
+++ b/src/diffusers/hooks/__init__.py
@@ -16,6 +16,7 @@ from ..utils import is_torch_available
if is_torch_available():
+ from .context_parallel import apply_context_parallel
from .faster_cache import FasterCacheConfig, apply_faster_cache
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading
diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py
index f328078ce4..f6e5bdd52d 100644
--- a/src/diffusers/hooks/_helpers.py
+++ b/src/diffusers/hooks/_helpers.py
@@ -108,6 +108,7 @@ def _register_attention_processors_metadata():
from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
from ..models.transformers.transformer_flux import FluxAttnProcessor
+ from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
# AttnProcessor2_0
@@ -133,16 +134,26 @@ def _register_attention_processors_metadata():
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
),
)
+
# FluxAttnProcessor
AttentionProcessorRegistry.register(
model_class=FluxAttnProcessor,
metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor),
)
+ # QwenDoubleStreamAttnProcessor2
+ AttentionProcessorRegistry.register(
+ model_class=QwenDoubleStreamAttnProcessor2_0,
+ metadata=AttentionProcessorMetadata(
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0
+ ),
+ )
+
def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
+ from ..models.transformers.transformer_bria import BriaTransformerBlock
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from ..models.transformers.transformer_hunyuan_video import (
@@ -164,6 +175,13 @@ def _register_transformer_blocks_metadata():
return_encoder_hidden_states_index=None,
),
)
+ TransformerBlockRegistry.register(
+ model_class=BriaTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=None,
+ ),
+ )
# CogVideoX
TransformerBlockRegistry.register(
@@ -289,4 +307,5 @@ _skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___h
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
# not sure what this is yet.
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
+_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
# fmt: on
diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py
new file mode 100644
index 0000000000..915fe453b9
--- /dev/null
+++ b/src/diffusers/hooks/context_parallel.py
@@ -0,0 +1,300 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from dataclasses import dataclass
+from typing import Dict, List, Type, Union
+
+import torch
+
+
+if torch.distributed.is_available():
+ import torch.distributed._functional_collectives as funcol
+
+from ..models._modeling_parallel import (
+ ContextParallelConfig,
+ ContextParallelInput,
+ ContextParallelModelPlan,
+ ContextParallelOutput,
+)
+from ..utils import get_logger
+from ..utils.torch_utils import unwrap_module
+from .hooks import HookRegistry, ModelHook
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}"
+_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
+
+
+# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
+@dataclass
+class ModuleForwardMetadata:
+ cached_parameter_indices: Dict[str, int] = None
+ _cls: Type = None
+
+ def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
+ kwargs = kwargs or {}
+
+ if identifier in kwargs:
+ return kwargs[identifier], True, None
+
+ if self.cached_parameter_indices is not None:
+ index = self.cached_parameter_indices.get(identifier, None)
+ if index is None:
+ raise ValueError(f"Parameter '{identifier}' not found in cached indices.")
+ return args[index], False, index
+
+ if self._cls is None:
+ raise ValueError("Model class is not set for metadata.")
+
+ parameters = list(inspect.signature(self._cls.forward).parameters.keys())
+ parameters = parameters[1:] # skip `self`
+ self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
+
+ if identifier not in self.cached_parameter_indices:
+ raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
+
+ index = self.cached_parameter_indices[identifier]
+
+ if index >= len(args):
+ raise ValueError(f"Expected {index} arguments but got {len(args)}.")
+
+ return args[index], False, index
+
+
+def apply_context_parallel(
+ module: torch.nn.Module,
+ parallel_config: ContextParallelConfig,
+ plan: Dict[str, ContextParallelModelPlan],
+) -> None:
+ """Apply context parallel on a model."""
+ logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")
+
+ for module_id, cp_model_plan in plan.items():
+ submodule = _get_submodule_by_name(module, module_id)
+ if not isinstance(submodule, list):
+ submodule = [submodule]
+
+ logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules")
+
+ for m in submodule:
+ if isinstance(cp_model_plan, dict):
+ hook = ContextParallelSplitHook(cp_model_plan, parallel_config)
+ hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
+ elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
+ if isinstance(cp_model_plan, ContextParallelOutput):
+ cp_model_plan = [cp_model_plan]
+ if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan):
+ raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")
+ hook = ContextParallelGatherHook(cp_model_plan, parallel_config)
+ hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
+ else:
+ raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
+ registry = HookRegistry.check_if_exists_or_initialize(m)
+ registry.register_hook(hook, hook_name)
+
+
+def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None:
+ for module_id, cp_model_plan in plan.items():
+ submodule = _get_submodule_by_name(module, module_id)
+ if not isinstance(submodule, list):
+ submodule = [submodule]
+
+ for m in submodule:
+ registry = HookRegistry.check_if_exists_or_initialize(m)
+ if isinstance(cp_model_plan, dict):
+ hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
+ elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
+ hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
+ else:
+ raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
+ registry.remove_hook(hook_name)
+
+
+class ContextParallelSplitHook(ModelHook):
+ def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
+ super().__init__()
+ self.metadata = metadata
+ self.parallel_config = parallel_config
+ self.module_forward_metadata = None
+
+ def initialize_hook(self, module):
+ cls = unwrap_module(module).__class__
+ self.module_forward_metadata = ModuleForwardMetadata(_cls=cls)
+ return module
+
+ def pre_forward(self, module, *args, **kwargs):
+ args_list = list(args)
+
+ for name, cpm in self.metadata.items():
+ if isinstance(cpm, ContextParallelInput) and cpm.split_output:
+ continue
+
+ # Maybe the parameter was passed as a keyword argument
+ input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs(
+ name, args_list, kwargs
+ )
+
+ if input_val is None:
+ continue
+
+ # The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
+ # the output instead of input for a particular layer by setting split_output=True
+ if isinstance(input_val, torch.Tensor):
+ input_val = self._prepare_cp_input(input_val, cpm)
+ elif isinstance(input_val, (list, tuple)):
+ if len(input_val) != len(cpm):
+ raise ValueError(
+ f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
+ )
+ sharded_input_val = []
+ for i, x in enumerate(input_val):
+ if torch.is_tensor(x) and not cpm[i].split_output:
+ x = self._prepare_cp_input(x, cpm[i])
+ sharded_input_val.append(x)
+ input_val = sharded_input_val
+ else:
+ raise ValueError(f"Unsupported input type: {type(input_val)}")
+
+ if is_kwarg:
+ kwargs[name] = input_val
+ elif index is not None and index < len(args_list):
+ args_list[index] = input_val
+ else:
+ raise ValueError(
+ f"An unexpected error occurred while processing the input '{name}'. Please open an "
+ f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible "
+ f"example along with the full stack trace."
+ )
+
+ return tuple(args_list), kwargs
+
+ def post_forward(self, module, output):
+ is_tensor = isinstance(output, torch.Tensor)
+ is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)
+
+ if not is_tensor and not is_tensor_list:
+ raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
+
+ output = [output] if is_tensor else list(output)
+ for index, cpm in self.metadata.items():
+ if not isinstance(cpm, ContextParallelInput) or not cpm.split_output:
+ continue
+ if index >= len(output):
+ raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
+ current_output = output[index]
+ current_output = self._prepare_cp_input(current_output, cpm)
+ output[index] = current_output
+
+ return output[0] if is_tensor else tuple(output)
+
+ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
+ if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
+ raise ValueError(
+ f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
+ )
+ return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
+
+
+class ContextParallelGatherHook(ModelHook):
+ def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
+ super().__init__()
+ self.metadata = metadata
+ self.parallel_config = parallel_config
+
+ def post_forward(self, module, output):
+ is_tensor = isinstance(output, torch.Tensor)
+
+ if is_tensor:
+ output = [output]
+ elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)):
+ raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
+
+ output = list(output)
+
+ if len(output) != len(self.metadata):
+ raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.")
+
+ for i, cpm in enumerate(self.metadata):
+ if cpm is None:
+ continue
+ output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
+
+ return output[0] if is_tensor else tuple(output)
+
+
+class AllGatherFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, tensor, dim, group):
+ ctx.dim = dim
+ ctx.group = group
+ ctx.world_size = torch.distributed.get_world_size(group)
+ ctx.rank = torch.distributed.get_rank(group)
+ return funcol.all_gather_tensor(tensor, dim, group=group)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim)
+ return grad_chunks[ctx.rank], None, None
+
+
+class EquipartitionSharder:
+ @classmethod
+ def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
+ # NOTE: the following assertion does not have to be true in general. We simply enforce it for now
+ # because the alternate case has not yet been tested/required for any model.
+ assert tensor.size()[dim] % mesh.size() == 0, (
+ "Tensor size along dimension to be sharded must be divisible by mesh size"
+ )
+
+ # The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
+ # return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
+
+ return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())]
+
+ @classmethod
+ def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
+ tensor = tensor.contiguous()
+ tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group())
+ return tensor
+
+
+def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
+ if name.count("*") > 1:
+ raise ValueError("Wildcard '*' can only be used once in the name")
+ return _find_submodule_by_name(model, name)
+
+
+def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
+ if name == "":
+ return model
+ first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
+ if first_atom == "*":
+ if not isinstance(model, torch.nn.ModuleList):
+ raise ValueError("Wildcard '*' can only be used with ModuleList")
+ submodules = []
+ for submodule in model:
+ subsubmodules = _find_submodule_by_name(submodule, remaining_name)
+ if not isinstance(subsubmodules, list):
+ subsubmodules = [subsubmodules]
+ submodules.extend(subsubmodules)
+ return submodules
+ else:
+ if hasattr(model, first_atom):
+ submodule = getattr(model, first_atom)
+ return _find_submodule_by_name(submodule, remaining_name)
+ else:
+ raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'")
diff --git a/src/diffusers/hooks/faster_cache.py b/src/diffusers/hooks/faster_cache.py
index 53e5bd792c..a01afeffdb 100644
--- a/src/diffusers/hooks/faster_cache.py
+++ b/src/diffusers/hooks/faster_cache.py
@@ -54,11 +54,11 @@ class FasterCacheConfig:
Attributes:
spatial_attention_block_skip_range (`int`, defaults to `2`):
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
- be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention
+ be skipped `N - 1` times (i.e., cached attention states will be reused) before computing the new attention
states again.
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
- be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention
+ be skipped `N - 1` times (i.e., cached attention states will be reused) before computing the new attention
states again.
spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`):
The timestep range within which the spatial attention computation can be skipped without a significant loss
@@ -90,7 +90,7 @@ class FasterCacheConfig:
from the conditional branch outputs.
unconditional_batch_skip_range (`int`, defaults to `5`):
Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch
- computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be re-used) before
+ computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be reused) before
computing the new unconditional branch states again.
unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`):
The timestep range within which the unconditional branch computation can be skipped without a significant
diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py
index 6b6871f9dc..38f291f520 100644
--- a/src/diffusers/hooks/group_offloading.py
+++ b/src/diffusers/hooks/group_offloading.py
@@ -245,7 +245,6 @@ class ModuleGroup:
param.data = self.cpu_param_dict[param]
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]
-
else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=False)
@@ -303,9 +302,23 @@ class GroupOffloadingHook(ModelHook):
if self.group.onload_leader == module:
if self.group.onload_self:
self.group.onload_()
- if self.next_group is not None and not self.next_group.onload_self:
+
+ should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
+ if should_onload_next_group:
self.next_group.onload_()
+ should_synchronize = (
+ not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
+ )
+ if should_synchronize:
+ # If this group didn't onload itself, it means it was asynchronously onloaded by the
+ # previous group. We need to synchronize the side stream to ensure parameters
+ # are completely loaded to proceed with forward pass. Without this, uninitialized
+ # weights will be used in the computation, leading to incorrect results
+ # Also, we should only do this synchronization if we don't already do it from the sync call in
+ # self.next_group.onload_, hence the `not should_onload_next_group` check.
+ self.group.stream.synchronize()
+
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
return args, kwargs
diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py
index ee3f410331..12d6aa0616 100644
--- a/src/diffusers/hooks/pyramid_attention_broadcast.py
+++ b/src/diffusers/hooks/pyramid_attention_broadcast.py
@@ -45,15 +45,15 @@ class PyramidAttentionBroadcastConfig:
spatial_attention_block_skip_range (`int`, *optional*, defaults to `None`):
The number of times a specific spatial attention broadcast is skipped before computing the attention states
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
- old attention states will be re-used) before computing the new attention states again.
+ old attention states will be reused) before computing the new attention states again.
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
The number of times a specific temporal attention broadcast is skipped before computing the attention
states to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times
- (i.e., old attention states will be re-used) before computing the new attention states again.
+ (i.e., old attention states will be reused) before computing the new attention states again.
cross_attention_block_skip_range (`int`, *optional*, defaults to `None`):
The number of times a specific cross-attention broadcast is skipped before computing the attention states
to re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e.,
- old attention states will be re-used) before computing the new attention states again.
+ old attention states will be reused) before computing the new attention states again.
spatial_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the spatial attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
@@ -305,7 +305,7 @@ def _apply_pyramid_attention_broadcast_hook(
block_skip_range (`int`):
The number of times a specific attention broadcast is skipped before computing the attention states to
re-use. If this is set to the value `N`, the attention computation will be skipped `N - 1` times (i.e., old
- attention states will be re-used) before computing the new attention states again.
+ attention states will be reused) before computing the new attention states again.
current_timestep_callback (`Callable[[], int]`):
A callback function that returns the current inference timestep.
"""
diff --git a/src/diffusers/hooks/utils.py b/src/diffusers/hooks/utils.py
new file mode 100644
index 0000000000..c5260eeebe
--- /dev/null
+++ b/src/diffusers/hooks/utils.py
@@ -0,0 +1,43 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES
+
+
+def _get_identifiable_transformer_blocks_in_module(module: torch.nn.Module):
+ module_list_with_transformer_blocks = []
+ for name, submodule in module.named_modules():
+ name_endswith_identifier = any(name.endswith(identifier) for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS)
+ is_modulelist = isinstance(submodule, torch.nn.ModuleList)
+ if name_endswith_identifier and is_modulelist:
+ module_list_with_transformer_blocks.append((name, submodule))
+ return module_list_with_transformer_blocks
+
+
+def _get_identifiable_attention_layers_in_module(module: torch.nn.Module):
+ attention_layers = []
+ for name, submodule in module.named_modules():
+ if isinstance(submodule, _ATTENTION_CLASSES):
+ attention_layers.append((name, submodule))
+ return attention_layers
+
+
+def _get_identifiable_feedforward_layers_in_module(module: torch.nn.Module):
+ feedforward_layers = []
+ for name, submodule in module.named_modules():
+ if isinstance(submodule, _FEEDFORWARD_CLASSES):
+ feedforward_layers.append((name, submodule))
+ return feedforward_layers
diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py
index 6a3cf77a7d..0e3082eada 100644
--- a/src/diffusers/image_processor.py
+++ b/src/diffusers/image_processor.py
@@ -523,6 +523,7 @@ class VaeImageProcessor(ConfigMixin):
size=(height, width),
)
image = self.pt_to_numpy(image)
+
return image
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
@@ -838,6 +839,137 @@ class VaeImageProcessor(ConfigMixin):
return image
+class InpaintProcessor(ConfigMixin):
+ """
+ Image processor for inpainting image and mask.
+ """
+
+ config_name = CONFIG_NAME
+
+ @register_to_config
+ def __init__(
+ self,
+ do_resize: bool = True,
+ vae_scale_factor: int = 8,
+ vae_latent_channels: int = 4,
+ resample: str = "lanczos",
+ reducing_gap: int = None,
+ do_normalize: bool = True,
+ do_binarize: bool = False,
+ do_convert_grayscale: bool = False,
+ mask_do_normalize: bool = False,
+ mask_do_binarize: bool = True,
+ mask_do_convert_grayscale: bool = True,
+ ):
+ super().__init__()
+
+ self._image_processor = VaeImageProcessor(
+ do_resize=do_resize,
+ vae_scale_factor=vae_scale_factor,
+ vae_latent_channels=vae_latent_channels,
+ resample=resample,
+ reducing_gap=reducing_gap,
+ do_normalize=do_normalize,
+ do_binarize=do_binarize,
+ do_convert_grayscale=do_convert_grayscale,
+ )
+ self._mask_processor = VaeImageProcessor(
+ do_resize=do_resize,
+ vae_scale_factor=vae_scale_factor,
+ vae_latent_channels=vae_latent_channels,
+ resample=resample,
+ reducing_gap=reducing_gap,
+ do_normalize=mask_do_normalize,
+ do_binarize=mask_do_binarize,
+ do_convert_grayscale=mask_do_convert_grayscale,
+ )
+
+ def preprocess(
+ self,
+ image: PIL.Image.Image,
+ mask: PIL.Image.Image = None,
+ height: int = None,
+ width: int = None,
+ padding_mask_crop: Optional[int] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Preprocess the image and mask.
+ """
+ if mask is None and padding_mask_crop is not None:
+ raise ValueError("mask must be provided if padding_mask_crop is provided")
+
+ # if mask is None, same behavior as regular image processor
+ if mask is None:
+ return self._image_processor.preprocess(image, height=height, width=width)
+
+ if padding_mask_crop is not None:
+ crops_coords = self._image_processor.get_crop_region(mask, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ processed_image = self._image_processor.preprocess(
+ image,
+ height=height,
+ width=width,
+ crops_coords=crops_coords,
+ resize_mode=resize_mode,
+ )
+
+ processed_mask = self._mask_processor.preprocess(
+ mask,
+ height=height,
+ width=width,
+ resize_mode=resize_mode,
+ crops_coords=crops_coords,
+ )
+
+ if crops_coords is not None:
+ postprocessing_kwargs = {
+ "crops_coords": crops_coords,
+ "original_image": image,
+ "original_mask": mask,
+ }
+ else:
+ postprocessing_kwargs = {
+ "crops_coords": None,
+ "original_image": None,
+ "original_mask": None,
+ }
+
+ return processed_image, processed_mask, postprocessing_kwargs
+
+ def postprocess(
+ self,
+ image: torch.Tensor,
+ output_type: str = "pil",
+ original_image: Optional[PIL.Image.Image] = None,
+ original_mask: Optional[PIL.Image.Image] = None,
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
+ ) -> Tuple[PIL.Image.Image, PIL.Image.Image]:
+ """
+ Postprocess the image, optionally apply mask overlay
+ """
+ image = self._image_processor.postprocess(
+ image,
+ output_type=output_type,
+ )
+ # optionally apply the mask overlay
+ if crops_coords is not None and (original_image is None or original_mask is None):
+ raise ValueError("original_image and original_mask must be provided if crops_coords is provided")
+
+ elif crops_coords is not None and output_type != "pil":
+ raise ValueError("output_type must be 'pil' if crops_coords is provided")
+
+ elif crops_coords is not None:
+ image = [
+ self._image_processor.apply_overlay(original_mask, original_image, i, crops_coords) for i in image
+ ]
+
+ return image
+
+
class VaeImageProcessorLDM3D(VaeImageProcessor):
"""
Image processor for VAE LDM3D.
diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py
index 3089086d54..3d75a7d875 100644
--- a/src/diffusers/loaders/lora_base.py
+++ b/src/diffusers/loaders/lora_base.py
@@ -544,11 +544,7 @@ class LoraBaseMixin:
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
- This is an experimental API.
-
-
+ > [!WARNING] > This is an experimental API.
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
@@ -628,11 +624,7 @@ class LoraBaseMixin:
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
- This is an experimental API.
-
-
+ > [!WARNING] > This is an experimental API.
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
@@ -754,7 +746,11 @@ class LoraBaseMixin:
# Decompose weights into weights for denoiser and text encoders.
_component_adapter_weights = {}
for component in self._lora_loadable_modules:
- model = getattr(self, component)
+ model = getattr(self, component, None)
+ # To guard for cases like Wan. In Wan2.1 and WanVace, we have a single denoiser.
+ # Whereas in Wan 2.2, we have two denoisers.
+ if model is None:
+ continue
for adapter_name, weights in zip(adapter_names, adapter_weights):
if isinstance(weights, dict):
@@ -1060,6 +1056,41 @@ class LoraBaseMixin:
save_function(state_dict, save_path)
logger.info(f"Model weights saved in {save_path}")
+ @classmethod
+ def _save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ lora_layers: Dict[str, Dict[str, Union[torch.nn.Module, torch.Tensor]]],
+ lora_metadata: Dict[str, Optional[dict]],
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ """
+ Helper method to pack and save LoRA weights and metadata. This method centralizes the saving logic for all
+ pipeline types.
+ """
+ state_dict = {}
+ final_lora_adapter_metadata = {}
+
+ for prefix, layers in lora_layers.items():
+ state_dict.update(cls.pack_weights(layers, prefix))
+
+ for prefix, metadata in lora_metadata.items():
+ if metadata:
+ final_lora_adapter_metadata.update(_pack_dict_with_prefix(metadata, prefix))
+
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ lora_adapter_metadata=final_lora_adapter_metadata if final_lora_adapter_metadata else None,
+ )
+
@classmethod
def _optionally_disable_offloading(cls, _pipeline):
return _func_optionally_disable_offloading(_pipeline=_pipeline)
diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py
index ba96dccbe3..89afb6529a 100644
--- a/src/diffusers/loaders/lora_conversion_utils.py
+++ b/src/diffusers/loaders/lora_conversion_utils.py
@@ -558,70 +558,62 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
ait_sd[target_key] = value
if any("guidance_in" in k for k in sds_sd):
- assign_remaining_weights(
- [
- (
- "time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
- "lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
- None,
- ),
- (
- "time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
- "lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
- None,
- ),
- ],
+ _convert_to_ai_toolkit(
sds_sd,
+ ait_sd,
+ "lora_unet_guidance_in_in_layer",
+ "time_text_embed.guidance_embedder.linear_1",
+ )
+
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ "lora_unet_guidance_in_out_layer",
+ "time_text_embed.guidance_embedder.linear_2",
)
if any("img_in" in k for k in sds_sd):
- assign_remaining_weights(
- [
- ("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
- ],
+ _convert_to_ai_toolkit(
sds_sd,
+ ait_sd,
+ "lora_unet_img_in",
+ "x_embedder",
)
if any("txt_in" in k for k in sds_sd):
- assign_remaining_weights(
- [
- ("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
- ],
+ _convert_to_ai_toolkit(
sds_sd,
+ ait_sd,
+ "lora_unet_txt_in",
+ "context_embedder",
)
if any("time_in" in k for k in sds_sd):
- assign_remaining_weights(
- [
- (
- "time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
- "lora_unet_time_in_in_layer.{orig_lora_key}.weight",
- None,
- ),
- (
- "time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
- "lora_unet_time_in_out_layer.{orig_lora_key}.weight",
- None,
- ),
- ],
+ _convert_to_ai_toolkit(
sds_sd,
+ ait_sd,
+ "lora_unet_time_in_in_layer",
+ "time_text_embed.timestep_embedder.linear_1",
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ "lora_unet_time_in_out_layer",
+ "time_text_embed.timestep_embedder.linear_2",
)
if any("vector_in" in k for k in sds_sd):
- assign_remaining_weights(
- [
- (
- "time_text_embed.text_embedder.linear_1.{lora_key}.weight",
- "lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
- None,
- ),
- (
- "time_text_embed.text_embedder.linear_2.{lora_key}.weight",
- "lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
- None,
- ),
- ],
+ _convert_to_ai_toolkit(
sds_sd,
+ ait_sd,
+ "lora_unet_vector_in_in_layer",
+ "time_text_embed.text_embedder.linear_1",
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ "lora_unet_vector_in_out_layer",
+ "time_text_embed.text_embedder.linear_2",
)
if any("final_layer" in k for k in sds_sd):
@@ -817,7 +809,11 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
# has both `peft` and non-peft state dict.
has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
if has_peft_state_dict:
- state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
+ state_dict = {
+ k.replace("lora_down.weight", "lora_A.weight").replace("lora_up.weight", "lora_B.weight"): v
+ for k, v in state_dict.items()
+ if k.startswith("transformer.")
+ }
return state_dict
# Another weird one.
@@ -1829,6 +1825,17 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
)
+ def get_alpha_scales(down_weight, alpha_key):
+ rank = down_weight.shape[0]
+ alpha = original_state_dict.pop(alpha_key).item()
+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
+ scale_down = scale
+ scale_up = 1.0
+ while scale_down * 2 < scale_up:
+ scale_down *= 2
+ scale_up /= 2
+ return scale_down, scale_up
+
for key in list(original_state_dict.keys()):
if key.endswith((".diff", ".diff_b")) and "norm" in key:
# NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
@@ -1848,15 +1855,26 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
for i in range(min_block, max_block + 1):
# Self-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
- original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
- converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight"
- if original_key in original_state_dict:
- converted_state_dict[converted_key] = original_state_dict.pop(original_key)
+ alpha_key = f"blocks.{i}.self_attn.{o}.alpha"
+ has_alpha = alpha_key in original_state_dict
+ original_key_A = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
+ converted_key_A = f"blocks.{i}.attn1.{c}.lora_A.weight"
- original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
- converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight"
- if original_key in original_state_dict:
- converted_state_dict[converted_key] = original_state_dict.pop(original_key)
+ original_key_B = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
+ converted_key_B = f"blocks.{i}.attn1.{c}.lora_B.weight"
+
+ if has_alpha:
+ down_weight = original_state_dict.pop(original_key_A)
+ up_weight = original_state_dict.pop(original_key_B)
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
+ converted_state_dict[converted_key_A] = down_weight * scale_down
+ converted_state_dict[converted_key_B] = up_weight * scale_up
+
+ else:
+ if original_key_A in original_state_dict:
+ converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
+ if original_key_B in original_state_dict:
+ converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B)
original_key = f"blocks.{i}.self_attn.{o}.diff_b"
converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
@@ -1865,15 +1883,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
# Cross-attention
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
- original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
- converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
- if original_key in original_state_dict:
- converted_state_dict[converted_key] = original_state_dict.pop(original_key)
+ alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
+ has_alpha = alpha_key in original_state_dict
+ original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
+ converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
- original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
- converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
- if original_key in original_state_dict:
- converted_state_dict[converted_key] = original_state_dict.pop(original_key)
+ original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
+ converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
+
+ if original_key_A in original_state_dict:
+ down_weight = original_state_dict.pop(original_key_A)
+ converted_state_dict[converted_key_A] = down_weight
+ if original_key_B in original_state_dict:
+ up_weight = original_state_dict.pop(original_key_B)
+ converted_state_dict[converted_key_B] = up_weight
+ if has_alpha:
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
+ converted_state_dict[converted_key_A] *= scale_down
+ converted_state_dict[converted_key_B] *= scale_up
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
@@ -1882,15 +1909,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
if is_i2v_lora:
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
- original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
- converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
- if original_key in original_state_dict:
- converted_state_dict[converted_key] = original_state_dict.pop(original_key)
+ alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
+ has_alpha = alpha_key in original_state_dict
+ original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
+ converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
- original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
- converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
- if original_key in original_state_dict:
- converted_state_dict[converted_key] = original_state_dict.pop(original_key)
+ original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
+ converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
+
+ if original_key_A in original_state_dict:
+ down_weight = original_state_dict.pop(original_key_A)
+ converted_state_dict[converted_key_A] = down_weight
+ if original_key_B in original_state_dict:
+ up_weight = original_state_dict.pop(original_key_B)
+ converted_state_dict[converted_key_B] = up_weight
+ if has_alpha:
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
+ converted_state_dict[converted_key_A] *= scale_down
+ converted_state_dict[converted_key_B] *= scale_up
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
@@ -1899,15 +1935,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
# FFN
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
- original_key = f"blocks.{i}.{o}.{lora_down_key}.weight"
- converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight"
- if original_key in original_state_dict:
- converted_state_dict[converted_key] = original_state_dict.pop(original_key)
+ alpha_key = f"blocks.{i}.{o}.alpha"
+ has_alpha = alpha_key in original_state_dict
+ original_key_A = f"blocks.{i}.{o}.{lora_down_key}.weight"
+ converted_key_A = f"blocks.{i}.ffn.{c}.lora_A.weight"
- original_key = f"blocks.{i}.{o}.{lora_up_key}.weight"
- converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight"
- if original_key in original_state_dict:
- converted_state_dict[converted_key] = original_state_dict.pop(original_key)
+ original_key_B = f"blocks.{i}.{o}.{lora_up_key}.weight"
+ converted_key_B = f"blocks.{i}.ffn.{c}.lora_B.weight"
+
+ if original_key_A in original_state_dict:
+ down_weight = original_state_dict.pop(original_key_A)
+ converted_state_dict[converted_key_A] = down_weight
+ if original_key_B in original_state_dict:
+ up_weight = original_state_dict.pop(original_key_B)
+ converted_state_dict[converted_key_B] = up_weight
+ if has_alpha:
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
+ converted_state_dict[converted_key_A] *= scale_down
+ converted_state_dict[converted_key_B] *= scale_up
original_key = f"blocks.{i}.{o}.diff_b"
converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"
@@ -2073,3 +2118,126 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict
+
+
+def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
+ has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
+ if has_diffusion_model:
+ state_dict = {k.removeprefix("diffusion_model."): v for k, v in state_dict.items()}
+
+ has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
+ if has_lora_unet:
+ state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
+
+ def convert_key(key: str) -> str:
+ prefix = "transformer_blocks"
+ if "." in key:
+ base, suffix = key.rsplit(".", 1)
+ else:
+ base, suffix = key, ""
+
+ start = f"{prefix}_"
+ rest = base[len(start) :]
+
+ if "." in rest:
+ head, tail = rest.split(".", 1)
+ tail = "." + tail
+ else:
+ head, tail = rest, ""
+
+ # Protected n-grams that must keep their internal underscores
+ protected = {
+ # pairs
+ ("to", "q"),
+ ("to", "k"),
+ ("to", "v"),
+ ("to", "out"),
+ ("add", "q"),
+ ("add", "k"),
+ ("add", "v"),
+ ("txt", "mlp"),
+ ("img", "mlp"),
+ ("txt", "mod"),
+ ("img", "mod"),
+ # triplets
+ ("add", "q", "proj"),
+ ("add", "k", "proj"),
+ ("add", "v", "proj"),
+ ("to", "add", "out"),
+ }
+
+ prot_by_len = {}
+ for ng in protected:
+ prot_by_len.setdefault(len(ng), set()).add(ng)
+
+ parts = head.split("_")
+ merged = []
+ i = 0
+ lengths_desc = sorted(prot_by_len.keys(), reverse=True)
+
+ while i < len(parts):
+ matched = False
+ for L in lengths_desc:
+ if i + L <= len(parts) and tuple(parts[i : i + L]) in prot_by_len[L]:
+ merged.append("_".join(parts[i : i + L]))
+ i += L
+ matched = True
+ break
+ if not matched:
+ merged.append(parts[i])
+ i += 1
+
+ head_converted = ".".join(merged)
+ converted_base = f"{prefix}.{head_converted}{tail}"
+ return converted_base + (("." + suffix) if suffix else "")
+
+ state_dict = {convert_key(k): v for k, v in state_dict.items()}
+
+ converted_state_dict = {}
+ all_keys = list(state_dict.keys())
+ down_key = ".lora_down.weight"
+ up_key = ".lora_up.weight"
+ a_key = ".lora_A.weight"
+ b_key = ".lora_B.weight"
+
+ has_non_diffusers_lora_id = any(down_key in k or up_key in k for k in all_keys)
+ has_diffusers_lora_id = any(a_key in k or b_key in k for k in all_keys)
+
+ if has_non_diffusers_lora_id:
+
+ def get_alpha_scales(down_weight, alpha_key):
+ rank = down_weight.shape[0]
+ alpha = state_dict.pop(alpha_key).item()
+ scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
+ scale_down = scale
+ scale_up = 1.0
+ while scale_down * 2 < scale_up:
+ scale_down *= 2
+ scale_up /= 2
+ return scale_down, scale_up
+
+ for k in all_keys:
+ if k.endswith(down_key):
+ diffusers_down_key = k.replace(down_key, ".lora_A.weight")
+ diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
+ alpha_key = k.replace(down_key, ".alpha")
+
+ down_weight = state_dict.pop(k)
+ up_weight = state_dict.pop(k.replace(down_key, up_key))
+ scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
+ converted_state_dict[diffusers_down_key] = down_weight * scale_down
+ converted_state_dict[diffusers_up_key] = up_weight * scale_up
+
+ # Already in diffusers format (lora_A/lora_B), just pop
+ elif has_diffusers_lora_id:
+ for k in all_keys:
+ if a_key in k or b_key in k:
+ converted_state_dict[k] = state_dict.pop(k)
+ elif ".alpha" in k:
+ state_dict.pop(k)
+
+ if len(state_dict) > 0:
+ raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
+
+ converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
+ return converted_state_dict
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 45c20e505c..e25a29e1c0 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -49,6 +49,7 @@ from .lora_conversion_utils import (
_convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_ltxv_lora_to_diffusers,
_convert_non_diffusers_lumina2_lora_to_diffusers,
+ _convert_non_diffusers_qwen_lora_to_diffusers,
_convert_non_diffusers_wan_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
_maybe_map_sgm_blocks_to_diffusers,
@@ -245,13 +246,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
r"""
Return state dict for lora weights and the network alphas.
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
+ > [!WARNING] > We support loading A1111 formatted LoRA checkpoints in a limited capacity. > > This function is
+ experimental and might change in the future.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
@@ -509,35 +505,28 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
text_encoder_lora_adapter_metadata:
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
"""
- state_dict = {}
- lora_adapter_metadata = {}
-
- if not (unet_lora_layers or text_encoder_lora_layers):
- raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
+ lora_layers = {}
+ lora_metadata = {}
if unet_lora_layers:
- state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
+ lora_layers[cls.unet_name] = unet_lora_layers
+ lora_metadata[cls.unet_name] = unet_lora_adapter_metadata
if text_encoder_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
+ lora_layers[cls.text_encoder_name] = text_encoder_lora_layers
+ lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata
- if unet_lora_adapter_metadata:
- lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.")
- if text_encoder_lora_adapter_metadata:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
def fuse_lora(
@@ -551,11 +540,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
- This is an experimental API.
-
-
+ > [!WARNING] > This is an experimental API.
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
@@ -592,11 +577,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
- This is an experimental API.
-
-
+ > [!WARNING] > This is an experimental API.
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
@@ -627,33 +608,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
- `self.text_encoder`.
-
- All kwargs are forwarded to `self.lora_state_dict`.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
- loaded.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
- loaded into `self.unet`.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
- dict is loaded into `self.text_encoder`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -730,13 +685,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
r"""
Return state dict for lora weights and the network alphas.
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
+ > [!WARNING] > We support loading A1111 formatted LoRA checkpoints in a limited capacity. > > This function is
+ experimental and might change in the future.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
@@ -973,74 +923,36 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
text_encoder_2_lora_adapter_metadata=None,
):
r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `unet`.
- text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
- encoder LoRA state dict because it comes from 🤗 Transformers.
- text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
- encoder LoRA state dict because it comes from 🤗 Transformers.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- unet_lora_adapter_metadata:
- LoRA adapter metadata associated with the unet to be serialized with the state dict.
- text_encoder_lora_adapter_metadata:
- LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
- text_encoder_2_lora_adapter_metadata:
- LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
-
- if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
- raise ValueError(
- "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
- )
+ lora_layers = {}
+ lora_metadata = {}
if unet_lora_layers:
- state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
+ lora_layers[cls.unet_name] = unet_lora_layers
+ lora_metadata[cls.unet_name] = unet_lora_adapter_metadata
if text_encoder_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
+ lora_layers["text_encoder"] = text_encoder_lora_layers
+ lora_metadata["text_encoder"] = text_encoder_lora_adapter_metadata
if text_encoder_2_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
+ lora_layers["text_encoder_2"] = text_encoder_2_lora_layers
+ lora_metadata["text_encoder_2"] = text_encoder_2_lora_adapter_metadata
- if unet_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
-
- if text_encoder_lora_adapter_metadata:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
+ if not lora_layers:
+ raise ValueError(
+ "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`."
)
- if text_encoder_2_lora_adapter_metadata:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
- )
-
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
def fuse_lora(
@@ -1052,35 +964,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -1092,21 +976,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- unfuse_text_encoder (`bool`, defaults to `True`):
- Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
- LoRA parameters then it won't have any effect.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -1132,51 +1002,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -1230,30 +1056,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
- `self.text_encoder`.
-
- All kwargs are forwarded to `self.lora_state_dict`.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
- loaded.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1322,26 +1125,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`SD3Transformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -1436,76 +1220,36 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
text_encoder_2_lora_adapter_metadata=None,
):
r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
- encoder LoRA state dict because it comes from 🤗 Transformers.
- text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
- encoder LoRA state dict because it comes from 🤗 Transformers.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
- text_encoder_lora_adapter_metadata:
- LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
- text_encoder_2_lora_adapter_metadata:
- LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
-
- if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
- raise ValueError(
- "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
- )
+ lora_layers = {}
+ lora_metadata = {}
if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if text_encoder_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
+ lora_layers["text_encoder"] = text_encoder_lora_layers
+ lora_metadata["text_encoder"] = text_encoder_lora_adapter_metadata
if text_encoder_2_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
+ lora_layers["text_encoder_2"] = text_encoder_2_lora_layers
+ lora_metadata["text_encoder_2"] = text_encoder_2_lora_adapter_metadata
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
+ if not lora_layers:
+ raise ValueError(
+ "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`."
)
- if text_encoder_lora_adapter_metadata:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
- )
-
- if text_encoder_2_lora_adapter_metadata:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
- )
-
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
@@ -1518,35 +1262,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -1559,21 +1275,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- unfuse_text_encoder (`bool`, defaults to `True`):
- Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
- LoRA parameters then it won't have any effect.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -1595,51 +1297,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -1694,25 +1352,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1758,26 +1398,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`AuraFlowTransformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -1809,48 +1430,26 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
@@ -1863,35 +1462,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -1904,18 +1475,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -1942,50 +1502,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -2239,30 +1756,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
hotswap: bool = False,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- network_alphas (`Dict[str, float]`):
- The value of the network alpha used for stable learning and preventing underflow. This value has the
- same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
- link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
- transformer (`FluxTransformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
@@ -2434,37 +1928,28 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
text_encoder_lora_adapter_metadata:
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
"""
- state_dict = {}
- lora_adapter_metadata = {}
-
- if not (transformer_lora_layers or text_encoder_lora_layers):
- raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
+ lora_layers = {}
+ lora_metadata = {}
if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if text_encoder_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
+ lora_layers[cls.text_encoder_name] = text_encoder_lora_layers
+ lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata
- if transformer_lora_adapter_metadata:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if text_encoder_lora_adapter_metadata:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
def fuse_lora(
@@ -2476,35 +1961,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
@@ -2532,11 +1989,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
- This is an experimental API.
-
-
+ > [!WARNING] > This is an experimental API.
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
@@ -2847,30 +2300,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
hotswap: bool = False,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- network_alphas (`Dict[str, float]`):
- The value of the network alpha used for stable learning and preventing underflow. This value has the
- same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
- link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
- transformer (`UVit2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
@@ -3020,51 +2450,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -3118,25 +2504,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -3182,26 +2550,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`CogVideoXTransformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -3221,7 +2570,6 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
)
@classmethod
- # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
@@ -3233,48 +2581,26 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
def fuse_lora(
@@ -3286,35 +2612,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -3326,18 +2624,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -3359,51 +2646,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -3458,25 +2701,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -3522,26 +2747,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`MochiTransformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -3573,48 +2779,26 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -3627,35 +2811,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -3668,18 +2824,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -3700,50 +2845,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -3802,25 +2904,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -3866,26 +2950,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`LTXVideoTransformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -3917,48 +2982,26 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -3971,35 +3014,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -4012,18 +3027,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -4045,51 +3049,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -4144,25 +3104,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -4208,26 +3150,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`SanaTransformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -4259,48 +3182,26 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -4313,35 +3214,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -4354,18 +3227,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -4386,50 +3248,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading original format HunyuanVideo LoRA checkpoints.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -4488,25 +3307,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -4552,26 +3353,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`HunyuanVideoTransformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -4603,48 +3385,26 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -4657,35 +3417,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -4698,18 +3430,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -4730,50 +3451,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -4833,25 +3511,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -4897,26 +3557,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`Lumina2Transformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -4948,48 +3589,26 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
@@ -5002,35 +3621,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -5043,18 +3634,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -5064,7 +3644,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
"""
- _lora_loadable_modules = ["transformer"]
+ _lora_loadable_modules = ["transformer", "transformer_2"]
transformer_name = TRANSFORMER_NAME
@classmethod
@@ -5075,50 +3655,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -5224,25 +3761,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -5269,15 +3788,35 @@ class WanLoraLoaderMixin(LoraBaseMixin):
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
- self.load_lora_into_transformer(
- state_dict,
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
- adapter_name=adapter_name,
- metadata=metadata,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
+ load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
+ if load_into_transformer_2:
+ if not hasattr(self, "transformer_2"):
+ raise AttributeError(
+ f"'{type(self).__name__}' object has no attribute transformer_2"
+ "Note that Wan2.1 models do not have a transformer_2 component."
+ "Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
+ )
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=self.transformer_2,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+ else:
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name)
+ if not hasattr(self, "transformer")
+ else self.transformer,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
@@ -5292,26 +3831,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`WanTransformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -5343,48 +3863,26 @@ class WanLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -5397,35 +3895,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -5438,18 +3908,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -5471,50 +3930,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -5622,25 +4038,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -5667,15 +4065,35 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
- self.load_lora_into_transformer(
- state_dict,
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
- adapter_name=adapter_name,
- metadata=metadata,
- _pipeline=self,
- low_cpu_mem_usage=low_cpu_mem_usage,
- hotswap=hotswap,
- )
+ load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
+ if load_into_transformer_2:
+ if not hasattr(self, "transformer_2"):
+ raise AttributeError(
+ f"'{type(self).__name__}' object has no attribute transformer_2"
+ "Note that Wan2.1 models do not have a transformer_2 component."
+ "Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
+ )
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=self.transformer_2,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+ else:
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name)
+ if not hasattr(self, "transformer")
+ else self.transformer,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel
@@ -5690,26 +4108,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`SkyReelsV2Transformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -5741,48 +4140,26 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -5795,35 +4172,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -5836,18 +4185,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -5869,51 +4207,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -5968,25 +4262,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -6032,26 +4308,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`CogView4Transformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -6083,48 +4340,26 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -6137,35 +4372,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -6178,18 +4385,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -6210,50 +4406,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -6312,25 +4465,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -6376,26 +4511,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`HiDreamImageTransformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -6427,48 +4543,26 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
@@ -6481,35 +4575,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -6522,18 +4588,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -6548,58 +4603,13 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
@classmethod
@validate_hf_hub_args
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -6642,6 +4652,12 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+ has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
+ has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
+ has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
+ if has_alphas_in_sd or has_lora_unet or has_diffusion_model:
+ state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
+
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
@@ -6654,25 +4670,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -6718,26 +4716,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`QwenImageTransformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -6769,48 +4748,26 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -6823,35 +4780,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -6864,18 +4793,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py
index d048298fd4..2381ccfef3 100644
--- a/src/diffusers/loaders/peft.py
+++ b/src/diffusers/loaders/peft.py
@@ -320,7 +320,9 @@ class PeftAdapterMixin:
# it to None
incompatible_keys = None
else:
- inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
+ inject_adapter_in_model(
+ lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs
+ )
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
if self._prepare_lora_hotswap_kwargs is not None:
diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py
index 76fefc1260..b53647d476 100644
--- a/src/diffusers/loaders/single_file_model.py
+++ b/src/diffusers/loaders/single_file_model.py
@@ -22,8 +22,9 @@ from huggingface_hub.utils import validate_hf_hub_args
from typing_extensions import Self
from .. import __version__
+from ..models.model_loading_utils import _caching_allocator_warmup, _determine_device_map, _expand_device_map
from ..quantizers import DiffusersAutoQuantizer
-from ..utils import deprecate, is_accelerate_available, logging
+from ..utils import deprecate, is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache
from .single_file_utils import (
SingleFileComponentError,
@@ -62,8 +63,12 @@ logger = logging.get_logger(__name__)
if is_accelerate_available():
from accelerate import dispatch_model, init_empty_weights
- from ..models.modeling_utils import load_model_dict_into_meta
+ from ..models.model_loading_utils import load_model_dict_into_meta
+if is_torch_version(">=", "1.9.0") and is_accelerate_available():
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
+else:
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
SINGLE_FILE_LOADABLE_CLASSES = {
"StableCascadeUNet": {
@@ -153,9 +158,17 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
+ "QwenImageTransformer2DModel": {
+ "checkpoint_mapping_fn": lambda x: x,
+ "default_subfolder": "transformer",
+ },
}
+def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
+ return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
+
+
def _get_single_file_loadable_mapping_class(cls):
diffusers_module = importlib.import_module(__name__.split(".")[0])
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
@@ -228,6 +241,11 @@ class FromOriginalModelMixin:
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 and
+ is_accelerate_available() else `False`): Speed up model loading only loading the pretrained weights and
+ not initializing the weights. This also tries to not use more than 1x model size in CPU memory
+ (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using
+ an older version of PyTorch, setting this argument to `True` will raise an error.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
@@ -277,8 +295,10 @@ class FromOriginalModelMixin:
config_revision = kwargs.pop("config_revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)
+ device_map = kwargs.pop("device_map", None)
user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
@@ -381,19 +401,12 @@ class FromOriginalModelMixin:
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
diffusers_model_config.update(model_kwargs)
- checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
- diffusers_format_checkpoint = checkpoint_mapping_fn(
- config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
- )
- if not diffusers_format_checkpoint:
- raise SingleFileComponentError(
- f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
- )
-
- ctx = init_empty_weights if is_accelerate_available() else nullcontext
+ ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
model = cls.from_config(diffusers_model_config)
+ model_state_dict = model.state_dict()
+
# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
@@ -406,6 +419,26 @@ class FromOriginalModelMixin:
else:
keep_in_fp32_modules = []
+ # Now that the model is loaded, we can determine the `device_map`
+ device_map = _determine_device_map(model, device_map, None, torch_dtype, keep_in_fp32_modules, hf_quantizer)
+ if device_map is not None:
+ expanded_device_map = _expand_device_map(device_map, model_state_dict.keys())
+ _caching_allocator_warmup(model, expanded_device_map, torch_dtype, hf_quantizer)
+
+ checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
+
+ if _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint):
+ diffusers_format_checkpoint = checkpoint_mapping_fn(
+ config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
+ )
+ else:
+ diffusers_format_checkpoint = checkpoint
+
+ if not diffusers_format_checkpoint:
+ raise SingleFileComponentError(
+ f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
+ )
+
if hf_quantizer is not None:
hf_quantizer.preprocess_model(
model=model,
@@ -415,7 +448,7 @@ class FromOriginalModelMixin:
)
device_map = None
- if is_accelerate_available():
+ if low_cpu_mem_usage:
param_device = torch.device(device) if device else torch.device("cpu")
empty_state_dict = model.state_dict()
unexpected_keys = [
diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py
index a804ea80a9..ef6c41e3ce 100644
--- a/src/diffusers/loaders/single_file_utils.py
+++ b/src/diffusers/loaders/single_file_utils.py
@@ -55,11 +55,12 @@ if is_transformers_available():
if is_accelerate_available():
from accelerate import init_empty_weights
- from ..models.modeling_utils import load_model_dict_into_meta
+ from ..models.model_loading_utils import load_model_dict_into_meta
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
CHECKPOINT_KEY_NAMES = {
+ "v1": "model.diffusion_model.output_blocks.11.0.skip_connection.weight",
"v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py
index ced81960fa..ef7b921b7d 100644
--- a/src/diffusers/loaders/transformer_flux.py
+++ b/src/diffusers/loaders/transformer_flux.py
@@ -17,7 +17,8 @@ from ..models.embeddings import (
ImageProjection,
MultiIPAdapterImageProjection,
)
-from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
+from ..models.model_loading_utils import load_model_dict_into_meta
+from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache
diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py
index 1bc3a9c7a8..e3728082ef 100644
--- a/src/diffusers/loaders/transformer_sd3.py
+++ b/src/diffusers/loaders/transformer_sd3.py
@@ -16,7 +16,8 @@ from typing import Dict
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
-from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
+from ..models.model_loading_utils import load_model_dict_into_meta
+from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache
diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py
index 1d698e5a8b..c5e56af156 100644
--- a/src/diffusers/loaders/unet.py
+++ b/src/diffusers/loaders/unet.py
@@ -30,7 +30,8 @@ from ..models.embeddings import (
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
-from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
+from ..models.model_loading_utils import load_model_dict_into_meta
+from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 972233bd98..457f70448a 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -25,6 +25,7 @@ from ..utils import (
_import_structure = {}
if is_torch_available():
+ _import_structure["_modeling_parallel"] = ["ContextParallelConfig", "ParallelConfig"]
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
_import_structure["auto_model"] = ["AutoModel"]
@@ -52,6 +53,10 @@ if is_torch_available():
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DMultiControlNetModel",
]
+ _import_structure["controlnets.controlnet_qwenimage"] = [
+ "QwenImageControlNetModel",
+ "QwenImageMultiControlNetModel",
+ ]
_import_structure["controlnets.controlnet_sana"] = ["SanaControlNetModel"]
_import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
@@ -76,6 +81,7 @@ if is_torch_available():
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
+ _import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
@@ -114,6 +120,7 @@ if is_flax_available():
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
+ from ._modeling_parallel import ContextParallelConfig, ParallelConfig
from .adapter import MultiAdapter, T2IAdapter
from .attention_dispatch import AttentionBackendName, attention_backend
from .auto_model import AutoModel
@@ -147,6 +154,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DMultiControlNetModel,
MultiControlNetModel,
MultiControlNetUnionModel,
+ QwenImageControlNetModel,
+ QwenImageMultiControlNetModel,
SanaControlNetModel,
SD3ControlNetModel,
SD3MultiControlNetModel,
@@ -158,6 +167,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .transformers import (
AllegroTransformer3DModel,
AuraFlowTransformer2DModel,
+ BriaTransformer2DModel,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py
new file mode 100644
index 0000000000..2a1d2cc6ce
--- /dev/null
+++ b/src/diffusers/models/_modeling_parallel.py
@@ -0,0 +1,241 @@
+# 🚨🚨🚨 Experimental parallelism support for Diffusers 🚨🚨🚨
+# Experimental changes are subject to change and APIs may break without warning.
+
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
+
+import torch
+
+from ..utils import get_logger
+
+
+if TYPE_CHECKING:
+ pass
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+
+# TODO(aryan): add support for the following:
+# - Unified Attention
+# - More dispatcher attention backends
+# - CFG/Data Parallel
+# - Tensor Parallel
+
+
+@dataclass
+class ContextParallelConfig:
+ """
+ Configuration for context parallelism.
+
+ Args:
+ ring_degree (`int`, *optional*, defaults to `1`):
+ Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
+ total number of devices in the context parallel mesh.
+ ulysses_degree (`int`, *optional*, defaults to `1`):
+ Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
+ total number of devices in the context parallel mesh.
+ convert_to_fp32 (`bool`, *optional*, defaults to `True`):
+ Whether to convert output and LSE to float32 for ring attention numerical stability.
+ rotate_method (`str`, *optional*, defaults to `"allgather"`):
+ Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
+ is supported.
+
+ """
+
+ ring_degree: Optional[int] = None
+ ulysses_degree: Optional[int] = None
+ convert_to_fp32: bool = True
+ # TODO: support alltoall
+ rotate_method: Literal["allgather", "alltoall"] = "allgather"
+
+ _rank: int = None
+ _world_size: int = None
+ _device: torch.device = None
+ _mesh: torch.distributed.device_mesh.DeviceMesh = None
+ _flattened_mesh: torch.distributed.device_mesh.DeviceMesh = None
+ _ring_mesh: torch.distributed.device_mesh.DeviceMesh = None
+ _ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None
+ _ring_local_rank: int = None
+ _ulysses_local_rank: int = None
+
+ def __post_init__(self):
+ if self.ring_degree is None:
+ self.ring_degree = 1
+ if self.ulysses_degree is None:
+ self.ulysses_degree = 1
+
+ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
+ self._rank = rank
+ self._world_size = world_size
+ self._device = device
+ self._mesh = mesh
+ if self.ring_degree is None:
+ self.ring_degree = 1
+ if self.ulysses_degree is None:
+ self.ulysses_degree = 1
+ if self.rotate_method != "allgather":
+ raise NotImplementedError(
+ f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
+ )
+ if self._flattened_mesh is None:
+ self._flattened_mesh = self._mesh._flatten()
+ if self._ring_mesh is None:
+ self._ring_mesh = self._mesh["ring"]
+ if self._ulysses_mesh is None:
+ self._ulysses_mesh = self._mesh["ulysses"]
+ if self._ring_local_rank is None:
+ self._ring_local_rank = self._ring_mesh.get_local_rank()
+ if self._ulysses_local_rank is None:
+ self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
+
+
+@dataclass
+class ParallelConfig:
+ """
+ Configuration for applying different parallelisms.
+
+ Args:
+ context_parallel_config (`ContextParallelConfig`, *optional*):
+ Configuration for context parallelism.
+ """
+
+ context_parallel_config: Optional[ContextParallelConfig] = None
+
+ _rank: int = None
+ _world_size: int = None
+ _device: torch.device = None
+ _cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
+
+ def setup(
+ self,
+ rank: int,
+ world_size: int,
+ device: torch.device,
+ *,
+ cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
+ ):
+ self._rank = rank
+ self._world_size = world_size
+ self._device = device
+ self._cp_mesh = cp_mesh
+ if self.context_parallel_config is not None:
+ self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
+
+
+@dataclass(frozen=True)
+class ContextParallelInput:
+ """
+ Configuration for splitting an input tensor across context parallel region.
+
+ Args:
+ split_dim (`int`):
+ The dimension along which to split the tensor.
+ expected_dims (`int`, *optional*):
+ The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
+ tensor has the expected number of dimensions before splitting.
+ split_output (`bool`, *optional*, defaults to `False`):
+ Whether to split the output tensor of the layer along the given `split_dim` instead of the input tensor.
+ This is useful for layers whose outputs should be split after it does some preprocessing on the inputs (ex:
+ RoPE).
+ """
+
+ split_dim: int
+ expected_dims: Optional[int] = None
+ split_output: bool = False
+
+ def __repr__(self):
+ return f"ContextParallelInput(split_dim={self.split_dim}, expected_dims={self.expected_dims}, split_output={self.split_output})"
+
+
+@dataclass(frozen=True)
+class ContextParallelOutput:
+ """
+ Configuration for gathering an output tensor across context parallel region.
+
+ Args:
+ gather_dim (`int`):
+ The dimension along which to gather the tensor.
+ expected_dims (`int`, *optional*):
+ The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
+ tensor has the expected number of dimensions before gathering.
+ """
+
+ gather_dim: int
+ expected_dims: Optional[int] = None
+
+ def __repr__(self):
+ return f"ContextParallelOutput(gather_dim={self.gather_dim}, expected_dims={self.expected_dims})"
+
+
+# A dictionary where keys denote the input to be split across context parallel region, and the
+# value denotes the sharding configuration.
+# If the key is a string, it denotes the name of the parameter in the forward function.
+# If the key is an integer, split_output must be set to True, and it denotes the index of the output
+# to be split across context parallel region.
+ContextParallelInputType = Dict[
+ Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
+]
+
+# A dictionary where keys denote the output to be gathered across context parallel region, and the
+# value denotes the gathering configuration.
+ContextParallelOutputType = Union[
+ ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
+]
+
+# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
+# the module should be split/gathered across context parallel region.
+ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
+
+
+# Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):
+#
+# Each model should define a _cp_plan attribute that contains information on how to shard/gather
+# tensors at different stages of the forward:
+#
+# ```python
+# _cp_plan = {
+# "": {
+# "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+# "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+# "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
+# },
+# "pos_embed": {
+# 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+# 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+# },
+# "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+# }
+# ```
+#
+# The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be
+# split/gathered according to this at the respective module level. Here, the following happens:
+# - "":
+# we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before
+# the actual forward logic of the QwenImageTransformer2DModel is run, we will splitthe inputs)
+# - "pos_embed":
+# we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs),
+# we can individually specify how they should be split
+# - "proj_out":
+# before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear
+# layer forward has run).
+#
+# ContextParallelInput:
+# specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to
+#
+# ContextParallelOutput:
+# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index c720b37955..5164cf311d 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -111,11 +111,7 @@ class AttentionMixin:
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
for module in self.modules():
if isinstance(module, AttentionModuleMixin):
@@ -241,7 +237,7 @@ class AttentionModuleMixin:
op_fw, op_bw = attention_op
dtype, *_ = op_fw.SUPPORTED_DTYPES
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
- _ = xops.memory_efficient_attention(q, q, q)
+ _ = xops.ops.memory_efficient_attention(q, q, q)
except Exception as e:
raise e
@@ -674,7 +670,7 @@ class JointTransformerBlock(nn.Module):
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
- ):
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py
index c00ec7dd6e..e169491099 100644
--- a/src/diffusers/models/attention_dispatch.py
+++ b/src/diffusers/models/attention_dispatch.py
@@ -17,15 +17,20 @@ import functools
import inspect
import math
from enum import Enum
-from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import torch
+
+if torch.distributed.is_available():
+ import torch.distributed._functional_collectives as funcol
+
from ..utils import (
get_logger,
is_flash_attn_3_available,
is_flash_attn_available,
is_flash_attn_version,
+ is_kernels_available,
is_sageattention_available,
is_sageattention_version,
is_torch_npu_available,
@@ -35,9 +40,12 @@ from ..utils import (
is_xformers_available,
is_xformers_version,
)
-from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
+from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS
+if TYPE_CHECKING:
+ from ._modeling_parallel import ParallelConfig
+
_REQUIRED_FLASH_VERSION = "2.6.3"
_REQUIRED_SAGE_VERSION = "2.1.1"
_REQUIRED_FLEX_VERSION = "2.5.0"
@@ -55,9 +63,12 @@ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _
if _CAN_USE_FLASH_ATTN:
from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
else:
flash_attn_func = None
flash_attn_varlen_func = None
+ _wrapped_flash_attn_backward = None
+ _wrapped_flash_attn_forward = None
if _CAN_USE_FLASH_ATTN_3:
@@ -67,6 +78,17 @@ else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
+if DIFFUSERS_ENABLE_HUB_KERNELS:
+ if not is_kernels_available():
+ raise ImportError(
+ "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
+ )
+ from ..utils.kernels_utils import _get_fa3_from_hub
+
+ flash_attn_interface_hub = _get_fa3_from_hub()
+ flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
+else:
+ flash_attn_3_func_hub = None
if _CAN_USE_SAGE_ATTN:
from sageattention import (
@@ -110,6 +132,27 @@ if _CAN_USE_XFORMERS_ATTN:
else:
xops = None
+# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
+if torch.__version__ >= "2.4.0":
+ _custom_op = torch.library.custom_op
+ _register_fake = torch.library.register_fake
+else:
+
+ def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
+ def wrap(func):
+ return func
+
+ return wrap if fn is None else fn
+
+ def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
+ def wrap(func):
+ return func
+
+ return wrap if fn is None else fn
+
+ _custom_op = custom_op_no_op
+ _register_fake = register_fake_no_op
+
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -132,6 +175,8 @@ class AttentionBackendName(str, Enum):
FLASH_VARLEN = "flash_varlen"
_FLASH_3 = "_flash_3"
_FLASH_VARLEN_3 = "_flash_varlen_3"
+ _FLASH_3_HUB = "_flash_3_hub"
+ # _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
# PyTorch native
FLEX = "flex"
@@ -162,17 +207,24 @@ class _AttentionBackendRegistry:
_backends = {}
_constraints = {}
_supported_arg_names = {}
+ _supports_context_parallel = {}
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
_checks_enabled = DIFFUSERS_ATTN_CHECKS
@classmethod
- def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None):
+ def register(
+ cls,
+ backend: AttentionBackendName,
+ constraints: Optional[List[Callable]] = None,
+ supports_context_parallel: bool = False,
+ ):
logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}")
def decorator(func):
cls._backends[backend] = func
cls._constraints[backend] = constraints or []
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
+ cls._supports_context_parallel[backend] = supports_context_parallel
return func
return decorator
@@ -185,6 +237,17 @@ class _AttentionBackendRegistry:
def list_backends(cls):
return list(cls._backends.keys())
+ @classmethod
+ def _is_context_parallel_enabled(
+ cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"]
+ ) -> bool:
+ supports_context_parallel = backend in cls._supports_context_parallel
+ is_degree_greater_than_1 = parallel_config is not None and (
+ parallel_config.context_parallel_config.ring_degree > 1
+ or parallel_config.context_parallel_config.ulysses_degree > 1
+ )
+ return supports_context_parallel and is_degree_greater_than_1
+
@contextlib.contextmanager
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
@@ -218,6 +281,7 @@ def dispatch_attention_fn(
attention_kwargs: Optional[Dict[str, Any]] = None,
*,
backend: Optional[AttentionBackendName] = None,
+ parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
attention_kwargs = attention_kwargs or {}
@@ -229,6 +293,14 @@ def dispatch_attention_fn(
backend_name = AttentionBackendName(backend)
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
+ if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled(
+ backend_name, parallel_config
+ ):
+ raise ValueError(
+ f"Backend {backend_name} either does not support context parallelism or context parallelism "
+ f"was enabled with a world size of 1."
+ )
+
kwargs = {
"query": query,
"key": key,
@@ -238,6 +310,7 @@ def dispatch_attention_fn(
"is_causal": is_causal,
"scale": scale,
**attention_kwargs,
+ "_parallel_config": parallel_config,
}
if is_torch_version(">=", "2.5.0"):
kwargs["enable_gqa"] = enable_gqa
@@ -330,6 +403,17 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
)
+ # TODO: add support Hub variant of FA3 varlen later
+ elif backend in [AttentionBackendName._FLASH_3_HUB]:
+ if not DIFFUSERS_ENABLE_HUB_KERNELS:
+ raise RuntimeError(
+ f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
+ )
+ if not is_kernels_available():
+ raise RuntimeError(
+ f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
+ )
+
elif backend in [
AttentionBackendName.SAGE,
AttentionBackendName.SAGE_VARLEN,
@@ -473,25 +557,623 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
# ===== torch op registrations =====
# Registrations are required for fullgraph tracing compatibility
-
-
-# TODO: library.custom_op and register_fake probably need version guards?
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
-@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
-def _wrapped_flash_attn_3_original(
- query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+@_custom_op("_diffusers_flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
+def _wrapped_flash_attn_3(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ softmax_scale: Optional[float] = None,
+ causal: bool = False,
+ qv: Optional[torch.Tensor] = None,
+ q_descale: Optional[torch.Tensor] = None,
+ k_descale: Optional[torch.Tensor] = None,
+ v_descale: Optional[torch.Tensor] = None,
+ attention_chunk: int = 0,
+ softcap: float = 0.0,
+ num_splits: int = 1,
+ pack_gqa: Optional[bool] = None,
+ deterministic: bool = False,
+ sm_margin: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
- out, lse = flash_attn_3_func(query, key, value)
+ # Hardcoded for now because pytorch does not support tuple/int type hints
+ window_size = (-1, -1)
+ out, lse, *_ = flash_attn_3_func(
+ q=q,
+ k=k,
+ v=v,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ qv=qv,
+ q_descale=q_descale,
+ k_descale=k_descale,
+ v_descale=v_descale,
+ window_size=window_size,
+ attention_chunk=attention_chunk,
+ softcap=softcap,
+ num_splits=num_splits,
+ pack_gqa=pack_gqa,
+ deterministic=deterministic,
+ sm_margin=sm_margin,
+ )
lse = lse.permute(0, 2, 1)
return out, lse
-@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
-def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- batch_size, seq_len, num_heads, head_dim = query.shape
+@_register_fake("_diffusers_flash_attn_3::_flash_attn_forward")
+def _(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ softmax_scale: Optional[float] = None,
+ causal: bool = False,
+ qv: Optional[torch.Tensor] = None,
+ q_descale: Optional[torch.Tensor] = None,
+ k_descale: Optional[torch.Tensor] = None,
+ v_descale: Optional[torch.Tensor] = None,
+ attention_chunk: int = 0,
+ softcap: float = 0.0,
+ num_splits: int = 1,
+ pack_gqa: Optional[bool] = None,
+ deterministic: bool = False,
+ sm_margin: int = 0,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ window_size = (-1, -1) # noqa: F841
+ # A lot of the parameters here are not yet used in any way within diffusers.
+ # We can safely ignore for now and keep the fake op shape propagation simple.
+ batch_size, seq_len, num_heads, head_dim = q.shape
lse_shape = (batch_size, seq_len, num_heads)
- return torch.empty_like(query), query.new_empty(lse_shape)
+ return torch.empty_like(q), q.new_empty(lse_shape)
+
+
+# ===== Helper functions to use attention backends with templated CP autograd functions =====
+
+
+# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
+# forward declaration:
+# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
+def _cudnn_attention_forward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _save_ctx: bool = True,
+ _parallel_config: Optional["ParallelConfig"] = None,
+):
+ if enable_gqa:
+ raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.")
+
+ tensors_to_save = ()
+
+ # Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results
+ # if the input tensors are not contiguous.
+ query = query.transpose(1, 2).contiguous()
+ key = key.transpose(1, 2).contiguous()
+ value = value.transpose(1, 2).contiguous()
+ tensors_to_save += (query, key, value)
+
+ out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
+ torch.ops.aten._scaled_dot_product_cudnn_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_bias=attn_mask,
+ compute_log_sumexp=return_lse,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ return_debug_mask=False,
+ scale=scale,
+ )
+ )
+
+ tensors_to_save += (out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
+ if _save_ctx:
+ ctx.save_for_backward(*tensors_to_save)
+ ctx.dropout_p = dropout_p
+ ctx.is_causal = is_causal
+ ctx.scale = scale
+ ctx.attn_mask = attn_mask
+ ctx.max_q = max_q
+ ctx.max_k = max_k
+
+ out = out.transpose(1, 2).contiguous()
+ if lse is not None:
+ lse = lse.transpose(1, 2).contiguous()
+ return (out, lse) if return_lse else out
+
+
+# backward declaration:
+# aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
+def _cudnn_attention_backward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ **kwargs,
+):
+ query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors
+
+ grad_out = grad_out.transpose(1, 2).contiguous()
+ key = key.transpose(1, 2).contiguous()
+ value = value.transpose(1, 2).contiguous()
+
+ # Cannot pass first 5 arguments as kwargs because: https://github.com/pytorch/pytorch/blob/d26ca5de058dbcf56ac52bb43e84dd98df2ace97/torch/_dynamo/variables/torch.py#L1341
+ grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward(
+ grad_out,
+ query,
+ key,
+ value,
+ out,
+ logsumexp=lse,
+ philox_seed=philox_seed,
+ philox_offset=philox_offset,
+ attn_bias=ctx.attn_mask,
+ cum_seq_q=cum_seq_q,
+ cum_seq_k=cum_seq_k,
+ max_q=ctx.max_q,
+ max_k=ctx.max_k,
+ dropout_p=ctx.dropout_p,
+ is_causal=ctx.is_causal,
+ scale=ctx.scale,
+ )
+ grad_query, grad_key, grad_value = (x.transpose(1, 2).contiguous() for x in (grad_query, grad_key, grad_value))
+
+ return grad_query, grad_key, grad_value
+
+
+# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807
+def _flash_attention_forward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _save_ctx: bool = True,
+ _parallel_config: Optional["ParallelConfig"] = None,
+):
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not yet supported for flash-attn 2.")
+ if enable_gqa:
+ raise ValueError("`enable_gqa` is not yet supported for flash-attn 2.")
+
+ # Hardcoded for now
+ window_size = (-1, -1)
+ softcap = 0.0
+ alibi_slopes = None
+ deterministic = False
+ grad_enabled = any(x.requires_grad for x in (query, key, value))
+
+ if scale is None:
+ scale = query.shape[-1] ** (-0.5)
+
+ # flash-attn only returns LSE if dropout_p > 0. So, we need to workaround.
+ if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
+ dropout_p = dropout_p if dropout_p > 0 else 1e-30
+
+ with torch.set_grad_enabled(grad_enabled):
+ out, lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
+ query,
+ key,
+ value,
+ dropout_p,
+ scale,
+ is_causal,
+ window_size[0],
+ window_size[1],
+ softcap,
+ alibi_slopes,
+ return_lse,
+ )
+ lse = lse.permute(0, 2, 1)
+
+ if _save_ctx:
+ ctx.save_for_backward(query, key, value, out, lse, rng_state)
+ ctx.dropout_p = dropout_p
+ ctx.scale = scale
+ ctx.is_causal = is_causal
+ ctx.window_size = window_size
+ ctx.softcap = softcap
+ ctx.alibi_slopes = alibi_slopes
+ ctx.deterministic = deterministic
+
+ return (out, lse) if return_lse else out
+
+
+def _flash_attention_backward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ **kwargs,
+):
+ query, key, value, out, lse, rng_state = ctx.saved_tensors
+ grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
+
+ lse_d = _wrapped_flash_attn_backward( # noqa: F841
+ grad_out,
+ query,
+ key,
+ value,
+ out,
+ lse,
+ grad_query,
+ grad_key,
+ grad_value,
+ ctx.dropout_p,
+ ctx.scale,
+ ctx.is_causal,
+ ctx.window_size[0],
+ ctx.window_size[1],
+ ctx.softcap,
+ ctx.alibi_slopes,
+ ctx.deterministic,
+ rng_state,
+ )
+
+ # Head dimension may have been padded
+ grad_query = grad_query[..., : grad_out.shape[-1]]
+ grad_key = grad_key[..., : grad_out.shape[-1]]
+ grad_value = grad_value[..., : grad_out.shape[-1]]
+
+ return grad_query, grad_key, grad_value
+
+
+def _sage_attention_forward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _save_ctx: bool = True,
+ _parallel_config: Optional["ParallelConfig"] = None,
+):
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not yet supported for Sage attention.")
+ if dropout_p > 0.0:
+ raise ValueError("`dropout_p` is not yet supported for Sage attention.")
+ if enable_gqa:
+ raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
+
+ out = sageattn(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ is_causal=is_causal,
+ sm_scale=scale,
+ return_lse=return_lse,
+ )
+ lse = None
+ if return_lse:
+ out, lse, *_ = out
+ lse = lse.permute(0, 2, 1)
+
+ return (out, lse) if return_lse else out
+
+
+def _sage_attention_backward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+):
+ raise NotImplementedError("Backward pass is not implemented for Sage attention.")
+
+
+# ===== Context parallel =====
+
+
+# Reference:
+# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L827
+# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L246
+# For fullgraph=True tracing compatibility (since FakeTensor does not have a `wait` method):
+def _wait_tensor(tensor):
+ if isinstance(tensor, funcol.AsyncCollectiveTensor):
+ tensor = tensor.wait()
+ return tensor
+
+
+def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
+ shape = x.shape
+ # HACK: We need to flatten because despite making tensors contiguous, torch single-file-ization
+ # to benchmark triton codegen fails somewhere:
+ # buf25 = torch.ops._c10d_functional.all_to_all_single.default(buf24, [1, 1], [1, 1], '3')
+ # ValueError: Tensors must be contiguous
+ x = x.flatten()
+ x = funcol.all_to_all_single(x, None, None, group)
+ x = x.reshape(shape)
+ x = _wait_tensor(x)
+ return x
+
+
+class TemplatedRingAttention(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor],
+ dropout_p: float,
+ is_causal: bool,
+ scale: Optional[float],
+ enable_gqa: bool,
+ return_lse: bool,
+ forward_op,
+ backward_op,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ):
+ ring_mesh = _parallel_config.context_parallel_config._ring_mesh
+ rank = _parallel_config.context_parallel_config._ring_local_rank
+ world_size = _parallel_config.context_parallel_config.ring_degree
+ next_rank = (rank + 1) % world_size
+ prev_out = prev_lse = None
+
+ ctx.forward_op = forward_op
+ ctx.backward_op = backward_op
+ ctx.q_shape = query.shape
+ ctx.kv_shape = key.shape
+ ctx._parallel_config = _parallel_config
+
+ kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
+ kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group())
+ kv_buffer = kv_buffer.chunk(world_size)
+
+ for i in range(world_size):
+ if i > 0:
+ kv = kv_buffer[next_rank]
+ key_numel = key.numel()
+ key = kv[:key_numel].reshape_as(key)
+ value = kv[key_numel:].reshape_as(value)
+ next_rank = (next_rank + 1) % world_size
+
+ out, lse = forward_op(
+ ctx,
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ True,
+ _save_ctx=i == 0,
+ _parallel_config=_parallel_config,
+ )
+
+ if _parallel_config.context_parallel_config.convert_to_fp32:
+ out = out.to(torch.float32)
+ lse = lse.to(torch.float32)
+
+ lse = lse.unsqueeze(-1)
+ if prev_out is not None:
+ out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
+ lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
+ prev_out = out
+ prev_lse = lse
+
+ out = out.to(query.dtype)
+ lse = lse.squeeze(-1)
+
+ return (out, lse) if return_lse else out
+
+ @staticmethod
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ ):
+ ring_mesh = ctx._parallel_config.context_parallel_config._ring_mesh
+ rank = ctx._parallel_config.context_parallel_config._ring_local_rank
+ world_size = ctx._parallel_config.context_parallel_config.ring_degree
+ next_rank = (rank + 1) % world_size
+ next_ranks = list(range(1, world_size)) + [0]
+
+ accum_dtype = torch.float32 if ctx._parallel_config.context_parallel_config.convert_to_fp32 else grad_out.dtype
+ grad_query = torch.zeros(ctx.q_shape, dtype=accum_dtype, device=grad_out.device)
+ grad_key = torch.zeros(ctx.kv_shape, dtype=accum_dtype, device=grad_out.device)
+ grad_value = torch.zeros(ctx.kv_shape, dtype=accum_dtype, device=grad_out.device)
+ next_grad_kv = None
+
+ query, key, value, *_ = ctx.saved_tensors
+ kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
+ kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group())
+ kv_buffer = kv_buffer.chunk(world_size)
+
+ for i in range(world_size):
+ if i > 0:
+ kv = kv_buffer[next_rank]
+ key_numel = key.numel()
+ key = kv[:key_numel].reshape_as(key)
+ value = kv[key_numel:].reshape_as(value)
+ next_rank = (next_rank + 1) % world_size
+
+ grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx, grad_out)
+
+ if i > 0:
+ grad_kv_buffer = _wait_tensor(next_grad_kv)
+ grad_key_numel = grad_key.numel()
+ grad_key = grad_kv_buffer[:grad_key_numel].reshape_as(grad_key)
+ grad_value = grad_kv_buffer[grad_key_numel:].reshape_as(grad_value)
+
+ grad_query += grad_query_op
+ grad_key += grad_key_op
+ grad_value += grad_value_op
+
+ if i < world_size - 1:
+ grad_kv_buffer = torch.cat([grad_key.flatten(), grad_value.flatten()]).contiguous()
+ next_grad_kv = funcol.permute_tensor(grad_kv_buffer, next_ranks, group=ring_mesh.get_group())
+
+ grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))
+
+ return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
+
+
+class TemplatedUlyssesAttention(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor],
+ dropout_p: float,
+ is_causal: bool,
+ scale: Optional[float],
+ enable_gqa: bool,
+ return_lse: bool,
+ forward_op,
+ backward_op,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ):
+ ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
+ world_size = _parallel_config.context_parallel_config.ulysses_degree
+ group = ulysses_mesh.get_group()
+
+ ctx.forward_op = forward_op
+ ctx.backward_op = backward_op
+ ctx._parallel_config = _parallel_config
+
+ B, S_Q_LOCAL, H, D = query.shape
+ _, S_KV_LOCAL, _, _ = key.shape
+ H_LOCAL = H // world_size
+ query = query.reshape(B, S_Q_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
+ key = key.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
+ value = value.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
+ query, key, value = (_all_to_all_single(x, group) for x in (query, key, value))
+ query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))
+
+ out = forward_op(
+ ctx,
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ _save_ctx=True,
+ _parallel_config=_parallel_config,
+ )
+ if return_lse:
+ out, lse, *_ = out
+
+ out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
+ out = _all_to_all_single(out, group)
+ out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
+
+ if return_lse:
+ lse = lse.reshape(B, world_size, S_Q_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous()
+ lse = _all_to_all_single(lse, group)
+ lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous()
+ else:
+ lse = None
+
+ return (out, lse) if return_lse else out
+
+ @staticmethod
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ ):
+ ulysses_mesh = ctx._parallel_config.context_parallel_config._ulysses_mesh
+ world_size = ctx._parallel_config.context_parallel_config.ulysses_degree
+ group = ulysses_mesh.get_group()
+
+ B, S_LOCAL, H, D = grad_out.shape
+ H_LOCAL = H // world_size
+
+ grad_out = grad_out.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
+ grad_out = _all_to_all_single(grad_out, group)
+ grad_out = grad_out.flatten(0, 1).permute(1, 0, 2, 3).contiguous()
+
+ grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx, grad_out)
+
+ grad_query, grad_key, grad_value = (
+ x.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
+ for x in (grad_query_op, grad_key_op, grad_value_op)
+ )
+ grad_query, grad_key, grad_value = (_all_to_all_single(x, group) for x in (grad_query, grad_key, grad_value))
+ grad_query, grad_key, grad_value = (
+ x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
+ )
+
+ return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
+
+
+def _templated_context_parallel_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ *,
+ forward_op,
+ backward_op,
+ _parallel_config: Optional["ParallelConfig"] = None,
+):
+ if attn_mask is not None:
+ raise ValueError("Attention mask is not yet supported for templated attention.")
+ if is_causal:
+ raise ValueError("Causal attention is not yet supported for templated attention.")
+ if enable_gqa:
+ raise ValueError("GQA is not yet supported for templated attention.")
+
+ # TODO: add support for unified attention with ring/ulysses degree both being > 1
+ if _parallel_config.context_parallel_config.ring_degree > 1:
+ return TemplatedRingAttention.apply(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op,
+ backward_op,
+ _parallel_config,
+ )
+ elif _parallel_config.context_parallel_config.ulysses_degree > 1:
+ return TemplatedUlyssesAttention.apply(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op,
+ backward_op,
+ _parallel_config,
+ )
+ else:
+ raise ValueError("Reaching this branch of code is unexpected. Please report a bug.")
# ===== Attention backends =====
@@ -500,34 +1182,50 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_context_parallel=True,
)
def _flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
- scale: Optional[float] = None,
is_causal: bool = False,
- window_size: Tuple[int, int] = (-1, -1),
- softcap: float = 0.0,
- alibi_slopes: Optional[torch.Tensor] = None,
- deterministic: bool = False,
- return_attn_probs: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
- out = flash_attn_func(
- q=query,
- k=key,
- v=value,
- dropout_p=dropout_p,
- softmax_scale=scale,
- causal=is_causal,
- window_size=window_size,
- softcap=softcap,
- alibi_slopes=alibi_slopes,
- deterministic=deterministic,
- return_attn_probs=return_attn_probs,
- )
- return out
+ lse = None
+ if _parallel_config is None:
+ out = flash_attn_func(
+ q=query,
+ k=key,
+ v=value,
+ dropout_p=dropout_p,
+ softmax_scale=scale,
+ causal=is_causal,
+ return_attn_probs=return_lse,
+ )
+ if return_lse:
+ out, lse, *_ = out
+ else:
+ out = _templated_context_parallel_attention(
+ query,
+ key,
+ value,
+ None,
+ dropout_p,
+ is_causal,
+ scale,
+ False,
+ return_lse,
+ forward_op=_flash_attention_forward_op,
+ backward_op=_flash_attention_backward_op,
+ _parallel_config=_parallel_config,
+ )
+ if return_lse:
+ out, lse = out
+
+ return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
@@ -538,19 +1236,12 @@ def _flash_varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
- cu_seqlens_q: Optional[torch.Tensor] = None,
- cu_seqlens_k: Optional[torch.Tensor] = None,
- max_seqlen_q: Optional[int] = None,
- max_seqlen_k: Optional[int] = None,
+ attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
is_causal: bool = False,
- window_size: Tuple[int, int] = (-1, -1),
- softcap: float = 0.0,
- alibi_slopes: Optional[torch.Tensor] = None,
- deterministic: bool = False,
- return_attn_probs: bool = False,
- attn_mask: Optional[torch.Tensor] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape
@@ -558,16 +1249,11 @@ def _flash_varlen_attention(
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
- if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
- (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
- _prepare_for_flash_attn_or_sage_varlen(
- batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
- )
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
- else:
- seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
- cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
- cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
+ )
key_valid, value_valid = [], []
for b in range(batch_size):
@@ -590,11 +1276,7 @@ def _flash_varlen_attention(
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
- window_size=window_size,
- softcap=softcap,
- alibi_slopes=alibi_slopes,
- deterministic=deterministic,
- return_attn_probs=return_attn_probs,
+ return_attn_probs=return_lse,
)
out = out.unflatten(0, (batch_size, -1))
@@ -606,6 +1288,29 @@ def _flash_varlen_attention(
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _flash_attention_3(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ out, lse = _wrapped_flash_attn_3(
+ q=query,
+ k=key,
+ v=value,
+ softmax_scale=scale,
+ causal=is_causal,
+ )
+ return (out, lse) if return_lse else out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._FLASH_3_HUB,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _flash_attention_3_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
@@ -615,8 +1320,9 @@ def _flash_attention_3(
softcap: float = 0.0,
deterministic: bool = False,
return_attn_probs: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
- out, lse, *_ = flash_attn_3_func(
+ out = flash_attn_3_func_hub(
q=query,
k=key,
v=value,
@@ -627,14 +1333,16 @@ def _flash_attention_3(
k_descale=None,
v_descale=None,
window_size=window_size,
- attention_chunk=0,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
+ return_attn_probs=return_attn_probs,
)
- return (out, lse) if return_attn_probs else out
+ # When `return_attn_probs` is True, the above returns a tuple of
+ # actual outputs and lse.
+ return (out[0], out[1]) if return_attn_probs else out
@_AttentionBackendRegistry.register(
@@ -645,17 +1353,11 @@ def _flash_varlen_attention_3(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
- cu_seqlens_q: Optional[torch.Tensor] = None,
- cu_seqlens_k: Optional[torch.Tensor] = None,
- max_seqlen_q: Optional[int] = None,
- max_seqlen_k: Optional[int] = None,
+ attn_mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
is_causal: bool = False,
- window_size: Tuple[int, int] = (-1, -1),
- softcap: float = 0.0,
- deterministic: bool = False,
- return_attn_probs: bool = False,
- attn_mask: Optional[torch.Tensor] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape
@@ -663,16 +1365,11 @@ def _flash_varlen_attention_3(
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
- if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
- (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
- _prepare_for_flash_attn_or_sage_varlen(
- batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
- )
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
- else:
- seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
- cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
- cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
+ )
key_valid, value_valid = [], []
for b in range(batch_size):
@@ -692,24 +1389,12 @@ def _flash_varlen_attention_3(
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
- seqused_q=None,
- seqused_k=None,
softmax_scale=scale,
causal=is_causal,
- qv=None,
- q_descale=None,
- k_descale=None,
- v_descale=None,
- window_size=window_size,
- softcap=softcap,
- num_splits=1,
- pack_gqa=None,
- deterministic=deterministic,
- sm_margin=0,
)
out = out.unflatten(0, (batch_size, -1))
- return (out, lse) if return_attn_probs else out
+ return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
@@ -725,7 +1410,7 @@ def _native_flex_attention(
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
- kernel_options: Optional[Dict[str, Any]] = None,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
# TODO: should we LRU cache the block mask creation?
score_mod = None
@@ -770,7 +1455,6 @@ def _native_flex_attention(
scale=scale,
enable_gqa=enable_gqa,
return_lse=return_lse,
- kernel_options=kernel_options,
)
out = out.permute(0, 2, 1, 3)
return out
@@ -789,7 +1473,11 @@ def _native_attention(
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("Native attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
@@ -808,6 +1496,7 @@ def _native_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_CUDNN,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_context_parallel=True,
)
def _native_cudnn_attention(
query: torch.Tensor,
@@ -818,21 +1507,43 @@ def _native_cudnn_attention(
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
- query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
- with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
- out = torch.nn.functional.scaled_dot_product_attention(
- query=query,
- key=key,
- value=value,
- attn_mask=attn_mask,
- dropout_p=dropout_p,
- is_causal=is_causal,
- scale=scale,
- enable_gqa=enable_gqa,
+ lse = None
+ if _parallel_config is None and not return_lse:
+ query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value))
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+ else:
+ out = _templated_context_parallel_attention(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op=_cudnn_attention_forward_op,
+ backward_op=_cudnn_attention_backward_op,
+ _parallel_config=_parallel_config,
)
- out = out.permute(0, 2, 1, 3)
- return out
+ if return_lse:
+ out, lse = out
+
+ return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
@@ -848,7 +1559,11 @@ def _native_efficient_attention(
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("Native efficient attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention(
@@ -877,7 +1592,11 @@ def _native_flash_attention(
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("Native flash attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention(
@@ -907,7 +1626,11 @@ def _native_math_attention(
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("Native math attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
out = torch.nn.functional.scaled_dot_product_attention(
@@ -934,21 +1657,28 @@ def _native_npu_attention(
value: torch.Tensor,
dropout_p: float = 0.0,
scale: Optional[float] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
- return npu_fusion_attention(
+ if return_lse:
+ raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
+ query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
+ out = npu_fusion_attention(
query,
key,
value,
- query.size(2), # num_heads
- input_layout="BSND",
+ query.size(1), # num_heads
+ input_layout="BNSD",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
pre_tockens=65536,
- next_tokens=65536,
+ next_tockens=65536,
keep_prob=1.0 - dropout_p,
sync=False,
inner_precise=0,
)[0]
+ out = out.transpose(1, 2).contiguous()
+ return out
# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853
@@ -961,7 +1691,11 @@ def _native_xla_attention(
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("XLA attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
query = query / math.sqrt(query.shape[-1])
out = xla_flash_attention(
@@ -977,6 +1711,7 @@ def _native_xla_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_context_parallel=True,
)
def _sage_attention(
query: torch.Tensor,
@@ -985,16 +1720,40 @@ def _sage_attention(
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
- return sageattn(
- q=query,
- k=key,
- v=value,
- tensor_layout="NHD",
- is_causal=is_causal,
- sm_scale=scale,
- return_lse=return_lse,
- )
+ lse = None
+ if _parallel_config is None:
+ out = sageattn(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ is_causal=is_causal,
+ sm_scale=scale,
+ return_lse=return_lse,
+ )
+ if return_lse:
+ out, lse, *_ = out
+ else:
+ out = _templated_context_parallel_attention(
+ query,
+ key,
+ value,
+ None,
+ 0.0,
+ is_causal,
+ scale,
+ False,
+ return_lse,
+ forward_op=_sage_attention_forward_op,
+ backward_op=_sage_attention_backward_op,
+ _parallel_config=_parallel_config,
+ )
+ if return_lse:
+ out, lse = out
+
+ return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
@@ -1005,31 +1764,26 @@ def _sage_varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
- cu_seqlens_q: Optional[torch.Tensor] = None,
- cu_seqlens_k: Optional[torch.Tensor] = None,
- max_seqlen_q: Optional[int] = None,
- max_seqlen_k: Optional[int] = None,
+ attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
- smooth_k: bool = True,
- attn_mask: Optional[torch.Tensor] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("Sage varlen backend does not support setting `return_lse=True`.")
+
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
- if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
- (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
- _prepare_for_flash_attn_or_sage_varlen(
- batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
- )
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
- else:
- seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
- cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
- cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
+ )
key_valid, value_valid = [], []
for b in range(batch_size):
@@ -1051,7 +1805,6 @@ def _sage_varlen_attention(
max_seqlen_k=max_seqlen_k,
is_causal=is_causal,
sm_scale=scale,
- smooth_k=smooth_k,
)
out = out.unflatten(0, (batch_size, -1))
@@ -1068,11 +1821,8 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
- qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
- pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
- smooth_k: bool = True,
- smooth_v: bool = False,
return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp8_cuda(
q=query,
@@ -1080,11 +1830,7 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
v=value,
tensor_layout="NHD",
is_causal=is_causal,
- qk_quant_gran=qk_quant_gran,
sm_scale=scale,
- pv_accum_dtype=pv_accum_dtype,
- smooth_k=smooth_k,
- smooth_v=smooth_v,
return_lse=return_lse,
)
@@ -1099,10 +1845,8 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
- qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
- pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
- smooth_k: bool = True,
return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp8_cuda_sm90(
q=query,
@@ -1110,10 +1854,7 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
v=value,
tensor_layout="NHD",
is_causal=is_causal,
- qk_quant_gran=qk_quant_gran,
sm_scale=scale,
- pv_accum_dtype=pv_accum_dtype,
- smooth_k=smooth_k,
return_lse=return_lse,
)
@@ -1128,11 +1869,8 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
- qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
- pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32",
- smooth_k: bool = True,
- smooth_v: bool = False,
return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp16_cuda(
q=query,
@@ -1140,11 +1878,7 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
v=value,
tensor_layout="NHD",
is_causal=is_causal,
- qk_quant_gran=qk_quant_gran,
sm_scale=scale,
- pv_accum_dtype=pv_accum_dtype,
- smooth_k=smooth_k,
- smooth_v=smooth_v,
return_lse=return_lse,
)
@@ -1159,19 +1893,16 @@ def _sage_qk_int8_pv_fp16_triton_attention(
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
- quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton",
- smooth_k: bool = True,
return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp16_triton(
q=query,
k=key,
v=value,
tensor_layout="NHD",
- quantization_backend=quantization_backend,
is_causal=is_causal,
sm_scale=scale,
- smooth_k=smooth_k,
return_lse=return_lse,
)
@@ -1189,7 +1920,12 @@ def _xformers_attention(
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("xformers attention backend does not support setting `return_lse=True`.")
+
batch_size, seq_len_q, num_heads_q, _ = query.shape
_, seq_len_kv, num_heads_kv, _ = key.shape
diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py
index 17e6f33df0..1bde62e5c6 100644
--- a/src/diffusers/models/attention_flax.py
+++ b/src/diffusers/models/attention_flax.py
@@ -19,6 +19,11 @@ import flax.linen as nn
import jax
import jax.numpy as jnp
+from ..utils import logging
+
+
+logger = logging.get_logger(__name__)
+
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
"""Multi-head dot product attention with a limited number of queries."""
@@ -151,6 +156,11 @@ class FlaxAttention(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
inner_dim = self.dim_head * self.heads
self.scale = self.dim_head**-0.5
@@ -277,6 +287,11 @@ class FlaxBasicTransformerBlock(nn.Module):
split_head_dim: bool = False
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
# self attention (or cross_attention if only_cross_attention is True)
self.attn1 = FlaxAttention(
self.dim,
@@ -365,6 +380,11 @@ class FlaxTransformer2DModel(nn.Module):
split_head_dim: bool = False
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
inner_dim = self.n_heads * self.d_head
@@ -454,6 +474,11 @@ class FlaxFeedForward(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
# The second linear layer needs to be called
# net_2 for now to match the index of the Sequential layer
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
@@ -484,6 +509,11 @@ class FlaxGEGLU(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
inner_dim = self.dim * 4
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.dropout)
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 990245de17..66455d733a 100755
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -3669,11 +3669,7 @@ class FusedAttnProcessor2_0:
fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is currently 🧪 experimental in nature and can change in future.
-
-
+ > [!WARNING] > This API is currently 🧪 experimental in nature and can change in future.
"""
def __init__(self):
diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py
index bfe386f1f6..a95b0ae64a 100644
--- a/src/diffusers/models/auto_model.py
+++ b/src/diffusers/models/auto_model.py
@@ -19,6 +19,7 @@ from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin
from ..utils import logging
+from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
logger = logging.get_logger(__name__)
@@ -114,16 +115,14 @@ class AutoModel(ConfigMixin):
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
+ trust_remote_cocde (`bool`, *optional*, defaults to `False`):
+ Whether to trust remote code
-
-
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
- auth login`. You can also activate the special
- ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
+ > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
+ with `hf > auth login`. You can also activate the special >
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a >
firewalled environment.
-
-
Example:
```py
@@ -140,22 +139,22 @@ class AutoModel(ConfigMixin):
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
"""
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- token = kwargs.pop("token", None)
- local_files_only = kwargs.pop("local_files_only", False)
- revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
+ trust_remote_code = kwargs.pop("trust_remote_code", False)
- load_config_kwargs = {
- "cache_dir": cache_dir,
- "force_download": force_download,
- "proxies": proxies,
- "token": token,
- "local_files_only": local_files_only,
- "revision": revision,
- }
+ hub_kwargs_names = [
+ "cache_dir",
+ "force_download",
+ "local_files_only",
+ "proxies",
+ "resume_download",
+ "revision",
+ "token",
+ ]
+ hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}
+
+ # load_config_kwargs uses the same hub kwargs minus subfolder and resume_download
+ load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder", "resume_download"]}
library = None
orig_class_name = None
@@ -189,15 +188,35 @@ class AutoModel(ConfigMixin):
else:
raise ValueError(f"Couldn't find model associated with the config file at {pretrained_model_or_path}.")
- from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
+ has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
+ trust_remote_code = resolve_trust_remote_code(trust_remote_code, pretrained_model_or_path, has_remote_code)
+ if not has_remote_code and trust_remote_code:
+ raise ValueError(
+ "Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
+ )
- model_cls, _ = get_class_obj_and_candidates(
- library_name=library,
- class_name=orig_class_name,
- importable_classes=ALL_IMPORTABLE_CLASSES,
- pipelines=None,
- is_pipeline_module=False,
- )
+ if has_remote_code and trust_remote_code:
+ class_ref = config["auto_map"][cls.__name__]
+ module_file, class_name = class_ref.split(".")
+ module_file = module_file + ".py"
+ model_cls = get_class_from_dynamic_module(
+ pretrained_model_or_path,
+ subfolder=subfolder,
+ module_file=module_file,
+ class_name=class_name,
+ **hub_kwargs,
+ **kwargs,
+ )
+ else:
+ from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
+
+ model_cls, _ = get_class_obj_and_candidates(
+ library_name=library,
+ class_name=orig_class_name,
+ importable_classes=ALL_IMPORTABLE_CLASSES,
+ pipelines=None,
+ is_pipeline_module=False,
+ )
if model_cls is None:
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
diff --git a/src/diffusers/models/autoencoders/autoencoder_dc.py b/src/diffusers/models/autoencoders/autoencoder_dc.py
index 9c7d6360e0..783f22e97d 100644
--- a/src/diffusers/models/autoencoders/autoencoder_dc.py
+++ b/src/diffusers/models/autoencoders/autoencoder_dc.py
@@ -299,6 +299,7 @@ class Decoder(nn.Module):
act_fn: Union[str, Tuple[str]] = "silu",
upsample_block_type: str = "pixel_shuffle",
in_shortcut: bool = True,
+ conv_act_fn: str = "relu",
):
super().__init__()
@@ -349,7 +350,7 @@ class Decoder(nn.Module):
channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
- self.conv_act = nn.ReLU()
+ self.conv_act = get_activation(conv_act_fn)
self.conv_out = None
if layers_per_block[0] > 0:
@@ -414,6 +415,12 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
The normalization type(s) to use in the decoder.
decoder_act_fns (`Union[str, Tuple[str]]`, defaults to `"silu"`):
The activation function(s) to use in the decoder.
+ encoder_out_shortcut (`bool`, defaults to `True`):
+ Whether to use shortcut at the end of the encoder.
+ decoder_in_shortcut (`bool`, defaults to `True`):
+ Whether to use shortcut at the beginning of the decoder.
+ decoder_conv_act_fn (`str`, defaults to `"relu"`):
+ The activation function to use at the end of the decoder.
scaling_factor (`float`, defaults to `1.0`):
The multiplicative inverse of the root mean square of the latent features. This is used to scale the latent
space to have unit variance when training the diffusion model. The latents are scaled with the formula `z =
@@ -441,6 +448,9 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
downsample_block_type: str = "pixel_unshuffle",
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
decoder_act_fns: Union[str, Tuple[str]] = "silu",
+ encoder_out_shortcut: bool = True,
+ decoder_in_shortcut: bool = True,
+ decoder_conv_act_fn: str = "relu",
scaling_factor: float = 1.0,
) -> None:
super().__init__()
@@ -454,6 +464,7 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
layers_per_block=encoder_layers_per_block,
qkv_multiscales=encoder_qkv_multiscales,
downsample_block_type=downsample_block_type,
+ out_shortcut=encoder_out_shortcut,
)
self.decoder = Decoder(
in_channels=in_channels,
@@ -466,6 +477,8 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
norm_type=decoder_norm_types,
act_fn=decoder_act_fns,
upsample_block_type=upsample_block_type,
+ in_shortcut=decoder_in_shortcut,
+ conv_act_fn=decoder_conv_act_fn,
)
self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
@@ -604,7 +617,7 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
returned.
"""
if self.use_slicing and z.size(0) > 1:
- decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py
index 9a4375a36b..d823c2fb8b 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl.py
@@ -532,11 +532,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -556,11 +552,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
index 7b0f9889a5..dc5e775f67 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
@@ -18,7 +18,6 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
index 87ac406592..9872cf0968 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
@@ -23,7 +23,6 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py
index d84a0861e9..f95c4cf374 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py
@@ -17,7 +17,6 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
@@ -1052,7 +1051,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
is_residual=is_residual,
)
- self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
+ self.spatial_compression_ratio = scale_factor_spatial
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
@@ -1145,12 +1144,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def _encode(self, x: torch.Tensor):
_, _, num_frame, height, width = x.shape
- if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
- return self.tiled_encode(x)
-
self.clear_cache()
if self.config.patch_size is not None:
x = patchify(x, patch_size=self.config.patch_size)
+
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
+ return self.tiled_encode(x)
+
iter_ = 1 + (num_frame - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
diff --git a/src/diffusers/models/controlnets/__init__.py b/src/diffusers/models/controlnets/__init__.py
index 90ef438d25..7ce352879d 100644
--- a/src/diffusers/models/controlnets/__init__.py
+++ b/src/diffusers/models/controlnets/__init__.py
@@ -9,6 +9,7 @@ if is_torch_available():
HunyuanDiT2DControlNetModel,
HunyuanDiT2DMultiControlNetModel,
)
+ from .controlnet_qwenimage import QwenImageControlNetModel, QwenImageMultiControlNetModel
from .controlnet_sana import SanaControlNetModel
from .controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel
from .controlnet_sparsectrl import (
diff --git a/src/diffusers/models/controlnets/controlnet_flax.py b/src/diffusers/models/controlnets/controlnet_flax.py
index 4b2148666e..f7a8b98fa2 100644
--- a/src/diffusers/models/controlnets/controlnet_flax.py
+++ b/src/diffusers/models/controlnets/controlnet_flax.py
@@ -20,7 +20,7 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from ...configuration_utils import ConfigMixin, flax_register_to_config
-from ...utils import BaseOutput
+from ...utils import BaseOutput, logging
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from ..modeling_flax_utils import FlaxModelMixin
from ..unets.unet_2d_blocks_flax import (
@@ -30,6 +30,9 @@ from ..unets.unet_2d_blocks_flax import (
)
+logger = logging.get_logger(__name__)
+
+
@flax.struct.dataclass
class FlaxControlNetOutput(BaseOutput):
"""
@@ -50,6 +53,11 @@ class FlaxControlNetConditioningEmbedding(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
self.conv_in = nn.Conv(
self.block_out_channels[0],
kernel_size=(3, 3),
@@ -184,6 +192,11 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
return self.init(rngs, sample, timesteps, encoder_hidden_states, controlnet_cond)["params"]
def setup(self) -> None:
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4
diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py
new file mode 100644
index 0000000000..7c4955eb58
--- /dev/null
+++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py
@@ -0,0 +1,359 @@
+# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
+from ..attention_processor import AttentionProcessor
+from ..cache_utils import CacheMixin
+from ..controlnets.controlnet import zero_module
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..transformers.transformer_qwenimage import (
+ QwenEmbedRope,
+ QwenImageTransformerBlock,
+ QwenTimestepProjEmbeddings,
+ RMSNorm,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class QwenImageControlNetOutput(BaseOutput):
+ controlnet_block_samples: Tuple[torch.Tensor]
+
+
+class QwenImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 64,
+ out_channels: Optional[int] = 16,
+ num_layers: int = 60,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ joint_attention_dim: int = 3584,
+ axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
+ extra_condition_channels: int = 0, # for controlnet-inpainting
+ ):
+ super().__init__()
+ self.out_channels = out_channels or in_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+
+ self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
+
+ self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
+
+ self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
+
+ self.img_in = nn.Linear(in_channels, self.inner_dim)
+ self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ QwenImageTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # controlnet_blocks
+ self.controlnet_blocks = nn.ModuleList([])
+ for _ in range(len(self.transformer_blocks)):
+ self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
+ self.controlnet_x_embedder = zero_module(
+ torch.nn.Linear(in_channels + extra_condition_channels, self.inner_dim)
+ )
+
+ self.gradient_checkpointing = False
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self):
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ @classmethod
+ def from_transformer(
+ cls,
+ transformer,
+ num_layers: int = 5,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ load_weights_from_transformer=True,
+ extra_condition_channels: int = 0,
+ ):
+ config = dict(transformer.config)
+ config["num_layers"] = num_layers
+ config["attention_head_dim"] = attention_head_dim
+ config["num_attention_heads"] = num_attention_heads
+ config["extra_condition_channels"] = extra_condition_channels
+
+ controlnet = cls.from_config(config)
+
+ if load_weights_from_transformer:
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
+ controlnet.img_in.load_state_dict(transformer.img_in.state_dict())
+ controlnet.txt_in.load_state_dict(transformer.txt_in.state_dict())
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
+ controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
+
+ return controlnet
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ controlnet_cond: torch.Tensor,
+ conditioning_scale: float = 1.0,
+ encoder_hidden_states: torch.Tensor = None,
+ encoder_hidden_states_mask: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_shapes: Optional[List[Tuple[int, int, int]]] = None,
+ txt_seq_lens: Optional[List[int]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ """
+ The [`FluxTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ Input `hidden_states`.
+ controlnet_cond (`torch.Tensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for ControlNet outputs.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
+ A list of tensors that if specified are added to the residuals of transformer blocks.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ if joint_attention_kwargs is not None:
+ joint_attention_kwargs = joint_attention_kwargs.copy()
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+ hidden_states = self.img_in(hidden_states)
+
+ # add
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
+
+ temb = self.time_text_embed(timestep, hidden_states)
+
+ image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
+
+ timestep = timestep.to(hidden_states.dtype)
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
+
+ block_samples = ()
+ for index_block, block in enumerate(self.transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ encoder_hidden_states_mask,
+ temb,
+ image_rotary_emb,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+ block_samples = block_samples + (hidden_states,)
+
+ # controlnet block
+ controlnet_block_samples = ()
+ for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
+ block_sample = controlnet_block(block_sample)
+ controlnet_block_samples = controlnet_block_samples + (block_sample,)
+
+ # scaling
+ controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
+ controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return controlnet_block_samples
+
+ return QwenImageControlNetOutput(
+ controlnet_block_samples=controlnet_block_samples,
+ )
+
+
+class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
+ r"""
+ `QwenImageMultiControlNetModel` wrapper class for Multi-QwenImageControlNetModel
+
+ This module is a wrapper for multiple instances of the `QwenImageControlNetModel`. The `forward()` API is designed
+ to be compatible with `QwenImageControlNetModel`.
+
+ Args:
+ controlnets (`List[QwenImageControlNetModel]`):
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
+ `QwenImageControlNetModel` as a list.
+ """
+
+ def __init__(self, controlnets):
+ super().__init__()
+ self.nets = nn.ModuleList(controlnets)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ controlnet_cond: List[torch.tensor],
+ conditioning_scale: List[float],
+ encoder_hidden_states: torch.Tensor = None,
+ encoder_hidden_states_mask: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_shapes: Optional[List[Tuple[int, int, int]]] = None,
+ txt_seq_lens: Optional[List[int]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[QwenImageControlNetOutput, Tuple]:
+ # ControlNet-Union with multiple conditions
+ # only load one ControlNet for saving memories
+ if len(self.nets) == 1:
+ controlnet = self.nets[0]
+
+ for i, (image, scale) in enumerate(zip(controlnet_cond, conditioning_scale)):
+ block_samples = controlnet(
+ hidden_states=hidden_states,
+ controlnet_cond=image,
+ conditioning_scale=scale,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
+ timestep=timestep,
+ img_shapes=img_shapes,
+ txt_seq_lens=txt_seq_lens,
+ joint_attention_kwargs=joint_attention_kwargs,
+ return_dict=return_dict,
+ )
+
+ # merge samples
+ if i == 0:
+ control_block_samples = block_samples
+ else:
+ if block_samples is not None and control_block_samples is not None:
+ control_block_samples = [
+ control_block_sample + block_sample
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
+ ]
+ else:
+ raise ValueError("QwenImageMultiControlNetModel only supports a single controlnet-union now.")
+
+ return control_block_samples
diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py
index 8d892cb3b6..0641c8bc01 100644
--- a/src/diffusers/models/controlnets/controlnet_sd3.py
+++ b/src/diffusers/models/controlnets/controlnet_sd3.py
@@ -270,11 +270,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -294,11 +290,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/controlnets/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py
index aabae709e9..f5c69b9a46 100644
--- a/src/diffusers/models/controlnets/controlnet_xs.py
+++ b/src/diffusers/models/controlnets/controlnet_xs.py
@@ -16,7 +16,6 @@ from math import gcd
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
-import torch.utils.checkpoint
from torch import Tensor, nn
from ...configuration_utils import ConfigMixin, register_to_config
@@ -980,11 +979,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -1004,11 +999,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py
index 1e7e84edea..3790905e58 100644
--- a/src/diffusers/models/embeddings_flax.py
+++ b/src/diffusers/models/embeddings_flax.py
@@ -16,6 +16,11 @@ import math
import flax.linen as nn
import jax.numpy as jnp
+from ..utils import logging
+
+
+logger = logging.get_logger(__name__)
+
def get_sinusoidal_embeddings(
timesteps: jnp.ndarray,
@@ -76,6 +81,11 @@ class FlaxTimestepEmbedding(nn.Module):
The data type for the embedding parameters.
"""
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
time_embed_dim: int = 32
dtype: jnp.dtype = jnp.float32
@@ -104,6 +114,11 @@ class FlaxTimesteps(nn.Module):
flip_sin_to_cos: bool = False
freq_shift: float = 1
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
@nn.compact
def __call__(self, timesteps):
return get_sinusoidal_embeddings(
diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py
index 4e2d24b750..8b48ba6b48 100644
--- a/src/diffusers/models/model_loading_utils.py
+++ b/src/diffusers/models/model_loading_utils.py
@@ -14,12 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import functools
import importlib
import inspect
-import math
import os
from array import array
from collections import OrderedDict, defaultdict
+from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
@@ -31,6 +32,7 @@ from huggingface_hub.utils import EntryNotFoundError
from ..quantizers import DiffusersQuantizer
from ..utils import (
+ DEFAULT_HF_PARALLEL_LOADING_WORKERS,
GGUF_FILE_EXTENSION,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
@@ -110,9 +112,6 @@ def _determine_device_map(
device_map_kwargs["max_memory"] = max_memory
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
- if hf_quantizer is not None:
- hf_quantizer.validate_environment(device_map=device_map)
-
return device_map
@@ -310,6 +309,161 @@ def load_model_dict_into_meta(
return offload_index, state_dict_index
+def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
+ """
+ Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
+ checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
+ parameters.
+
+ """
+ if model_to_load.device.type == "meta":
+ return False
+
+ if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
+ return False
+
+ # Some models explicitly do not support param buffer assignment
+ if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
+ logger.debug(
+ f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
+ )
+ return False
+
+ # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
+ first_key = next(iter(model_to_load.state_dict().keys()))
+ if start_prefix + first_key in state_dict:
+ return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
+
+ return False
+
+
+def _load_shard_file(
+ shard_file,
+ model,
+ model_state_dict,
+ device_map=None,
+ dtype=None,
+ hf_quantizer=None,
+ keep_in_fp32_modules=None,
+ dduf_entries=None,
+ loaded_keys=None,
+ unexpected_keys=None,
+ offload_index=None,
+ offload_folder=None,
+ state_dict_index=None,
+ state_dict_folder=None,
+ ignore_mismatched_sizes=False,
+ low_cpu_mem_usage=False,
+):
+ state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
+ mismatched_keys = _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ loaded_keys,
+ ignore_mismatched_sizes,
+ )
+ error_msgs = []
+ if low_cpu_mem_usage:
+ offload_index, state_dict_index = load_model_dict_into_meta(
+ model,
+ state_dict,
+ device_map=device_map,
+ dtype=dtype,
+ hf_quantizer=hf_quantizer,
+ keep_in_fp32_modules=keep_in_fp32_modules,
+ unexpected_keys=unexpected_keys,
+ offload_folder=offload_folder,
+ offload_index=offload_index,
+ state_dict_index=state_dict_index,
+ state_dict_folder=state_dict_folder,
+ )
+ else:
+ assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
+
+ error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
+ return offload_index, state_dict_index, mismatched_keys, error_msgs
+
+
+def _load_shard_files_with_threadpool(
+ shard_files,
+ model,
+ model_state_dict,
+ device_map=None,
+ dtype=None,
+ hf_quantizer=None,
+ keep_in_fp32_modules=None,
+ dduf_entries=None,
+ loaded_keys=None,
+ unexpected_keys=None,
+ offload_index=None,
+ offload_folder=None,
+ state_dict_index=None,
+ state_dict_folder=None,
+ ignore_mismatched_sizes=False,
+ low_cpu_mem_usage=False,
+):
+ # Do not spawn anymore workers than you need
+ num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)
+
+ logger.info(f"Loading model weights in parallel with {num_workers} workers...")
+
+ error_msgs = []
+ mismatched_keys = []
+
+ load_one = functools.partial(
+ _load_shard_file,
+ model=model,
+ model_state_dict=model_state_dict,
+ device_map=device_map,
+ dtype=dtype,
+ hf_quantizer=hf_quantizer,
+ keep_in_fp32_modules=keep_in_fp32_modules,
+ dduf_entries=dduf_entries,
+ loaded_keys=loaded_keys,
+ unexpected_keys=unexpected_keys,
+ offload_index=offload_index,
+ offload_folder=offload_folder,
+ state_dict_index=state_dict_index,
+ state_dict_folder=state_dict_folder,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
+
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
+ with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
+ futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
+ for future in as_completed(futures):
+ result = future.result()
+ offload_index, state_dict_index, _mismatched_keys, _error_msgs = result
+ error_msgs += _error_msgs
+ mismatched_keys += _mismatched_keys
+ pbar.update(1)
+
+ return offload_index, state_dict_index, mismatched_keys, error_msgs
+
+
+def _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ loaded_keys,
+ ignore_mismatched_sizes,
+):
+ mismatched_keys = []
+ if ignore_mismatched_sizes:
+ for checkpoint_key in loaded_keys:
+ model_key = checkpoint_key
+ # If the checkpoint is sharded, we may not have the key here.
+ if checkpoint_key not in state_dict:
+ continue
+
+ if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
+ mismatched_keys.append(
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
+ )
+ del state_dict[checkpoint_key]
+ return mismatched_keys
+
+
def _load_state_dict_into_model(
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
) -> List[str]:
@@ -559,27 +713,39 @@ def _expand_device_map(device_map, param_names):
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
-def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None:
+def _caching_allocator_warmup(
+ model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype, hf_quantizer: Optional[DiffusersQuantizer]
+) -> None:
"""
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
very large margin.
"""
- # Remove disk and cpu devices, and cast to proper torch.device
+ factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
+
+ # Keep only accelerator devices
accelerator_device_map = {
param: torch.device(device)
for param, device in expanded_device_map.items()
if str(device) not in ["cpu", "disk"]
}
- parameter_count = defaultdict(lambda: 0)
+ if not accelerator_device_map:
+ return
+
+ elements_per_device = defaultdict(int)
for param_name, device in accelerator_device_map.items():
try:
- param = model.get_parameter(param_name)
+ p = model.get_parameter(param_name)
except AttributeError:
- param = model.get_buffer(param_name)
- parameter_count[device] += math.prod(param.shape)
+ try:
+ p = model.get_buffer(param_name)
+ except AttributeError:
+ raise AttributeError(f"Parameter or buffer with name={param_name} not found in model")
+ # TODO: account for TP when needed.
+ elements_per_device[device] += p.numel()
# This will kick off the caching allocator to avoid having to Malloc afterwards
- for device, param_count in parameter_count.items():
- _ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
+ for device, elem_count in elements_per_device.items():
+ warmup_elems = max(1, elem_count // factor)
+ _ = torch.empty(warmup_elems, dtype=dtype, device=device, requires_grad=False)
diff --git a/src/diffusers/models/modeling_flax_utils.py b/src/diffusers/models/modeling_flax_utils.py
index 010b737745..fd19578321 100644
--- a/src/diffusers/models/modeling_flax_utils.py
+++ b/src/diffusers/models/modeling_flax_utils.py
@@ -26,11 +26,11 @@ from flax.traverse_util import flatten_dict, unflatten_dict
from huggingface_hub import create_repo, hf_hub_download
from huggingface_hub.utils import (
EntryNotFoundError,
+ HfHubHTTPError,
RepositoryNotFoundError,
RevisionNotFoundError,
validate_hf_hub_args,
)
-from requests import HTTPError
from .. import __version__, is_torch_available
from ..utils import (
@@ -227,15 +227,9 @@ class FlaxModelMixin(PushToHubMixin):
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified, all the computation will be performed with the given `dtype`.
-
-
- This only specifies the dtype of the *computation* and does not influence the dtype of model
- parameters.
-
- If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and
- [`~FlaxModelMixin.to_bf16`].
-
-
+ > [!TIP] > This only specifies the dtype of the *computation* and does not influence the dtype of model
+ > parameters. > > If you wish to change the dtype of the model parameters, see
+ [`~FlaxModelMixin.to_fp16`] and > [`~FlaxModelMixin.to_bf16`].
model_args (sequence of positional arguments, *optional*):
All remaining positional arguments are passed to the underlying model's `__init__` method.
@@ -290,6 +284,10 @@ class FlaxModelMixin(PushToHubMixin):
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
"""
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
@@ -381,7 +379,7 @@ class FlaxModelMixin(PushToHubMixin):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
)
- except HTTPError as err:
+ except HfHubHTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py
index 815f12a707..1af7ba9ac5 100644
--- a/src/diffusers/models/modeling_utils.py
+++ b/src/diffusers/models/modeling_utils.py
@@ -15,6 +15,7 @@
# limitations under the License.
import copy
+import functools
import inspect
import itertools
import json
@@ -42,6 +43,7 @@ from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
CONFIG_NAME,
FLAX_WEIGHTS_NAME,
+ HF_ENABLE_PARALLEL_LOADING,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
@@ -63,15 +65,15 @@ from ..utils.hub_utils import (
populate_model_card,
)
from ..utils.torch_utils import empty_device_cache
+from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
from .model_loading_utils import (
_caching_allocator_warmup,
_determine_device_map,
_expand_device_map,
_fetch_index_file,
_fetch_index_file_legacy,
- _find_mismatched_keys,
- _load_state_dict_into_model,
- load_model_dict_into_meta,
+ _load_shard_file,
+ _load_shard_files_with_threadpool,
load_state_dict,
)
@@ -208,34 +210,6 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
return last_tuple[1].dtype
-def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
- """
- Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
- checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
- parameters.
-
- """
- if model_to_load.device.type == "meta":
- return False
-
- if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
- return False
-
- # Some models explicitly do not support param buffer assignment
- if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
- logger.debug(
- f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
- )
- return False
-
- # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
- first_key = next(iter(model_to_load.state_dict().keys()))
- if start_prefix + first_key in state_dict:
- return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
-
- return False
-
-
@contextmanager
def no_init_weights():
"""
@@ -275,6 +249,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_skip_layerwise_casting_patterns = None
_supports_group_offloading = True
_repeated_blocks = []
+ _parallel_config = None
+ _cp_plan = None
def __init__(self):
super().__init__()
@@ -427,12 +403,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
inference. Speed up during training is not guaranteed.
-
-
- ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
- precedent.
-
-
+ > [!WARNING] > ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient
+ attention takes > precedent.
Parameters:
attention_op (`Callable`, *optional*):
@@ -647,8 +619,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
def reset_attention_backend(self) -> None:
"""
- Resets the attention backend for the model. Following calls to `forward` will use the environment default or
- the torch native scaled dot product attention.
+ Resets the attention backend for the model. Following calls to `forward` will use the environment default, if
+ set, or the torch native scaled dot product attention.
"""
from .attention import AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
@@ -941,15 +913,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
-
-
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
- auth login`. You can also activate the special
- ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
+ > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
+ with `hf > auth login`. You can also activate the special >
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a >
firewalled environment.
-
-
Example:
```py
@@ -987,6 +955,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
quantization_config = kwargs.pop("quantization_config", None)
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)
+ parallel_config: Optional[Union[ParallelConfig, ContextParallelConfig]] = kwargs.pop("parallel_config", None)
+
+ is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
+ if is_parallel_loading_enabled and not low_cpu_mem_usage:
+ raise NotImplementedError("Parallel loading is not supported when not using `low_cpu_mem_usage`.")
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
@@ -1323,6 +1296,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
dduf_entries=dduf_entries,
+ is_parallel_loading_enabled=is_parallel_loading_enabled,
)
loading_info = {
"missing_keys": missing_keys,
@@ -1362,6 +1336,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
+ if parallel_config is not None:
+ model.enable_parallelism(config=parallel_config)
+
if output_loading_info:
return model, loading_info
@@ -1500,6 +1477,73 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
)
+ def enable_parallelism(
+ self,
+ *,
+ config: Union[ParallelConfig, ContextParallelConfig],
+ cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
+ ):
+ from ..hooks.context_parallel import apply_context_parallel
+ from .attention import AttentionModuleMixin
+ from .attention_processor import Attention, MochiAttention
+
+ logger.warning(
+ "`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
+ )
+
+ if isinstance(config, ContextParallelConfig):
+ config = ParallelConfig(context_parallel_config=config)
+
+ if not torch.distributed.is_initialized():
+ raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.")
+
+ rank = torch.distributed.get_rank()
+ world_size = torch.distributed.get_world_size()
+ device_type = torch._C._get_accelerator().type
+ device_module = torch.get_device_module(device_type)
+ device = torch.device(device_type, rank % device_module.device_count())
+
+ cp_mesh = None
+ if config.context_parallel_config is not None:
+ cp_config = config.context_parallel_config
+ if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1:
+ raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
+ if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1:
+ raise ValueError(
+ "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
+ )
+ if cp_config.ring_degree * cp_config.ulysses_degree > world_size:
+ raise ValueError(
+ f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})."
+ )
+ cp_mesh = torch.distributed.device_mesh.init_device_mesh(
+ device_type=device_type,
+ mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree),
+ mesh_dim_names=("ring", "ulysses"),
+ )
+
+ config.setup(rank, world_size, device, cp_mesh=cp_mesh)
+
+ if cp_plan is None and self._cp_plan is None:
+ raise ValueError(
+ "`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
+ )
+ cp_plan = cp_plan if cp_plan is not None else self._cp_plan
+
+ if config.context_parallel_config is not None:
+ apply_context_parallel(self, config.context_parallel_config, cp_plan)
+
+ self._parallel_config = config
+
+ attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
+ for module in self.modules():
+ if not isinstance(module, attention_classes):
+ continue
+ processor = module.processor
+ if processor is None or not hasattr(processor, "_parallel_config"):
+ continue
+ processor._parallel_config = config
+
@classmethod
def _load_pretrained_model(
cls,
@@ -1518,6 +1562,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
offload_state_dict: Optional[bool] = None,
offload_folder: Optional[Union[str, os.PathLike]] = None,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
+ is_parallel_loading_enabled: Optional[bool] = False,
):
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
@@ -1531,6 +1576,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
+ mismatched_keys = []
+ error_msgs = []
+
# Deal with offload
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
@@ -1550,10 +1598,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# tensors using their expected shape and not performing any initialization of the memory (empty data).
# When the actual device allocations happen, the allocator already has a pool of unused device memory
# that it can re-use for faster loading of the model.
- # TODO: add support for warmup with hf_quantizer
- if device_map is not None and hf_quantizer is None:
+ if device_map is not None:
expanded_device_map = _expand_device_map(device_map, expected_keys)
- _caching_allocator_warmup(model, expanded_device_map, dtype)
+ _caching_allocator_warmup(model, expanded_device_map, dtype, hf_quantizer)
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
state_dict_folder, state_dict_index = None, None
@@ -1566,37 +1613,39 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# if state dict is not None, it means that we don't need to read the files from resolved_model_file also
resolved_model_file = [state_dict]
- if len(resolved_model_file) > 1:
- resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
+ # Prepare the loading function sharing the attributes shared between them.
+ load_fn = functools.partial(
+ _load_shard_files_with_threadpool if is_parallel_loading_enabled else _load_shard_file,
+ model=model,
+ model_state_dict=model_state_dict,
+ device_map=device_map,
+ dtype=dtype,
+ hf_quantizer=hf_quantizer,
+ keep_in_fp32_modules=keep_in_fp32_modules,
+ dduf_entries=dduf_entries,
+ loaded_keys=loaded_keys,
+ unexpected_keys=unexpected_keys,
+ offload_index=offload_index,
+ offload_folder=offload_folder,
+ state_dict_index=state_dict_index,
+ state_dict_folder=state_dict_folder,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ )
- mismatched_keys = []
- assign_to_params_buffers = None
- error_msgs = []
+ if is_parallel_loading_enabled:
+ offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(resolved_model_file)
+ error_msgs += _error_msgs
+ mismatched_keys += _mismatched_keys
+ else:
+ shard_files = resolved_model_file
+ if len(resolved_model_file) > 1:
+ shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
- for shard_file in resolved_model_file:
- state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
- mismatched_keys += _find_mismatched_keys(
- state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes
- )
-
- if low_cpu_mem_usage:
- offload_index, state_dict_index = load_model_dict_into_meta(
- model,
- state_dict,
- device_map=device_map,
- dtype=dtype,
- hf_quantizer=hf_quantizer,
- keep_in_fp32_modules=keep_in_fp32_modules,
- unexpected_keys=unexpected_keys,
- offload_folder=offload_folder,
- offload_index=offload_index,
- state_dict_index=state_dict_index,
- state_dict_folder=state_dict_folder,
- )
- else:
- if assign_to_params_buffers is None:
- assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
- error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
+ for shard_file in shard_files:
+ offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)
+ error_msgs += _error_msgs
+ mismatched_keys += _mismatched_keys
empty_device_cache()
diff --git a/src/diffusers/models/resnet_flax.py b/src/diffusers/models/resnet_flax.py
index 9c80932c5c..9bedaa9a36 100644
--- a/src/diffusers/models/resnet_flax.py
+++ b/src/diffusers/models/resnet_flax.py
@@ -15,12 +15,22 @@ import flax.linen as nn
import jax
import jax.numpy as jnp
+from ..utils import logging
+
+
+logger = logging.get_logger(__name__)
+
class FlaxUpsample2D(nn.Module):
out_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
@@ -45,6 +55,11 @@ class FlaxDownsample2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
self.conv = nn.Conv(
self.out_channels,
kernel_size=(3, 3),
@@ -68,6 +83,11 @@ class FlaxResnetBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
out_channels = self.in_channels if self.out_channels is None else self.out_channels
self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-5)
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index 5550fed92d..b60f0636e6 100755
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -17,6 +17,7 @@ if is_torch_available():
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_allegro import AllegroTransformer3DModel
+ from .transformer_bria import BriaTransformer2DModel
from .transformer_chroma import ChromaTransformer2DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_cogview4 import CogView4Transformer2DModel
diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py
index a8d275d142..bf6d9e1b38 100644
--- a/src/diffusers/models/transformers/auraflow_transformer_2d.py
+++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py
@@ -13,7 +13,7 @@
# limitations under the License.
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -92,7 +92,7 @@ class AuraFlowPatchEmbed(nn.Module):
return selected_indices
- def forward(self, latent):
+ def forward(self, latent) -> torch.Tensor:
batch_size, num_channels, height, width = latent.size()
latent = latent.view(
batch_size,
@@ -173,7 +173,7 @@ class AuraFlowSingleTransformerBlock(nn.Module):
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
- ):
+ ) -> torch.Tensor:
residual = hidden_states
attention_kwargs = attention_kwargs or {}
@@ -242,7 +242,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
- ):
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
residual_context = encoder_hidden_states
attention_kwargs = attention_kwargs or {}
@@ -431,11 +431,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -455,11 +451,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
@@ -472,7 +464,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
timestep: torch.LongTensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py
index a8c98bccb8..9e0afdee66 100644
--- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py
+++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py
@@ -122,7 +122,7 @@ class CogVideoXBlock(nn.Module):
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1)
attention_kwargs = attention_kwargs or {}
@@ -397,11 +397,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -421,11 +417,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
@@ -441,7 +433,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
- ):
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
diff --git a/src/diffusers/models/transformers/consisid_transformer_3d.py b/src/diffusers/models/transformers/consisid_transformer_3d.py
index 41632dbd47..91fe811f00 100644
--- a/src/diffusers/models/transformers/consisid_transformer_3d.py
+++ b/src/diffusers/models/transformers/consisid_transformer_3d.py
@@ -315,7 +315,7 @@ class ConsisIDBlock(nn.Module):
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
@@ -691,7 +691,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
id_cond: Optional[torch.Tensor] = None,
id_vit_hidden: Optional[torch.Tensor] = None,
return_dict: bool = True,
- ):
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py
index f634718788..fbe9fe8df9 100644
--- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py
+++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py
@@ -324,11 +324,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -348,11 +344,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/transformers/lumina_nextdit2d.py b/src/diffusers/models/transformers/lumina_nextdit2d.py
index 84b1175386..bed5e69c2d 100644
--- a/src/diffusers/models/transformers/lumina_nextdit2d.py
+++ b/src/diffusers/models/transformers/lumina_nextdit2d.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -124,7 +124,7 @@ class LuminaNextDiTBlock(nn.Module):
encoder_mask: torch.Tensor,
temb: torch.Tensor,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- ):
+ ) -> torch.Tensor:
"""
Perform a forward pass through the LuminaNextDiTBlock.
@@ -297,7 +297,7 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
image_rotary_emb: torch.Tensor,
cross_attention_kwargs: Dict[str, Any] = None,
return_dict=True,
- ) -> torch.Tensor:
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
"""
Forward pass of LuminaNextDiT.
diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py
index 40a14bfd9b..5a22144228 100644
--- a/src/diffusers/models/transformers/pixart_transformer_2d.py
+++ b/src/diffusers/models/transformers/pixart_transformer_2d.py
@@ -258,11 +258,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -282,11 +278,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/transformers/stable_audio_transformer.py b/src/diffusers/models/transformers/stable_audio_transformer.py
index 969e6db122..ac9b3fca41 100644
--- a/src/diffusers/models/transformers/stable_audio_transformer.py
+++ b/src/diffusers/models/transformers/stable_audio_transformer.py
@@ -18,7 +18,6 @@ from typing import Dict, Optional, Union
import numpy as np
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
diff --git a/src/diffusers/models/transformers/transformer_bria.py b/src/diffusers/models/transformers/transformer_bria.py
new file mode 100644
index 0000000000..d54679306e
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_bria.py
@@ -0,0 +1,725 @@
+import inspect
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..cache_utils import CacheMixin
+from ..embeddings import TimestepEmbedding, apply_rotary_emb, get_timestep_embedding
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def _get_projections(attn: "BriaAttention", hidden_states, encoder_hidden_states=None):
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ encoder_query = encoder_key = encoder_value = None
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+
+ return query, key, value, encoder_query, encoder_key, encoder_value
+
+
+def _get_fused_projections(attn: "BriaAttention", hidden_states, encoder_hidden_states=None):
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+
+ encoder_query = encoder_key = encoder_value = (None,)
+ if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
+ encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
+
+ return query, key, value, encoder_query, encoder_key, encoder_value
+
+
+def _get_qkv_projections(attn: "BriaAttention", hidden_states, encoder_hidden_states=None):
+ if attn.fused_projections:
+ return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
+
+
+def get_1d_rotary_pos_embed(
+ dim: int,
+ pos: Union[np.ndarray, int],
+ theta: float = 10000.0,
+ use_real=False,
+ linear_factor=1.0,
+ ntk_factor=1.0,
+ repeat_interleave_real=True,
+ freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
+):
+ """
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
+
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
+ data type.
+
+ Args:
+ dim (`int`): Dimension of the frequency tensor.
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
+ theta (`float`, *optional*, defaults to 10000.0):
+ Scaling factor for frequency computation. Defaults to 10000.0.
+ use_real (`bool`, *optional*):
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+ linear_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor for the context extrapolation. Defaults to 1.0.
+ ntk_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
+ repeat_interleave_real (`bool`, *optional*, defaults to `True`):
+ If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
+ Otherwise, they are concateanted with themselves.
+ freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
+ the dtype of the frequency tensor.
+ Returns:
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
+ """
+ assert dim % 2 == 0
+
+ if isinstance(pos, int):
+ pos = torch.arange(pos)
+ if isinstance(pos, np.ndarray):
+ pos = torch.from_numpy(pos) # type: ignore # [S]
+
+ theta = theta * ntk_factor
+ freqs = (
+ 1.0
+ / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
+ / linear_factor
+ ) # [D/2]
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
+ if use_real and repeat_interleave_real:
+ # bria
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
+ return freqs_cos, freqs_sin
+ elif use_real:
+ # stable audio, allegro
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
+ return freqs_cos, freqs_sin
+ else:
+ # lumina
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
+ return freqs_cis
+
+
+class BriaAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
+
+ def __call__(
+ self,
+ attn: "BriaAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
+ attn, hidden_states, encoder_hidden_states
+ )
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ if attn.added_kv_proj_dim is not None:
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
+
+ encoder_query = attn.norm_added_q(encoder_query)
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([encoder_query, query], dim=1)
+ key = torch.cat([encoder_key, key], dim=1)
+ value = torch.cat([encoder_value, value], dim=1)
+
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
+ )
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class BriaAttention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = BriaAttnProcessor
+ _available_processors = [
+ BriaAttnProcessor,
+ ]
+
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ added_proj_bias: Optional[bool] = True,
+ out_bias: bool = True,
+ eps: float = 1e-5,
+ out_dim: int = None,
+ context_pre_only: Optional[bool] = None,
+ pre_only: bool = False,
+ elementwise_affine: bool = True,
+ processor=None,
+ ):
+ super().__init__()
+
+ self.head_dim = dim_head
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.query_dim = query_dim
+ self.use_bias = bias
+ self.dropout = dropout
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.context_pre_only = context_pre_only
+ self.pre_only = pre_only
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.added_proj_bias = added_proj_bias
+
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ if not self.pre_only:
+ self.to_out = torch.nn.ModuleList([])
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(torch.nn.Dropout(dropout))
+
+ if added_kv_proj_dim is not None:
+ self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
+
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
+
+
+class BriaEmbedND(torch.nn.Module):
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
+ def __init__(self, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ cos_out = []
+ sin_out = []
+ pos = ids.float()
+ is_mps = ids.device.type == "mps"
+ freqs_dtype = torch.float32 if is_mps else torch.float64
+ for i in range(n_axes):
+ cos, sin = get_1d_rotary_pos_embed(
+ self.axes_dim[i],
+ pos[:, i],
+ theta=self.theta,
+ repeat_interleave_real=True,
+ use_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ cos_out.append(cos)
+ sin_out.append(sin)
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
+ return freqs_cos, freqs_sin
+
+
+class BriaTimesteps(nn.Module):
+ def __init__(
+ self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
+ ):
+ super().__init__()
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+ self.scale = scale
+ self.time_theta = time_theta
+
+ def forward(self, timesteps):
+ t_emb = get_timestep_embedding(
+ timesteps,
+ self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ scale=self.scale,
+ max_period=self.time_theta,
+ )
+ return t_emb
+
+
+class BriaTimestepProjEmbeddings(nn.Module):
+ def __init__(self, embedding_dim, time_theta):
+ super().__init__()
+
+ self.time_proj = BriaTimesteps(
+ num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
+ )
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ def forward(self, timestep, dtype):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
+ return timesteps_emb
+
+
+class BriaPosEmbed(torch.nn.Module):
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
+ def __init__(self, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ cos_out = []
+ sin_out = []
+ pos = ids.float()
+ is_mps = ids.device.type == "mps"
+ freqs_dtype = torch.float32 if is_mps else torch.float64
+ for i in range(n_axes):
+ cos, sin = get_1d_rotary_pos_embed(
+ self.axes_dim[i],
+ pos[:, i],
+ theta=self.theta,
+ repeat_interleave_real=True,
+ use_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ cos_out.append(cos)
+ sin_out.append(sin)
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
+ return freqs_cos, freqs_sin
+
+
+@maybe_allow_in_graph
+class BriaTransformerBlock(nn.Module):
+ def __init__(
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
+ ):
+ super().__init__()
+
+ self.norm1 = AdaLayerNormZero(dim)
+ self.norm1_context = AdaLayerNormZero(dim)
+
+ self.attn = BriaAttention(
+ query_dim=dim,
+ added_kv_proj_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ context_pre_only=False,
+ bias=True,
+ processor=BriaAttnProcessor(),
+ eps=eps,
+ )
+
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+ attention_kwargs = attention_kwargs or {}
+
+ # Attention.
+ attention_outputs = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **attention_kwargs,
+ )
+
+ if len(attention_outputs) == 2:
+ attn_output, context_attn_output = attention_outputs
+ elif len(attention_outputs) == 3:
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
+
+ # Process attention outputs for the `hidden_states`.
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = hidden_states + ff_output
+ if len(attention_outputs) == 3:
+ hidden_states = hidden_states + ip_attn_output
+
+ # Process attention outputs for the `encoder_hidden_states`.
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+ if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+
+ return encoder_hidden_states, hidden_states
+
+
+@maybe_allow_in_graph
+class BriaSingleTransformerBlock(nn.Module):
+ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
+ super().__init__()
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
+
+ self.norm = AdaLayerNormZeroSingle(dim)
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
+ self.act_mlp = nn.GELU(approximate="tanh")
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
+
+ processor = BriaAttnProcessor()
+
+ self.attn = BriaAttention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=True,
+ processor=processor,
+ eps=1e-6,
+ pre_only=True,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ text_seq_len = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ residual = hidden_states
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+ attention_kwargs = attention_kwargs or {}
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **attention_kwargs,
+ )
+
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ gate = gate.unsqueeze(1)
+ hidden_states = gate * self.proj_out(hidden_states)
+ hidden_states = residual + hidden_states
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
+ return encoder_hidden_states, hidden_states
+
+
+class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
+ """
+ The Transformer model introduced in Flux. Based on FluxPipeline with several changes:
+ - no pooled embeddings
+ - We use zero padding for prompts
+ - No guidance embedding since this is not a distilled version
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Parameters:
+ patch_size (`int`): Patch size to turn the input data into small patches.
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 1,
+ in_channels: int = 64,
+ num_layers: int = 19,
+ num_single_layers: int = 38,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ joint_attention_dim: int = 4096,
+ pooled_projection_dim: int = None,
+ guidance_embeds: bool = False,
+ axes_dims_rope: List[int] = [16, 56, 56],
+ rope_theta=10000,
+ time_theta=10000,
+ ):
+ super().__init__()
+ self.out_channels = in_channels
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
+
+ self.pos_embed = BriaEmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
+
+ self.time_embed = BriaTimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
+ if guidance_embeds:
+ self.guidance_embed = BriaTimestepProjEmbeddings(embedding_dim=self.inner_dim)
+
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BriaTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=self.config.num_attention_heads,
+ attention_head_dim=self.config.attention_head_dim,
+ )
+ for i in range(self.config.num_layers)
+ ]
+ )
+
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ BriaSingleTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=self.config.num_attention_heads,
+ attention_head_dim=self.config.attention_head_dim,
+ )
+ for i in range(self.config.num_single_layers)
+ ]
+ )
+
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ controlnet_block_samples=None,
+ controlnet_single_block_samples=None,
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
+ """
+ The [`BriaTransformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
+ A list of tensors that if specified are added to the residuals of transformer blocks.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+ hidden_states = self.x_embedder(hidden_states)
+
+ timestep = timestep.to(hidden_states.dtype)
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype)
+ else:
+ guidance = None
+
+ temb = self.time_embed(timestep, dtype=hidden_states.dtype)
+
+ if guidance:
+ temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
+
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ if len(txt_ids.shape) == 3:
+ txt_ids = txt_ids[0]
+
+ if len(img_ids.shape) == 3:
+ img_ids = img_ids[0]
+
+ ids = torch.cat((txt_ids, img_ids), dim=0)
+ image_rotary_emb = self.pos_embed(ids)
+
+ for index_block, block in enumerate(self.transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ attention_kwargs,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # controlnet residual
+ if controlnet_block_samples is not None:
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
+ interval_control = int(np.ceil(interval_control))
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
+
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ attention_kwargs,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # controlnet residual
+ if controlnet_single_block_samples is not None:
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
+ interval_control = int(np.ceil(interval_control))
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ + controlnet_single_block_samples[index_block // interval_control]
+ )
+
+ hidden_states = self.norm_out(hidden_states, temb)
+ output = self.proj_out(hidden_states)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index 77f15f6ca6..7356f4a606 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -13,7 +13,7 @@
# limitations under the License.
-from typing import Dict, Union
+from typing import Dict, Tuple, Union
import torch
import torch.nn as nn
@@ -79,7 +79,7 @@ class CogView3PlusTransformerBlock(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
emb: torch.Tensor,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
@@ -293,7 +293,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
target_size: torch.Tensor,
crop_coords: torch.Tensor,
return_dict: bool = True,
- ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
"""
The [`CogView3PlusTransformer2DModel`] forward method.
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
index dc45befb98..64e9a538a7 100644
--- a/src/diffusers/models/transformers/transformer_cogview4.py
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -28,7 +28,7 @@ from ..cache_utils import CacheMixin
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
-from ..normalization import AdaLayerNormContinuous
+from ..normalization import LayerNorm, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -494,7 +494,7 @@ class CogView4TransformerBlock(nn.Module):
] = None,
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Timestep conditioning
(
norm_hidden_states,
@@ -584,6 +584,38 @@ class CogView4RotaryPosEmbed(nn.Module):
return (freqs.cos(), freqs.sin())
+class CogView4AdaLayerNormContinuous(nn.Module):
+ """
+ CogView4-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the
+ Linear on conditioning embedding.
+ """
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ conditioning_embedding_dim: int,
+ elementwise_affine: bool = True,
+ eps: float = 1e-5,
+ bias: bool = True,
+ norm_type: str = "layer_norm",
+ ):
+ super().__init__()
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
+ if norm_type == "layer_norm":
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
+ elif norm_type == "rms_norm":
+ self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
+ else:
+ raise ValueError(f"unknown norm_type {norm_type}")
+
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
+ # *** NO SiLU here ***
+ emb = self.linear(conditioning_embedding.to(x.dtype))
+ scale, shift = torch.chunk(emb, 2, dim=1)
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
+ return x
+
+
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
r"""
Args:
@@ -666,7 +698,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
)
# 4. Output projection
- self.norm_out = AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
+ self.norm_out = CogView4AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
self.gradient_checkpointing = False
@@ -685,7 +717,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
- ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py
index 9080cd508d..1a44644324 100644
--- a/src/diffusers/models/transformers/transformer_flux.py
+++ b/src/diffusers/models/transformers/transformer_flux.py
@@ -22,9 +22,9 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
-from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
-from ...utils.import_utils import is_torch_npu_available
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
+from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
@@ -74,6 +74,7 @@ def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_st
class FluxAttnProcessor:
_attention_backend = None
+ _parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
@@ -115,7 +116,12 @@ class FluxAttnProcessor:
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
- query, key, value, attn_mask=attention_mask, backend=self._attention_backend
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
@@ -137,6 +143,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
"""Flux Attention processor for IP-Adapter."""
_attention_backend = None
+ _parallel_config = None
def __init__(
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
@@ -221,6 +228,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
@@ -253,6 +261,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
@@ -354,25 +363,13 @@ class FluxSingleTransformerBlock(nn.Module):
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
- if is_torch_npu_available():
- from ..attention_processor import FluxAttnProcessor2_0_NPU
-
- deprecation_message = (
- "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
- "should be set explicitly using the `set_attn_processor` method."
- )
- deprecate("npu_processor", "0.34.0", deprecation_message)
- processor = FluxAttnProcessor2_0_NPU()
- else:
- processor = FluxAttnProcessor()
-
self.attn = FluxAttention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
- processor=processor,
+ processor=FluxAttnProcessor(),
eps=1e-6,
pre_only=True,
)
@@ -384,7 +381,7 @@ class FluxSingleTransformerBlock(nn.Module):
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
@@ -569,6 +566,15 @@ class FluxTransformer2DModel(
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
+ _cp_plan = {
+ "": {
+ "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
+ "txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
@register_to_config
def __init__(
diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py
index 77902dcf58..4a5aee29ab 100644
--- a/src/diffusers/models/transformers/transformer_hidream_image.py
+++ b/src/diffusers/models/transformers/transformer_hidream_image.py
@@ -55,7 +55,7 @@ class HiDreamImageTimestepEmbed(nn.Module):
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
- def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None):
+ def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None) -> torch.Tensor:
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
t_emb = self.timestep_embedder(t_emb)
return t_emb
@@ -87,7 +87,7 @@ class HiDreamImagePatchEmbed(nn.Module):
self.out_channels = out_channels
self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
- def forward(self, latent):
+ def forward(self, latent) -> torch.Tensor:
latent = self.proj(latent)
return latent
@@ -534,7 +534,7 @@ class HiDreamImageTransformerBlock(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
wtype = hidden_states.dtype
(
shift_msa_i,
@@ -592,7 +592,7 @@ class HiDreamBlock(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None,
- ) -> torch.Tensor:
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
return self.block(
hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks,
@@ -786,7 +786,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
**kwargs,
- ):
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if encoder_hidden_states is not None:
diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py
index 6944a6c536..bc857ccab4 100644
--- a/src/diffusers/models/transformers/transformer_hunyuan_video.py
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py
@@ -529,7 +529,7 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
@@ -684,7 +684,7 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
token_replace_emb: torch.Tensor = None,
num_tokens: int = None,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
@@ -1038,7 +1038,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
guidance: torch.Tensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
- ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py
index c2eb7fd2a7..60b40fff3c 100644
--- a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -216,7 +216,7 @@ class HunyuanVideoFramepackTransformer3DModel(
indices_latents_history_4x: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
- ):
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py
index 79149fb760..685c73c07c 100644
--- a/src/diffusers/models/transformers/transformer_ltx.py
+++ b/src/diffusers/models/transformers/transformer_ltx.py
@@ -24,6 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
+from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
@@ -51,6 +52,7 @@ class LTXVideoAttnProcessor:
"""
_attention_backend = None
+ _parallel_config = None
def __init__(self):
if is_torch_version("<", "2.0"):
@@ -100,6 +102,7 @@ class LTXVideoAttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
@@ -350,7 +353,9 @@ class LTXVideoTransformerBlock(nn.Module):
norm_hidden_states = self.norm1(hidden_states)
num_ada_params = self.scale_shift_table.shape[0]
- ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
+ ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
+ batch_size, temb.size(1), num_ada_params, -1
+ )
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
@@ -409,6 +414,18 @@ class LTXVideoTransformer3DModel(
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
_repeated_blocks = ["LTXVideoTransformerBlock"]
+ _cp_plan = {
+ "": {
+ "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "encoder_attention_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
+ },
+ "rope": {
+ 0: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
+ 1: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
@register_to_config
def __init__(
diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py
index 961ed72b73..05379270c1 100644
--- a/src/diffusers/models/transformers/transformer_qwenimage.py
+++ b/src/diffusers/models/transformers/transformer_qwenimage.py
@@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import functools
import math
from typing import Any, Dict, List, Optional, Tuple, Union
+import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -24,7 +25,8 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
-from ..attention import FeedForward
+from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
+from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
@@ -160,8 +162,8 @@ class QwenEmbedRope(nn.Module):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
- pos_index = torch.arange(1024)
- neg_index = torch.arange(1024).flip(0) * -1 - 1
+ pos_index = torch.arange(4096)
+ neg_index = torch.arange(4096).flip(0) * -1 - 1
self.pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
@@ -180,7 +182,7 @@ class QwenEmbedRope(nn.Module):
)
self.rope_cache = {}
- # 是否使用 scale rope
+ # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
self.scale_rope = scale_rope
def rope_params(self, index, dim, theta=10000):
@@ -204,38 +206,54 @@ class QwenEmbedRope(nn.Module):
if isinstance(video_fhw, list):
video_fhw = video_fhw[0]
- frame, height, width = video_fhw
- rope_key = f"{frame}_{height}_{width}"
+ if not isinstance(video_fhw, list):
+ video_fhw = [video_fhw]
- if rope_key not in self.rope_cache:
- seq_lens = frame * height * width
- freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
- freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
- freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
- if self.scale_rope:
- freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
- freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
- freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
- freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
+ vid_freqs = []
+ max_vid_index = 0
+ for idx, fhw in enumerate(video_fhw):
+ frame, height, width = fhw
+ rope_key = f"{idx}_{height}_{width}"
+ if not torch.compiler.is_compiling():
+ if rope_key not in self.rope_cache:
+ self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
+ video_freq = self.rope_cache[rope_key]
else:
- freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
- freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
+ video_freq = self._compute_video_freqs(frame, height, width, idx)
+ video_freq = video_freq.to(device)
+ vid_freqs.append(video_freq)
- freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
- self.rope_cache[rope_key] = freqs.clone().contiguous()
- vid_freqs = self.rope_cache[rope_key]
-
- if self.scale_rope:
- max_vid_index = max(height // 2, width // 2)
- else:
- max_vid_index = max(height, width)
+ if self.scale_rope:
+ max_vid_index = max(height // 2, width // 2, max_vid_index)
+ else:
+ max_vid_index = max(height, width, max_vid_index)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
+ vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
+ @functools.lru_cache(maxsize=None)
+ def _compute_video_freqs(self, frame, height, width, idx=0):
+ seq_lens = frame * height * width
+ freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
+ freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
+
+ freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
+ if self.scale_rope:
+ freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
+ freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
+ freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
+ freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
+ else:
+ freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
+ freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
+
+ freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
+ return freqs.clone().contiguous()
+
class QwenDoubleStreamAttnProcessor2_0:
"""
@@ -244,6 +262,7 @@ class QwenDoubleStreamAttnProcessor2_0:
"""
_attention_backend = None
+ _parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
@@ -317,6 +336,7 @@ class QwenDoubleStreamAttnProcessor2_0:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
# Reshape back
@@ -453,7 +473,9 @@ class QwenImageTransformerBlock(nn.Module):
return encoder_hidden_states, hidden_states
-class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
+class QwenImageTransformer2DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
+):
"""
The Transformer model introduced in Qwen.
@@ -482,6 +504,19 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
_supports_gradient_checkpointing = True
_no_split_modules = ["QwenImageTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
+ _repeated_blocks = ["QwenImageTransformerBlock"]
+ _cp_plan = {
+ "": {
+ "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
+ },
+ "pos_embed": {
+ 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+ 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
@register_to_config
def __init__(
@@ -535,6 +570,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
txt_seq_lens: Optional[List[int]] = None,
guidance: torch.Tensor = None, # TODO: this should probably be removed
attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_block_samples=None,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
@@ -614,6 +650,12 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
joint_attention_kwargs=attention_kwargs,
)
+ # controlnet residual
+ if controlnet_block_samples is not None:
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
+ interval_control = int(np.ceil(interval_control))
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
+
# Use only the image part (hidden_states) from the dual-stream blocks
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py
index edf77a7df7..762d89c303 100644
--- a/src/diffusers/models/transformers/transformer_sd3.py
+++ b/src/diffusers/models/transformers/transformer_sd3.py
@@ -280,11 +280,7 @@ class SD3Transformer2DModel(
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -304,11 +300,7 @@ class SD3Transformer2DModel(
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py
index 236fca690a..6b600aa224 100644
--- a/src/diffusers/models/transformers/transformer_skyreels_v2.py
+++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py
@@ -1,4 +1,4 @@
-# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 The SkyReels Team, The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,9 +21,10 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
-from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
-from ..attention import FeedForward
-from ..attention_processor import Attention
+from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import (
PixArtAlphaTextProjection,
@@ -39,20 +40,54 @@ from ..normalization import FP32LayerNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-class SkyReelsV2AttnProcessor2_0:
+def _get_qkv_projections(
+ attn: "SkyReelsV2Attention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor
+):
+ # encoder_hidden_states is only passed for cross-attention
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ if attn.fused_projections:
+ if attn.cross_attention_dim_head is None:
+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+ else:
+ # In cross-attention layers, we can only fuse the KV projections into a single linear
+ query = attn.to_q(hidden_states)
+ key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
+ else:
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+ return query, key, value
+
+
+def _get_added_kv_projections(attn: "SkyReelsV2Attention", encoder_hidden_states_img: torch.Tensor):
+ if attn.fused_projections:
+ key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
+ else:
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
+ return key_img, value_img
+
+
+class SkyReelsV2AttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
- "SkyReelsV2AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
+ "SkyReelsV2AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
- attn: Attention,
+ attn: "SkyReelsV2Attention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
- rotary_emb: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
@@ -60,58 +95,68 @@ class SkyReelsV2AttnProcessor2_0:
image_context_length = encoder_hidden_states.shape[1] - 512
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- query = attn.to_q(hidden_states)
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
+ query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
- query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
if rotary_emb is not None:
- def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
- x_rotated = torch.view_as_complex(hidden_states.to(torch.float32).unflatten(3, (-1, 2)))
- x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
- return x_out.type_as(hidden_states)
+ def apply_rotary_emb(
+ hidden_states: torch.Tensor,
+ freqs_cos: torch.Tensor,
+ freqs_sin: torch.Tensor,
+ ):
+ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
+ cos = freqs_cos[..., 0::2]
+ sin = freqs_sin[..., 1::2]
+ out = torch.empty_like(hidden_states)
+ out[..., 0::2] = x1 * cos - x2 * sin
+ out[..., 1::2] = x1 * sin + x2 * cos
+ return out.type_as(hidden_states)
- query = apply_rotary_emb(query, rotary_emb)
- key = apply_rotary_emb(key, rotary_emb)
+ query = apply_rotary_emb(query, *rotary_emb)
+ key = apply_rotary_emb(key, *rotary_emb)
# I2V task
hidden_states_img = None
if encoder_hidden_states_img is not None:
- key_img = attn.add_k_proj(encoder_hidden_states_img)
+ key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
key_img = attn.norm_added_k(key_img)
- value_img = attn.add_v_proj(encoder_hidden_states_img)
- key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ key_img = key_img.unflatten(2, (attn.heads, -1))
+ value_img = value_img.unflatten(2, (attn.heads, -1))
- hidden_states_img = F.scaled_dot_product_attention(
- query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
+ hidden_states_img = dispatch_attention_fn(
+ query,
+ key_img,
+ value_img,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
- hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
+ hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
- hidden_states = F.scaled_dot_product_attention(
+ hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
- hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
if hidden_states_img is not None:
@@ -122,7 +167,122 @@ class SkyReelsV2AttnProcessor2_0:
return hidden_states
-# Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding with WanImageEmbedding -> SkyReelsV2ImageEmbedding
+class SkyReelsV2AttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = (
+ "The SkyReelsV2AttnProcessor2_0 class is deprecated and will be removed in a future version. "
+ "Please use SkyReelsV2AttnProcessor instead. "
+ )
+ deprecate("SkyReelsV2AttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
+ return SkyReelsV2AttnProcessor(*args, **kwargs)
+
+
+class SkyReelsV2Attention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = SkyReelsV2AttnProcessor
+ _available_processors = [SkyReelsV2AttnProcessor]
+
+ def __init__(
+ self,
+ dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ eps: float = 1e-5,
+ dropout: float = 0.0,
+ added_kv_proj_dim: Optional[int] = None,
+ cross_attention_dim_head: Optional[int] = None,
+ processor=None,
+ is_cross_attention=None,
+ ):
+ super().__init__()
+
+ self.inner_dim = dim_head * heads
+ self.heads = heads
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.cross_attention_dim_head = cross_attention_dim_head
+ self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
+
+ self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
+ self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_out = torch.nn.ModuleList(
+ [
+ torch.nn.Linear(self.inner_dim, dim, bias=True),
+ torch.nn.Dropout(dropout),
+ ]
+ )
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
+ self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
+
+ self.add_k_proj = self.add_v_proj = None
+ if added_kv_proj_dim is not None:
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
+
+ self.is_cross_attention = cross_attention_dim_head is not None
+
+ self.set_processor(processor)
+
+ def fuse_projections(self):
+ if getattr(self, "fused_projections", False):
+ return
+
+ if self.cross_attention_dim_head is None:
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_qkv = nn.Linear(in_features, out_features, bias=True)
+ self.to_qkv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+ else:
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ if self.added_kv_proj_dim is not None:
+ concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
+ concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_added_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ self.fused_projections = True
+
+ @torch.no_grad()
+ def unfuse_projections(self):
+ if not getattr(self, "fused_projections", False):
+ return
+
+ if hasattr(self, "to_qkv"):
+ delattr(self, "to_qkv")
+ if hasattr(self, "to_kv"):
+ delattr(self, "to_kv")
+ if hasattr(self, "to_added_kv"):
+ delattr(self, "to_added_kv")
+
+ self.fused_projections = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
+
+
class SkyReelsV2ImageEmbedding(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
super().__init__()
@@ -213,7 +373,11 @@ class SkyReelsV2TimeTextImageEmbedding(nn.Module):
class SkyReelsV2RotaryPosEmbed(nn.Module):
def __init__(
- self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
+ self,
+ attention_head_dim: int,
+ patch_size: Tuple[int, int, int],
+ max_seq_len: int,
+ theta: float = 10000.0,
):
super().__init__()
@@ -223,37 +387,55 @@ class SkyReelsV2RotaryPosEmbed(nn.Module):
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+
+ freqs_cos = []
+ freqs_sin = []
- freqs = []
for dim in [t_dim, h_dim, w_dim]:
- freq = get_1d_rotary_pos_embed(
- dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float32
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
+ dim,
+ max_seq_len,
+ theta,
+ use_real=True,
+ repeat_interleave_real=True,
+ freqs_dtype=freqs_dtype,
)
- freqs.append(freq)
- self.freqs = torch.cat(freqs, dim=1)
+ freqs_cos.append(freq_cos)
+ freqs_sin.append(freq_sin)
+
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
- freqs = self.freqs.to(hidden_states.device)
- freqs = freqs.split_with_sizes(
- [
- self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
- self.attention_head_dim // 6,
- self.attention_head_dim // 6,
- ],
- dim=1,
- )
+ split_sizes = [
+ self.attention_head_dim - 2 * (self.attention_head_dim // 3),
+ self.attention_head_dim // 3,
+ self.attention_head_dim // 3,
+ ]
- freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
- freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
- freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
- freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
- return freqs
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
+
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+
+ return freqs_cos, freqs_sin
+@maybe_allow_in_graph
class SkyReelsV2TransformerBlock(nn.Module):
def __init__(
self,
@@ -269,33 +451,24 @@ class SkyReelsV2TransformerBlock(nn.Module):
# 1. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
- self.attn1 = Attention(
- query_dim=dim,
+ self.attn1 = SkyReelsV2Attention(
+ dim=dim,
heads=num_heads,
- kv_heads=num_heads,
dim_head=dim // num_heads,
- qk_norm=qk_norm,
eps=eps,
- bias=True,
- cross_attention_dim=None,
- out_bias=True,
- processor=SkyReelsV2AttnProcessor2_0(),
+ cross_attention_dim_head=None,
+ processor=SkyReelsV2AttnProcessor(),
)
# 2. Cross-attention
- self.attn2 = Attention(
- query_dim=dim,
+ self.attn2 = SkyReelsV2Attention(
+ dim=dim,
heads=num_heads,
- kv_heads=num_heads,
dim_head=dim // num_heads,
- qk_norm=qk_norm,
eps=eps,
- bias=True,
- cross_attention_dim=None,
- out_bias=True,
added_kv_proj_dim=added_kv_proj_dim,
- added_proj_bias=True,
- processor=SkyReelsV2AttnProcessor2_0(),
+ cross_attention_dim_head=dim // num_heads,
+ processor=SkyReelsV2AttnProcessor(),
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
@@ -321,15 +494,15 @@ class SkyReelsV2TransformerBlock(nn.Module):
# For 4D temb in Diffusion Forcing framework, we assume the shape is (b, 6, f * pp_h * pp_w, inner_dim)
e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e]
+
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
- attn_output = self.attn1(
- hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask
- )
+ attn_output = self.attn1(norm_hidden_states, None, attention_mask, rotary_emb)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
+
# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
- attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
hidden_states = hidden_states + attn_output
# 3. Feed-forward
@@ -338,10 +511,13 @@ class SkyReelsV2TransformerBlock(nn.Module):
)
ff_output = self.ffn(norm_hidden_states)
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
+
return hidden_states
-class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
+class SkyReelsV2Transformer3DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
+):
r"""
A Transformer model for video-like data used in the Wan-based SkyReels-V2 model.
@@ -389,6 +565,7 @@ class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr
_no_split_modules = ["SkyReelsV2TransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
+ _repeated_blocks = ["SkyReelsV2TransformerBlock"]
@register_to_config
def __init__(
diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py
index 968a0369c2..dd75fb124f 100644
--- a/src/diffusers/models/transformers/transformer_wan.py
+++ b/src/diffusers/models/transformers/transformer_wan.py
@@ -23,6 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
+from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
@@ -66,6 +67,7 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t
class WanAttnProcessor:
_attention_backend = None
+ _parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
@@ -132,6 +134,7 @@ class WanAttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
@@ -144,6 +147,7 @@ class WanAttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
@@ -539,6 +543,19 @@ class WanTransformer3DModel(
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = ["WanTransformerBlock"]
+ _cp_plan = {
+ "rope": {
+ 0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
+ 1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
+ },
+ "blocks.0": {
+ "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "blocks.*": {
+ "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
@register_to_config
def __init__(
@@ -665,12 +682,12 @@ class WanTransformer3DModel(
# 5. Output norm, projection & unpatchify
if temb.ndim == 3:
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
- shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
+ shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
shift = shift.squeeze(2)
scale = scale.squeeze(2)
else:
# batch_size, inner_dim
- shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
+ shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py
index e039d36219..30c38c244a 100644
--- a/src/diffusers/models/transformers/transformer_wan_vace.py
+++ b/src/diffusers/models/transformers/transformer_wan_vace.py
@@ -21,7 +21,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
-from ..attention import FeedForward
+from ..attention import AttentionMixin, FeedForward
from ..cache_utils import CacheMixin
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -103,7 +103,7 @@ class WanVACETransformerBlock(nn.Module):
control_hidden_states = control_hidden_states + hidden_states
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
- self.scale_shift_table + temb.float()
+ self.scale_shift_table.to(temb.device) + temb.float()
).chunk(6, dim=1)
# 1. Self-attention
@@ -134,7 +134,9 @@ class WanVACETransformerBlock(nn.Module):
return conditioning_states, control_hidden_states
-class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
+class WanVACETransformer3DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
+):
r"""
A Transformer model for video-like data used in the Wan model.
@@ -359,7 +361,7 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
hidden_states = hidden_states + control_hint * scale
# 6. Output norm, projection & unpatchify
- shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
+ shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
diff --git a/src/diffusers/models/unets/unet_1d.py b/src/diffusers/models/unets/unet_1d.py
index 4f57f3349b..4c4c528a59 100644
--- a/src/diffusers/models/unets/unet_1d.py
+++ b/src/diffusers/models/unets/unet_1d.py
@@ -82,6 +82,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
out_channels: int = 2,
extra_in_channels: int = 0,
time_embedding_type: str = "fourier",
+ time_embedding_dim: Optional[int] = None,
flip_sin_to_cos: bool = True,
use_timestep_embedding: bool = False,
freq_shift: float = 0.0,
@@ -100,15 +101,23 @@ class UNet1DModel(ModelMixin, ConfigMixin):
# time
if time_embedding_type == "fourier":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
self.time_proj = GaussianFourierProjection(
- embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ embedding_size=time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
- timestep_input_dim = 2 * block_out_channels[0]
+ timestep_input_dim = time_embed_dim
elif time_embedding_type == "positional":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
self.time_proj = Timesteps(
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
)
timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
if use_timestep_embedding:
time_embed_dim = block_out_channels[0] * 4
diff --git a/src/diffusers/models/unets/unet_2d_blocks_flax.py b/src/diffusers/models/unets/unet_2d_blocks_flax.py
index abd025165e..6e6005afdc 100644
--- a/src/diffusers/models/unets/unet_2d_blocks_flax.py
+++ b/src/diffusers/models/unets/unet_2d_blocks_flax.py
@@ -15,10 +15,14 @@
import flax.linen as nn
import jax.numpy as jnp
+from ...utils import logging
from ..attention_flax import FlaxTransformer2DModel
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
+logger = logging.get_logger(__name__)
+
+
class FlaxCrossAttnDownBlock2D(nn.Module):
r"""
Cross Attention 2D Downsizing block - original architecture from Unet transformers:
@@ -60,6 +64,11 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
transformer_layers_per_block: int = 1
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
resnets = []
attentions = []
@@ -135,6 +144,11 @@ class FlaxDownBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
resnets = []
for i in range(self.num_layers):
@@ -208,6 +222,11 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
transformer_layers_per_block: int = 1
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
resnets = []
attentions = []
@@ -288,6 +307,11 @@ class FlaxUpBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
resnets = []
for i in range(self.num_layers):
@@ -356,6 +380,11 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
transformer_layers_per_block: int = 1
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
# there is always at least one resnet
resnets = [
FlaxResnetBlock2D(
diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py
index 736deb28c3..f04d3dfa01 100644
--- a/src/diffusers/models/unets/unet_2d_condition.py
+++ b/src/diffusers/models/unets/unet_2d_condition.py
@@ -16,7 +16,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
@@ -872,11 +871,7 @@ class UNet2DConditionModel(
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -895,11 +890,7 @@ class UNet2DConditionModel(
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/unets/unet_2d_condition_flax.py b/src/diffusers/models/unets/unet_2d_condition_flax.py
index 7c21ddb690..8d9a309afb 100644
--- a/src/diffusers/models/unets/unet_2d_condition_flax.py
+++ b/src/diffusers/models/unets/unet_2d_condition_flax.py
@@ -20,7 +20,7 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from ...configuration_utils import ConfigMixin, flax_register_to_config
-from ...utils import BaseOutput
+from ...utils import BaseOutput, logging
from ..embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from ..modeling_flax_utils import FlaxModelMixin
from .unet_2d_blocks_flax import (
@@ -32,6 +32,9 @@ from .unet_2d_blocks_flax import (
)
+logger = logging.get_logger(__name__)
+
+
@flax.struct.dataclass
class FlaxUNet2DConditionOutput(BaseOutput):
"""
@@ -163,6 +166,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
def setup(self) -> None:
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
block_out_channels = self.block_out_channels
time_embed_dim = block_out_channels[0] * 4
diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py
index bd67ea414a..6a119185b8 100644
--- a/src/diffusers/models/unets/unet_3d_condition.py
+++ b/src/diffusers/models/unets/unet_3d_condition.py
@@ -18,7 +18,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
@@ -508,11 +507,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -532,11 +527,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py
index 8449bf894c..3dba8edca7 100644
--- a/src/diffusers/models/unets/unet_i2vgen_xl.py
+++ b/src/diffusers/models/unets/unet_i2vgen_xl.py
@@ -16,7 +16,6 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
@@ -472,11 +471,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -496,11 +491,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/unets/unet_kandinsky3.py b/src/diffusers/models/unets/unet_kandinsky3.py
index 423669a22f..27241ce2e6 100644
--- a/src/diffusers/models/unets/unet_kandinsky3.py
+++ b/src/diffusers/models/unets/unet_kandinsky3.py
@@ -16,7 +16,6 @@ from dataclasses import dataclass
from typing import Dict, Tuple, Union
import torch
-import torch.utils.checkpoint
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py
index 0a112b5249..18d5eb917f 100644
--- a/src/diffusers/models/unets/unet_motion_model.py
+++ b/src/diffusers/models/unets/unet_motion_model.py
@@ -18,7 +18,6 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
@@ -1911,11 +1910,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -1935,11 +1930,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py
index 93398a51ea..13653b9037 100644
--- a/src/diffusers/models/vae_flax.py
+++ b/src/diffusers/models/vae_flax.py
@@ -25,10 +25,13 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from ..configuration_utils import ConfigMixin, flax_register_to_config
-from ..utils import BaseOutput
+from ..utils import BaseOutput, logging
from .modeling_flax_utils import FlaxModelMixin
+logger = logging.get_logger(__name__)
+
+
@flax.struct.dataclass
class FlaxDecoderOutput(BaseOutput):
"""
@@ -73,6 +76,10 @@ class FlaxUpsample2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
self.conv = nn.Conv(
self.in_channels,
kernel_size=(3, 3),
@@ -107,6 +114,11 @@ class FlaxDownsample2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
self.conv = nn.Conv(
self.in_channels,
kernel_size=(3, 3),
@@ -149,6 +161,11 @@ class FlaxResnetBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
out_channels = self.in_channels if self.out_channels is None else self.out_channels
self.norm1 = nn.GroupNorm(num_groups=self.groups, epsilon=1e-6)
@@ -221,6 +238,11 @@ class FlaxAttentionBlock(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
self.num_heads = self.channels // self.num_head_channels if self.num_head_channels is not None else 1
dense = partial(nn.Dense, self.channels, dtype=self.dtype)
@@ -302,6 +324,11 @@ class FlaxDownEncoderBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
resnets = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
@@ -359,6 +386,11 @@ class FlaxUpDecoderBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
resnets = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
@@ -413,6 +445,11 @@ class FlaxUNetMidBlock2D(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
resnet_groups = self.resnet_groups if self.resnet_groups is not None else min(self.in_channels // 4, 32)
# there is always at least one resnet
@@ -504,6 +541,11 @@ class FlaxEncoder(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
block_out_channels = self.block_out_channels
# in
self.conv_in = nn.Conv(
@@ -616,6 +658,11 @@ class FlaxDecoder(nn.Module):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
block_out_channels = self.block_out_channels
# z to block_in
@@ -788,6 +835,11 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
dtype: jnp.dtype = jnp.float32
def setup(self):
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
self.encoder = FlaxEncoder(
in_channels=self.config.in_channels,
out_channels=self.config.latent_channels,
diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py
index e0f2e31388..86ed735134 100644
--- a/src/diffusers/modular_pipelines/__init__.py
+++ b/src/diffusers/modular_pipelines/__init__.py
@@ -7,9 +7,15 @@ from ..utils import (
get_objects_from_module,
is_torch_available,
is_transformers_available,
+ logging,
)
+logger = logging.get_logger(__name__)
+logger.warning(
+ "Modular Diffusers is currently an experimental feature under active development. The API is subject to breaking changes in future releases."
+)
+
# These modules contain pipelines from multiple libraries/frameworks
_dummy_objects = {}
_import_structure = {}
@@ -25,7 +31,6 @@ else:
_import_structure["modular_pipeline"] = [
"ModularPipelineBlocks",
"ModularPipeline",
- "PipelineBlock",
"AutoPipelineBlocks",
"SequentialPipelineBlocks",
"LoopSequentialPipelineBlocks",
@@ -41,7 +46,20 @@ else:
]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
- _import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline"]
+ _import_structure["flux"] = [
+ "FluxAutoBlocks",
+ "FluxModularPipeline",
+ "FluxKontextAutoBlocks",
+ "FluxKontextModularPipeline",
+ ]
+ _import_structure["qwenimage"] = [
+ "QwenImageAutoBlocks",
+ "QwenImageModularPipeline",
+ "QwenImageEditModularPipeline",
+ "QwenImageEditAutoBlocks",
+ "QwenImageEditPlusModularPipeline",
+ "QwenImageEditPlusAutoBlocks",
+ ]
_import_structure["components_manager"] = ["ComponentsManager"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -52,28 +70,26 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .components_manager import ComponentsManager
- from .flux import FluxAutoBlocks, FluxModularPipeline
+ from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,
LoopSequentialPipelineBlocks,
ModularPipeline,
ModularPipelineBlocks,
- PipelineBlock,
PipelineState,
SequentialPipelineBlocks,
)
- from .modular_pipeline_utils import (
- ComponentSpec,
- ConfigSpec,
- InputParam,
- InsertableDict,
- OutputParam,
- )
- from .stable_diffusion_xl import (
- StableDiffusionXLAutoBlocks,
- StableDiffusionXLModularPipeline,
+ from .modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, InsertableDict, OutputParam
+ from .qwenimage import (
+ QwenImageAutoBlocks,
+ QwenImageEditAutoBlocks,
+ QwenImageEditModularPipeline,
+ QwenImageEditPlusAutoBlocks,
+ QwenImageEditPlusModularPipeline,
+ QwenImageModularPipeline,
)
+ from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
from .wan import WanAutoBlocks, WanModularPipeline
else:
import sys
diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py
index f48a227e2e..9dd8035c44 100644
--- a/src/diffusers/modular_pipelines/components_manager.py
+++ b/src/diffusers/modular_pipelines/components_manager.py
@@ -25,6 +25,7 @@ from ..utils import (
is_accelerate_available,
logging,
)
+from ..utils.torch_utils import get_device
if is_accelerate_available():
@@ -161,7 +162,9 @@ class AutoOffloadStrategy:
current_module_size = model.get_memory_footprint()
- mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0]
+ device_type = execution_device.type
+ device_module = getattr(torch, device_type, torch.cuda)
+ mem_on_device = device_module.mem_get_info(execution_device.index)[0]
mem_on_device = mem_on_device - self.memory_reserve_margin
if current_module_size < mem_on_device:
return []
@@ -283,11 +286,7 @@ class ComponentsManager:
encoders, etc.) across different modular pipelines. It includes features for duplicate detection, memory
management, and component organization.
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
Example:
```python
@@ -301,7 +300,7 @@ class ComponentsManager:
cm.add("vae", vae_model, collection="sdxl")
# Enable auto offloading
- cm.enable_auto_cpu_offload(device="cuda")
+ cm.enable_auto_cpu_offload()
# Retrieve components
unet = cm.get_one(name="unet", collection="sdxl")
@@ -490,6 +489,8 @@ class ComponentsManager:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
+ if torch.xpu.is_available():
+ torch.xpu.empty_cache()
# YiYi TODO: rename to search_components for now, may remove this method
def search_components(
@@ -678,7 +679,7 @@ class ComponentsManager:
return get_return_dict(matches, return_dict_with_names)
- def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"):
+ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None, memory_reserve_margin="3GB"):
"""
Enable automatic CPU offloading for all components.
@@ -704,6 +705,8 @@ class ComponentsManager:
self.disable_auto_cpu_offload()
offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
+ if device is None:
+ device = get_device()
device = torch.device(device)
if device.index is None:
device = torch.device(f"{device.type}:{0}")
diff --git a/src/diffusers/modular_pipelines/flux/__init__.py b/src/diffusers/modular_pipelines/flux/__init__.py
index 2891edf790..ec00986611 100644
--- a/src/diffusers/modular_pipelines/flux/__init__.py
+++ b/src/diffusers/modular_pipelines/flux/__init__.py
@@ -25,14 +25,18 @@ else:
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
+ "AUTO_BLOCKS_KONTEXT",
+ "FLUX_KONTEXT_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"FluxAutoBeforeDenoiseStep",
"FluxAutoBlocks",
- "FluxAutoBlocks",
"FluxAutoDecodeStep",
"FluxAutoDenoiseStep",
+ "FluxKontextAutoBlocks",
+ "FluxKontextAutoDenoiseStep",
+ "FluxKontextBeforeDenoiseStep",
]
- _import_structure["modular_pipeline"] = ["FluxModularPipeline"]
+ _import_structure["modular_pipeline"] = ["FluxKontextModularPipeline", "FluxModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -45,13 +49,18 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
+ AUTO_BLOCKS_KONTEXT,
+ FLUX_KONTEXT_BLOCKS,
TEXT2IMAGE_BLOCKS,
FluxAutoBeforeDenoiseStep,
FluxAutoBlocks,
FluxAutoDecodeStep,
FluxAutoDenoiseStep,
+ FluxKontextAutoBlocks,
+ FluxKontextAutoDenoiseStep,
+ FluxKontextBeforeDenoiseStep,
)
- from .modular_pipeline import FluxModularPipeline
+ from .modular_pipeline import FluxKontextModularPipeline, FluxModularPipeline
else:
import sys
diff --git a/src/diffusers/modular_pipelines/flux/before_denoise.py b/src/diffusers/modular_pipelines/flux/before_denoise.py
index ffc77bb24f..c098b7d4f1 100644
--- a/src/diffusers/modular_pipelines/flux/before_denoise.py
+++ b/src/diffusers/modular_pipelines/flux/before_denoise.py
@@ -18,10 +18,11 @@ from typing import List, Optional, Union
import numpy as np
import torch
+from ...pipelines import FluxPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
-from ..modular_pipeline import PipelineBlock, PipelineState
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import FluxModularPipeline
@@ -103,120 +104,55 @@ def calculate_shift(
return mu
-def _pack_latents(latents, batch_size, num_channels_latents, height, width):
- latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
- latents = latents.permute(0, 2, 4, 1, 3, 5)
- latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
-
- return latents
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
-def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
- latent_image_ids = torch.zeros(height, width, 3)
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+def _get_initial_timesteps_and_optionals(
+ transformer,
+ scheduler,
+ batch_size,
+ height,
+ width,
+ vae_scale_factor,
+ num_inference_steps,
+ guidance_scale,
+ sigmas,
+ device,
+):
+ image_seq_len = (int(height) // vae_scale_factor // 2) * (int(width) // vae_scale_factor // 2)
- latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
-
- latent_image_ids = latent_image_ids.reshape(
- latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
+ sigmas = None
+ mu = calculate_shift(
+ image_seq_len,
+ scheduler.config.get("base_image_seq_len", 256),
+ scheduler.config.get("max_image_seq_len", 4096),
+ scheduler.config.get("base_shift", 0.5),
+ scheduler.config.get("max_shift", 1.15),
)
+ timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu)
+ if transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(batch_size)
+ else:
+ guidance = None
- return latent_image_ids.to(device=device, dtype=dtype)
+ return timesteps, num_inference_steps, sigmas, guidance
-class FluxInputStep(PipelineBlock):
- model_name = "flux"
-
- @property
- def description(self) -> str:
- return (
- "Input processing step that:\n"
- " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
- " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n"
- "All input tensors are expected to have either batch_size=1 or match the batch_size\n"
- "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
- "have a final batch_size of batch_size * num_images_per_prompt."
- )
-
- @property
- def inputs(self) -> List[InputParam]:
- return [
- InputParam("num_images_per_prompt", default=1),
- ]
-
- @property
- def intermediate_inputs(self) -> List[str]:
- return [
- InputParam(
- "prompt_embeds",
- required=True,
- type_hint=torch.Tensor,
- description="Pre-generated text embeddings. Can be generated from text_encoder step.",
- ),
- InputParam(
- "pooled_prompt_embeds",
- type_hint=torch.Tensor,
- description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.",
- ),
- # TODO: support negative embeddings?
- ]
-
- @property
- def intermediate_outputs(self) -> List[str]:
- return [
- OutputParam(
- "batch_size",
- type_hint=int,
- description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
- ),
- OutputParam(
- "dtype",
- type_hint=torch.dtype,
- description="Data type of model tensor inputs (determined by `prompt_embeds`)",
- ),
- OutputParam(
- "prompt_embeds",
- type_hint=torch.Tensor,
- description="text embeddings used to guide the image generation",
- ),
- OutputParam(
- "pooled_prompt_embeds",
- type_hint=torch.Tensor,
- description="pooled text embeddings used to guide the image generation",
- ),
- # TODO: support negative embeddings?
- ]
-
- def check_inputs(self, components, block_state):
- if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is not None:
- if block_state.prompt_embeds.shape[0] != block_state.pooled_prompt_embeds.shape[0]:
- raise ValueError(
- "`prompt_embeds` and `pooled_prompt_embeds` must have the same batch size when passed directly, but"
- f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `pooled_prompt_embeds`"
- f" {block_state.pooled_prompt_embeds.shape}."
- )
-
- @torch.no_grad()
- def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
- # TODO: consider adding negative embeddings?
- block_state = self.get_block_state(state)
- self.check_inputs(components, block_state)
-
- block_state.batch_size = block_state.prompt_embeds.shape[0]
- block_state.dtype = block_state.prompt_embeds.dtype
-
- _, seq_len, _ = block_state.prompt_embeds.shape
- block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
- block_state.prompt_embeds = block_state.prompt_embeds.view(
- block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
- )
- self.set_block_state(state, block_state)
-
- return components, state
-
-
-class FluxSetTimestepsStep(PipelineBlock):
+class FluxSetTimestepsStep(ModularPipelineBlocks):
model_name = "flux"
@property
@@ -235,17 +171,15 @@ class FluxSetTimestepsStep(PipelineBlock):
InputParam("sigmas"),
InputParam("guidance_scale", default=3.5),
InputParam("latents", type_hint=torch.Tensor),
- ]
-
- @property
- def intermediate_inputs(self) -> List[str]:
- return [
+ InputParam("num_images_per_prompt", default=1),
+ InputParam("height", type_hint=int),
+ InputParam("width", type_hint=int),
InputParam(
- "latents",
+ "batch_size",
required=True,
- type_hint=torch.Tensor,
- description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
- )
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
+ ),
]
@property
@@ -264,39 +198,127 @@ class FluxSetTimestepsStep(PipelineBlock):
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
+
scheduler = components.scheduler
+ transformer = components.transformer
- latents = block_state.latents
- image_seq_len = latents.shape[1]
-
- num_inference_steps = block_state.num_inference_steps
- sigmas = block_state.sigmas
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
- if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
- sigmas = None
+ batch_size = block_state.batch_size * block_state.num_images_per_prompt
+ timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals(
+ transformer,
+ scheduler,
+ batch_size,
+ block_state.height,
+ block_state.width,
+ components.vae_scale_factor,
+ block_state.num_inference_steps,
+ block_state.guidance_scale,
+ block_state.sigmas,
+ block_state.device,
+ )
+ block_state.timesteps = timesteps
+ block_state.num_inference_steps = num_inference_steps
block_state.sigmas = sigmas
- mu = calculate_shift(
- image_seq_len,
- scheduler.config.get("base_image_seq_len", 256),
- scheduler.config.get("max_image_seq_len", 4096),
- scheduler.config.get("base_shift", 0.5),
- scheduler.config.get("max_shift", 1.15),
+ block_state.guidance = guidance
+
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
+ components.scheduler.set_begin_index(0)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class FluxImg2ImgSetTimestepsStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
+
+ @property
+ def description(self) -> str:
+ return "Step that sets the scheduler's timesteps for inference"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("num_inference_steps", default=50),
+ InputParam("timesteps"),
+ InputParam("sigmas"),
+ InputParam("strength", default=0.6),
+ InputParam("guidance_scale", default=3.5),
+ InputParam("num_images_per_prompt", default=1),
+ InputParam("height", type_hint=int),
+ InputParam("width", type_hint=int),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
+ OutputParam(
+ "num_inference_steps",
+ type_hint=int,
+ description="The number of denoising steps to perform at inference time",
+ ),
+ OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."),
+ ]
+
+ @staticmethod
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps with self.scheduler->scheduler
+ def get_timesteps(scheduler, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = scheduler.timesteps[t_start * scheduler.order :]
+ if hasattr(scheduler, "set_begin_index"):
+ scheduler.set_begin_index(t_start * scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ block_state.device = components._execution_device
+
+ block_state.height = block_state.height or components.default_height
+ block_state.width = block_state.width or components.default_width
+
+ scheduler = components.scheduler
+ transformer = components.transformer
+ batch_size = block_state.batch_size * block_state.num_images_per_prompt
+ timesteps, num_inference_steps, sigmas, guidance = _get_initial_timesteps_and_optionals(
+ transformer,
+ scheduler,
+ batch_size,
+ block_state.height,
+ block_state.width,
+ components.vae_scale_factor,
+ block_state.num_inference_steps,
+ block_state.guidance_scale,
+ block_state.sigmas,
+ block_state.device,
)
- block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
- scheduler, block_state.num_inference_steps, block_state.device, sigmas=block_state.sigmas, mu=mu
+ timesteps, num_inference_steps = self.get_timesteps(
+ scheduler, num_inference_steps, block_state.strength, block_state.device
)
- if components.transformer.config.guidance_embeds:
- guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
- guidance = guidance.expand(latents.shape[0])
- else:
- guidance = None
+ block_state.timesteps = timesteps
+ block_state.num_inference_steps = num_inference_steps
+ block_state.sigmas = sigmas
block_state.guidance = guidance
self.set_block_state(state, block_state)
return components, state
-class FluxPrepareLatentsStep(PipelineBlock):
+class FluxPrepareLatentsStep(ModularPipelineBlocks):
model_name = "flux"
@property
@@ -305,7 +327,7 @@ class FluxPrepareLatentsStep(PipelineBlock):
@property
def description(self) -> str:
- return "Prepare latents step that prepares the latents for the text-to-video generation process"
+ return "Prepare latents step that prepares the latents for the text-to-image generation process"
@property
def inputs(self) -> List[InputParam]:
@@ -314,11 +336,6 @@ class FluxPrepareLatentsStep(PipelineBlock):
InputParam("width", type_hint=int),
InputParam("latents", type_hint=Optional[torch.Tensor]),
InputParam("num_images_per_prompt", type_hint=int, default=1),
- ]
-
- @property
- def intermediate_inputs(self) -> List[InputParam]:
- return [
InputParam("generator"),
InputParam(
"batch_size",
@@ -335,11 +352,6 @@ class FluxPrepareLatentsStep(PipelineBlock):
OutputParam(
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
),
- OutputParam(
- "latent_image_ids",
- type_hint=torch.Tensor,
- description="IDs computed from the image sequence needed for RoPE",
- ),
]
@staticmethod
@@ -363,20 +375,13 @@ class FluxPrepareLatentsStep(PipelineBlock):
generator,
latents=None,
):
- # Couldn't use the `prepare_latents` method directly from Flux because I decided to copy over
- # the packing methods here. So, for example, `comp._pack_latents()` won't work if we were
- # to go with the "# Copied from ..." approach. Or maybe there's a way?
-
- # VAE applies 8x compression on images but we must also account for packing which requires
- # latent height and width to be divisible by 2.
height = 2 * (int(height) // (comp.vae_scale_factor * 2))
width = 2 * (int(width) // (comp.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
- latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
- return latents.to(device=device, dtype=dtype), latent_image_ids
+ return latents.to(device=device, dtype=dtype)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -384,28 +389,25 @@ class FluxPrepareLatentsStep(PipelineBlock):
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
+ # TODO: move packing latents code to a patchifier similar to Qwen
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
- latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
+ latents = FluxPipeline._pack_latents(latents, batch_size, num_channels_latents, height, width)
- latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
-
- return latents, latent_image_ids
+ return latents
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
-
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
block_state.device = components._execution_device
- block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this?
block_state.num_channels_latents = components.num_channels_latents
self.check_inputs(components, block_state)
-
- block_state.latents, block_state.latent_image_ids = self.prepare_latents(
+ batch_size = block_state.batch_size * block_state.num_images_per_prompt
+ block_state.latents = self.prepare_latents(
components,
- block_state.batch_size * block_state.num_images_per_prompt,
+ batch_size,
block_state.num_channels_latents,
block_state.height,
block_state.width,
@@ -418,3 +420,200 @@ class FluxPrepareLatentsStep(PipelineBlock):
self.set_block_state(state, block_state)
return components, state
+
+
+class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ @property
+ def description(self) -> str:
+ return "Step that adds noise to image latents for image-to-image. Should be run after `set_timesteps`,"
+ " `prepare_latents`. Both noise and image latents should already be patchified."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ name="latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial random noised, can be generated in prepare latent step.",
+ ),
+ InputParam(
+ name="image_latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.",
+ ),
+ InputParam(
+ name="timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="initial_noise",
+ type_hint=torch.Tensor,
+ description="The initial random noised used for inpainting denoising.",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(image_latents, latents):
+ if image_latents.shape[0] != latents.shape[0]:
+ raise ValueError(
+ f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}"
+ )
+
+ if image_latents.ndim != 3:
+ raise ValueError(f"`image_latents` must have 3 dimensions (patchified), but got {image_latents.ndim}")
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(image_latents=block_state.image_latents, latents=block_state.latents)
+
+ # prepare latent timestep
+ latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0])
+
+ # make copy of initial_noise
+ block_state.initial_noise = block_state.latents
+
+ # scale noise
+ block_state.latents = components.scheduler.scale_noise(
+ block_state.image_latents, latent_timestep, block_state.latents
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class FluxRoPEInputsStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ @property
+ def description(self) -> str:
+ return "Step that prepares the RoPE inputs for the denoising process. Should be placed after text encoder and latent preparation steps."
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="height", required=True),
+ InputParam(name="width", required=True),
+ InputParam(name="prompt_embeds"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="txt_ids",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the prompt embeds, used for RoPE calculation.",
+ ),
+ OutputParam(
+ name="img_ids",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the image latents, used for RoPE calculation.",
+ ),
+ ]
+
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ prompt_embeds = block_state.prompt_embeds
+ device, dtype = prompt_embeds.device, prompt_embeds.dtype
+ block_state.txt_ids = torch.zeros(prompt_embeds.shape[1], 3).to(
+ device=prompt_embeds.device, dtype=prompt_embeds.dtype
+ )
+
+ height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
+ width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
+ block_state.img_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class FluxKontextRoPEInputsStep(ModularPipelineBlocks):
+ model_name = "flux-kontext"
+
+ @property
+ def description(self) -> str:
+ return "Step that prepares the RoPE inputs for the denoising process of Flux Kontext. Should be placed after text encoder and latent preparation steps."
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="image_height"),
+ InputParam(name="image_width"),
+ InputParam(name="height"),
+ InputParam(name="width"),
+ InputParam(name="prompt_embeds"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="txt_ids",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the prompt embeds, used for RoPE calculation.",
+ ),
+ OutputParam(
+ name="img_ids",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the image latents, used for RoPE calculation.",
+ ),
+ ]
+
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ prompt_embeds = block_state.prompt_embeds
+ device, dtype = prompt_embeds.device, prompt_embeds.dtype
+ block_state.txt_ids = torch.zeros(prompt_embeds.shape[1], 3).to(
+ device=prompt_embeds.device, dtype=prompt_embeds.dtype
+ )
+
+ img_ids = None
+ if (
+ getattr(block_state, "image_height", None) is not None
+ and getattr(block_state, "image_width", None) is not None
+ ):
+ image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2))
+ image_latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
+ img_ids = FluxPipeline._prepare_latent_image_ids(
+ None, image_latent_height // 2, image_latent_width // 2, device, dtype
+ )
+ # image ids are the same as latent ids with the first dimension set to 1 instead of 0
+ img_ids[..., 0] = 1
+
+ height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
+ width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
+ latent_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)
+
+ if img_ids is not None:
+ latent_ids = torch.cat([latent_ids, img_ids], dim=0)
+
+ block_state.img_ids = latent_ids
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/flux/decoders.py b/src/diffusers/modular_pipelines/flux/decoders.py
index 8d561d38c6..846549b1a3 100644
--- a/src/diffusers/modular_pipelines/flux/decoders.py
+++ b/src/diffusers/modular_pipelines/flux/decoders.py
@@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL
from ...utils import logging
from ...video_processor import VaeImageProcessor
-from ..modular_pipeline import PipelineBlock, PipelineState
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
@@ -45,7 +45,7 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
return latents
-class FluxDecodeStep(PipelineBlock):
+class FluxDecodeStep(ModularPipelineBlocks):
model_name = "flux"
@property
@@ -70,17 +70,12 @@ class FluxDecodeStep(PipelineBlock):
InputParam("output_type", default="pil"),
InputParam("height", default=1024),
InputParam("width", default=1024),
- ]
-
- @property
- def intermediate_inputs(self) -> List[str]:
- return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
- )
+ ),
]
@property
diff --git a/src/diffusers/modular_pipelines/flux/denoise.py b/src/diffusers/modular_pipelines/flux/denoise.py
index c4619c17fb..b1796bb63c 100644
--- a/src/diffusers/modular_pipelines/flux/denoise.py
+++ b/src/diffusers/modular_pipelines/flux/denoise.py
@@ -22,7 +22,7 @@ from ...utils import logging
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
- PipelineBlock,
+ ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
@@ -32,7 +32,7 @@ from .modular_pipeline import FluxModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-class FluxLoopDenoiser(PipelineBlock):
+class FluxLoopDenoiser(ModularPipelineBlocks):
model_name = "flux"
@property
@@ -49,11 +49,8 @@ class FluxLoopDenoiser(PipelineBlock):
@property
def inputs(self) -> List[Tuple[str, Any]]:
- return [InputParam("joint_attention_kwargs")]
-
- @property
- def intermediate_inputs(self) -> List[str]:
return [
+ InputParam("joint_attention_kwargs"),
InputParam(
"latents",
required=True,
@@ -79,18 +76,17 @@ class FluxLoopDenoiser(PipelineBlock):
description="Pooled prompt embeddings",
),
InputParam(
- "text_ids",
+ "txt_ids",
required=True,
type_hint=torch.Tensor,
description="IDs computed from text sequence needed for RoPE",
),
InputParam(
- "latent_image_ids",
+ "img_ids",
required=True,
type_hint=torch.Tensor,
description="IDs computed from image sequence needed for RoPE",
),
- # TODO: guidance
]
@torch.no_grad()
@@ -104,8 +100,8 @@ class FluxLoopDenoiser(PipelineBlock):
encoder_hidden_states=block_state.prompt_embeds,
pooled_projections=block_state.pooled_prompt_embeds,
joint_attention_kwargs=block_state.joint_attention_kwargs,
- txt_ids=block_state.text_ids,
- img_ids=block_state.latent_image_ids,
+ txt_ids=block_state.txt_ids,
+ img_ids=block_state.img_ids,
return_dict=False,
)[0]
block_state.noise_pred = noise_pred
@@ -113,7 +109,97 @@ class FluxLoopDenoiser(PipelineBlock):
return components, block_state
-class FluxLoopAfterDenoiser(PipelineBlock):
+class FluxKontextLoopDenoiser(ModularPipelineBlocks):
+ model_name = "flux-kontext"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [ComponentSpec("transformer", FluxTransformer2DModel)]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Step within the denoising loop that denoise the latents for Flux Kontext. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `FluxDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("joint_attention_kwargs"),
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "image_latents",
+ type_hint=torch.Tensor,
+ description="Image latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "guidance",
+ required=True,
+ type_hint=torch.Tensor,
+ description="Guidance scale as a tensor",
+ ),
+ InputParam(
+ "prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="Prompt embeddings",
+ ),
+ InputParam(
+ "pooled_prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="Pooled prompt embeddings",
+ ),
+ InputParam(
+ "txt_ids",
+ required=True,
+ type_hint=torch.Tensor,
+ description="IDs computed from text sequence needed for RoPE",
+ ),
+ InputParam(
+ "img_ids",
+ required=True,
+ type_hint=torch.Tensor,
+ description="IDs computed from latent sequence needed for RoPE",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(
+ self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
+ ) -> PipelineState:
+ latents = block_state.latents
+ latent_model_input = latents
+ image_latents = block_state.image_latents
+ if image_latents is not None:
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
+
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ noise_pred = components.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=block_state.guidance,
+ encoder_hidden_states=block_state.prompt_embeds,
+ pooled_projections=block_state.pooled_prompt_embeds,
+ joint_attention_kwargs=block_state.joint_attention_kwargs,
+ txt_ids=block_state.txt_ids,
+ img_ids=block_state.img_ids,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred[:, : latents.size(1)]
+ block_state.noise_pred = noise_pred
+
+ return components, block_state
+
+
+class FluxLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "flux"
@property
@@ -175,7 +261,7 @@ class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
]
@property
- def loop_intermediate_inputs(self) -> List[InputParam]:
+ def loop_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
@@ -198,9 +284,6 @@ class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
block_state.num_warmup_steps = max(
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
)
- # We set the index here to remove DtoH sync, helpful especially during compilation.
- # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
- components.scheduler.set_begin_index(0)
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(components, block_state, i=i, t=t)
@@ -223,8 +306,25 @@ class FluxDenoiseStep(FluxDenoiseLoopWrapper):
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n"
- "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `FluxLoopDenoiser`\n"
" - `FluxLoopAfterDenoiser`\n"
- "This block supports text2image tasks."
+ "This block supports both text2image and img2img tasks."
+ )
+
+
+class FluxKontextDenoiseStep(FluxDenoiseLoopWrapper):
+ model_name = "flux-kontext"
+ block_classes = [FluxKontextLoopDenoiser, FluxLoopAfterDenoiser]
+ block_names = ["denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
+ " - `FluxKontextLoopDenoiser`\n"
+ " - `FluxLoopAfterDenoiser`\n"
+ "This block supports both text2image and img2img tasks."
)
diff --git a/src/diffusers/modular_pipelines/flux/encoders.py b/src/diffusers/modular_pipelines/flux/encoders.py
index 9bf2f54eec..b71962bd93 100644
--- a/src/diffusers/modular_pipelines/flux/encoders.py
+++ b/src/diffusers/modular_pipelines/flux/encoders.py
@@ -19,10 +19,13 @@ import regex as re
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
+from ...configuration_utils import FrozenDict
+from ...image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
-from ..modular_pipeline import PipelineBlock, PipelineState
-from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import FluxModularPipeline
@@ -50,12 +53,245 @@ def prompt_clean(text):
return text
-class FluxTextEncoderStep(PipelineBlock):
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+def encode_vae_image(vae: AutoencoderKL, image: torch.Tensor, generator: torch.Generator, sample_mode="sample"):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode)
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode)
+
+ image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
+
+ return image_latents
+
+
+class FluxProcessImagesInputStep(ModularPipelineBlocks):
model_name = "flux"
@property
def description(self) -> str:
- return "Text Encoder step that generate text_embeddings to guide the video generation"
+ return "Image Preprocess step."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [OutputParam(name="processed_image")]
+
+ @staticmethod
+ def check_inputs(height, width, vae_scale_factor):
+ if height is not None and height % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
+
+ if width is not None and width % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ if block_state.resized_image is None and block_state.image is None:
+ raise ValueError("`resized_image` and `image` cannot be None at the same time")
+
+ if block_state.resized_image is None:
+ image = block_state.image
+ self.check_inputs(
+ height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
+ )
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
+ else:
+ width, height = block_state.resized_image[0].size
+ image = block_state.resized_image
+
+ block_state.processed_image = components.image_processor.preprocess(image=image, height=height, width=width)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
+ model_name = "flux-kontext"
+
+ def __init__(self, _auto_resize=True):
+ self._auto_resize = _auto_resize
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ return (
+ "Image preprocess step for Flux Kontext. The preprocessed image goes to the VAE.\n"
+ "Kontext works as a T2I model, too, in case no input image is provided."
+ )
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [InputParam("image")]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [OutputParam(name="processed_image")]
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState):
+ from ...pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS
+
+ block_state = self.get_block_state(state)
+ images = block_state.image
+
+ if images is None:
+ block_state.processed_image = None
+
+ else:
+ multiple_of = components.image_processor.config.vae_scale_factor
+
+ if not is_valid_image_imagelist(images):
+ raise ValueError(f"Images must be image or list of images but are {type(images)}")
+
+ if is_valid_image(images):
+ images = [images]
+
+ img = images[0]
+ image_height, image_width = components.image_processor.get_default_height_width(img)
+ aspect_ratio = image_width / image_height
+ if self._auto_resize:
+ # Kontext is trained on specific resolutions, using one of them is recommended
+ _, image_width, image_height = min(
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
+ )
+ image_width = image_width // multiple_of * multiple_of
+ image_height = image_height // multiple_of * multiple_of
+ images = components.image_processor.resize(images, image_height, image_width)
+ block_state.processed_image = components.image_processor.preprocess(images, image_height, image_width)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class FluxVaeEncoderDynamicStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ def __init__(
+ self, input_name: str = "processed_image", output_name: str = "image_latents", sample_mode: str = "sample"
+ ):
+ """Initialize a VAE encoder step for converting images to latent representations.
+
+ Both the input and output names are configurable so this block can be configured to process to different image
+ inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents").
+
+ Args:
+ input_name (str, optional): Name of the input image tensor. Defaults to "processed_image".
+ Examples: "processed_image" or "processed_control_image"
+ output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents".
+ Examples: "image_latents" or "control_image_latents"
+ sample_mode (str, optional): Sampling mode to be used.
+
+ Examples:
+ # Basic usage with default settings (includes image processor): # FluxImageVaeEncoderDynamicStep()
+
+ # Custom input/output names for control image: # FluxImageVaeEncoderDynamicStep(
+ input_name="processed_control_image", output_name="control_image_latents"
+ )
+ """
+ self._image_input_name = input_name
+ self._image_latents_output_name = output_name
+ self.sample_mode = sample_mode
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ components = [ComponentSpec("vae", AutoencoderKL)]
+ return components
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [InputParam(self._image_input_name), InputParam("generator")]
+ return inputs
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ self._image_latents_output_name,
+ type_hint=torch.Tensor,
+ description="The latents representing the reference image",
+ )
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ image = getattr(block_state, self._image_input_name)
+
+ if image is None:
+ setattr(block_state, self._image_latents_output_name, None)
+ else:
+ device = components._execution_device
+ dtype = components.vae.dtype
+ image = image.to(device=device, dtype=dtype)
+
+ # Encode image into latents
+ image_latents = encode_vae_image(
+ image=image, vae=components.vae, generator=block_state.generator, sample_mode=self.sample_mode
+ )
+ setattr(block_state, self._image_latents_output_name, image_latents)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class FluxTextEncoderStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ @property
+ def description(self) -> str:
+ return "Text Encoder step that generate text_embeddings to guide the image generation"
@property
def expected_components(self) -> List[ComponentSpec]:
@@ -66,15 +302,12 @@ class FluxTextEncoderStep(PipelineBlock):
ComponentSpec("tokenizer_2", T5TokenizerFast),
]
- @property
- def expected_configs(self) -> List[ConfigSpec]:
- return []
-
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_2"),
+ InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("joint_attention_kwargs"),
]
@@ -83,19 +316,16 @@ class FluxTextEncoderStep(PipelineBlock):
return [
OutputParam(
"prompt_embeds",
+ kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="text embeddings used to guide the image generation",
),
OutputParam(
"pooled_prompt_embeds",
+ kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="pooled text embeddings used to guide the image generation",
),
- OutputParam(
- "text_ids",
- type_hint=torch.Tensor,
- description="ids from the text sequence for RoPE",
- ),
]
@staticmethod
@@ -106,16 +336,10 @@ class FluxTextEncoderStep(PipelineBlock):
@staticmethod
def _get_t5_prompt_embeds(
- components,
- prompt: Union[str, List[str]],
- num_images_per_prompt: int,
- max_sequence_length: int,
- device: torch.device,
+ components, prompt: Union[str, List[str]], max_sequence_length: int, device: torch.device
):
dtype = components.text_encoder_2.dtype
-
prompt = [prompt] if isinstance(prompt, str) else prompt
- batch_size = len(prompt)
if isinstance(components, TextualInversionLoaderMixin):
prompt = components.maybe_convert_prompt(prompt, components.tokenizer_2)
@@ -141,23 +365,11 @@ class FluxTextEncoderStep(PipelineBlock):
prompt_embeds = components.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
- _, seq_len, _ = prompt_embeds.shape
-
- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
-
return prompt_embeds
@staticmethod
- def _get_clip_prompt_embeds(
- components,
- prompt: Union[str, List[str]],
- num_images_per_prompt: int,
- device: torch.device,
- ):
+ def _get_clip_prompt_embeds(components, prompt: Union[str, List[str]], device: torch.device):
prompt = [prompt] if isinstance(prompt, str) else prompt
- batch_size = len(prompt)
if isinstance(components, TextualInversionLoaderMixin):
prompt = components.maybe_convert_prompt(prompt, components.tokenizer)
@@ -187,10 +399,6 @@ class FluxTextEncoderStep(PipelineBlock):
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=components.text_encoder.dtype, device=device)
- # duplicate text embeddings for each generation per prompt, using mps friendly method
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
-
return prompt_embeds
@staticmethod
@@ -199,34 +407,11 @@ class FluxTextEncoderStep(PipelineBlock):
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
device: Optional[torch.device] = None,
- num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
):
- r"""
- Encodes the prompt into text encoder hidden states.
-
- Args:
- prompt (`str` or `List[str]`, *optional*):
- prompt to be encoded
- prompt_2 (`str` or `List[str]`, *optional*):
- The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
- used in all text-encoders
- device: (`torch.device`):
- torch device
- num_images_per_prompt (`int`):
- number of images that should be generated per prompt
- prompt_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
- provided, text embeddings will be generated from `prompt` input argument.
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
- lora_scale (`float`, *optional*):
- A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
- """
device = device or components._execution_device
# set lora scale so that monkey patched LoRA
@@ -251,12 +436,10 @@ class FluxTextEncoderStep(PipelineBlock):
components,
prompt=prompt,
device=device,
- num_images_per_prompt=num_images_per_prompt,
)
prompt_embeds = FluxTextEncoderStep._get_t5_prompt_embeds(
components,
prompt=prompt_2,
- num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
@@ -271,10 +454,7 @@ class FluxTextEncoderStep(PipelineBlock):
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder_2, lora_scale)
- dtype = components.text_encoder.dtype if components.text_encoder is not None else torch.bfloat16
- text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
-
- return prompt_embeds, pooled_prompt_embeds, text_ids
+ return prompt_embeds, pooled_prompt_embeds
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
@@ -290,14 +470,14 @@ class FluxTextEncoderStep(PipelineBlock):
if block_state.joint_attention_kwargs is not None
else None
)
- (block_state.prompt_embeds, block_state.pooled_prompt_embeds, block_state.text_ids) = self.encode_prompt(
+ block_state.prompt_embeds, block_state.pooled_prompt_embeds = self.encode_prompt(
components,
prompt=block_state.prompt,
prompt_2=None,
prompt_embeds=None,
pooled_prompt_embeds=None,
device=block_state.device,
- num_images_per_prompt=1, # hardcoded for now.
+ max_sequence_length=block_state.max_sequence_length,
lora_scale=block_state.text_encoder_lora_scale,
)
diff --git a/src/diffusers/modular_pipelines/flux/inputs.py b/src/diffusers/modular_pipelines/flux/inputs.py
new file mode 100644
index 0000000000..e1bc17f5ff
--- /dev/null
+++ b/src/diffusers/modular_pipelines/flux/inputs.py
@@ -0,0 +1,359 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import torch
+
+from ...pipelines import FluxPipeline
+from ...utils import logging
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import InputParam, OutputParam
+
+# TODO: consider making these common utilities for modular if they are not pipeline-specific.
+from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size
+from .modular_pipeline import FluxModularPipeline
+
+
+logger = logging.get_logger(__name__)
+
+
+class FluxTextInputStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Text input processing step that standardizes text embeddings for the pipeline.\n"
+ "This step:\n"
+ " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
+ " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("num_images_per_prompt", default=1),
+ InputParam(
+ "prompt_embeds",
+ required=True,
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="Pre-generated text embeddings. Can be generated from text_encoder step.",
+ ),
+ InputParam(
+ "pooled_prompt_embeds",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.",
+ ),
+ # TODO: support negative embeddings?
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "batch_size",
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
+ ),
+ OutputParam(
+ "dtype",
+ type_hint=torch.dtype,
+ description="Data type of model tensor inputs (determined by `prompt_embeds`)",
+ ),
+ OutputParam(
+ "prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields",
+ description="text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "pooled_prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields",
+ description="pooled text embeddings used to guide the image generation",
+ ),
+ # TODO: support negative embeddings?
+ ]
+
+ def check_inputs(self, components, block_state):
+ if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is not None:
+ if block_state.prompt_embeds.shape[0] != block_state.pooled_prompt_embeds.shape[0]:
+ raise ValueError(
+ "`prompt_embeds` and `pooled_prompt_embeds` must have the same batch size when passed directly, but"
+ f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `pooled_prompt_embeds`"
+ f" {block_state.pooled_prompt_embeds.shape}."
+ )
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ # TODO: consider adding negative embeddings?
+ block_state = self.get_block_state(state)
+ self.check_inputs(components, block_state)
+
+ block_state.batch_size = block_state.prompt_embeds.shape[0]
+ block_state.dtype = block_state.prompt_embeds.dtype
+
+ _, seq_len, _ = block_state.prompt_embeds.shape
+ block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
+ block_state.prompt_embeds = block_state.prompt_embeds.view(
+ block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
+ )
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+# Adapted from `QwenImageInputsDynamicStep`
+class FluxInputsDynamicStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ def __init__(
+ self,
+ image_latent_inputs: List[str] = ["image_latents"],
+ additional_batch_inputs: List[str] = [],
+ ):
+ if not isinstance(image_latent_inputs, list):
+ image_latent_inputs = [image_latent_inputs]
+ if not isinstance(additional_batch_inputs, list):
+ additional_batch_inputs = [additional_batch_inputs]
+
+ self._image_latent_inputs = image_latent_inputs
+ self._additional_batch_inputs = additional_batch_inputs
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ # Functionality section
+ summary_section = (
+ "Input processing step that:\n"
+ " 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n"
+ " 2. For additional batch inputs: Expands batch dimensions to match final batch size"
+ )
+
+ # Inputs info
+ inputs_info = ""
+ if self._image_latent_inputs or self._additional_batch_inputs:
+ inputs_info = "\n\nConfigured inputs:"
+ if self._image_latent_inputs:
+ inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
+ if self._additional_batch_inputs:
+ inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
+
+ # Placement guidance
+ placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
+
+ return summary_section + inputs_info + placement_section
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [
+ InputParam(name="num_images_per_prompt", default=1),
+ InputParam(name="batch_size", required=True),
+ InputParam(name="height"),
+ InputParam(name="width"),
+ ]
+
+ # Add image latent inputs
+ for image_latent_input_name in self._image_latent_inputs:
+ inputs.append(InputParam(name=image_latent_input_name))
+
+ # Add additional batch inputs
+ for input_name in self._additional_batch_inputs:
+ inputs.append(InputParam(name=input_name))
+
+ return inputs
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(name="image_height", type_hint=int, description="The height of the image latents"),
+ OutputParam(name="image_width", type_hint=int, description="The width of the image latents"),
+ ]
+
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # Process image latent inputs (height/width calculation, patchify, and batch expansion)
+ for image_latent_input_name in self._image_latent_inputs:
+ image_latent_tensor = getattr(block_state, image_latent_input_name)
+ if image_latent_tensor is None:
+ continue
+
+ # 1. Calculate height/width from latents
+ height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
+ block_state.height = block_state.height or height
+ block_state.width = block_state.width or width
+
+ if not hasattr(block_state, "image_height"):
+ block_state.image_height = height
+ if not hasattr(block_state, "image_width"):
+ block_state.image_width = width
+
+ # 2. Patchify the image latent tensor
+ # TODO: Implement patchifier for Flux.
+ latent_height, latent_width = image_latent_tensor.shape[2:]
+ image_latent_tensor = FluxPipeline._pack_latents(
+ image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width
+ )
+
+ # 3. Expand batch size
+ image_latent_tensor = repeat_tensor_to_batch_size(
+ input_name=image_latent_input_name,
+ input_tensor=image_latent_tensor,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, image_latent_input_name, image_latent_tensor)
+
+ # Process additional batch inputs (only batch expansion)
+ for input_name in self._additional_batch_inputs:
+ input_tensor = getattr(block_state, input_name)
+ if input_tensor is None:
+ continue
+
+ # Only expand batch size
+ input_tensor = repeat_tensor_to_batch_size(
+ input_name=input_name,
+ input_tensor=input_tensor,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, input_name, input_tensor)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class FluxKontextInputsDynamicStep(FluxInputsDynamicStep):
+ model_name = "flux-kontext"
+
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # Process image latent inputs (height/width calculation, patchify, and batch expansion)
+ for image_latent_input_name in self._image_latent_inputs:
+ image_latent_tensor = getattr(block_state, image_latent_input_name)
+ if image_latent_tensor is None:
+ continue
+
+ # 1. Calculate height/width from latents
+ # Unlike the `FluxInputsDynamicStep`, we don't overwrite the `block.height` and `block.width`
+ height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
+ if not hasattr(block_state, "image_height"):
+ block_state.image_height = height
+ if not hasattr(block_state, "image_width"):
+ block_state.image_width = width
+
+ # 2. Patchify the image latent tensor
+ # TODO: Implement patchifier for Flux.
+ latent_height, latent_width = image_latent_tensor.shape[2:]
+ image_latent_tensor = FluxPipeline._pack_latents(
+ image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width
+ )
+
+ # 3. Expand batch size
+ image_latent_tensor = repeat_tensor_to_batch_size(
+ input_name=image_latent_input_name,
+ input_tensor=image_latent_tensor,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, image_latent_input_name, image_latent_tensor)
+
+ # Process additional batch inputs (only batch expansion)
+ for input_name in self._additional_batch_inputs:
+ input_tensor = getattr(block_state, input_name)
+ if input_tensor is None:
+ continue
+
+ # Only expand batch size
+ input_tensor = repeat_tensor_to_batch_size(
+ input_name=input_name,
+ input_tensor=input_tensor,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, input_name, input_tensor)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class FluxKontextSetResolutionStep(ModularPipelineBlocks):
+ model_name = "flux-kontext"
+
+ def description(self):
+ return (
+ "Determines the height and width to be used during the subsequent computations.\n"
+ "It should always be placed _before_ the latent preparation step."
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [
+ InputParam(name="height"),
+ InputParam(name="width"),
+ InputParam(name="max_area", type_hint=int, default=1024**2),
+ ]
+ return inputs
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(name="height", type_hint=int, description="The height of the initial noisy latents"),
+ OutputParam(name="width", type_hint=int, description="The width of the initial noisy latents"),
+ ]
+
+ @staticmethod
+ def check_inputs(height, width, vae_scale_factor):
+ if height is not None and height % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
+
+ if width is not None and width % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
+
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
+ self.check_inputs(height, width, components.vae_scale_factor)
+
+ original_height, original_width = height, width
+ max_area = block_state.max_area
+ aspect_ratio = width / height
+ width = round((max_area * aspect_ratio) ** 0.5)
+ height = round((max_area / aspect_ratio) ** 0.5)
+
+ multiple_of = components.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
+
+ if height != original_height or width != original_width:
+ logger.warning(
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
+ )
+
+ block_state.height = height
+ block_state.width = width
+
+ self.set_block_state(state, block_state)
+ return components, state
diff --git a/src/diffusers/modular_pipelines/flux/modular_blocks.py b/src/diffusers/modular_pipelines/flux/modular_blocks.py
index b170673037..a80bc2a5f7 100644
--- a/src/diffusers/modular_pipelines/flux/modular_blocks.py
+++ b/src/diffusers/modular_pipelines/flux/modular_blocks.py
@@ -15,40 +15,143 @@
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
-from .before_denoise import FluxInputStep, FluxPrepareLatentsStep, FluxSetTimestepsStep
+from .before_denoise import (
+ FluxImg2ImgPrepareLatentsStep,
+ FluxImg2ImgSetTimestepsStep,
+ FluxKontextRoPEInputsStep,
+ FluxPrepareLatentsStep,
+ FluxRoPEInputsStep,
+ FluxSetTimestepsStep,
+)
from .decoders import FluxDecodeStep
-from .denoise import FluxDenoiseStep
-from .encoders import FluxTextEncoderStep
+from .denoise import FluxDenoiseStep, FluxKontextDenoiseStep
+from .encoders import (
+ FluxKontextProcessImagesInputStep,
+ FluxProcessImagesInputStep,
+ FluxTextEncoderStep,
+ FluxVaeEncoderDynamicStep,
+)
+from .inputs import (
+ FluxInputsDynamicStep,
+ FluxKontextInputsDynamicStep,
+ FluxKontextSetResolutionStep,
+ FluxTextInputStep,
+)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-# before_denoise: text2vid
-class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
- block_classes = [
- FluxInputStep,
- FluxPrepareLatentsStep,
- FluxSetTimestepsStep,
- ]
- block_names = ["input", "prepare_latents", "set_timesteps"]
+# vae encoder (run before before_denoise)
+FluxImg2ImgVaeEncoderBlocks = InsertableDict(
+ [("preprocess", FluxProcessImagesInputStep()), ("encode", FluxVaeEncoderDynamicStep())]
+)
+
+
+class FluxImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "flux"
+
+ block_classes = FluxImg2ImgVaeEncoderBlocks.values()
+ block_names = FluxImg2ImgVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
+
+
+class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [FluxImg2ImgVaeEncoderStep]
+ block_names = ["img2img"]
+ block_trigger_inputs = ["image"]
@property
def description(self):
return (
- "Before denoise step that prepare the inputs for the denoise step.\n"
- + "This is a sequential pipeline blocks:\n"
- + " - `FluxInputStep` is used to adjust the batch size of the model inputs\n"
- + " - `FluxPrepareLatentsStep` is used to prepare the latents\n"
- + " - `FluxSetTimestepsStep` is used to set the timesteps\n"
+ "Vae encoder step that encode the image inputs into their latent representations.\n"
+ + "This is an auto pipeline block that works for img2img tasks.\n"
+ + " - `FluxImg2ImgVaeEncoderStep` (img2img) is used when only `image` is provided."
+ + " - if `image` is not provided, step will be skipped."
)
-# before_denoise: all task (text2vid,)
+# Flux Kontext vae encoder (run before before_denoise)
+
+FluxKontextVaeEncoderBlocks = InsertableDict(
+ [("preprocess", FluxKontextProcessImagesInputStep()), ("encode", FluxVaeEncoderDynamicStep(sample_mode="argmax"))]
+)
+
+
+class FluxKontextVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "flux-kontext"
+
+ block_classes = FluxKontextVaeEncoderBlocks.values()
+ block_names = FluxKontextVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
+
+
+class FluxKontextAutoVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [FluxKontextVaeEncoderStep]
+ block_names = ["img2img"]
+ block_trigger_inputs = ["image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae encoder step that encode the image inputs into their latent representations.\n"
+ + "This is an auto pipeline block that works for img2img tasks.\n"
+ + " - `FluxKontextVaeEncoderStep` (img2img) is used when only `image` is provided."
+ + " - if `image` is not provided, step will be skipped."
+ )
+
+
+# before_denoise: text2img
+FluxBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", FluxPrepareLatentsStep()),
+ ("set_timesteps", FluxSetTimestepsStep()),
+ ("prepare_rope_inputs", FluxRoPEInputsStep()),
+ ]
+)
+
+
+class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
+ block_classes = FluxBeforeDenoiseBlocks.values()
+ block_names = FluxBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return "Before denoise step that prepares the inputs for the denoise step in text-to-image generation."
+
+
+# before_denoise: img2img
+FluxImg2ImgBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", FluxPrepareLatentsStep()),
+ ("set_timesteps", FluxImg2ImgSetTimestepsStep()),
+ ("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()),
+ ("prepare_rope_inputs", FluxRoPEInputsStep()),
+ ]
+)
+
+
+class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
+ block_classes = FluxImg2ImgBeforeDenoiseBlocks.values()
+ block_names = FluxImg2ImgBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs for the denoise step for img2img task."
+
+
+# before_denoise: all task (text2img, img2img)
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
- block_classes = [FluxBeforeDenoiseStep]
- block_names = ["text2image"]
- block_trigger_inputs = [None]
+ model_name = "flux-kontext"
+ block_classes = [FluxImg2ImgBeforeDenoiseStep, FluxBeforeDenoiseStep]
+ block_names = ["img2img", "text2image"]
+ block_trigger_inputs = ["image_latents", None]
@property
def description(self):
@@ -56,6 +159,45 @@ class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is an auto pipeline block that works for text2image.\n"
+ " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
+ + " - `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
+ )
+
+
+# before_denoise: FluxKontext
+
+FluxKontextBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", FluxPrepareLatentsStep()),
+ ("set_timesteps", FluxSetTimestepsStep()),
+ ("prepare_rope_inputs", FluxKontextRoPEInputsStep()),
+ ]
+)
+
+
+class FluxKontextBeforeDenoiseStep(SequentialPipelineBlocks):
+ block_classes = FluxKontextBeforeDenoiseBlocks.values()
+ block_names = FluxKontextBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs for the denoise step\n"
+ "for img2img/text2img task for Flux Kontext."
+ )
+
+
+class FluxKontextAutoBeforeDenoiseStep(AutoPipelineBlocks):
+ block_classes = [FluxKontextBeforeDenoiseStep, FluxBeforeDenoiseStep]
+ block_names = ["img2img", "text2image"]
+ block_trigger_inputs = ["image_latents", None]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs for the denoise step.\n"
+ + "This is an auto pipeline block that works for text2image.\n"
+ + " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
+ + " - `FluxKontextBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
)
@@ -69,12 +211,29 @@ class FluxAutoDenoiseStep(AutoPipelineBlocks):
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
- "This is a auto pipeline block that works for text2image tasks."
- " - `FluxDenoiseStep` (denoise) for text2image tasks."
+ "This is a auto pipeline block that works for text2image and img2img tasks."
+ " - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
)
-# decode: all task (text2img, img2img, inpainting)
+# denoise: Flux Kontext
+
+
+class FluxKontextAutoDenoiseStep(AutoPipelineBlocks):
+ block_classes = [FluxKontextDenoiseStep]
+ block_names = ["denoise"]
+ block_trigger_inputs = [None]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents for Flux Kontext. "
+ "This is a auto pipeline block that works for text2image and img2img tasks."
+ " - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
+ )
+
+
+# decode: all task (text2img, img2img)
class FluxAutoDecodeStep(AutoPipelineBlocks):
block_classes = [FluxDecodeStep]
block_names = ["non-inpaint"]
@@ -82,44 +241,206 @@ class FluxAutoDecodeStep(AutoPipelineBlocks):
@property
def description(self):
- return "Decode step that decode the denoised latents into videos outputs.\n - `FluxDecodeStep`"
+ return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`"
-# text2image
-class FluxAutoBlocks(SequentialPipelineBlocks):
- block_classes = [FluxTextEncoderStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep, FluxAutoDecodeStep]
- block_names = ["text_encoder", "before_denoise", "denoise", "decoder"]
+# inputs: text2image/img2img
+FluxImg2ImgBlocks = InsertableDict(
+ [("text_inputs", FluxTextInputStep()), ("additional_inputs", FluxInputsDynamicStep())]
+)
+
+
+class FluxImg2ImgInputStep(SequentialPipelineBlocks):
+ model_name = "flux"
+ block_classes = FluxImg2ImgBlocks.values()
+ block_names = FluxImg2ImgBlocks.keys()
+
+ @property
+ def description(self):
+ return "Input step that prepares the inputs for the img2img denoising step. It:\n"
+ " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
+ " - update height/width based `image_latents`, patchify `image_latents`."
+
+
+class FluxAutoInputStep(AutoPipelineBlocks):
+ block_classes = [FluxImg2ImgInputStep, FluxTextInputStep]
+ block_names = ["img2img", "text2image"]
+ block_trigger_inputs = ["image_latents", None]
@property
def description(self):
return (
- "Auto Modular pipeline for text-to-image using Flux.\n"
- + "- for text-to-image generation, all you need to provide is `prompt`"
+ "Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
+ " This is an auto pipeline block that works for text2image/img2img tasks.\n"
+ + " - `FluxImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n"
+ + " - `FluxTextInputStep` (text2image) is used when `image_latents` are not provided.\n"
)
+# inputs: Flux Kontext
+
+FluxKontextBlocks = InsertableDict(
+ [
+ ("set_resolution", FluxKontextSetResolutionStep()),
+ ("text_inputs", FluxTextInputStep()),
+ ("additional_inputs", FluxKontextInputsDynamicStep()),
+ ]
+)
+
+
+class FluxKontextInputStep(SequentialPipelineBlocks):
+ model_name = "flux-kontext"
+ block_classes = FluxKontextBlocks.values()
+ block_names = FluxKontextBlocks.keys()
+
+ @property
+ def description(self):
+ return (
+ "Input step that prepares the inputs for the both text2img and img2img denoising step. It:\n"
+ " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
+ " - update height/width based `image_latents`, patchify `image_latents`."
+ )
+
+
+class FluxKontextAutoInputStep(AutoPipelineBlocks):
+ block_classes = [FluxKontextInputStep, FluxTextInputStep]
+ # block_classes = [FluxKontextInputStep]
+ block_names = ["img2img", "text2img"]
+ # block_names = ["img2img"]
+ block_trigger_inputs = ["image_latents", None]
+ # block_trigger_inputs = ["image_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
+ " This is an auto pipeline block that works for text2image/img2img tasks.\n"
+ + " - `FluxKontextInputStep` (img2img) is used when `image_latents` is provided.\n"
+ + " - `FluxKontextInputStep` is also capable of handling text2image task when `image_latent` isn't present."
+ )
+
+
+class FluxCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "flux"
+ block_classes = [FluxAutoInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
+ block_names = ["input", "before_denoise", "denoise"]
+
+ @property
+ def description(self):
+ return (
+ "Core step that performs the denoising process. \n"
+ + " - `FluxAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ + " - `FluxAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ + " - `FluxAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ + "This step supports text-to-image and image-to-image tasks for Flux:\n"
+ + " - for image-to-image generation, you need to provide `image_latents`\n"
+ + " - for text-to-image generation, all you need to provide is prompt embeddings."
+ )
+
+
+class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "flux-kontext"
+ block_classes = [FluxKontextAutoInputStep, FluxKontextAutoBeforeDenoiseStep, FluxKontextAutoDenoiseStep]
+ block_names = ["input", "before_denoise", "denoise"]
+
+ @property
+ def description(self):
+ return (
+ "Core step that performs the denoising process. \n"
+ + " - `FluxKontextAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ + " - `FluxKontextAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ + " - `FluxKontextAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ + "This step supports text-to-image and image-to-image tasks for Flux:\n"
+ + " - for image-to-image generation, you need to provide `image_latents`\n"
+ + " - for text-to-image generation, all you need to provide is prompt embeddings."
+ )
+
+
+# Auto blocks (text2image and img2img)
+AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", FluxTextEncoderStep()),
+ ("image_encoder", FluxAutoVaeEncoderStep()),
+ ("denoise", FluxCoreDenoiseStep()),
+ ("decode", FluxDecodeStep()),
+ ]
+)
+
+AUTO_BLOCKS_KONTEXT = InsertableDict(
+ [
+ ("text_encoder", FluxTextEncoderStep()),
+ ("image_encoder", FluxKontextAutoVaeEncoderStep()),
+ ("denoise", FluxKontextCoreDenoiseStep()),
+ ("decode", FluxDecodeStep()),
+ ]
+)
+
+
+class FluxAutoBlocks(SequentialPipelineBlocks):
+ model_name = "flux"
+
+ block_classes = AUTO_BLOCKS.values()
+ block_names = AUTO_BLOCKS.keys()
+
+ @property
+ def description(self):
+ return (
+ "Auto Modular pipeline for text-to-image and image-to-image using Flux.\n"
+ + "- for text-to-image generation, all you need to provide is `prompt`\n"
+ + "- for image-to-image generation, you need to provide either `image` or `image_latents`"
+ )
+
+
+class FluxKontextAutoBlocks(FluxAutoBlocks):
+ model_name = "flux-kontext"
+
+ block_classes = AUTO_BLOCKS_KONTEXT.values()
+ block_names = AUTO_BLOCKS_KONTEXT.keys()
+
+
TEXT2IMAGE_BLOCKS = InsertableDict(
[
- ("text_encoder", FluxTextEncoderStep),
- ("input", FluxInputStep),
- ("prepare_latents", FluxPrepareLatentsStep),
- # Setting it after preparation of latents because we rely on `latents`
- # to calculate `img_seq_len` for `shift`.
- ("set_timesteps", FluxSetTimestepsStep),
- ("denoise", FluxDenoiseStep),
- ("decode", FluxDecodeStep),
+ ("text_encoder", FluxTextEncoderStep()),
+ ("input", FluxTextInputStep()),
+ ("prepare_latents", FluxPrepareLatentsStep()),
+ ("set_timesteps", FluxSetTimestepsStep()),
+ ("prepare_rope_inputs", FluxRoPEInputsStep()),
+ ("denoise", FluxDenoiseStep()),
+ ("decode", FluxDecodeStep()),
]
)
-
-AUTO_BLOCKS = InsertableDict(
+IMAGE2IMAGE_BLOCKS = InsertableDict(
[
- ("text_encoder", FluxTextEncoderStep),
- ("before_denoise", FluxAutoBeforeDenoiseStep),
- ("denoise", FluxAutoDenoiseStep),
- ("decode", FluxAutoDecodeStep),
+ ("text_encoder", FluxTextEncoderStep()),
+ ("vae_encoder", FluxVaeEncoderDynamicStep()),
+ ("input", FluxImg2ImgInputStep()),
+ ("prepare_latents", FluxPrepareLatentsStep()),
+ ("set_timesteps", FluxImg2ImgSetTimestepsStep()),
+ ("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()),
+ ("prepare_rope_inputs", FluxRoPEInputsStep()),
+ ("denoise", FluxDenoiseStep()),
+ ("decode", FluxDecodeStep()),
]
)
+FLUX_KONTEXT_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", FluxTextEncoderStep()),
+ ("vae_encoder", FluxVaeEncoderDynamicStep(sample_mode="argmax")),
+ ("input", FluxKontextInputStep()),
+ ("prepare_latents", FluxPrepareLatentsStep()),
+ ("set_timesteps", FluxSetTimestepsStep()),
+ ("prepare_rope_inputs", FluxKontextRoPEInputsStep()),
+ ("denoise", FluxKontextDenoiseStep()),
+ ("decode", FluxDecodeStep()),
+ ]
+)
-ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
+ALL_BLOCKS = {
+ "text2image": TEXT2IMAGE_BLOCKS,
+ "img2img": IMAGE2IMAGE_BLOCKS,
+ "auto": AUTO_BLOCKS,
+ "auto_kontext": AUTO_BLOCKS_KONTEXT,
+ "kontext": FLUX_KONTEXT_BLOCKS,
+}
diff --git a/src/diffusers/modular_pipelines/flux/modular_pipeline.py b/src/diffusers/modular_pipelines/flux/modular_pipeline.py
index 3cd5df0c70..d8158f5d4f 100644
--- a/src/diffusers/modular_pipelines/flux/modular_pipeline.py
+++ b/src/diffusers/modular_pipelines/flux/modular_pipeline.py
@@ -13,7 +13,7 @@
# limitations under the License.
-from ...loaders import FluxLoraLoaderMixin
+from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...utils import logging
from ..modular_pipeline import ModularPipeline
@@ -21,17 +21,15 @@ from ..modular_pipeline import ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin):
+class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin, TextualInversionLoaderMixin):
"""
A ModularPipeline for Flux.
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
+ default_blocks_name = "FluxAutoBlocks"
+
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@@ -57,3 +55,13 @@ class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin):
if getattr(self, "transformer", None):
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents
+
+
+class FluxKontextModularPipeline(FluxModularPipeline):
+ """
+ A ModularPipeline for Flux Kontext.
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+ """
+
+ default_blocks_name = "FluxKontextAutoBlocks"
diff --git a/src/diffusers/modular_pipelines/mellon_node_utils.py b/src/diffusers/modular_pipelines/mellon_node_utils.py
new file mode 100644
index 0000000000..a405aebee2
--- /dev/null
+++ b/src/diffusers/modular_pipelines/mellon_node_utils.py
@@ -0,0 +1,763 @@
+import json
+import logging
+import os
+
+# Simple typed wrapper for parameter overrides
+from dataclasses import asdict, dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from huggingface_hub import create_repo, hf_hub_download
+from huggingface_hub.utils import (
+ EntryNotFoundError,
+ HfHubHTTPError,
+ RepositoryNotFoundError,
+ RevisionNotFoundError,
+ validate_hf_hub_args,
+)
+
+from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, PushToHubMixin, extract_commit_hash
+from .modular_pipeline import ModularPipelineBlocks
+
+
+logger = logging.getLogger(__name__)
+
+
+SUPPORTED_NODE_TYPES = {"controlnet", "vae_encoder", "denoise", "text_encoder", "decoder"}
+
+
+# Mellon Input Parameters (runtime parameters, not models)
+MELLON_INPUT_PARAMS = {
+ # controlnet
+ "control_image": {
+ "label": "Control Image",
+ "type": "image",
+ "display": "input",
+ },
+ "controlnet_conditioning_scale": {
+ "label": "Scale",
+ "type": "float",
+ "default": 0.5,
+ "min": 0,
+ "max": 1,
+ },
+ "control_guidance_end": {
+ "label": "End",
+ "type": "float",
+ "default": 1.0,
+ "min": 0,
+ "max": 1,
+ },
+ "control_guidance_start": {
+ "label": "Start",
+ "type": "float",
+ "default": 0.0,
+ "min": 0,
+ "max": 1,
+ },
+ "controlnet": {
+ "label": "Controlnet",
+ "type": "custom_controlnet",
+ "display": "input",
+ },
+ "embeddings": {
+ "label": "Text Embeddings",
+ "display": "input",
+ "type": "embeddings",
+ },
+ "image": {
+ "label": "Image",
+ "type": "image",
+ "display": "input",
+ },
+ "negative_prompt": {
+ "label": "Negative Prompt",
+ "type": "string",
+ "default": "",
+ "display": "textarea",
+ },
+ "prompt": {
+ "label": "Prompt",
+ "type": "string",
+ "default": "",
+ "display": "textarea",
+ },
+ "guidance_scale": {
+ "label": "Guidance Scale",
+ "type": "float",
+ "display": "slider",
+ "default": 5,
+ "min": 1.0,
+ "max": 30.0,
+ "step": 0.1,
+ },
+ "height": {
+ "label": "Height",
+ "type": "int",
+ "default": 1024,
+ "min": 64,
+ "step": 8,
+ },
+ "image_latents": {
+ "label": "Image Latents",
+ "type": "latents",
+ "display": "input",
+ "onChange": {False: ["height", "width"], True: ["strength"]},
+ },
+ "latents": {
+ "label": "Latents",
+ "type": "latents",
+ "display": "input",
+ },
+ "num_inference_steps": {
+ "label": "Steps",
+ "type": "int",
+ "display": "slider",
+ "default": 25,
+ "min": 1,
+ "max": 100,
+ },
+ "seed": {
+ "label": "Seed",
+ "type": "int",
+ "display": "random",
+ "default": 0,
+ "min": 0,
+ "max": 4294967295,
+ },
+ "strength": {
+ "label": "Strength",
+ "type": "float",
+ "default": 0.5,
+ "min": 0.0,
+ "max": 1.0,
+ "step": 0.01,
+ },
+ "width": {
+ "label": "Width",
+ "type": "int",
+ "default": 1024,
+ "min": 64,
+ "step": 8,
+ },
+ "ip_adapter": {
+ "label": "IP Adapter",
+ "type": "custom_ip_adapter",
+ "display": "input",
+ },
+}
+
+# Mellon Model Parameters (diffusers_auto_model types)
+MELLON_MODEL_PARAMS = {
+ "scheduler": {
+ "label": "Scheduler",
+ "display": "input",
+ "type": "diffusers_auto_model",
+ },
+ "text_encoders": {
+ "label": "Text Encoders",
+ "type": "diffusers_auto_models",
+ "display": "input",
+ },
+ "unet": {
+ "label": "Unet",
+ "display": "input",
+ "type": "diffusers_auto_model",
+ "onSignal": {
+ "action": "signal",
+ "target": "guider",
+ },
+ },
+ "guider": {
+ "label": "Guider",
+ "display": "input",
+ "type": "custom_guider",
+ "onChange": {False: ["guidance_scale"], True: []},
+ },
+ "vae": {
+ "label": "VAE",
+ "display": "input",
+ "type": "diffusers_auto_model",
+ },
+ "controlnet": {
+ "label": "Controlnet Model",
+ "type": "diffusers_auto_model",
+ "display": "input",
+ },
+}
+
+# Mellon Output Parameters (display = "output")
+MELLON_OUTPUT_PARAMS = {
+ "embeddings": {
+ "label": "Text Embeddings",
+ "display": "output",
+ "type": "embeddings",
+ },
+ "images": {
+ "label": "Images",
+ "type": "image",
+ "display": "output",
+ },
+ "image_latents": {
+ "label": "Image Latents",
+ "type": "latents",
+ "display": "output",
+ },
+ "latents": {
+ "label": "Latents",
+ "type": "latents",
+ "display": "output",
+ },
+ "latents_preview": {
+ "label": "Latents Preview",
+ "display": "output",
+ "type": "latent",
+ },
+ "controlnet_out": {
+ "label": "Controlnet",
+ "display": "output",
+ "type": "controlnet",
+ },
+}
+
+
+# Default param selections per supported node_type
+# from MELLON_INPUT_PARAMS / MELLON_MODEL_PARAMS / MELLON_OUTPUT_PARAMS.
+NODE_TYPE_PARAMS_MAP = {
+ "controlnet": {
+ "inputs": [
+ "control_image",
+ "controlnet_conditioning_scale",
+ "control_guidance_start",
+ "control_guidance_end",
+ "height",
+ "width",
+ ],
+ "model_inputs": [
+ "controlnet",
+ "vae",
+ ],
+ "outputs": [
+ "controlnet",
+ ],
+ "block_names": ["controlnet_vae_encoder"],
+ },
+ "denoise": {
+ "inputs": [
+ "embeddings",
+ "width",
+ "height",
+ "seed",
+ "num_inference_steps",
+ "guidance_scale",
+ "image_latents",
+ "strength",
+ # custom adapters coming in as inputs
+ "controlnet",
+ # ip_adapter is optional and custom; include if available
+ "ip_adapter",
+ ],
+ "model_inputs": [
+ "unet",
+ "guider",
+ "scheduler",
+ ],
+ "outputs": [
+ "latents",
+ "latents_preview",
+ ],
+ "block_names": ["denoise"],
+ },
+ "vae_encoder": {
+ "inputs": [
+ "image",
+ "width",
+ "height",
+ ],
+ "model_inputs": [
+ "vae",
+ ],
+ "outputs": [
+ "image_latents",
+ ],
+ "block_names": ["vae_encoder"],
+ },
+ "text_encoder": {
+ "inputs": [
+ "prompt",
+ "negative_prompt",
+ # optional image prompt input supported in embeddings node
+ "image",
+ ],
+ "model_inputs": [
+ "text_encoders",
+ ],
+ "outputs": [
+ "embeddings",
+ ],
+ "block_names": ["text_encoder"],
+ },
+ "decoder": {
+ "inputs": [
+ "latents",
+ ],
+ "model_inputs": [
+ "vae",
+ ],
+ "outputs": [
+ "images",
+ ],
+ "block_names": ["decode"],
+ },
+}
+
+
+@dataclass(frozen=True)
+class MellonParam:
+ name: str
+ label: str
+ type: str
+ display: Optional[str] = None
+ default: Any = None
+ min: Optional[float] = None
+ max: Optional[float] = None
+ step: Optional[float] = None
+ options: Any = None
+ value: Any = None
+ fieldOptions: Optional[Dict[str, Any]] = None
+ onChange: Any = None
+ onSignal: Any = None
+ _map_to_input: Any = None # the block input name this parameter maps to
+
+ def to_dict(self) -> Dict[str, Any]:
+ data = asdict(self)
+ return {k: v for k, v in data.items() if not k.startswith("_") and v is not None}
+
+
+@dataclass
+class MellonNodeConfig(PushToHubMixin):
+ """
+ A MellonNodeConfig is a base class to build Mellon nodes UI with modular diffusers.
+
+
+
+ This is an experimental feature and is likely to change in the future.
+
+
+ """
+
+ inputs: List[Union[str, MellonParam]]
+ model_inputs: List[Union[str, MellonParam]]
+ outputs: List[Union[str, MellonParam]]
+ blocks_names: list[str]
+ node_type: str
+ config_name = "mellon_config.json"
+
+ def __post_init__(self):
+ if isinstance(self.inputs, list):
+ self.inputs = self._resolve_params_list(self.inputs, MELLON_INPUT_PARAMS)
+ if isinstance(self.model_inputs, list):
+ self.model_inputs = self._resolve_params_list(self.model_inputs, MELLON_MODEL_PARAMS)
+ if isinstance(self.outputs, list):
+ self.outputs = self._resolve_params_list(self.outputs, MELLON_OUTPUT_PARAMS)
+
+ @staticmethod
+ def _resolve_params_list(
+ params: List[Union[str, MellonParam]], default_map: Dict[str, Dict[str, Any]]
+ ) -> Dict[str, Dict[str, Any]]:
+ def _resolve_param(
+ param: Union[str, MellonParam], default_params_map: Dict[str, Dict[str, Any]]
+ ) -> Tuple[str, Dict[str, Any]]:
+ if isinstance(param, str):
+ if param not in default_params_map:
+ raise ValueError(f"Unknown param '{param}', please define a `MellonParam` object instead")
+ return param, default_params_map[param].copy()
+ elif isinstance(param, MellonParam):
+ param_dict = param.to_dict()
+ param_name = param_dict.pop("name")
+ return param_name, param_dict
+ else:
+ raise ValueError(
+ f"Unknown param type '{type(param)}', please use a string or a `MellonParam` object instead"
+ )
+
+ resolved = {}
+ for p in params:
+ logger.info(f" Resolving param: {p}")
+ name, cfg = _resolve_param(p, default_map)
+ if name in resolved:
+ raise ValueError(f"Duplicate param '{name}'")
+ resolved[name] = cfg
+ return resolved
+
+ @classmethod
+ @validate_hf_hub_args
+ def load_mellon_config(
+ cls,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ return_unused_kwargs=False,
+ return_commit_hash=False,
+ **kwargs,
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ r"""
+ Load a model or scheduler configuration.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
+ [`~ConfigMixin.save_config`].
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False):
+ Whether unused keyword arguments of the config are returned.
+ return_commit_hash (`bool`, *optional*, defaults to `False):
+ Whether the `commit_hash` of the loaded configuration are returned.
+
+ Returns:
+ `dict`:
+ A dictionary of all the parameters stored in a JSON configuration file.
+
+ """
+ cache_dir = kwargs.pop("cache_dir", None)
+ local_dir = kwargs.pop("local_dir", None)
+ local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto")
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ token = kwargs.pop("token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+
+ if cls.config_name is None:
+ raise ValueError(
+ "`self.config_name` is not defined. Note that one should not load a config from "
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
+ )
+ if os.path.isfile(pretrained_model_name_or_path):
+ config_file = pretrained_model_name_or_path
+ elif os.path.isdir(pretrained_model_name_or_path):
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
+ # Load from a PyTorch checkpoint
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
+ else:
+ raise EnvironmentError(
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
+ )
+ else:
+ try:
+ # Load from URL or cache if already cached
+ config_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=cls.config_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ local_dir=local_dir,
+ local_dir_use_symlinks=local_dir_use_symlinks,
+ )
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
+ " token having permission to this repo with `token` or log in with `hf auth login`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
+ " this model name. Check the model page at"
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
+ )
+ except HfHubHTTPError as err:
+ raise EnvironmentError(
+ "There was a specific connection error when trying to load"
+ f" {pretrained_model_name_or_path}:\n{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
+ " run the library in offline mode at"
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a {cls.config_name} file"
+ )
+ try:
+ with open(config_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ config_dict = json.loads(text)
+
+ commit_hash = extract_commit_hash(config_file)
+ except (json.JSONDecodeError, UnicodeDecodeError):
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
+
+ if not (return_unused_kwargs or return_commit_hash):
+ return config_dict
+
+ outputs = (config_dict,)
+
+ if return_unused_kwargs:
+ outputs += (kwargs,)
+
+ if return_commit_hash:
+ outputs += (commit_hash,)
+
+ return outputs
+
+ def save_mellon_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save the Mellon node definition to a JSON file.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file is saved (will be created if it does not exist).
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ """
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # If we save using the predefined names, we can load using `from_config`
+ output_config_file = os.path.join(save_directory, self.config_name)
+
+ self.to_json_file(output_config_file)
+ logger.info(f"Mellon node definition saved in {output_config_file}")
+
+ if push_to_hub:
+ commit_message = kwargs.pop("commit_message", None)
+ private = kwargs.pop("private", None)
+ create_pr = kwargs.pop("create_pr", False)
+ token = kwargs.pop("token", None)
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
+ subfolder = kwargs.pop("subfolder", None)
+
+ self._upload_folder(
+ save_directory,
+ repo_id,
+ token=token,
+ commit_message=commit_message,
+ create_pr=create_pr,
+ subfolder=subfolder,
+ )
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save the Mellon schema dictionary to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file to save a configuration instance's parameters.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
+
+ def to_json_string(self) -> str:
+ """
+ Serializes this instance to a JSON string of the Mellon schema dict.
+
+ Args:
+ Returns:
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
+ """
+
+ mellon_dict = self.to_mellon_dict()
+ return json.dumps(mellon_dict, indent=2, sort_keys=True) + "\n"
+
+ def to_mellon_dict(self) -> Dict[str, Any]:
+ """Return a JSON-serializable dict focusing on the Mellon schema fields only.
+
+ params is a single flat dict composed as: {**inputs, **model_inputs, **outputs}.
+ """
+ # inputs/model_inputs/outputs are already normalized dicts
+ merged_params = {}
+ merged_params.update(self.inputs or {})
+ merged_params.update(self.model_inputs or {})
+ merged_params.update(self.outputs or {})
+
+ return {
+ "node_type": self.node_type,
+ "blocks_names": self.blocks_names,
+ "params": merged_params,
+ }
+
+ @classmethod
+ def from_mellon_dict(cls, mellon_dict: Dict[str, Any]) -> "MellonNodeConfig":
+ """Create a config from a Mellon schema dict produced by to_mellon_dict().
+
+ Splits the flat params dict back into inputs/model_inputs/outputs using the known key spaces from
+ MELLON_INPUT_PARAMS / MELLON_MODEL_PARAMS / MELLON_OUTPUT_PARAMS. Unknown keys are treated as inputs by
+ default.
+ """
+ flat_params = mellon_dict.get("params", {})
+
+ inputs: Dict[str, Any] = {}
+ model_inputs: Dict[str, Any] = {}
+ outputs: Dict[str, Any] = {}
+
+ for param_name, param_dict in flat_params.items():
+ if param_dict.get("display", "") == "output":
+ outputs[param_name] = param_dict
+ elif param_dict.get("type", "") in ("diffusers_auto_model", "diffusers_auto_models"):
+ model_inputs[param_name] = param_dict
+ else:
+ inputs[param_name] = param_dict
+
+ return cls(
+ inputs=inputs,
+ model_inputs=model_inputs,
+ outputs=outputs,
+ blocks_names=mellon_dict.get("blocks_names", []),
+ node_type=mellon_dict.get("node_type"),
+ )
+
+ # YiYi Notes: not used yet
+ @classmethod
+ def from_blocks(cls, blocks: ModularPipelineBlocks, node_type: str) -> "MellonNodeConfig":
+ """
+ Create an instance from a ModularPipeline object. If a preset exists in NODE_TYPE_PARAMS_MAP for the node_type,
+ use it; otherwise fall back to deriving lists from the pipeline's expected inputs/components/outputs.
+ """
+ if node_type not in NODE_TYPE_PARAMS_MAP:
+ raise ValueError(f"Node type {node_type} not supported")
+
+ blocks_names = list(blocks.sub_blocks.keys())
+
+ default_node_config = NODE_TYPE_PARAMS_MAP[node_type]
+ inputs_list: List[Union[str, MellonParam]] = default_node_config.get("inputs", [])
+ model_inputs_list: List[Union[str, MellonParam]] = default_node_config.get("model_inputs", [])
+ outputs_list: List[Union[str, MellonParam]] = default_node_config.get("outputs", [])
+
+ for required_input_name in blocks.required_inputs:
+ if required_input_name not in inputs_list:
+ inputs_list.append(
+ MellonParam(
+ name=required_input_name, label=required_input_name, type=required_input_name, display="input"
+ )
+ )
+
+ for component_spec in blocks.expected_components:
+ if component_spec.name not in model_inputs_list:
+ model_inputs_list.append(
+ MellonParam(
+ name=component_spec.name,
+ label=component_spec.name,
+ type="diffusers_auto_model",
+ display="input",
+ )
+ )
+
+ return cls(
+ inputs=inputs_list,
+ model_inputs=model_inputs_list,
+ outputs=outputs_list,
+ blocks_names=blocks_names,
+ node_type=node_type,
+ )
+
+
+# Minimal modular registry for Mellon node configs
+class ModularMellonNodeRegistry:
+ """Registry mapping (pipeline class, blocks_name) -> list of MellonNodeConfig."""
+
+ def __init__(self):
+ self._registry = {}
+ self._initialized = False
+
+ def register(self, pipeline_cls: type, node_params: Dict[str, MellonNodeConfig]):
+ if not self._initialized:
+ _initialize_registry(self)
+ self._registry[pipeline_cls] = node_params
+
+ def get(self, pipeline_cls: type) -> MellonNodeConfig:
+ if not self._initialized:
+ _initialize_registry(self)
+ return self._registry.get(pipeline_cls, None)
+
+ def get_all(self) -> Dict[type, Dict[str, MellonNodeConfig]]:
+ if not self._initialized:
+ _initialize_registry(self)
+ return self._registry
+
+
+def _register_preset_node_types(
+ pipeline_cls, params_map: Dict[str, Dict[str, Any]], registry: ModularMellonNodeRegistry
+):
+ """Register all node-type presets for a given pipeline class from a params map."""
+ node_configs = {}
+ for node_type, spec in params_map.items():
+ node_config = MellonNodeConfig(
+ inputs=spec.get("inputs", []),
+ model_inputs=spec.get("model_inputs", []),
+ outputs=spec.get("outputs", []),
+ blocks_names=spec.get("block_names", []),
+ node_type=node_type,
+ )
+ node_configs[node_type] = node_config
+ registry.register(pipeline_cls, node_configs)
+
+
+def _initialize_registry(registry: ModularMellonNodeRegistry):
+ """Initialize the registry and register all available pipeline configs."""
+ print("Initializing registry")
+
+ registry._initialized = True
+
+ try:
+ from .qwenimage.modular_pipeline import QwenImageModularPipeline
+ from .qwenimage.node_utils import QwenImage_NODE_TYPES_PARAMS_MAP
+
+ _register_preset_node_types(QwenImageModularPipeline, QwenImage_NODE_TYPES_PARAMS_MAP, registry)
+ except Exception:
+ raise Exception("Failed to register QwenImageModularPipeline")
+
+ try:
+ from .stable_diffusion_xl.modular_pipeline import StableDiffusionXLModularPipeline
+ from .stable_diffusion_xl.node_utils import SDXL_NODE_TYPES_PARAMS_MAP
+
+ _register_preset_node_types(StableDiffusionXLModularPipeline, SDXL_NODE_TYPES_PARAMS_MAP, registry)
+ except Exception:
+ raise Exception("Failed to register StableDiffusionXLModularPipeline")
diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py
index 0ef1d59f4d..cfbca48a98 100644
--- a/src/diffusers/modular_pipelines/modular_pipeline.py
+++ b/src/diffusers/modular_pipelines/modular_pipeline.py
@@ -29,11 +29,7 @@ from typing_extensions import Self
from ..configuration_utils import ConfigMixin, FrozenDict
from ..pipelines.pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj
-from ..utils import (
- PushToHubMixin,
- is_accelerate_available,
- logging,
-)
+from ..utils import PushToHubMixin, is_accelerate_available, logging
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
from .components_manager import ComponentsManager
@@ -45,8 +41,6 @@ from .modular_pipeline_utils import (
OutputParam,
format_components,
format_configs,
- format_inputs_short,
- format_intermediates_short,
make_doc_string,
)
@@ -57,19 +51,16 @@ if is_accelerate_available():
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+# map regular pipeline to modular pipeline class name
MODULAR_PIPELINE_MAPPING = OrderedDict(
[
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
("wan", "WanModularPipeline"),
("flux", "FluxModularPipeline"),
- ]
-)
-
-MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict(
- [
- ("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"),
- ("WanModularPipeline", "WanAutoBlocks"),
- ("FluxModularPipeline", "FluxAutoBlocks"),
+ ("flux-kontext", "FluxKontextModularPipeline"),
+ ("qwenimage", "QwenImageModularPipeline"),
+ ("qwenimage-edit", "QwenImageEditModularPipeline"),
+ ("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
]
)
@@ -80,139 +71,68 @@ class PipelineState:
[`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks.
"""
- inputs: Dict[str, Any] = field(default_factory=dict)
- intermediates: Dict[str, Any] = field(default_factory=dict)
- input_kwargs: Dict[str, List[str]] = field(default_factory=dict)
- intermediate_kwargs: Dict[str, List[str]] = field(default_factory=dict)
+ values: Dict[str, Any] = field(default_factory=dict)
+ kwargs_mapping: Dict[str, List[str]] = field(default_factory=dict)
- def set_input(self, key: str, value: Any, kwargs_type: str = None):
+ def set(self, key: str, value: Any, kwargs_type: str = None):
"""
- Add an input to the immutable pipeline state, i.e, pipeline_state.inputs.
-
- The kwargs_type parameter allows you to associate inputs with specific input types. For example, if you call
- set_input(prompt_embeds=..., kwargs_type="guider_kwargs"), this input will be automatically fetched when a
- pipeline block has "guider_kwargs" in its expected_inputs list.
+ Add a value to the pipeline state.
Args:
- key (str): The key for the input
- value (Any): The input value
- kwargs_type (str): The kwargs_type with which the input is associated
+ key (str): The key for the value
+ value (Any): The value to store
+ kwargs_type (str): The kwargs_type with which the value is associated
"""
- self.inputs[key] = value
+ self.values[key] = value
+
if kwargs_type is not None:
- if kwargs_type not in self.input_kwargs:
- self.input_kwargs[kwargs_type] = [key]
+ if kwargs_type not in self.kwargs_mapping:
+ self.kwargs_mapping[kwargs_type] = [key]
else:
- self.input_kwargs[kwargs_type].append(key)
+ self.kwargs_mapping[kwargs_type].append(key)
- def set_intermediate(self, key: str, value: Any, kwargs_type: str = None):
+ def get(self, keys: Union[str, List[str]], default: Any = None) -> Union[Any, Dict[str, Any]]:
"""
- Add an intermediate value to the mutable pipeline state, i.e, pipeline_state.intermediates.
-
- The kwargs_type parameter allows you to associate intermediate values with specific input types. For example,
- if you call set_intermediate(latents=..., kwargs_type="latents_kwargs"), this intermediate value will be
- automatically fetched when a pipeline block has "latents_kwargs" in its expected_intermediate_inputs list.
+ Get one or multiple values from the pipeline state.
Args:
- key (str): The key for the intermediate value
- value (Any): The intermediate value
- kwargs_type (str): The kwargs_type with which the intermediate value is associated
- """
- self.intermediates[key] = value
- if kwargs_type is not None:
- if kwargs_type not in self.intermediate_kwargs:
- self.intermediate_kwargs[kwargs_type] = [key]
- else:
- self.intermediate_kwargs[kwargs_type].append(key)
-
- def get_input(self, key: str, default: Any = None) -> Any:
- """
- Get an input from the pipeline state.
-
- Args:
- key (str): The key for the input
- default (Any): The default value to return if the input is not found
+ keys (Union[str, List[str]]): Key or list of keys for the values
+ default (Any): The default value to return if not found
Returns:
- Any: The input value
+ Union[Any, Dict[str, Any]]: Single value if keys is str, dictionary of values if keys is list
"""
- value = self.inputs.get(key, default)
- if value is not None:
- return deepcopy(value)
+ if isinstance(keys, str):
+ return self.values.get(keys, default)
+ return {key: self.values.get(key, default) for key in keys}
- def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]:
+ def get_by_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
"""
- Get multiple inputs from the pipeline state.
-
- Args:
- keys (List[str]): The keys for the inputs
- default (Any): The default value to return if the input is not found
-
- Returns:
- Dict[str, Any]: Dictionary of inputs with matching keys
- """
- return {key: self.inputs.get(key, default) for key in keys}
-
- def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
- """
- Get all inputs with matching kwargs_type.
+ Get all values with matching kwargs_type.
Args:
kwargs_type (str): The kwargs_type to filter by
Returns:
- Dict[str, Any]: Dictionary of inputs with matching kwargs_type
+ Dict[str, Any]: Dictionary of values with matching kwargs_type
"""
- input_names = self.input_kwargs.get(kwargs_type, [])
- return self.get_inputs(input_names)
-
- def get_intermediate_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
- """
- Get all intermediates with matching kwargs_type.
-
- Args:
- kwargs_type (str): The kwargs_type to filter by
-
- Returns:
- Dict[str, Any]: Dictionary of intermediates with matching kwargs_type
- """
- intermediate_names = self.intermediate_kwargs.get(kwargs_type, [])
- return self.get_intermediates(intermediate_names)
-
- def get_intermediate(self, key: str, default: Any = None) -> Any:
- """
- Get an intermediate value from the pipeline state.
-
- Args:
- key (str): The key for the intermediate value
- default (Any): The default value to return if the intermediate value is not found
-
- Returns:
- Any: The intermediate value
- """
- return self.intermediates.get(key, default)
-
- def get_intermediates(self, keys: List[str], default: Any = None) -> Dict[str, Any]:
- """
- Get multiple intermediate values from the pipeline state.
-
- Args:
- keys (List[str]): The keys for the intermediate values
- default (Any): The default value to return if the intermediate value is not found
-
- Returns:
- Dict[str, Any]: Dictionary of intermediate values with matching keys
- """
- return {key: self.intermediates.get(key, default) for key in keys}
+ value_names = self.kwargs_mapping.get(kwargs_type, [])
+ return self.get(value_names)
def to_dict(self) -> Dict[str, Any]:
"""
Convert PipelineState to a dictionary.
-
- Returns:
- Dict[str, Any]: Dictionary containing all attributes of the PipelineState
"""
- return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates}
+ return {**self.__dict__}
+
+ def __getattr__(self, name):
+ """
+ Allow attribute access to intermediate values. If an attribute is not found in the object, look for it in the
+ intermediates dict.
+ """
+ if name in self.values:
+ return self.values[name]
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def __repr__(self):
def format_value(v):
@@ -223,21 +143,10 @@ class PipelineState:
else:
return repr(v)
- inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items())
- intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items())
+ values_str = "\n".join(f" {k}: {format_value(v)}" for k, v in self.values.items())
+ kwargs_mapping_str = "\n".join(f" {k}: {v}" for k, v in self.kwargs_mapping.items())
- # Format input_kwargs and intermediate_kwargs
- input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items())
- intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items())
-
- return (
- f"PipelineState(\n"
- f" inputs={{\n{inputs}\n }},\n"
- f" intermediates={{\n{intermediates}\n }},\n"
- f" input_kwargs={{\n{input_kwargs_str}\n }},\n"
- f" intermediate_kwargs={{\n{intermediate_kwargs_str}\n }}\n"
- f")"
- )
+ return f"PipelineState(\n values={{\n{values_str}\n }},\n kwargs_mapping={{\n{kwargs_mapping_str}\n }}\n)"
@dataclass
@@ -317,16 +226,12 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks,
LoopSequentialPipelineBlocks
- [`ModularPipelineBlocks`] provides method to load and save the defination of pipeline blocks.
+ [`ModularPipelineBlocks`] provides method to load and save the definition of pipeline blocks.
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
- config_name = "config.json"
+ config_name = "modular_config.json"
model_name = None
@classmethod
@@ -338,6 +243,14 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
return expected_modules, optional_parameters
+ def __init__(self):
+ self.sub_blocks = InsertableDict()
+
+ @property
+ def description(self) -> str:
+ """Description of the block. Must be implemented by subclasses."""
+ return ""
+
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@@ -346,11 +259,40 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
def expected_configs(self) -> List[ConfigSpec]:
return []
+ @property
+ def inputs(self) -> List[InputParam]:
+ """List of input parameters. Must be implemented by subclasses."""
+ return []
+
+ def _get_required_inputs(self):
+ input_names = []
+ for input_param in self.inputs:
+ if input_param.required:
+ input_names.append(input_param.name)
+
+ return input_names
+
+ @property
+ def required_inputs(self) -> List[InputParam]:
+ return self._get_required_inputs()
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ """List of intermediate output parameters. Must be implemented by subclasses."""
+ return []
+
+ def _get_outputs(self):
+ return self.intermediate_outputs
+
+ @property
+ def outputs(self) -> List[OutputParam]:
+ return self._get_outputs()
+
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
- trust_remote_code: Optional[bool] = None,
+ trust_remote_code: bool = False,
**kwargs,
):
hub_kwargs_names = [
@@ -370,7 +312,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
- if not (has_remote_code and trust_remote_code):
+ if not has_remote_code and trust_remote_code:
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
@@ -427,6 +369,63 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
)
return modular_pipeline
+ def get_block_state(self, state: PipelineState) -> dict:
+ """Get all inputs and intermediates in one dictionary"""
+ data = {}
+ state_inputs = self.inputs
+
+ # Check inputs
+ for input_param in state_inputs:
+ if input_param.name:
+ value = state.get(input_param.name)
+ if input_param.required and value is None:
+ raise ValueError(f"Required input '{input_param.name}' is missing")
+ elif value is not None or (value is None and input_param.name not in data):
+ data[input_param.name] = value
+
+ elif input_param.kwargs_type:
+ # if kwargs_type is provided, get all inputs with matching kwargs_type
+ if input_param.kwargs_type not in data:
+ data[input_param.kwargs_type] = {}
+ inputs_kwargs = state.get_by_kwargs(input_param.kwargs_type)
+ if inputs_kwargs:
+ for k, v in inputs_kwargs.items():
+ if v is not None:
+ data[k] = v
+ data[input_param.kwargs_type][k] = v
+
+ return BlockState(**data)
+
+ def set_block_state(self, state: PipelineState, block_state: BlockState):
+ for output_param in self.intermediate_outputs:
+ if not hasattr(block_state, output_param.name):
+ raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
+ param = getattr(block_state, output_param.name)
+ state.set(output_param.name, param, output_param.kwargs_type)
+
+ for input_param in self.inputs:
+ if input_param.name and hasattr(block_state, input_param.name):
+ param = getattr(block_state, input_param.name)
+ # Only add if the value is different from what's in the state
+ current_value = state.get(input_param.name)
+ if current_value is not param: # Using identity comparison to check if object was modified
+ state.set(input_param.name, param, input_param.kwargs_type)
+
+ elif input_param.kwargs_type:
+ # if it is a kwargs type, e.g. "denoiser_input_fields", it is likely to be a list of parameters
+ # we need to first find out which inputs are and loop through them.
+ intermediate_kwargs = state.get_by_kwargs(input_param.kwargs_type)
+ for param_name, current_value in intermediate_kwargs.items():
+ if param_name is None:
+ continue
+
+ if not hasattr(block_state, param_name):
+ continue
+
+ param = getattr(block_state, param_name)
+ if current_value is not param: # Using identity comparison to check if object was modified
+ state.set(param_name, param, input_param.kwargs_type)
+
@staticmethod
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
"""
@@ -493,162 +492,22 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
return list(combined_dict.values())
-
-class PipelineBlock(ModularPipelineBlocks):
- """
- A Pipeline Block is the basic building block of a Modular Pipeline.
-
- This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
- library implements for all the pipeline blocks (such as loading or saving etc.)
-
-
-
- This is an experimental feature and is likely to change in the future.
-
-
-
- Args:
- description (str, optional): A description of the block, defaults to None. Define as a property in subclasses.
- expected_components (List[ComponentSpec], optional):
- A list of components that are expected to be used in the block, defaults to []. To override, define as a
- property in subclasses.
- expected_configs (List[ConfigSpec], optional):
- A list of configs that are expected to be used in the block, defaults to []. To override, define as a
- property in subclasses.
- inputs (List[InputParam], optional):
- A list of inputs that are expected to be used in the block, defaults to []. To override, define as a
- property in subclasses.
- intermediate_inputs (List[InputParam], optional):
- A list of intermediate inputs that are expected to be used in the block, defaults to []. To override,
- define as a property in subclasses.
- intermediate_outputs (List[OutputParam], optional):
- A list of intermediate outputs that are expected to be used in the block, defaults to []. To override,
- define as a property in subclasses.
- outputs (List[OutputParam], optional):
- A list of outputs that are expected to be used in the block, defaults to []. To override, define as a
- property in subclasses.
- required_inputs (List[str], optional):
- A list of required inputs that are expected to be used in the block, defaults to []. To override, define as
- a property in subclasses.
- required_intermediate_inputs (List[str], optional):
- A list of required intermediate inputs that are expected to be used in the block, defaults to []. To
- override, define as a property in subclasses.
- required_intermediate_outputs (List[str], optional):
- A list of required intermediate outputs that are expected to be used in the block, defaults to []. To
- override, define as a property in subclasses.
- """
-
- model_name = None
-
- def __init__(self):
- self.sub_blocks = InsertableDict()
+ @property
+ def input_names(self) -> List[str]:
+ return [input_param.name for input_param in self.inputs]
@property
- def description(self) -> str:
- """Description of the block. Must be implemented by subclasses."""
- # raise NotImplementedError("description method must be implemented in subclasses")
- return "TODO: add a description"
+ def intermediate_output_names(self) -> List[str]:
+ return [output_param.name for output_param in self.intermediate_outputs]
@property
- def expected_components(self) -> List[ComponentSpec]:
- return []
-
- @property
- def expected_configs(self) -> List[ConfigSpec]:
- return []
-
- @property
- def inputs(self) -> List[InputParam]:
- """List of input parameters. Must be implemented by subclasses."""
- return []
-
- @property
- def intermediate_inputs(self) -> List[InputParam]:
- """List of intermediate input parameters. Must be implemented by subclasses."""
- return []
-
- @property
- def intermediate_outputs(self) -> List[OutputParam]:
- """List of intermediate output parameters. Must be implemented by subclasses."""
- return []
-
- def _get_outputs(self):
- return self.intermediate_outputs
-
- # YiYi TODO: is it too easy for user to unintentionally override these properties?
- # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks
- @property
- def outputs(self) -> List[OutputParam]:
- return self._get_outputs()
-
- def _get_required_inputs(self):
- input_names = []
- for input_param in self.inputs:
- if input_param.required:
- input_names.append(input_param.name)
- return input_names
-
- @property
- def required_inputs(self) -> List[str]:
- return self._get_required_inputs()
-
- def _get_required_intermediate_inputs(self):
- input_names = []
- for input_param in self.intermediate_inputs:
- if input_param.required:
- input_names.append(input_param.name)
- return input_names
-
- # YiYi TODO: maybe we do not need this, it is only used in docstring,
- # intermediate_inputs is by default required, unless you manually handle it inside the block
- @property
- def required_intermediate_inputs(self) -> List[str]:
- return self._get_required_intermediate_inputs()
-
- def __call__(self, pipeline, state: PipelineState) -> PipelineState:
- raise NotImplementedError("__call__ method must be implemented in subclasses")
-
- def __repr__(self):
- class_name = self.__class__.__name__
- base_class = self.__class__.__bases__[0].__name__
-
- # Format description with proper indentation
- desc_lines = self.description.split("\n")
- desc = []
- # First line with "Description:" label
- desc.append(f" Description: {desc_lines[0]}")
- # Subsequent lines with proper indentation
- if len(desc_lines) > 1:
- desc.extend(f" {line}" for line in desc_lines[1:])
- desc = "\n".join(desc) + "\n"
-
- # Components section - use format_components with add_empty_lines=False
- expected_components = getattr(self, "expected_components", [])
- components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
- components = " " + components_str.replace("\n", "\n ")
-
- # Configs section - use format_configs with add_empty_lines=False
- expected_configs = getattr(self, "expected_configs", [])
- configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
- configs = " " + configs_str.replace("\n", "\n ")
-
- # Inputs section
- inputs_str = format_inputs_short(self.inputs)
- inputs = "Inputs:\n " + inputs_str
-
- # Intermediates section
- intermediates_str = format_intermediates_short(
- self.intermediate_inputs, self.required_intermediate_inputs, self.intermediate_outputs
- )
- intermediates = f"Intermediates:\n{intermediates_str}"
-
- return f"{class_name}(\n Class: {base_class}\n{desc}{components}\n{configs}\n {inputs}\n {intermediates}\n)"
+ def output_names(self) -> List[str]:
+ return [output_param.name for output_param in self.outputs]
@property
def doc(self):
return make_doc_string(
self.inputs,
- self.intermediate_inputs,
self.outputs,
self.description,
class_name=self.__class__.__name__,
@@ -656,82 +515,6 @@ class PipelineBlock(ModularPipelineBlocks):
expected_configs=self.expected_configs,
)
- # YiYi TODO: input and inteermediate inputs with same name? should warn?
- def get_block_state(self, state: PipelineState) -> dict:
- """Get all inputs and intermediates in one dictionary"""
- data = {}
-
- # Check inputs
- for input_param in self.inputs:
- if input_param.name:
- value = state.get_input(input_param.name)
- if input_param.required and value is None:
- raise ValueError(f"Required input '{input_param.name}' is missing")
- elif value is not None or (value is None and input_param.name not in data):
- data[input_param.name] = value
- elif input_param.kwargs_type:
- # if kwargs_type is provided, get all inputs with matching kwargs_type
- if input_param.kwargs_type not in data:
- data[input_param.kwargs_type] = {}
- inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type)
- if inputs_kwargs:
- for k, v in inputs_kwargs.items():
- if v is not None:
- data[k] = v
- data[input_param.kwargs_type][k] = v
-
- # Check intermediates
- for input_param in self.intermediate_inputs:
- if input_param.name:
- value = state.get_intermediate(input_param.name)
- if input_param.required and value is None:
- raise ValueError(f"Required intermediate input '{input_param.name}' is missing")
- elif value is not None or (value is None and input_param.name not in data):
- data[input_param.name] = value
- elif input_param.kwargs_type:
- # if kwargs_type is provided, get all intermediates with matching kwargs_type
- if input_param.kwargs_type not in data:
- data[input_param.kwargs_type] = {}
- intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
- if intermediate_kwargs:
- for k, v in intermediate_kwargs.items():
- if v is not None:
- if k not in data:
- data[k] = v
- data[input_param.kwargs_type][k] = v
- return BlockState(**data)
-
- def set_block_state(self, state: PipelineState, block_state: BlockState):
- for output_param in self.intermediate_outputs:
- if not hasattr(block_state, output_param.name):
- raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
- param = getattr(block_state, output_param.name)
- state.set_intermediate(output_param.name, param, output_param.kwargs_type)
-
- for input_param in self.intermediate_inputs:
- if hasattr(block_state, input_param.name):
- param = getattr(block_state, input_param.name)
- # Only add if the value is different from what's in the state
- current_value = state.get_intermediate(input_param.name)
- if current_value is not param: # Using identity comparison to check if object was modified
- state.set_intermediate(input_param.name, param, input_param.kwargs_type)
-
- for input_param in self.intermediate_inputs:
- if input_param.name and hasattr(block_state, input_param.name):
- param = getattr(block_state, input_param.name)
- # Only add if the value is different from what's in the state
- current_value = state.get_intermediate(input_param.name)
- if current_value is not param: # Using identity comparison to check if object was modified
- state.set_intermediate(input_param.name, param, input_param.kwargs_type)
- elif input_param.kwargs_type:
- # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
- # we need to first find out which inputs are and loop through them.
- intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
- for param_name, current_value in intermediate_kwargs.items():
- param = getattr(block_state, param_name)
- if current_value is not param: # Using identity comparison to check if object was modified
- state.set_intermediate(param_name, param, input_param.kwargs_type)
-
class AutoPipelineBlocks(ModularPipelineBlocks):
"""
@@ -740,11 +523,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
library implements for all the pipeline blocks (such as loading or saving etc.)
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
Attributes:
block_classes: List of block classes to be used
@@ -758,8 +537,11 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
def __init__(self):
sub_blocks = InsertableDict()
- for block_name, block_cls in zip(self.block_names, self.block_classes):
- sub_blocks[block_name] = block_cls()
+ for block_name, block in zip(self.block_names, self.block_classes):
+ if inspect.isclass(block):
+ sub_blocks[block_name] = block()
+ else:
+ sub_blocks[block_name] = block
self.sub_blocks = sub_blocks
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
raise ValueError(
@@ -821,22 +603,6 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
return list(required_by_all)
- # YiYi TODO: maybe we do not need this, it is only used in docstring,
- # intermediate_inputs is by default required, unless you manually handle it inside the block
- @property
- def required_intermediate_inputs(self) -> List[str]:
- if None not in self.block_trigger_inputs:
- return []
- first_block = next(iter(self.sub_blocks.values()))
- required_by_all = set(getattr(first_block, "required_intermediate_inputs", set()))
-
- # Intersect with required inputs from all other blocks
- for block in list(self.sub_blocks.values())[1:]:
- block_required = set(getattr(block, "required_intermediate_inputs", set()))
- required_by_all.intersection_update(block_required)
-
- return list(required_by_all)
-
# YiYi TODO: add test for this
@property
def inputs(self) -> List[Tuple[str, Any]]:
@@ -850,18 +616,6 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
input_param.required = False
return combined_inputs
- @property
- def intermediate_inputs(self) -> List[str]:
- named_inputs = [(name, block.intermediate_inputs) for name, block in self.sub_blocks.items()]
- combined_inputs = self.combine_inputs(*named_inputs)
- # mark Required inputs only if that input is required by all the blocks
- for input_param in combined_inputs:
- if input_param.name in self.required_intermediate_inputs:
- input_param.required = True
- else:
- input_param.required = False
- return combined_inputs
-
@property
def intermediate_outputs(self) -> List[str]:
named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
@@ -880,15 +634,12 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
block = self.trigger_to_block_map.get(None)
for input_name in self.block_trigger_inputs:
- if input_name is not None and state.get_input(input_name) is not None:
- block = self.trigger_to_block_map[input_name]
- break
- elif input_name is not None and state.get_intermediate(input_name) is not None:
+ if input_name is not None and state.get(input_name) is not None:
block = self.trigger_to_block_map[input_name]
break
if block is None:
- logger.warning(f"skipping auto block: {self.__class__.__name__}")
+ logger.info(f"skipping auto block: {self.__class__.__name__}")
return pipeline, state
try:
@@ -1014,7 +765,6 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
def doc(self):
return make_doc_string(
self.inputs,
- self.intermediate_inputs,
self.outputs,
self.description,
class_name=self.__class__.__name__,
@@ -1031,11 +781,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
library implements for all the pipeline blocks (such as loading or saving etc.)
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
Attributes:
block_classes: List of block classes to be used
@@ -1051,7 +797,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
@property
def model_name(self):
- return next(iter(self.sub_blocks.values())).model_name
+ return next((block.model_name for block in self.sub_blocks.values() if block.model_name is not None), None)
@property
def expected_components(self):
@@ -1072,7 +818,9 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
return expected_configs
@classmethod
- def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks":
+ def from_blocks_dict(
+ cls, blocks_dict: Dict[str, Any], description: Optional[str] = None
+ ) -> "SequentialPipelineBlocks":
"""Creates a SequentialPipelineBlocks instance from a dictionary of blocks.
Args:
@@ -1094,14 +842,49 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
instance.block_classes = [block.__class__ for block in sub_blocks.values()]
instance.block_names = list(sub_blocks.keys())
instance.sub_blocks = sub_blocks
+
+ if description is not None:
+ instance.description = description
+
return instance
def __init__(self):
sub_blocks = InsertableDict()
- for block_name, block_cls in zip(self.block_names, self.block_classes):
- sub_blocks[block_name] = block_cls()
+ for block_name, block in zip(self.block_names, self.block_classes):
+ if inspect.isclass(block):
+ sub_blocks[block_name] = block()
+ else:
+ sub_blocks[block_name] = block
self.sub_blocks = sub_blocks
+ def _get_inputs(self):
+ inputs = []
+ outputs = set()
+
+ # Go through all blocks in order
+ for block in self.sub_blocks.values():
+ # Add inputs that aren't in outputs yet
+ for inp in block.inputs:
+ if inp.name not in outputs and inp.name not in {input.name for input in inputs}:
+ inputs.append(inp)
+
+ # Only add outputs if the block cannot be skipped
+ should_add_outputs = True
+ if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
+ should_add_outputs = False
+
+ if should_add_outputs:
+ # Add this block's outputs
+ block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
+ outputs.update(block_intermediate_outputs)
+
+ return inputs
+
+ # YiYi TODO: add test for this
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return self._get_inputs()
+
@property
def required_inputs(self) -> List[str]:
# Get the first block from the dictionary
@@ -1115,65 +898,11 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
return list(required_by_any)
- # YiYi TODO: maybe we do not need this, it is only used in docstring,
- # intermediate_inputs is by default required, unless you manually handle it inside the block
- @property
- def required_intermediate_inputs(self) -> List[str]:
- required_intermediate_inputs = []
- for input_param in self.intermediate_inputs:
- if input_param.required:
- required_intermediate_inputs.append(input_param.name)
- return required_intermediate_inputs
-
- # YiYi TODO: add test for this
- @property
- def inputs(self) -> List[Tuple[str, Any]]:
- return self.get_inputs()
-
- def get_inputs(self):
- named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
- combined_inputs = self.combine_inputs(*named_inputs)
- # mark Required inputs only if that input is required any of the blocks
- for input_param in combined_inputs:
- if input_param.name in self.required_inputs:
- input_param.required = True
- else:
- input_param.required = False
- return combined_inputs
-
- @property
- def intermediate_inputs(self) -> List[str]:
- return self.get_intermediate_inputs()
-
- def get_intermediate_inputs(self):
- inputs = []
- outputs = set()
- added_inputs = set()
-
- # Go through all blocks in order
- for block in self.sub_blocks.values():
- # Add inputs that aren't in outputs yet
- for inp in block.intermediate_inputs:
- if inp.name not in outputs and inp.name not in added_inputs:
- inputs.append(inp)
- added_inputs.add(inp.name)
-
- # Only add outputs if the block cannot be skipped
- should_add_outputs = True
- if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
- should_add_outputs = False
-
- if should_add_outputs:
- # Add this block's outputs
- block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
- outputs.update(block_intermediate_outputs)
- return inputs
-
@property
def intermediate_outputs(self) -> List[str]:
named_outputs = []
for name, block in self.sub_blocks.items():
- inp_names = {inp.name for inp in block.intermediate_inputs}
+ inp_names = {inp.name for inp in block.inputs}
# so we only need to list new variables as intermediate_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce)
# filter out them here so they do not end up as intermediate_outputs
if name not in inp_names:
@@ -1391,7 +1120,6 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
def doc(self):
return make_doc_string(
self.inputs,
- self.intermediate_inputs,
self.outputs,
self.description,
class_name=self.__class__.__name__,
@@ -1408,11 +1136,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
library implements for all the pipeline blocks (such as loading or saving etc.)
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
Attributes:
block_classes: List of block classes to be used
@@ -1441,16 +1165,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
"""List of input parameters. Must be implemented by subclasses."""
return []
- @property
- def loop_intermediate_inputs(self) -> List[InputParam]:
- """List of intermediate input parameters. Must be implemented by subclasses."""
- return []
-
- @property
- def loop_intermediate_outputs(self) -> List[OutputParam]:
- """List of intermediate output parameters. Must be implemented by subclasses."""
- return []
-
@property
def loop_required_inputs(self) -> List[str]:
input_names = []
@@ -1460,12 +1174,9 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
return input_names
@property
- def loop_required_intermediate_inputs(self) -> List[str]:
- input_names = []
- for input_param in self.loop_intermediate_inputs:
- if input_param.required:
- input_names.append(input_param.name)
- return input_names
+ def loop_intermediate_outputs(self) -> List[OutputParam]:
+ """List of intermediate output parameters. Must be implemented by subclasses."""
+ return []
# modified from SequentialPipelineBlocks to include loop_expected_components
@property
@@ -1493,43 +1204,16 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
expected_configs.append(config)
return expected_configs
- # modified from SequentialPipelineBlocks to include loop_inputs
- def get_inputs(self):
- named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
- named_inputs.append(("loop", self.loop_inputs))
- combined_inputs = self.combine_inputs(*named_inputs)
- # mark Required inputs only if that input is required any of the blocks
- for input_param in combined_inputs:
- if input_param.name in self.required_inputs:
- input_param.required = True
- else:
- input_param.required = False
- return combined_inputs
-
- @property
- # Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks.inputs
- def inputs(self):
- return self.get_inputs()
-
- # modified from SequentialPipelineBlocks to include loop_intermediate_inputs
- @property
- def intermediate_inputs(self):
- intermediates = self.get_intermediate_inputs()
- intermediate_names = [input.name for input in intermediates]
- for loop_intermediate_input in self.loop_intermediate_inputs:
- if loop_intermediate_input.name not in intermediate_names:
- intermediates.append(loop_intermediate_input)
- return intermediates
-
- # modified from SequentialPipelineBlocks
- def get_intermediate_inputs(self):
+ def _get_inputs(self):
inputs = []
+ inputs.extend(self.loop_inputs)
outputs = set()
- # Go through all blocks in order
- for block in self.sub_blocks.values():
+ for name, block in self.sub_blocks.items():
# Add inputs that aren't in outputs yet
- inputs.extend(input_name for input_name in block.intermediate_inputs if input_name.name not in outputs)
+ for inp in block.inputs:
+ if inp.name not in outputs and inp not in inputs:
+ inputs.append(inp)
# Only add outputs if the block cannot be skipped
should_add_outputs = True
@@ -1540,8 +1224,20 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
# Add this block's outputs
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
outputs.update(block_intermediate_outputs)
+
+ for input_param in inputs:
+ if input_param.name in self.required_inputs:
+ input_param.required = True
+ else:
+ input_param.required = False
+
return inputs
+ @property
+ # Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks.inputs
+ def inputs(self):
+ return self._get_inputs()
+
# modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block
@property
def required_inputs(self) -> List[str]:
@@ -1559,19 +1255,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
return list(required_by_any)
- # YiYi TODO: maybe we do not need this, it is only used in docstring,
- # intermediate_inputs is by default required, unless you manually handle it inside the block
- @property
- def required_intermediate_inputs(self) -> List[str]:
- required_intermediate_inputs = []
- for input_param in self.intermediate_inputs:
- if input_param.required:
- required_intermediate_inputs.append(input_param.name)
- for input_param in self.loop_intermediate_inputs:
- if input_param.required:
- required_intermediate_inputs.append(input_param.name)
- return required_intermediate_inputs
-
# YiYi TODO: this need to be thought about more
# modified from SequentialPipelineBlocks to include loop_intermediate_outputs
@property
@@ -1590,8 +1273,11 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
def __init__(self):
sub_blocks = InsertableDict()
- for block_name, block_cls in zip(self.block_names, self.block_classes):
- sub_blocks[block_name] = block_cls()
+ for block_name, block in zip(self.block_names, self.block_classes):
+ if inspect.isclass(block):
+ sub_blocks[block_name] = block()
+ else:
+ sub_blocks[block_name] = block
self.sub_blocks = sub_blocks
@classmethod
@@ -1637,80 +1323,10 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
def __call__(self, components, state: PipelineState) -> PipelineState:
raise NotImplementedError("`__call__` method needs to be implemented by the subclass")
- def get_block_state(self, state: PipelineState) -> dict:
- """Get all inputs and intermediates in one dictionary"""
- data = {}
-
- # Check inputs
- for input_param in self.inputs:
- if input_param.name:
- value = state.get_input(input_param.name)
- if input_param.required and value is None:
- raise ValueError(f"Required input '{input_param.name}' is missing")
- elif value is not None or (value is None and input_param.name not in data):
- data[input_param.name] = value
- elif input_param.kwargs_type:
- # if kwargs_type is provided, get all inputs with matching kwargs_type
- if input_param.kwargs_type not in data:
- data[input_param.kwargs_type] = {}
- inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type)
- if inputs_kwargs:
- for k, v in inputs_kwargs.items():
- if v is not None:
- data[k] = v
- data[input_param.kwargs_type][k] = v
-
- # Check intermediates
- for input_param in self.intermediate_inputs:
- if input_param.name:
- value = state.get_intermediate(input_param.name)
- if input_param.required and value is None:
- raise ValueError(f"Required intermediate input '{input_param.name}' is missing.")
- elif value is not None or (value is None and input_param.name not in data):
- data[input_param.name] = value
- elif input_param.kwargs_type:
- # if kwargs_type is provided, get all intermediates with matching kwargs_type
- if input_param.kwargs_type not in data:
- data[input_param.kwargs_type] = {}
- intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
- if intermediate_kwargs:
- for k, v in intermediate_kwargs.items():
- if v is not None:
- if k not in data:
- data[k] = v
- data[input_param.kwargs_type][k] = v
- return BlockState(**data)
-
- def set_block_state(self, state: PipelineState, block_state: BlockState):
- for output_param in self.intermediate_outputs:
- if not hasattr(block_state, output_param.name):
- raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
- param = getattr(block_state, output_param.name)
- state.set_intermediate(output_param.name, param, output_param.kwargs_type)
-
- for input_param in self.intermediate_inputs:
- if input_param.name and hasattr(block_state, input_param.name):
- param = getattr(block_state, input_param.name)
- # Only add if the value is different from what's in the state
- current_value = state.get_intermediate(input_param.name)
- if current_value is not param: # Using identity comparison to check if object was modified
- state.set_intermediate(input_param.name, param, input_param.kwargs_type)
- elif input_param.kwargs_type:
- # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
- # we need to first find out which inputs are and loop through them.
- intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
- for param_name, current_value in intermediate_kwargs.items():
- if not hasattr(block_state, param_name):
- continue
- param = getattr(block_state, param_name)
- if current_value is not param: # Using identity comparison to check if object was modified
- state.set_intermediate(param_name, param, input_param.kwargs_type)
-
@property
def doc(self):
return make_doc_string(
self.inputs,
- self.intermediate_inputs,
self.outputs,
self.description,
class_name=self.__class__.__name__,
@@ -1798,16 +1414,12 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
# YiYi TODO:
# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess)
# 2. do we need ConfigSpec? the are basically just key/val kwargs
-# 3. imnprove docstring and potentially add validator for methods where we accpet kwargs to be passed to from_pretrained/save_pretrained/load_default_components(), load_components()
+# 3. imnprove docstring and potentially add validator for methods where we accept kwargs to be passed to from_pretrained/save_pretrained/load_components()
class ModularPipeline(ConfigMixin, PushToHubMixin):
"""
Base class for all Modular pipelines.
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
Args:
blocks: ModularPipelineBlocks, the blocks to be used in the pipeline
@@ -1815,6 +1427,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
config_name = "modular_model_index.json"
hf_device_map = None
+ default_blocks_name = None
# YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name
def __init__(
@@ -1839,9 +1452,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
Args:
blocks: `ModularPipelineBlocks` instance. If None, will attempt to load
default blocks based on the pipeline class name.
- pretrained_model_name_or_path: Path to a pretrained pipeline configuration. If provided,
- will load component specs (only for from_pretrained components) and config values from the saved
- modular_model_index.json file.
+ pretrained_model_name_or_path: Path to a pretrained pipeline configuration. Can be None if the pipeline
+ does not require any additional loading config. If provided, will first try to load component specs
+ (only for from_pretrained components) and config values from `modular_model_index.json`, then
+ fallback to `model_index.json` for compatibility with standard non-modular repositories.
components_manager:
Optional ComponentsManager for managing multiple component cross different pipelines and apply
offloading strategies.
@@ -1867,14 +1481,14 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
- Components with default_creation_method="from_config" are created immediately, its specs are not included
in config dict and will not be saved in `modular_model_index.json`
- Components with default_creation_method="from_pretrained" are set to None and can be loaded later with
- `load_default_components()`/`load_components()`
+ `load_components()` (with or without specific component names)
- The pipeline's config dict is populated with component specs (only for from_pretrained components) and
config values, which will be saved as `modular_model_index.json` during `save_pretrained`
- The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as
`_blocks_class_name` in the config dict
"""
if blocks is None:
- blocks_class_name = MODULAR_PIPELINE_BLOCKS_MAPPING.get(self.__class__.__name__)
+ blocks_class_name = self.default_blocks_name
if blocks_class_name is not None:
diffusers_module = importlib.import_module("diffusers")
blocks_class = getattr(diffusers_module, blocks_class_name)
@@ -1890,18 +1504,70 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
# update component_specs and config_specs from modular_repo
if pretrained_model_name_or_path is not None:
- config_dict = self.load_config(pretrained_model_name_or_path, **kwargs)
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ token = kwargs.pop("token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
- for name, value in config_dict.items():
- # all the components in modular_model_index.json are from_pretrained components
- if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3:
- library, class_name, component_spec_dict = value
- component_spec = self._dict_to_component_spec(name, component_spec_dict)
- component_spec.default_creation_method = "from_pretrained"
- self._component_specs[name] = component_spec
+ load_config_kwargs = {
+ "cache_dir": cache_dir,
+ "force_download": force_download,
+ "proxies": proxies,
+ "token": token,
+ "local_files_only": local_files_only,
+ "revision": revision,
+ }
+ # try to load modular_model_index.json
+ try:
+ config_dict = self.load_config(pretrained_model_name_or_path, **load_config_kwargs)
+ except EnvironmentError as e:
+ logger.debug(f"modular_model_index.json not found: {e}")
+ config_dict = None
- elif name in self._config_specs:
- self._config_specs[name].default = value
+ # update component_specs and config_specs based on modular_model_index.json
+ if config_dict is not None:
+ for name, value in config_dict.items():
+ # all the components in modular_model_index.json are from_pretrained components
+ if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3:
+ library, class_name, component_spec_dict = value
+ component_spec = self._dict_to_component_spec(name, component_spec_dict)
+ component_spec.default_creation_method = "from_pretrained"
+ self._component_specs[name] = component_spec
+
+ elif name in self._config_specs:
+ self._config_specs[name].default = value
+
+ # if modular_model_index.json is not found, try to load model_index.json
+ else:
+ logger.debug(" loading config from model_index.json")
+ try:
+ from diffusers import DiffusionPipeline
+
+ config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
+ except EnvironmentError as e:
+ logger.debug(f" model_index.json not found in the repo: {e}")
+ config_dict = None
+
+ # update component_specs and config_specs based on model_index.json
+ if config_dict is not None:
+ for name, value in config_dict.items():
+ if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2:
+ library, class_name = value
+ component_spec_dict = {
+ "repo": pretrained_model_name_or_path,
+ "subfolder": name,
+ "type_hint": (library, class_name),
+ }
+ component_spec = self._dict_to_component_spec(name, component_spec_dict)
+ component_spec.default_creation_method = "from_pretrained"
+ self._component_specs[name] = component_spec
+ elif name in self._config_specs:
+ self._config_specs[name].default = value
+
+ if len(kwargs) > 0:
+ logger.warning(f"Unexpected input '{kwargs.keys()}' provided. This input will be ignored.")
register_components_dict = {}
for name, component_spec in self._component_specs.items():
@@ -1930,111 +1596,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
params[input_param.name] = input_param.default
return params
- def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
- """
- Execute the pipeline by running the pipeline blocks with the given inputs.
-
- Args:
- state (`PipelineState`, optional):
- PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be
- created based on the user inputs and the pipeline blocks's requirement.
- output (`str` or `List[str]`, optional):
- Optional specification of what to return:
- - None: Returns the complete `PipelineState` with all inputs and intermediates (default)
- - str: Returns a specific intermediate value from the state (e.g. `output="image"`)
- - List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image",
- "latents"]`)
-
-
- Examples:
- ```python
- # Get complete pipeline state
- state = pipeline(prompt="A beautiful sunset", num_inference_steps=20)
- print(state.intermediates) # All intermediate outputs
-
- # Get specific output
- image = pipeline(prompt="A beautiful sunset", output="image")
-
- # Get multiple specific outputs
- results = pipeline(prompt="A beautiful sunset", output=["image", "latents"])
- image, latents = results["image"], results["latents"]
-
- # Continue from previous state
- state = pipeline(prompt="A beautiful sunset")
- new_state = pipeline(state=state, output="image") # Continue processing
- ```
-
- Returns:
- - If `output` is None: Complete `PipelineState` containing all inputs and intermediates
- - If `output` is str: The specific intermediate value from the state (e.g. `output="image"`)
- - If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g.
- `output=["image", "latents"]`)
- """
- if state is None:
- state = PipelineState()
-
- # Make a copy of the input kwargs
- passed_kwargs = kwargs.copy()
-
- # Add inputs to state, using defaults if not provided in the kwargs or the state
- # if same input already in the state, will override it if provided in the kwargs
-
- intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs]
- for expected_input_param in self.blocks.inputs:
- name = expected_input_param.name
- default = expected_input_param.default
- kwargs_type = expected_input_param.kwargs_type
- if name in passed_kwargs:
- if name not in intermediate_inputs:
- state.set_input(name, passed_kwargs.pop(name), kwargs_type)
- else:
- state.set_input(name, passed_kwargs[name], kwargs_type)
- elif name not in state.inputs:
- state.set_input(name, default, kwargs_type)
-
- for expected_intermediate_param in self.blocks.intermediate_inputs:
- name = expected_intermediate_param.name
- kwargs_type = expected_intermediate_param.kwargs_type
- if name in passed_kwargs:
- state.set_intermediate(name, passed_kwargs.pop(name), kwargs_type)
-
- # Warn about unexpected inputs
- if len(passed_kwargs) > 0:
- warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
- # Run the pipeline
- with torch.no_grad():
- try:
- _, state = self.blocks(self, state)
- except Exception:
- error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
- logger.error(error_msg)
- raise
-
- if output is None:
- return state
-
- elif isinstance(output, str):
- return state.get_intermediate(output)
-
- elif isinstance(output, (list, tuple)):
- return state.get_intermediates(output)
- else:
- raise ValueError(f"Output '{output}' is not a valid output type")
-
- def load_default_components(self, **kwargs):
- """
- Load from_pretrained components using the loading specs in the config dict.
-
- Args:
- **kwargs: Additional arguments passed to `from_pretrained` method, e.g. torch_dtype, cache_dir, etc.
- """
- names = [
- name
- for name in self._component_specs.keys()
- if self._component_specs[name].default_creation_method == "from_pretrained"
- ]
- self.load_components(names=names, **kwargs)
-
@classmethod
@validate_hf_hub_args
def from_pretrained(
@@ -2050,8 +1611,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`, optional):
- Path to a pretrained pipeline configuration. If provided, will load component specs (only for
- from_pretrained components) and config values from the modular_model_index.json file.
+ Path to a pretrained pipeline configuration. It will first try to load config from
+ `modular_model_index.json`, then fallback to `model_index.json` for compatibility with standard
+ non-modular repositories. If the repo does not contain any pipeline config, it will be set to None
+ during initialization.
trust_remote_code (`bool`, optional):
Whether to trust remote code when loading the pipeline, need to be set to True if you want to create
pipeline blocks based on the custom code in `pretrained_model_name_or_path`
@@ -2067,7 +1630,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
blocks = ModularPipelineBlocks.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
- except EnvironmentError:
+ except EnvironmentError as e:
+ logger.debug(f"EnvironmentError: {e}")
blocks = None
cache_dir = kwargs.pop("cache_dir", None)
@@ -2087,11 +1651,35 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
}
try:
+ # try to load modular_model_index.json
config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs)
+ except EnvironmentError as e:
+ logger.debug(f" modular_model_index.json not found in the repo: {e}")
+ config_dict = None
+
+ if config_dict is not None:
pipeline_class = _get_pipeline_class(cls, config=config_dict)
- except EnvironmentError:
- pipeline_class = cls
- pretrained_model_name_or_path = None
+ else:
+ try:
+ logger.debug(" try to load model_index.json")
+ from diffusers import DiffusionPipeline
+ from diffusers.pipelines.auto_pipeline import _get_model
+
+ config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
+ except EnvironmentError as e:
+ logger.debug(f" model_index.json not found in the repo: {e}")
+
+ if config_dict is not None:
+ logger.debug(" try to determine the modular pipeline class from model_index.json")
+ standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
+ model_name = _get_model(standard_pipeline_class.__name__)
+ pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
+ diffusers_module = importlib.import_module("diffusers")
+ pipeline_class = getattr(diffusers_module, pipeline_class_name)
+ else:
+ # there is no config for modular pipeline, assuming that the pipeline block does not need any from_pretrained components
+ pipeline_class = cls
+ pretrained_model_name_or_path = None
pipeline = pipeline_class(
blocks=blocks,
@@ -2162,8 +1750,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
- non from_pretrained components are created during __init__ and registered as the object itself
- Components are updated with the `update_components()` method: e.g. loader.update_components(unet=unet) or
loader.update_components(guider=guider_spec)
- - (from_pretrained) Components are loaded with the `load_default_components()` method: e.g.
- loader.load_default_components(names=["unet"])
+ - (from_pretrained) Components are loaded with the `load_components()` method: e.g.
+ loader.load_components(names=["unet"]) or loader.load_components() to load all default components
Args:
**kwargs: Keyword arguments where keys are component names and values are component objects.
@@ -2429,17 +2017,31 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
for name, component in passed_components.items():
current_component_spec = self._component_specs[name]
- # warn if type changed
+ # log if type changed
if current_component_spec.type_hint is not None and not isinstance(
component, current_component_spec.type_hint
):
- logger.warning(
+ logger.info(
f"ModularPipeline.update_components: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}"
)
# update _component_specs based on the new component
- new_component_spec = ComponentSpec.from_component(name, component)
- if new_component_spec.default_creation_method != current_component_spec.default_creation_method:
+ if component is None:
+ new_component_spec = current_component_spec
+ if hasattr(self, name) and getattr(self, name) is not None:
+ logger.warning(f"ModularPipeline.update_components: setting {name} to None (spec unchanged)")
+ elif current_component_spec.default_creation_method == "from_pretrained" and not (
+ hasattr(component, "_diffusers_load_id") and component._diffusers_load_id is not None
+ ):
logger.warning(
+ f"ModularPipeline.update_components: {name} has no valid _diffusers_load_id. "
+ f"This will result in empty loading spec, use ComponentSpec.load() for proper specs"
+ )
+ new_component_spec = ComponentSpec(name=name, type_hint=type(component))
+ else:
+ new_component_spec = ComponentSpec.from_component(name, component)
+
+ if new_component_spec.default_creation_method != current_component_spec.default_creation_method:
+ logger.info(
f"ModularPipeline.update_components: changing the default_creation_method of {name} from {current_component_spec.default_creation_method} to {new_component_spec.default_creation_method}."
)
@@ -2460,7 +2062,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
if current_component_spec.type_hint is not None and not isinstance(
created_components[name], current_component_spec.type_hint
):
- logger.warning(
+ logger.info(
f"ModularPipeline.update_components: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}"
)
# update _component_specs based on the user passed component_spec
@@ -2475,13 +2077,14 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
self.register_to_config(**config_to_register)
# YiYi TODO: support map for additional from_pretrained kwargs
- # YiYi/Dhruv TODO: consolidate load_components and load_default_components?
- def load_components(self, names: Union[List[str], str], **kwargs):
+ def load_components(self, names: Optional[Union[List[str], str]] = None, **kwargs):
"""
Load selected components from specs.
Args:
- names: List of component names to load; by default will not load any components
+ names: List of component names to load. If None, will load all components with
+ default_creation_method == "from_pretrained". If provided as a list or string, will load only the
+ specified components.
**kwargs: additional kwargs to be passed to `from_pretrained()`.Can be:
- a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16
- a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32}
@@ -2489,7 +2092,13 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
`variant`, `revision`, etc.
"""
- if isinstance(names, str):
+ if names is None:
+ names = [
+ name
+ for name in self._component_specs.keys()
+ if self._component_specs[name].default_creation_method == "from_pretrained"
+ ]
+ elif isinstance(names, str):
names = [names]
elif not isinstance(names, list):
raise ValueError(f"Invalid type for names: {type(names)}")
@@ -2547,12 +2156,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
arguments of `self.to(*args, **kwargs).`
-
-
- If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise,
- the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
-
-
+ > [!TIP] > If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is.
+ Otherwise, > the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
Here are the ways to call `to`:
@@ -2839,3 +2444,88 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
type_hint=type_hint,
**spec_dict,
)
+
+ def set_progress_bar_config(self, **kwargs):
+ for sub_block_name, sub_block in self.blocks.sub_blocks.items():
+ if hasattr(sub_block, "set_progress_bar_config"):
+ sub_block.set_progress_bar_config(**kwargs)
+
+ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
+ """
+ Execute the pipeline by running the pipeline blocks with the given inputs.
+
+ Args:
+ state (`PipelineState`, optional):
+ PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be
+ created based on the user inputs and the pipeline blocks's requirement.
+ output (`str` or `List[str]`, optional):
+ Optional specification of what to return:
+ - None: Returns the complete `PipelineState` with all inputs and intermediates (default)
+ - str: Returns a specific intermediate value from the state (e.g. `output="image"`)
+ - List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image",
+ "latents"]`)
+
+
+ Examples:
+ ```python
+ # Get complete pipeline state
+ state = pipeline(prompt="A beautiful sunset", num_inference_steps=20)
+ print(state.intermediates) # All intermediate outputs
+
+ # Get specific output
+ image = pipeline(prompt="A beautiful sunset", output="image")
+
+ # Get multiple specific outputs
+ results = pipeline(prompt="A beautiful sunset", output=["image", "latents"])
+ image, latents = results["image"], results["latents"]
+
+ # Continue from previous state
+ state = pipeline(prompt="A beautiful sunset")
+ new_state = pipeline(state=state, output="image") # Continue processing
+ ```
+
+ Returns:
+ - If `output` is None: Complete `PipelineState` containing all inputs and intermediates
+ - If `output` is str: The specific intermediate value from the state (e.g. `output="image"`)
+ - If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g.
+ `output=["image", "latents"]`)
+ """
+ if state is None:
+ state = PipelineState()
+
+ # Make a copy of the input kwargs
+ passed_kwargs = kwargs.copy()
+
+ # Add inputs to state, using defaults if not provided in the kwargs or the state
+ # if same input already in the state, will override it if provided in the kwargs
+ for expected_input_param in self.blocks.inputs:
+ name = expected_input_param.name
+ default = expected_input_param.default
+ kwargs_type = expected_input_param.kwargs_type
+ if name in passed_kwargs:
+ state.set(name, passed_kwargs.pop(name), kwargs_type)
+ elif name not in state.values:
+ state.set(name, default, kwargs_type)
+
+ # Warn about unexpected inputs
+ if len(passed_kwargs) > 0:
+ warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
+ # Run the pipeline
+ with torch.no_grad():
+ try:
+ _, state = self.blocks(self, state)
+ except Exception:
+ error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
+ logger.error(error_msg)
+ raise
+
+ if output is None:
+ return state
+
+ if isinstance(output, str):
+ return state.get(output)
+
+ elif isinstance(output, (list, tuple)):
+ return state.get(output)
+ else:
+ raise ValueError(f"Output '{output}' is not a valid output type")
diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py
index f2fc015e94..b151268686 100644
--- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py
+++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py
@@ -209,7 +209,7 @@ class ComponentSpec:
# Get all loading fields in order
loading_fields = cls.loading_fields()
- result = {f: None for f in loading_fields}
+ result = dict.fromkeys(loading_fields)
if load_id == "null":
return result
@@ -618,7 +618,6 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
def make_doc_string(
inputs,
- intermediate_inputs,
outputs,
description="",
class_name=None,
@@ -664,7 +663,7 @@ def make_doc_string(
output += configs_str + "\n\n"
# Add inputs section
- output += format_input_params(inputs + intermediate_inputs, indent_level=2)
+ output += format_input_params(inputs, indent_level=2)
# Add outputs section
output += "\n\n"
diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py
index fb9a03c755..f7ee1dd309 100644
--- a/src/diffusers/modular_pipelines/node_utils.py
+++ b/src/diffusers/modular_pipelines/node_utils.py
@@ -351,11 +351,7 @@ class ModularNode(ConfigMixin):
A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper
around a ModularPipelineBlocks object.
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
config_name = "node_config.json"
@@ -384,14 +380,14 @@ class ModularNode(ConfigMixin):
# pass or create a default param dict for each input
# e.g. for prompt,
# prompt = {
- # "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers
+ # "name": "text_input", # the name of the input in node definition, could be different from the input name in diffusers
# "label": "Prompt",
# "type": "string",
# "default": "a bear sitting in a chair drinking a milkshake",
# "display": "textarea"}
# if type is not specified, it'll be a "custom" param of its own type
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
- # it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
+ # it will get this spec in node definition {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
inputs = self.blocks.inputs + self.blocks.intermediate_inputs
for inp in inputs:
diff --git a/src/diffusers/modular_pipelines/qwenimage/__init__.py b/src/diffusers/modular_pipelines/qwenimage/__init__.py
new file mode 100644
index 0000000000..ae4ec4799f
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/__init__.py
@@ -0,0 +1,89 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["encoders"] = ["QwenImageTextEncoderStep"]
+ _import_structure["modular_blocks"] = [
+ "ALL_BLOCKS",
+ "AUTO_BLOCKS",
+ "CONTROLNET_BLOCKS",
+ "EDIT_AUTO_BLOCKS",
+ "EDIT_BLOCKS",
+ "EDIT_INPAINT_BLOCKS",
+ "EDIT_PLUS_AUTO_BLOCKS",
+ "EDIT_PLUS_BLOCKS",
+ "IMAGE2IMAGE_BLOCKS",
+ "INPAINT_BLOCKS",
+ "TEXT2IMAGE_BLOCKS",
+ "QwenImageAutoBlocks",
+ "QwenImageEditAutoBlocks",
+ "QwenImageEditPlusAutoBlocks",
+ ]
+ _import_structure["modular_pipeline"] = [
+ "QwenImageEditModularPipeline",
+ "QwenImageEditPlusModularPipeline",
+ "QwenImageModularPipeline",
+ ]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .encoders import (
+ QwenImageTextEncoderStep,
+ )
+ from .modular_blocks import (
+ ALL_BLOCKS,
+ AUTO_BLOCKS,
+ CONTROLNET_BLOCKS,
+ EDIT_AUTO_BLOCKS,
+ EDIT_BLOCKS,
+ EDIT_INPAINT_BLOCKS,
+ EDIT_PLUS_AUTO_BLOCKS,
+ EDIT_PLUS_BLOCKS,
+ IMAGE2IMAGE_BLOCKS,
+ INPAINT_BLOCKS,
+ TEXT2IMAGE_BLOCKS,
+ QwenImageAutoBlocks,
+ QwenImageEditAutoBlocks,
+ QwenImageEditPlusAutoBlocks,
+ )
+ from .modular_pipeline import (
+ QwenImageEditModularPipeline,
+ QwenImageEditPlusModularPipeline,
+ QwenImageModularPipeline,
+ )
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py
new file mode 100644
index 0000000000..fdec95dc50
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py
@@ -0,0 +1,725 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ...models import QwenImageControlNetModel, QwenImageMultiControlNetModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils.torch_utils import randn_tensor, unwrap_module
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
+
+
+# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# modified from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+def get_timesteps(scheduler, num_inference_steps, strength):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = scheduler.timesteps[t_start * scheduler.order :]
+ if hasattr(scheduler, "set_begin_index"):
+ scheduler.set_begin_index(t_start * scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+
+# Prepare Latents steps
+
+
+class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Prepare initial random noise for the generation process"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="height"),
+ InputParam(name="width"),
+ InputParam(name="num_images_per_prompt", default=1),
+ InputParam(name="generator"),
+ InputParam(
+ name="batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
+ ),
+ InputParam(
+ name="dtype",
+ required=True,
+ type_hint=torch.dtype,
+ description="The dtype of the model inputs, can be generated in input step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="latents",
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(height, width, vae_scale_factor):
+ if height is not None and height % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
+
+ if width is not None and width % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(
+ height=block_state.height,
+ width=block_state.width,
+ vae_scale_factor=components.vae_scale_factor,
+ )
+
+ device = components._execution_device
+ batch_size = block_state.batch_size * block_state.num_images_per_prompt
+
+ # we can update the height and width here since it's used to generate the initial
+ block_state.height = block_state.height or components.default_height
+ block_state.width = block_state.width or components.default_width
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ latent_height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
+ latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
+
+ shape = (batch_size, components.num_channels_latents, 1, latent_height, latent_width)
+ if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ block_state.latents = randn_tensor(
+ shape, generator=block_state.generator, device=device, dtype=block_state.dtype
+ )
+ block_state.latents = components.pachifier.pack_latents(block_state.latents)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, prepare_latents. Both noise and image latents should alreadybe patchified."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ name="latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial random noised, can be generated in prepare latent step.",
+ ),
+ InputParam(
+ name="image_latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.",
+ ),
+ InputParam(
+ name="timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="initial_noise",
+ type_hint=torch.Tensor,
+ description="The initial random noised used for inpainting denoising.",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(image_latents, latents):
+ if image_latents.shape[0] != latents.shape[0]:
+ raise ValueError(
+ f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}"
+ )
+
+ if image_latents.ndim != 3:
+ raise ValueError(f"`image_latents` must have 3 dimensions (patchified), but got {image_latents.ndim}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(
+ image_latents=block_state.image_latents,
+ latents=block_state.latents,
+ )
+
+ # prepare latent timestep
+ latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0])
+
+ # make copy of initial_noise
+ block_state.initial_noise = block_state.latents
+
+ # scale noise
+ block_state.latents = components.scheduler.scale_noise(
+ block_state.image_latents, latent_timestep, block_state.latents
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Step that creates mask latents from preprocessed mask_image by interpolating to latent space."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ name="processed_mask_image",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The processed mask to use for the inpainting process.",
+ ),
+ InputParam(name="height", required=True),
+ InputParam(name="width", required=True),
+ InputParam(name="dtype", required=True),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process."
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ device = components._execution_device
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+
+ height_latents = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
+ width_latents = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
+
+ block_state.mask = torch.nn.functional.interpolate(
+ block_state.processed_mask_image,
+ size=(height_latents, width_latents),
+ )
+
+ block_state.mask = block_state.mask.unsqueeze(2)
+ block_state.mask = block_state.mask.repeat(1, components.num_channels_latents, 1, 1, 1)
+ block_state.mask = block_state.mask.to(device=device, dtype=block_state.dtype)
+
+ block_state.mask = components.pachifier.pack_latents(block_state.mask)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+# Set Timesteps steps
+
+
+class QwenImageSetTimestepsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Step that sets the the scheduler's timesteps for text-to-image generation. Should be run after prepare latents step."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="num_inference_steps", default=50),
+ InputParam(name="sigmas"),
+ InputParam(
+ name="latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The latents to use for the denoising process, used to calculate the image sequence length.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process"
+ ),
+ ]
+
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ device = components._execution_device
+ sigmas = (
+ np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
+ if block_state.sigmas is None
+ else block_state.sigmas
+ )
+
+ mu = calculate_shift(
+ image_seq_len=block_state.latents.shape[1],
+ base_seq_len=components.scheduler.config.get("base_image_seq_len", 256),
+ max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096),
+ base_shift=components.scheduler.config.get("base_shift", 0.5),
+ max_shift=components.scheduler.config.get("max_shift", 1.15),
+ )
+ block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
+ scheduler=components.scheduler,
+ num_inference_steps=block_state.num_inference_steps,
+ device=device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+
+ components.scheduler.set_begin_index(0)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Step that sets the the scheduler's timesteps for image-to-image generation, and inpainting. Should be run after prepare latents step."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="num_inference_steps", default=50),
+ InputParam(name="sigmas"),
+ InputParam(
+ name="latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The latents to use for the denoising process, used to calculate the image sequence length.",
+ ),
+ InputParam(name="strength", default=0.9),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="timesteps",
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ device = components._execution_device
+ sigmas = (
+ np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
+ if block_state.sigmas is None
+ else block_state.sigmas
+ )
+
+ mu = calculate_shift(
+ image_seq_len=block_state.latents.shape[1],
+ base_seq_len=components.scheduler.config.get("base_image_seq_len", 256),
+ max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096),
+ base_shift=components.scheduler.config.get("base_shift", 0.5),
+ max_shift=components.scheduler.config.get("max_shift", 1.15),
+ )
+ block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
+ scheduler=components.scheduler,
+ num_inference_steps=block_state.num_inference_steps,
+ device=device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+
+ block_state.timesteps, block_state.num_inference_steps = get_timesteps(
+ scheduler=components.scheduler,
+ num_inference_steps=block_state.num_inference_steps,
+ strength=block_state.strength,
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+# other inputs for denoiser
+
+## RoPE inputs for denoiser
+
+
+class QwenImageRoPEInputsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="batch_size", required=True),
+ InputParam(name="height", required=True),
+ InputParam(name="width", required=True),
+ InputParam(name="prompt_embeds_mask"),
+ InputParam(name="negative_prompt_embeds_mask"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="img_shapes",
+ type_hint=List[List[Tuple[int, int, int]]],
+ description="The shapes of the images latents, used for RoPE calculation",
+ ),
+ OutputParam(
+ name="txt_seq_lens",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the prompt embeds, used for RoPE calculation",
+ ),
+ OutputParam(
+ name="negative_txt_seq_lens",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
+ ),
+ ]
+
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.img_shapes = [
+ [
+ (
+ 1,
+ block_state.height // components.vae_scale_factor // 2,
+ block_state.width // components.vae_scale_factor // 2,
+ )
+ ]
+ * block_state.batch_size
+ ]
+ block_state.txt_seq_lens = (
+ block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
+ )
+ block_state.negative_txt_seq_lens = (
+ block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
+ if block_state.negative_prompt_embeds_mask is not None
+ else None
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after prepare_latents step"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="batch_size", required=True),
+ InputParam(name="image_height", required=True),
+ InputParam(name="image_width", required=True),
+ InputParam(name="height", required=True),
+ InputParam(name="width", required=True),
+ InputParam(name="prompt_embeds_mask"),
+ InputParam(name="negative_prompt_embeds_mask"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="img_shapes",
+ type_hint=List[List[Tuple[int, int, int]]],
+ description="The shapes of the images latents, used for RoPE calculation",
+ ),
+ OutputParam(
+ name="txt_seq_lens",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the prompt embeds, used for RoPE calculation",
+ ),
+ OutputParam(
+ name="negative_txt_seq_lens",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
+ ),
+ ]
+
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # for edit, image size can be different from the target size (height/width)
+
+ block_state.img_shapes = [
+ [
+ (
+ 1,
+ block_state.height // components.vae_scale_factor // 2,
+ block_state.width // components.vae_scale_factor // 2,
+ ),
+ (
+ 1,
+ block_state.image_height // components.vae_scale_factor // 2,
+ block_state.image_width // components.vae_scale_factor // 2,
+ ),
+ ]
+ ] * block_state.batch_size
+
+ block_state.txt_seq_lens = (
+ block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
+ )
+ block_state.negative_txt_seq_lens = (
+ block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
+ if block_state.negative_prompt_embeds_mask is not None
+ else None
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+## ControlNet inputs for denoiser
+class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("controlnet", QwenImageControlNetModel),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "step that prepare inputs for controlnet. Insert before the Denoise Step, after set_timesteps step."
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("control_guidance_start", default=0.0),
+ InputParam("control_guidance_end", default=1.0),
+ InputParam("controlnet_conditioning_scale", default=1.0),
+ InputParam("control_image_latents", required=True),
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ controlnet = unwrap_module(components.controlnet)
+
+ # control_guidance_start/control_guidance_end (align format)
+ if not isinstance(block_state.control_guidance_start, list) and isinstance(
+ block_state.control_guidance_end, list
+ ):
+ block_state.control_guidance_start = len(block_state.control_guidance_end) * [
+ block_state.control_guidance_start
+ ]
+ elif not isinstance(block_state.control_guidance_end, list) and isinstance(
+ block_state.control_guidance_start, list
+ ):
+ block_state.control_guidance_end = len(block_state.control_guidance_start) * [
+ block_state.control_guidance_end
+ ]
+ elif not isinstance(block_state.control_guidance_start, list) and not isinstance(
+ block_state.control_guidance_end, list
+ ):
+ mult = (
+ len(block_state.control_image_latents) if isinstance(controlnet, QwenImageMultiControlNetModel) else 1
+ )
+ block_state.control_guidance_start, block_state.control_guidance_end = (
+ mult * [block_state.control_guidance_start],
+ mult * [block_state.control_guidance_end],
+ )
+
+ # controlnet_conditioning_scale (align format)
+ if isinstance(controlnet, QwenImageMultiControlNetModel) and isinstance(
+ block_state.controlnet_conditioning_scale, float
+ ):
+ block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * mult
+
+ # controlnet_keep
+ block_state.controlnet_keep = []
+ for i in range(len(block_state.timesteps)):
+ keeps = [
+ 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e)
+ for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end)
+ ]
+ block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, QwenImageControlNetModel) else keeps)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/qwenimage/decoders.py b/src/diffusers/modular_pipelines/qwenimage/decoders.py
new file mode 100644
index 0000000000..6c82fe989e
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/decoders.py
@@ -0,0 +1,203 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Union
+
+import numpy as np
+import PIL
+import torch
+
+from ...configuration_utils import FrozenDict
+from ...image_processor import InpaintProcessor, VaeImageProcessor
+from ...models import AutoencoderKLQwenImage
+from ...utils import logging
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
+
+
+logger = logging.get_logger(__name__)
+
+
+class QwenImageDecoderStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Step that decodes the latents to images"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ components = [
+ ComponentSpec("vae", AutoencoderKLQwenImage),
+ ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
+ ]
+
+ return components
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="height", required=True),
+ InputParam(name="width", required=True),
+ InputParam(
+ name="latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The latents to decode, can be generated in the denoise step",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "images",
+ type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
+ description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array",
+ )
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular
+ block_state.latents = components.pachifier.unpack_latents(
+ block_state.latents, block_state.height, block_state.width
+ )
+ block_state.latents = block_state.latents.to(components.vae.dtype)
+
+ latents_mean = (
+ torch.tensor(components.vae.config.latents_mean)
+ .view(1, components.vae.config.z_dim, 1, 1, 1)
+ .to(block_state.latents.device, block_state.latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
+ 1, components.vae.config.z_dim, 1, 1, 1
+ ).to(block_state.latents.device, block_state.latents.dtype)
+ block_state.latents = block_state.latents / latents_std + latents_mean
+ block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0][:, :, 0]
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "postprocess the generated image"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("images", required=True, description="the generated image from decoders step"),
+ InputParam(
+ name="output_type",
+ default="pil",
+ type_hint=str,
+ description="The type of the output images, can be 'pil', 'np', 'pt'",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(output_type):
+ if output_type not in ["pil", "np", "pt"]:
+ raise ValueError(f"Invalid output_type: {output_type}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(block_state.output_type)
+
+ block_state.images = components.image_processor.postprocess(
+ image=block_state.images,
+ output_type=block_state.output_type,
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "postprocess the generated image, optional apply the mask overally to the original image.."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_mask_processor",
+ InpaintProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("images", required=True, description="the generated image from decoders step"),
+ InputParam(
+ name="output_type",
+ default="pil",
+ type_hint=str,
+ description="The type of the output images, can be 'pil', 'np', 'pt'",
+ ),
+ InputParam("mask_overlay_kwargs"),
+ ]
+
+ @staticmethod
+ def check_inputs(output_type, mask_overlay_kwargs):
+ if output_type not in ["pil", "np", "pt"]:
+ raise ValueError(f"Invalid output_type: {output_type}")
+
+ if mask_overlay_kwargs and output_type != "pil":
+ raise ValueError("only support output_type 'pil' for mask overlay")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(block_state.output_type, block_state.mask_overlay_kwargs)
+
+ if block_state.mask_overlay_kwargs is None:
+ mask_overlay_kwargs = {}
+ else:
+ mask_overlay_kwargs = block_state.mask_overlay_kwargs
+
+ block_state.images = components.image_mask_processor.postprocess(
+ image=block_state.images,
+ **mask_overlay_kwargs,
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py
new file mode 100644
index 0000000000..d0704ee6e0
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py
@@ -0,0 +1,668 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Tuple
+
+import torch
+
+from ...configuration_utils import FrozenDict
+from ...guiders import ClassifierFreeGuidance
+from ...models import QwenImageControlNetModel, QwenImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import logging
+from ..modular_pipeline import BlockState, LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import QwenImageModularPipeline
+
+
+logger = logging.get_logger(__name__)
+
+
+class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that prepares the latent input for the denoiser. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ # one timestep
+ block_state.timestep = t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype)
+ block_state.latent_model_input = block_state.latents
+ return components, block_state
+
+
+class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that prepares the latent input for the denoiser. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "image_latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial image latents to use for the denoising process. Can be encoded in vae_encoder step and packed in prepare_image_latents step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ # one timestep
+
+ block_state.latent_model_input = torch.cat([block_state.latents, block_state.image_latents], dim=1)
+ block_state.timestep = t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype)
+ return components, block_state
+
+
+class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 4.0}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("controlnet", QwenImageControlNetModel),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that runs the controlnet before the denoiser. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "control_image_latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
+ ),
+ InputParam(
+ "controlnet_conditioning_scale",
+ type_hint=float,
+ description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
+ ),
+ InputParam(
+ "controlnet_keep",
+ required=True,
+ type_hint=List[float],
+ description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ kwargs_type="denoiser_input_fields",
+ description=(
+ "All conditional model inputs for the denoiser. "
+ "It should contain prompt_embeds/negative_prompt_embeds, txt_seq_lens/negative_txt_seq_lens."
+ ),
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: int):
+ # cond_scale for the timestep (controlnet input)
+ if isinstance(block_state.controlnet_keep[i], list):
+ block_state.cond_scale = [
+ c * s for c, s in zip(block_state.controlnet_conditioning_scale, block_state.controlnet_keep[i])
+ ]
+ else:
+ controlnet_cond_scale = block_state.controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i]
+
+ # run controlnet for the guidance batch
+ controlnet_block_samples = components.controlnet(
+ hidden_states=block_state.latent_model_input,
+ controlnet_cond=block_state.control_image_latents,
+ conditioning_scale=block_state.cond_scale,
+ timestep=block_state.timestep / 1000,
+ img_shapes=block_state.img_shapes,
+ encoder_hidden_states=block_state.prompt_embeds,
+ encoder_hidden_states_mask=block_state.prompt_embeds_mask,
+ txt_seq_lens=block_state.txt_seq_lens,
+ return_dict=False,
+ )
+
+ block_state.additional_cond_kwargs["controlnet_block_samples"] = controlnet_block_samples
+
+ return components, block_state
+
+
+class QwenImageLoopDenoiser(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that denoise the latent input for the denoiser. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 4.0}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("transformer", QwenImageTransformer2DModel),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("attention_kwargs"),
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The latents to use for the denoising process. Can be generated in prepare_latents step.",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ kwargs_type="denoiser_input_fields",
+ description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
+ ),
+ InputParam(
+ "img_shapes",
+ required=True,
+ type_hint=List[Tuple[int, int]],
+ description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ guider_input_fields = {
+ "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
+ "encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
+ "txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
+ }
+
+ components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
+ guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
+
+ for guider_state_batch in guider_state:
+ components.guider.prepare_models(components.transformer)
+ cond_kwargs = guider_state_batch.as_dict()
+ cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
+
+ # YiYi TODO: add cache context
+ guider_state_batch.noise_pred = components.transformer(
+ hidden_states=block_state.latent_model_input,
+ timestep=block_state.timestep / 1000,
+ img_shapes=block_state.img_shapes,
+ attention_kwargs=block_state.attention_kwargs,
+ return_dict=False,
+ **cond_kwargs,
+ **block_state.additional_cond_kwargs,
+ )[0]
+
+ components.guider.cleanup_models(components.transformer)
+
+ guider_output = components.guider(guider_state)
+
+ # apply guidance rescale
+ pred_cond_norm = torch.norm(guider_output.pred_cond, dim=-1, keepdim=True)
+ pred_norm = torch.norm(guider_output.pred, dim=-1, keepdim=True)
+ block_state.noise_pred = guider_output.pred * (pred_cond_norm / pred_norm)
+
+ return components, block_state
+
+
+class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that denoise the latent input for the denoiser. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 4.0}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("transformer", QwenImageTransformer2DModel),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("attention_kwargs"),
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The latents to use for the denoising process. Can be generated in prepare_latents step.",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ kwargs_type="denoiser_input_fields",
+ description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
+ ),
+ InputParam(
+ "img_shapes",
+ required=True,
+ type_hint=List[Tuple[int, int]],
+ description="The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ guider_input_fields = {
+ "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
+ "encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
+ "txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
+ }
+
+ components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
+ guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
+
+ for guider_state_batch in guider_state:
+ components.guider.prepare_models(components.transformer)
+ cond_kwargs = guider_state_batch.as_dict()
+ cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
+
+ # YiYi TODO: add cache context
+ guider_state_batch.noise_pred = components.transformer(
+ hidden_states=block_state.latent_model_input,
+ timestep=block_state.timestep / 1000,
+ img_shapes=block_state.img_shapes,
+ attention_kwargs=block_state.attention_kwargs,
+ return_dict=False,
+ **cond_kwargs,
+ **block_state.additional_cond_kwargs,
+ )[0]
+
+ components.guider.cleanup_models(components.transformer)
+
+ guider_output = components.guider(guider_state)
+
+ pred = guider_output.pred[:, : block_state.latents.size(1)]
+ pred_cond = guider_output.pred_cond[:, : block_state.latents.size(1)]
+
+ # apply guidance rescale
+ pred_cond_norm = torch.norm(pred_cond, dim=-1, keepdim=True)
+ pred_norm = torch.norm(pred, dim=-1, keepdim=True)
+ block_state.noise_pred = pred * (pred_cond_norm / pred_norm)
+
+ return components, block_state
+
+
+class QwenImageLoopAfterDenoiser(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that updates the latents. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents."),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ latents_dtype = block_state.latents.dtype
+ block_state.latents = components.scheduler.step(
+ block_state.noise_pred,
+ t,
+ block_state.latents,
+ return_dict=False,
+ )[0]
+
+ if block_state.latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ block_state.latents = block_state.latents.to(latents_dtype)
+
+ return components, block_state
+
+
+class QwenImageLoopAfterDenoiserInpaint(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that updates the latents using mask and image_latents for inpainting. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `QwenImageDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "mask",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.",
+ ),
+ InputParam(
+ "image_latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The image latents to use for the inpainting process. Can be generated in inpaint prepare latents step.",
+ ),
+ InputParam(
+ "initial_noise",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial noise to use for the inpainting process. Can be generated in inpaint prepare latents step.",
+ ),
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ block_state.init_latents_proper = block_state.image_latents
+ if i < len(block_state.timesteps) - 1:
+ block_state.noise_timestep = block_state.timesteps[i + 1]
+ block_state.init_latents_proper = components.scheduler.scale_noise(
+ block_state.init_latents_proper, torch.tensor([block_state.noise_timestep]), block_state.initial_noise
+ )
+
+ block_state.latents = (
+ 1 - block_state.mask
+ ) * block_state.init_latents_proper + block_state.mask * block_state.latents
+
+ return components, block_state
+
+
+class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Pipeline block that iteratively denoise the latents over `timesteps`. "
+ "The specific steps with each iteration can be customized with `sub_blocks` attributes"
+ )
+
+ @property
+ def loop_expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
+ ]
+
+ @property
+ def loop_inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.num_warmup_steps = max(
+ len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
+ )
+
+ block_state.additional_cond_kwargs = {}
+
+ with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
+ for i, t in enumerate(block_state.timesteps):
+ components, block_state = self.loop_step(components, block_state, i=i, t=t)
+ if i == len(block_state.timesteps) - 1 or (
+ (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
+ ):
+ progress_bar.update()
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+# composing the denoising loops
+class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
+ block_classes = [
+ QwenImageLoopBeforeDenoiser,
+ QwenImageLoopDenoiser,
+ QwenImageLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `QwenImageLoopBeforeDenoiser`\n"
+ " - `QwenImageLoopDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiser`\n"
+ "This block supports text2image and image2image tasks for QwenImage."
+ )
+
+
+# composing the inpainting denoising loops
+class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
+ block_classes = [
+ QwenImageLoopBeforeDenoiser,
+ QwenImageLoopDenoiser,
+ QwenImageLoopAfterDenoiser,
+ QwenImageLoopAfterDenoiserInpaint,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser", "after_denoiser_inpaint"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `QwenImageLoopBeforeDenoiser`\n"
+ " - `QwenImageLoopDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiserInpaint`\n"
+ "This block supports inpainting tasks for QwenImage."
+ )
+
+
+# composing the controlnet denoising loops
+class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
+ block_classes = [
+ QwenImageLoopBeforeDenoiser,
+ QwenImageLoopBeforeDenoiserControlNet,
+ QwenImageLoopDenoiser,
+ QwenImageLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "before_denoiser_controlnet", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `QwenImageLoopBeforeDenoiser`\n"
+ " - `QwenImageLoopBeforeDenoiserControlNet`\n"
+ " - `QwenImageLoopDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiser`\n"
+ "This block supports text2img/img2img tasks with controlnet for QwenImage."
+ )
+
+
+# composing the controlnet denoising loops
+class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
+ block_classes = [
+ QwenImageLoopBeforeDenoiser,
+ QwenImageLoopBeforeDenoiserControlNet,
+ QwenImageLoopDenoiser,
+ QwenImageLoopAfterDenoiser,
+ QwenImageLoopAfterDenoiserInpaint,
+ ]
+ block_names = [
+ "before_denoiser",
+ "before_denoiser_controlnet",
+ "denoiser",
+ "after_denoiser",
+ "after_denoiser_inpaint",
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `QwenImageLoopBeforeDenoiser`\n"
+ " - `QwenImageLoopBeforeDenoiserControlNet`\n"
+ " - `QwenImageLoopDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiserInpaint`\n"
+ "This block supports inpainting tasks with controlnet for QwenImage."
+ )
+
+
+# composing the denoising loops
+class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
+ block_classes = [
+ QwenImageEditLoopBeforeDenoiser,
+ QwenImageEditLoopDenoiser,
+ QwenImageLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `QwenImageEditLoopBeforeDenoiser`\n"
+ " - `QwenImageEditLoopDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiser`\n"
+ "This block supports QwenImage Edit."
+ )
+
+
+class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
+ block_classes = [
+ QwenImageEditLoopBeforeDenoiser,
+ QwenImageEditLoopDenoiser,
+ QwenImageLoopAfterDenoiser,
+ QwenImageLoopAfterDenoiserInpaint,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser", "after_denoiser_inpaint"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `QwenImageEditLoopBeforeDenoiser`\n"
+ " - `QwenImageEditLoopDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiser`\n"
+ " - `QwenImageLoopAfterDenoiserInpaint`\n"
+ "This block supports inpainting tasks for QwenImage Edit."
+ )
diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py
new file mode 100644
index 0000000000..04fb3fdc94
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py
@@ -0,0 +1,1079 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List, Optional, Union
+
+import PIL
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
+
+from ...configuration_utils import FrozenDict
+from ...guiders import ClassifierFreeGuidance
+from ...image_processor import InpaintProcessor, VaeImageProcessor, is_valid_image, is_valid_image_imagelist
+from ...models import AutoencoderKLQwenImage, QwenImageControlNetModel, QwenImageMultiControlNetModel
+from ...pipelines.qwenimage.pipeline_qwenimage_edit import calculate_dimensions
+from ...utils import logging
+from ...utils.torch_utils import unwrap_module
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
+from .modular_pipeline import QwenImageModularPipeline
+
+
+logger = logging.get_logger(__name__)
+
+
+def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+ return split_result
+
+
+def get_qwen_prompt_embeds(
+ text_encoder,
+ tokenizer,
+ prompt: Union[str, List[str]] = None,
+ prompt_template_encode: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ prompt_template_encode_start_idx: int = 34,
+ tokenizer_max_length: int = 1024,
+ device: Optional[torch.device] = None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ template = prompt_template_encode
+ drop_idx = prompt_template_encode_start_idx
+ txt = [template.format(e) for e in prompt]
+ txt_tokens = tokenizer(
+ txt, max_length=tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
+ ).to(device)
+ encoder_hidden_states = text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask,
+ output_hidden_states=True,
+ )
+ hidden_states = encoder_hidden_states.hidden_states[-1]
+
+ split_hidden_states = _extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+
+def get_qwen_prompt_embeds_edit(
+ text_encoder,
+ processor,
+ prompt: Union[str, List[str]] = None,
+ image: Optional[torch.Tensor] = None,
+ prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
+ prompt_template_encode_start_idx: int = 64,
+ device: Optional[torch.device] = None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ template = prompt_template_encode
+ drop_idx = prompt_template_encode_start_idx
+ txt = [template.format(e) for e in prompt]
+
+ model_inputs = processor(
+ text=txt,
+ images=image,
+ padding=True,
+ return_tensors="pt",
+ ).to(device)
+
+ outputs = text_encoder(
+ input_ids=model_inputs.input_ids,
+ attention_mask=model_inputs.attention_mask,
+ pixel_values=model_inputs.pixel_values,
+ image_grid_thw=model_inputs.image_grid_thw,
+ output_hidden_states=True,
+ )
+
+ hidden_states = outputs.hidden_states[-1]
+ split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+
+def get_qwen_prompt_embeds_edit_plus(
+ text_encoder,
+ processor,
+ prompt: Union[str, List[str]] = None,
+ image: Optional[Union[torch.Tensor, List[PIL.Image.Image], PIL.Image.Image]] = None,
+ prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ img_template_encode: str = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>",
+ prompt_template_encode_start_idx: int = 64,
+ device: Optional[torch.device] = None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if isinstance(image, list):
+ base_img_prompt = ""
+ for i, img in enumerate(image):
+ base_img_prompt += img_template_encode.format(i + 1)
+ elif image is not None:
+ base_img_prompt = img_template_encode.format(1)
+ else:
+ base_img_prompt = ""
+
+ template = prompt_template_encode
+
+ drop_idx = prompt_template_encode_start_idx
+ txt = [template.format(base_img_prompt + e) for e in prompt]
+
+ model_inputs = processor(
+ text=txt,
+ images=image,
+ padding=True,
+ return_tensors="pt",
+ ).to(device)
+ outputs = text_encoder(
+ input_ids=model_inputs.input_ids,
+ attention_mask=model_inputs.attention_mask,
+ pixel_values=model_inputs.pixel_values,
+ image_grid_thw=model_inputs.image_grid_thw,
+ output_hidden_states=True,
+ )
+
+ hidden_states = outputs.hidden_states[-1]
+ split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(device=device)
+ return prompt_embeds, encoder_attention_mask
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Modified from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._encode_vae_image
+def encode_vae_image(
+ image: torch.Tensor,
+ vae: AutoencoderKLQwenImage,
+ generator: torch.Generator,
+ device: torch.device,
+ dtype: torch.dtype,
+ latent_channels: int = 16,
+ sample_mode: str = "argmax",
+):
+ if not isinstance(image, torch.Tensor):
+ raise ValueError(f"Expected image to be a tensor, got {type(image)}.")
+
+ # preprocessed image should be a 4D tensor: batch_size, num_channels, height, width
+ if image.dim() == 4:
+ image = image.unsqueeze(2)
+ elif image.dim() != 5:
+ raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.")
+
+ image = image.to(device=device, dtype=dtype)
+
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode)
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode)
+ latents_mean = (
+ torch.tensor(vae.config.latents_mean)
+ .view(1, latent_channels, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(vae.config.latents_std)
+ .view(1, latent_channels, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ image_latents = (image_latents - latents_mean) / latents_std
+
+ return image_latents
+
+
+class QwenImageEditResizeDynamicStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ def __init__(self, input_name: str = "image", output_name: str = "resized_image"):
+ """Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
+
+ This block resizes an input image tensor and exposes the resized result under configurable input and output
+ names. Use this when you need to wire the resize step to different image fields (e.g., "image",
+ "control_image")
+
+ Args:
+ input_name (str, optional): Name of the image field to read from the
+ pipeline state. Defaults to "image".
+ output_name (str, optional): Name of the resized image field to write
+ back to the pipeline state. Defaults to "resized_image".
+ """
+ if not isinstance(input_name, str) or not isinstance(output_name, str):
+ raise ValueError(
+ f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}"
+ )
+ self._image_input_name = input_name
+ self._resized_image_output_name = output_name
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ return f"Image Resize step that resize the {self._image_input_name} to the target area (1024 * 1024) while maintaining the aspect ratio."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_resize_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ name=self._image_input_name, required=True, type_hint=torch.Tensor, description="The image to resize"
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name=self._resized_image_output_name, type_hint=List[PIL.Image.Image], description="The resized images"
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ images = getattr(block_state, self._image_input_name)
+
+ if not is_valid_image_imagelist(images):
+ raise ValueError(f"Images must be image or list of images but are {type(images)}")
+
+ if is_valid_image(images):
+ images = [images]
+
+ image_width, image_height = images[0].size
+ calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_width / image_height)
+
+ resized_images = [
+ components.image_resize_processor.resize(image, height=calculated_height, width=calculated_width)
+ for image in images
+ ]
+
+ setattr(block_state, self._resized_image_output_name, resized_images)
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageEditPlusResizeDynamicStep(QwenImageEditResizeDynamicStep):
+ model_name = "qwenimage"
+
+ def __init__(
+ self,
+ input_name: str = "image",
+ output_name: str = "resized_image",
+ vae_image_output_name: str = "vae_image",
+ ):
+ """Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
+
+ This block resizes an input image or a list input images and exposes the resized result under configurable
+ input and output names. Use this when you need to wire the resize step to different image fields (e.g.,
+ "image", "control_image")
+
+ Args:
+ input_name (str, optional): Name of the image field to read from the
+ pipeline state. Defaults to "image".
+ output_name (str, optional): Name of the resized image field to write
+ back to the pipeline state. Defaults to "resized_image".
+ vae_image_output_name (str, optional): Name of the image field
+ to write back to the pipeline state. This is used by the VAE encoder step later on. QwenImage Edit Plus
+ processes the input image(s) differently for the VL and the VAE.
+ """
+ if not isinstance(input_name, str) or not isinstance(output_name, str):
+ raise ValueError(
+ f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}"
+ )
+ self.condition_image_size = 384 * 384
+ self._image_input_name = input_name
+ self._resized_image_output_name = output_name
+ self._vae_image_output_name = vae_image_output_name
+ super().__init__()
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return super().intermediate_outputs + [
+ OutputParam(
+ name=self._vae_image_output_name,
+ type_hint=List[PIL.Image.Image],
+ description="The images to be processed which will be further used by the VAE encoder.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ images = getattr(block_state, self._image_input_name)
+
+ if not is_valid_image_imagelist(images):
+ raise ValueError(f"Images must be image or list of images but are {type(images)}")
+
+ if (
+ not isinstance(images, torch.Tensor)
+ and isinstance(images, PIL.Image.Image)
+ and not isinstance(images, list)
+ ):
+ images = [images]
+
+ # TODO (sayakpaul): revisit this when the inputs are `torch.Tensor`s
+ condition_images = []
+ vae_images = []
+ for img in images:
+ image_width, image_height = img.size
+ condition_width, condition_height, _ = calculate_dimensions(
+ self.condition_image_size, image_width / image_height
+ )
+ condition_images.append(components.image_resize_processor.resize(img, condition_height, condition_width))
+ vae_images.append(img)
+
+ setattr(block_state, self._resized_image_output_name, condition_images)
+ setattr(block_state, self._vae_image_output_name, vae_images)
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageTextEncoderStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Text Encoder step that generate text_embeddings to guide the image generation"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration, description="The text encoder to use"),
+ ComponentSpec("tokenizer", Qwen2Tokenizer, description="The tokenizer to use"),
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 4.0}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def expected_configs(self) -> List[ConfigSpec]:
+ return [
+ ConfigSpec(
+ name="prompt_template_encode",
+ default="<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ ),
+ ConfigSpec(name="prompt_template_encode_start_idx", default=34),
+ ConfigSpec(name="tokenizer_max_length", default=1024),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
+ InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
+ InputParam(
+ name="max_sequence_length", type_hint=int, description="The max sequence length to use", default=1024
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="prompt_embeds",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The prompt embeddings",
+ ),
+ OutputParam(
+ name="prompt_embeds_mask",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The encoder attention mask",
+ ),
+ OutputParam(
+ name="negative_prompt_embeds",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The negative prompt embeddings",
+ ),
+ OutputParam(
+ name="negative_prompt_embeds_mask",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The negative prompt embeddings mask",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(prompt, negative_prompt, max_sequence_length):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if (
+ negative_prompt is not None
+ and not isinstance(negative_prompt, str)
+ and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ device = components._execution_device
+ self.check_inputs(block_state.prompt, block_state.negative_prompt, block_state.max_sequence_length)
+
+ block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds(
+ components.text_encoder,
+ components.tokenizer,
+ prompt=block_state.prompt,
+ prompt_template_encode=components.config.prompt_template_encode,
+ prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
+ tokenizer_max_length=components.config.tokenizer_max_length,
+ device=device,
+ )
+
+ block_state.prompt_embeds = block_state.prompt_embeds[:, : block_state.max_sequence_length]
+ block_state.prompt_embeds_mask = block_state.prompt_embeds_mask[:, : block_state.max_sequence_length]
+
+ if components.requires_unconditional_embeds:
+ negative_prompt = block_state.negative_prompt or ""
+ block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds(
+ components.text_encoder,
+ components.tokenizer,
+ prompt=negative_prompt,
+ prompt_template_encode=components.config.prompt_template_encode,
+ prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
+ tokenizer_max_length=components.config.tokenizer_max_length,
+ device=device,
+ )
+ block_state.negative_prompt_embeds = block_state.negative_prompt_embeds[
+ :, : block_state.max_sequence_length
+ ]
+ block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask[
+ :, : block_state.max_sequence_length
+ ]
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Text Encoder step that processes both prompt and image together to generate text embeddings for guiding image generation"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("text_encoder", Qwen2_5_VLForConditionalGeneration),
+ ComponentSpec("processor", Qwen2VLProcessor),
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 4.0}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def expected_configs(self) -> List[ConfigSpec]:
+ return [
+ ConfigSpec(
+ name="prompt_template_encode",
+ default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n",
+ ),
+ ConfigSpec(name="prompt_template_encode_start_idx", default=64),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="prompt", required=True, type_hint=str, description="The prompt to encode"),
+ InputParam(name="negative_prompt", type_hint=str, description="The negative prompt to encode"),
+ InputParam(
+ name="resized_image",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The image prompt to encode, should be resized using resize step",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="prompt_embeds",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The prompt embeddings",
+ ),
+ OutputParam(
+ name="prompt_embeds_mask",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The encoder attention mask",
+ ),
+ OutputParam(
+ name="negative_prompt_embeds",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The negative prompt embeddings",
+ ),
+ OutputParam(
+ name="negative_prompt_embeds_mask",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="The negative prompt embeddings mask",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(prompt, negative_prompt):
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if (
+ negative_prompt is not None
+ and not isinstance(negative_prompt, str)
+ and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(block_state.prompt, block_state.negative_prompt)
+
+ device = components._execution_device
+
+ block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds_edit(
+ components.text_encoder,
+ components.processor,
+ prompt=block_state.prompt,
+ image=block_state.resized_image,
+ prompt_template_encode=components.config.prompt_template_encode,
+ prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
+ device=device,
+ )
+
+ if components.requires_unconditional_embeds:
+ negative_prompt = block_state.negative_prompt or " "
+ block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit(
+ components.text_encoder,
+ components.processor,
+ prompt=negative_prompt,
+ image=block_state.resized_image,
+ prompt_template_encode=components.config.prompt_template_encode,
+ prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
+ device=device,
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep):
+ model_name = "qwenimage"
+
+ @property
+ def expected_configs(self) -> List[ConfigSpec]:
+ return [
+ ConfigSpec(
+ name="prompt_template_encode",
+ default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ ),
+ ConfigSpec(
+ name="img_template_encode",
+ default="Picture {}: <|vision_start|><|image_pad|><|vision_end|>",
+ ),
+ ConfigSpec(name="prompt_template_encode_start_idx", default=64),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(block_state.prompt, block_state.negative_prompt)
+
+ device = components._execution_device
+
+ block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds_edit_plus(
+ components.text_encoder,
+ components.processor,
+ prompt=block_state.prompt,
+ image=block_state.resized_image,
+ prompt_template_encode=components.config.prompt_template_encode,
+ img_template_encode=components.config.img_template_encode,
+ prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
+ device=device,
+ )
+
+ if components.requires_unconditional_embeds:
+ negative_prompt = block_state.negative_prompt or " "
+ block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = (
+ get_qwen_prompt_embeds_edit_plus(
+ components.text_encoder,
+ components.processor,
+ prompt=negative_prompt,
+ image=block_state.resized_image,
+ prompt_template_encode=components.config.prompt_template_encode,
+ img_template_encode=components.config.img_template_encode,
+ prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
+ device=device,
+ )
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Image Preprocess step for inpainting task. This processes the image and mask inputs together. Images can be resized first using QwenImageEditResizeDynamicStep."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_mask_processor",
+ InpaintProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("mask_image", required=True),
+ InputParam("resized_image"),
+ InputParam("image"),
+ InputParam("height"),
+ InputParam("width"),
+ InputParam("padding_mask_crop"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(name="processed_image"),
+ OutputParam(name="processed_mask_image"),
+ OutputParam(
+ name="mask_overlay_kwargs",
+ type_hint=Dict,
+ description="The kwargs for the postprocess step to apply the mask overlay",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(height, width, vae_scale_factor):
+ if height is not None and height % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
+
+ if width is not None and width % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ if block_state.resized_image is None and block_state.image is None:
+ raise ValueError("resized_image and image cannot be None at the same time")
+
+ if block_state.resized_image is None:
+ image = block_state.image
+ self.check_inputs(
+ height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
+ )
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
+ else:
+ width, height = block_state.resized_image[0].size
+ image = block_state.resized_image
+
+ block_state.processed_image, block_state.processed_mask_image, block_state.mask_overlay_kwargs = (
+ components.image_mask_processor.preprocess(
+ image=image,
+ mask=block_state.mask_image,
+ height=height,
+ width=width,
+ padding_mask_crop=block_state.padding_mask_crop,
+ )
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "Image Preprocess step. Images can be resized first using QwenImageEditResizeDynamicStep."
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(name="processed_image"),
+ ]
+
+ @staticmethod
+ def check_inputs(height, width, vae_scale_factor):
+ if height is not None and height % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
+
+ if width is not None and width % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ if block_state.resized_image is None and block_state.image is None:
+ raise ValueError("resized_image and image cannot be None at the same time")
+
+ if block_state.resized_image is None:
+ image = block_state.image
+ self.check_inputs(
+ height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
+ )
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
+ else:
+ width, height = block_state.resized_image[0].size
+ image = block_state.resized_image
+
+ block_state.processed_image = components.image_processor.preprocess(
+ image=image,
+ height=height,
+ width=width,
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
+ model_name = "qwenimage-edit-plus"
+ vae_image_size = 1024 * 1024
+
+ @property
+ def description(self) -> str:
+ return "Image Preprocess step for QwenImage Edit Plus. Unlike QwenImage Edit, QwenImage Edit Plus doesn't use the same resized image for further preprocessing."
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [InputParam("vae_image"), InputParam("image"), InputParam("height"), InputParam("width")]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ if block_state.vae_image is None and block_state.image is None:
+ raise ValueError("`vae_image` and `image` cannot be None at the same time")
+
+ if block_state.vae_image is None:
+ image = block_state.image
+ self.check_inputs(
+ height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
+ )
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
+ block_state.processed_image = components.image_processor.preprocess(
+ image=image, height=height, width=width
+ )
+ else:
+ width, height = block_state.vae_image[0].size
+ image = block_state.vae_image
+
+ block_state.processed_image = components.image_processor.preprocess(
+ image=image, height=height, width=width
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ def __init__(
+ self,
+ input_name: str = "processed_image",
+ output_name: str = "image_latents",
+ ):
+ """Initialize a VAE encoder step for converting images to latent representations.
+
+ Both the input and output names are configurable so this block can be configured to process to different image
+ inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents").
+
+ Args:
+ input_name (str, optional): Name of the input image tensor. Defaults to "processed_image".
+ Examples: "processed_image" or "processed_control_image"
+ output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents".
+ Examples: "image_latents" or "control_image_latents"
+
+ Examples:
+ # Basic usage with default settings (includes image processor) QwenImageVaeEncoderDynamicStep()
+
+ # Custom input/output names for control image QwenImageVaeEncoderDynamicStep(
+ input_name="processed_control_image", output_name="control_image_latents"
+ )
+ """
+ self._image_input_name = input_name
+ self._image_latents_output_name = output_name
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ components = [
+ ComponentSpec("vae", AutoencoderKLQwenImage),
+ ]
+ return components
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [
+ InputParam(self._image_input_name, required=True),
+ InputParam("generator"),
+ ]
+ return inputs
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ self._image_latents_output_name,
+ type_hint=torch.Tensor,
+ description="The latents representing the reference image",
+ )
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ device = components._execution_device
+ dtype = components.vae.dtype
+
+ image = getattr(block_state, self._image_input_name)
+
+ # Encode image into latents
+ image_latents = encode_vae_image(
+ image=image,
+ vae=components.vae,
+ generator=block_state.generator,
+ device=device,
+ dtype=dtype,
+ latent_channels=components.num_channels_latents,
+ )
+ setattr(block_state, self._image_latents_output_name, image_latents)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "VAE Encoder step that converts `control_image` into latent representations control_image_latents.\n"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ components = [
+ ComponentSpec("vae", AutoencoderKLQwenImage),
+ ComponentSpec("controlnet", QwenImageControlNetModel),
+ ComponentSpec(
+ "control_image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+ return components
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [
+ InputParam("control_image", required=True),
+ InputParam("height"),
+ InputParam("width"),
+ InputParam("generator"),
+ ]
+ return inputs
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "control_image_latents",
+ type_hint=torch.Tensor,
+ description="The latents representing the control image",
+ )
+ ]
+
+ @staticmethod
+ def check_inputs(height, width, vae_scale_factor):
+ if height is not None and height % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
+
+ if width is not None and width % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(block_state.height, block_state.width, components.vae_scale_factor)
+
+ device = components._execution_device
+ dtype = components.vae.dtype
+
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
+
+ controlnet = unwrap_module(components.controlnet)
+ if isinstance(controlnet, QwenImageMultiControlNetModel) and not isinstance(block_state.control_image, list):
+ block_state.control_image = [block_state.control_image]
+
+ if isinstance(controlnet, QwenImageMultiControlNetModel):
+ block_state.control_image_latents = []
+ for control_image_ in block_state.control_image:
+ control_image_ = components.control_image_processor.preprocess(
+ image=control_image_,
+ height=height,
+ width=width,
+ )
+
+ control_image_latents_ = encode_vae_image(
+ image=control_image_,
+ vae=components.vae,
+ generator=block_state.generator,
+ device=device,
+ dtype=dtype,
+ latent_channels=components.num_channels_latents,
+ sample_mode="sample",
+ )
+ block_state.control_image_latents.append(control_image_latents_)
+
+ elif isinstance(controlnet, QwenImageControlNetModel):
+ control_image = components.control_image_processor.preprocess(
+ image=block_state.control_image,
+ height=height,
+ width=width,
+ )
+ block_state.control_image_latents = encode_vae_image(
+ image=control_image,
+ vae=components.vae,
+ generator=block_state.generator,
+ device=device,
+ dtype=dtype,
+ latent_channels=components.num_channels_latents,
+ sample_mode="sample",
+ )
+
+ else:
+ raise ValueError(
+ f"Expected controlnet to be a QwenImageControlNetModel or QwenImageMultiControlNetModel, got {type(controlnet)}"
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/qwenimage/inputs.py b/src/diffusers/modular_pipelines/qwenimage/inputs.py
new file mode 100644
index 0000000000..2b229c040b
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/inputs.py
@@ -0,0 +1,443 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Tuple
+
+import torch
+
+from ...models import QwenImageMultiControlNetModel
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
+
+
+def repeat_tensor_to_batch_size(
+ input_name: str,
+ input_tensor: torch.Tensor,
+ batch_size: int,
+ num_images_per_prompt: int = 1,
+) -> torch.Tensor:
+ """Repeat tensor elements to match the final batch size.
+
+ This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt)
+ by repeating each element along dimension 0.
+
+ The input tensor must have batch size 1 or batch_size. The function will:
+ - If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times
+ - If batch size equals batch_size: repeat each element num_images_per_prompt times
+
+ Args:
+ input_name (str): Name of the input tensor (used for error messages)
+ input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size.
+ batch_size (int): The base batch size (number of prompts)
+ num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1.
+
+ Returns:
+ torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt)
+
+ Raises:
+ ValueError: If input_tensor is not a torch.Tensor or has invalid batch size
+
+ Examples:
+ tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor,
+ batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape:
+ [4, 3]
+
+ tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image",
+ tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]])
+ - shape: [4, 3]
+ """
+ # make sure input is a tensor
+ if not isinstance(input_tensor, torch.Tensor):
+ raise ValueError(f"`{input_name}` must be a tensor")
+
+ # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts
+ if input_tensor.shape[0] == 1:
+ repeat_by = batch_size * num_images_per_prompt
+ elif input_tensor.shape[0] == batch_size:
+ repeat_by = num_images_per_prompt
+ else:
+ raise ValueError(
+ f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}"
+ )
+
+ # expand the tensor to match the batch_size * num_images_per_prompt
+ input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0)
+
+ return input_tensor
+
+
+def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: int) -> Tuple[int, int]:
+ """Calculate image dimensions from latent tensor dimensions.
+
+ This function converts latent space dimensions to image space dimensions by multiplying the latent height and width
+ by the VAE scale factor.
+
+ Args:
+ latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions.
+ Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width]
+ vae_scale_factor (int): The scale factor used by the VAE to compress images.
+ Typically 8 for most VAEs (image is 8x larger than latents in each dimension)
+
+ Returns:
+ Tuple[int, int]: The calculated image dimensions as (height, width)
+
+ Raises:
+ ValueError: If latents tensor doesn't have 4 or 5 dimensions
+
+ """
+ # make sure the latents are not packed
+ if latents.ndim != 4 and latents.ndim != 5:
+ raise ValueError(f"unpacked latents must have 4 or 5 dimensions, but got {latents.ndim}")
+
+ latent_height, latent_width = latents.shape[-2:]
+
+ height = latent_height * vae_scale_factor
+ width = latent_width * vae_scale_factor
+
+ return height, width
+
+
+class QwenImageTextInputsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ summary_section = (
+ "Text input processing step that standardizes text embeddings for the pipeline.\n"
+ "This step:\n"
+ " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
+ " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
+ )
+
+ # Placement guidance
+ placement_section = "\n\nThis block should be placed after all encoder steps to process the text embeddings before they are used in subsequent pipeline steps."
+
+ return summary_section + placement_section
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="num_images_per_prompt", default=1),
+ InputParam(name="prompt_embeds", required=True, kwargs_type="denoiser_input_fields"),
+ InputParam(name="prompt_embeds_mask", required=True, kwargs_type="denoiser_input_fields"),
+ InputParam(name="negative_prompt_embeds", kwargs_type="denoiser_input_fields"),
+ InputParam(name="negative_prompt_embeds_mask", kwargs_type="denoiser_input_fields"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "batch_size",
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
+ ),
+ OutputParam(
+ "dtype",
+ type_hint=torch.dtype,
+ description="Data type of model tensor inputs (determined by `prompt_embeds`)",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(
+ prompt_embeds,
+ prompt_embeds_mask,
+ negative_prompt_embeds,
+ negative_prompt_embeds_mask,
+ ):
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError("`negative_prompt_embeds_mask` is required when `negative_prompt_embeds` is not None")
+
+ if negative_prompt_embeds is None and negative_prompt_embeds_mask is not None:
+ raise ValueError("cannot pass `negative_prompt_embeds_mask` without `negative_prompt_embeds`")
+
+ if prompt_embeds_mask.shape[0] != prompt_embeds.shape[0]:
+ raise ValueError("`prompt_embeds_mask` must have the same batch size as `prompt_embeds`")
+
+ elif negative_prompt_embeds is not None and negative_prompt_embeds.shape[0] != prompt_embeds.shape[0]:
+ raise ValueError("`negative_prompt_embeds` must have the same batch size as `prompt_embeds`")
+
+ elif (
+ negative_prompt_embeds_mask is not None and negative_prompt_embeds_mask.shape[0] != prompt_embeds.shape[0]
+ ):
+ raise ValueError("`negative_prompt_embeds_mask` must have the same batch size as `prompt_embeds`")
+
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(
+ prompt_embeds=block_state.prompt_embeds,
+ prompt_embeds_mask=block_state.prompt_embeds_mask,
+ negative_prompt_embeds=block_state.negative_prompt_embeds,
+ negative_prompt_embeds_mask=block_state.negative_prompt_embeds_mask,
+ )
+
+ block_state.batch_size = block_state.prompt_embeds.shape[0]
+ block_state.dtype = block_state.prompt_embeds.dtype
+
+ _, seq_len, _ = block_state.prompt_embeds.shape
+
+ block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
+ block_state.prompt_embeds = block_state.prompt_embeds.view(
+ block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
+ )
+
+ block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.repeat(1, block_state.num_images_per_prompt, 1)
+ block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.view(
+ block_state.batch_size * block_state.num_images_per_prompt, seq_len
+ )
+
+ if block_state.negative_prompt_embeds is not None:
+ _, seq_len, _ = block_state.negative_prompt_embeds.shape
+ block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
+ 1, block_state.num_images_per_prompt, 1
+ )
+ block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
+ block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
+ )
+
+ block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.repeat(
+ 1, block_state.num_images_per_prompt, 1
+ )
+ block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.view(
+ block_state.batch_size * block_state.num_images_per_prompt, seq_len
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class QwenImageInputsDynamicStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ def __init__(
+ self,
+ image_latent_inputs: List[str] = ["image_latents"],
+ additional_batch_inputs: List[str] = [],
+ ):
+ """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
+
+ This step handles multiple common tasks to prepare inputs for the denoising step:
+ 1. For encoded image latents, use it update height/width if None, patchifies, and expands batch size
+ 2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
+
+ This is a dynamic block that allows you to configure which inputs to process.
+
+ Args:
+ image_latent_inputs (List[str], optional): Names of image latent tensors to process.
+ These will be used to determine height/width, patchified, and batch-expanded. Can be a single string or
+ list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], ["control_image_latents"]
+ additional_batch_inputs (List[str], optional):
+ Names of additional conditional input tensors to expand batch size. These tensors will only have their
+ batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
+ Defaults to []. Examples: ["processed_mask_image"]
+
+ Examples:
+ # Configure to process image_latents (default behavior) QwenImageInputsDynamicStep()
+
+ # Configure to process multiple image latent inputs
+ QwenImageInputsDynamicStep(image_latent_inputs=["image_latents", "control_image_latents"])
+
+ # Configure to process image latents and additional batch inputs QwenImageInputsDynamicStep(
+ image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
+ )
+ """
+ if not isinstance(image_latent_inputs, list):
+ image_latent_inputs = [image_latent_inputs]
+ if not isinstance(additional_batch_inputs, list):
+ additional_batch_inputs = [additional_batch_inputs]
+
+ self._image_latent_inputs = image_latent_inputs
+ self._additional_batch_inputs = additional_batch_inputs
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ # Functionality section
+ summary_section = (
+ "Input processing step that:\n"
+ " 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n"
+ " 2. For additional batch inputs: Expands batch dimensions to match final batch size"
+ )
+
+ # Inputs info
+ inputs_info = ""
+ if self._image_latent_inputs or self._additional_batch_inputs:
+ inputs_info = "\n\nConfigured inputs:"
+ if self._image_latent_inputs:
+ inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
+ if self._additional_batch_inputs:
+ inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
+
+ # Placement guidance
+ placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
+
+ return summary_section + inputs_info + placement_section
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [
+ InputParam(name="num_images_per_prompt", default=1),
+ InputParam(name="batch_size", required=True),
+ InputParam(name="height"),
+ InputParam(name="width"),
+ ]
+
+ # Add image latent inputs
+ for image_latent_input_name in self._image_latent_inputs:
+ inputs.append(InputParam(name=image_latent_input_name))
+
+ # Add additional batch inputs
+ for input_name in self._additional_batch_inputs:
+ inputs.append(InputParam(name=input_name))
+
+ return inputs
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(name="image_height", type_hint=int, description="The height of the image latents"),
+ OutputParam(name="image_width", type_hint=int, description="The width of the image latents"),
+ ]
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
+ ]
+
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # Process image latent inputs (height/width calculation, patchify, and batch expansion)
+ for image_latent_input_name in self._image_latent_inputs:
+ image_latent_tensor = getattr(block_state, image_latent_input_name)
+ if image_latent_tensor is None:
+ continue
+
+ # 1. Calculate height/width from latents
+ height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
+ block_state.height = block_state.height or height
+ block_state.width = block_state.width or width
+
+ if not hasattr(block_state, "image_height"):
+ block_state.image_height = height
+ if not hasattr(block_state, "image_width"):
+ block_state.image_width = width
+
+ # 2. Patchify the image latent tensor
+ image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor)
+
+ # 3. Expand batch size
+ image_latent_tensor = repeat_tensor_to_batch_size(
+ input_name=image_latent_input_name,
+ input_tensor=image_latent_tensor,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, image_latent_input_name, image_latent_tensor)
+
+ # Process additional batch inputs (only batch expansion)
+ for input_name in self._additional_batch_inputs:
+ input_tensor = getattr(block_state, input_name)
+ if input_tensor is None:
+ continue
+
+ # Only expand batch size
+ input_tensor = repeat_tensor_to_batch_size(
+ input_name=input_name,
+ input_tensor=input_tensor,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, input_name, input_tensor)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class QwenImageControlNetInputsStep(ModularPipelineBlocks):
+ model_name = "qwenimage"
+
+ @property
+ def description(self) -> str:
+ return "prepare the `control_image_latents` for controlnet. Insert after all the other inputs steps."
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="control_image_latents", required=True),
+ InputParam(name="batch_size", required=True),
+ InputParam(name="num_images_per_prompt", default=1),
+ InputParam(name="height"),
+ InputParam(name="width"),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ if isinstance(components.controlnet, QwenImageMultiControlNetModel):
+ control_image_latents = []
+ # loop through each control_image_latents
+ for i, control_image_latents_ in enumerate(block_state.control_image_latents):
+ # 1. update height/width if not provided
+ height, width = calculate_dimension_from_latents(control_image_latents_, components.vae_scale_factor)
+ block_state.height = block_state.height or height
+ block_state.width = block_state.width or width
+
+ # 2. pack
+ control_image_latents_ = components.pachifier.pack_latents(control_image_latents_)
+
+ # 3. repeat to match the batch size
+ control_image_latents_ = repeat_tensor_to_batch_size(
+ input_name=f"control_image_latents[{i}]",
+ input_tensor=control_image_latents_,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ control_image_latents.append(control_image_latents_)
+
+ block_state.control_image_latents = control_image_latents
+
+ else:
+ # 1. update height/width if not provided
+ height, width = calculate_dimension_from_latents(
+ block_state.control_image_latents, components.vae_scale_factor
+ )
+ block_state.height = block_state.height or height
+ block_state.width = block_state.width or width
+
+ # 2. pack
+ block_state.control_image_latents = components.pachifier.pack_latents(block_state.control_image_latents)
+
+ # 3. repeat to match the batch size
+ block_state.control_image_latents = repeat_tensor_to_batch_size(
+ input_name="control_image_latents",
+ input_tensor=block_state.control_image_latents,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ block_state.control_image_latents = block_state.control_image_latents
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py
new file mode 100644
index 0000000000..83bfcb3da4
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py
@@ -0,0 +1,1036 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...utils import logging
+from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
+from ..modular_pipeline_utils import InsertableDict
+from .before_denoise import (
+ QwenImageControlNetBeforeDenoiserStep,
+ QwenImageCreateMaskLatentsStep,
+ QwenImageEditRoPEInputsStep,
+ QwenImagePrepareLatentsStep,
+ QwenImagePrepareLatentsWithStrengthStep,
+ QwenImageRoPEInputsStep,
+ QwenImageSetTimestepsStep,
+ QwenImageSetTimestepsWithStrengthStep,
+)
+from .decoders import QwenImageDecoderStep, QwenImageInpaintProcessImagesOutputStep, QwenImageProcessImagesOutputStep
+from .denoise import (
+ QwenImageControlNetDenoiseStep,
+ QwenImageDenoiseStep,
+ QwenImageEditDenoiseStep,
+ QwenImageEditInpaintDenoiseStep,
+ QwenImageInpaintControlNetDenoiseStep,
+ QwenImageInpaintDenoiseStep,
+ QwenImageLoopBeforeDenoiserControlNet,
+)
+from .encoders import (
+ QwenImageControlNetVaeEncoderStep,
+ QwenImageEditPlusProcessImagesInputStep,
+ QwenImageEditPlusResizeDynamicStep,
+ QwenImageEditPlusTextEncoderStep,
+ QwenImageEditResizeDynamicStep,
+ QwenImageEditTextEncoderStep,
+ QwenImageInpaintProcessImagesInputStep,
+ QwenImageProcessImagesInputStep,
+ QwenImageTextEncoderStep,
+ QwenImageVaeEncoderDynamicStep,
+)
+from .inputs import QwenImageControlNetInputsStep, QwenImageInputsDynamicStep, QwenImageTextInputsStep
+
+
+logger = logging.get_logger(__name__)
+
+# 1. QwenImage
+
+## 1.1 QwenImage/text2image
+
+#### QwenImage/decode
+#### (standard decode step works for most tasks except for inpaint)
+QwenImageDecodeBlocks = InsertableDict(
+ [
+ ("decode", QwenImageDecoderStep()),
+ ("postprocess", QwenImageProcessImagesOutputStep()),
+ ]
+)
+
+
+class QwenImageDecodeStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageDecodeBlocks.values()
+ block_names = QwenImageDecodeBlocks.keys()
+
+ @property
+ def description(self):
+ return "Decode step that decodes the latents to images and postprocess the generated image."
+
+
+#### QwenImage/text2image presets
+TEXT2IMAGE_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageTextEncoderStep()),
+ ("input", QwenImageTextInputsStep()),
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsStep()),
+ ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
+ ("denoise", QwenImageDenoiseStep()),
+ ("decode", QwenImageDecodeStep()),
+ ]
+)
+
+
+## 1.2 QwenImage/inpaint
+
+#### QwenImage/inpaint vae encoder
+QwenImageInpaintVaeEncoderBlocks = InsertableDict(
+ [
+ (
+ "preprocess",
+ QwenImageInpaintProcessImagesInputStep,
+ ), # image, mask_image -> processed_image, processed_mask_image, mask_overlay_kwargs
+ ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
+ ]
+)
+
+
+class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageInpaintVaeEncoderBlocks.values()
+ block_names = QwenImageInpaintVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return (
+ "This step is used for processing image and mask inputs for inpainting tasks. It:\n"
+ " - Resizes the image to the target size, based on `height` and `width`.\n"
+ " - Processes and updates `image` and `mask_image`.\n"
+ " - Creates `image_latents`."
+ )
+
+
+#### QwenImage/inpaint inputs
+QwenImageInpaintInputBlocks = InsertableDict(
+ [
+ ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
+ (
+ "additional_inputs",
+ QwenImageInputsDynamicStep(
+ image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
+ ),
+ ),
+ ]
+)
+
+
+class QwenImageInpaintInputStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageInpaintInputBlocks.values()
+ block_names = QwenImageInpaintInputBlocks.keys()
+
+ @property
+ def description(self):
+ return "Input step that prepares the inputs for the inpainting denoising step. It:\n"
+ " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents` and `processed_mask_image`).\n"
+ " - update height/width based `image_latents`, patchify `image_latents`."
+
+
+# QwenImage/inpaint prepare latents
+QwenImageInpaintPrepareLatentsBlocks = InsertableDict(
+ [
+ ("add_noise_to_latents", QwenImagePrepareLatentsWithStrengthStep()),
+ ("create_mask_latents", QwenImageCreateMaskLatentsStep()),
+ ]
+)
+
+
+class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageInpaintPrepareLatentsBlocks.values()
+ block_names = QwenImageInpaintPrepareLatentsBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return (
+ "This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n"
+ " - Add noise to the image latents to create the latents input for the denoiser.\n"
+ " - Create the pachified latents `mask` based on the processedmask image.\n"
+ )
+
+
+#### QwenImage/inpaint decode
+QwenImageInpaintDecodeBlocks = InsertableDict(
+ [
+ ("decode", QwenImageDecoderStep()),
+ ("postprocess", QwenImageInpaintProcessImagesOutputStep()),
+ ]
+)
+
+
+class QwenImageInpaintDecodeStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageInpaintDecodeBlocks.values()
+ block_names = QwenImageInpaintDecodeBlocks.keys()
+
+ @property
+ def description(self):
+ return "Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image."
+
+
+#### QwenImage/inpaint presets
+INPAINT_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageTextEncoderStep()),
+ ("vae_encoder", QwenImageInpaintVaeEncoderStep()),
+ ("input", QwenImageInpaintInputStep()),
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
+ ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
+ ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
+ ("denoise", QwenImageInpaintDenoiseStep()),
+ ("decode", QwenImageInpaintDecodeStep()),
+ ]
+)
+
+
+## 1.3 QwenImage/img2img
+
+#### QwenImage/img2img vae encoder
+QwenImageImg2ImgVaeEncoderBlocks = InsertableDict(
+ [
+ ("preprocess", QwenImageProcessImagesInputStep()),
+ ("encode", QwenImageVaeEncoderDynamicStep()),
+ ]
+)
+
+
+class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+
+ block_classes = QwenImageImg2ImgVaeEncoderBlocks.values()
+ block_names = QwenImageImg2ImgVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
+
+
+#### QwenImage/img2img inputs
+QwenImageImg2ImgInputBlocks = InsertableDict(
+ [
+ ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
+ ("additional_inputs", QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])),
+ ]
+)
+
+
+class QwenImageImg2ImgInputStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageImg2ImgInputBlocks.values()
+ block_names = QwenImageImg2ImgInputBlocks.keys()
+
+ @property
+ def description(self):
+ return "Input step that prepares the inputs for the img2img denoising step. It:\n"
+ " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
+ " - update height/width based `image_latents`, patchify `image_latents`."
+
+
+#### QwenImage/img2img presets
+IMAGE2IMAGE_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageTextEncoderStep()),
+ ("vae_encoder", QwenImageImg2ImgVaeEncoderStep()),
+ ("input", QwenImageImg2ImgInputStep()),
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
+ ("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()),
+ ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
+ ("denoise", QwenImageDenoiseStep()),
+ ("decode", QwenImageDecodeStep()),
+ ]
+)
+
+
+## 1.4 QwenImage/controlnet
+
+#### QwenImage/controlnet presets
+CONTROLNET_BLOCKS = InsertableDict(
+ [
+ ("controlnet_vae_encoder", QwenImageControlNetVaeEncoderStep()), # vae encoder step for control_image
+ ("controlnet_inputs", QwenImageControlNetInputsStep()), # additional input step for controlnet
+ (
+ "controlnet_before_denoise",
+ QwenImageControlNetBeforeDenoiserStep(),
+ ), # before denoise step (after set_timesteps step)
+ (
+ "controlnet_denoise_loop_before",
+ QwenImageLoopBeforeDenoiserControlNet(),
+ ), # controlnet loop step (insert before the denoiseloop_denoiser)
+ ]
+)
+
+
+## 1.5 QwenImage/auto encoders
+
+
+#### for inpaint and img2img tasks
+class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep]
+ block_names = ["inpaint", "img2img"]
+ block_trigger_inputs = ["mask_image", "image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae encoder step that encode the image inputs into their latent representations.\n"
+ + "This is an auto pipeline block.\n"
+ + " - `QwenImageInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n"
+ + " - `QwenImageImg2ImgVaeEncoderStep` (img2img) is used when `image` is provided.\n"
+ + " - if `mask_image` or `image` is not provided, step will be skipped."
+ )
+
+
+# for controlnet tasks
+class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [QwenImageControlNetVaeEncoderStep]
+ block_names = ["controlnet"]
+ block_trigger_inputs = ["control_image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae encoder step that encode the image inputs into their latent representations.\n"
+ + "This is an auto pipeline block.\n"
+ + " - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.\n"
+ + " - if `control_image` is not provided, step will be skipped."
+ )
+
+
+## 1.6 QwenImage/auto inputs
+
+
+# text2image/inpaint/img2img
+class QwenImageAutoInputStep(AutoPipelineBlocks):
+ block_classes = [QwenImageInpaintInputStep, QwenImageImg2ImgInputStep, QwenImageTextInputsStep]
+ block_names = ["inpaint", "img2img", "text2image"]
+ block_trigger_inputs = ["processed_mask_image", "image_latents", None]
+
+ @property
+ def description(self):
+ return (
+ "Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
+ " This is an auto pipeline block that works for text2image/inpaint/img2img tasks.\n"
+ + " - `QwenImageInpaintInputStep` (inpaint) is used when `processed_mask_image` is provided.\n"
+ + " - `QwenImageImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n"
+ + " - `QwenImageTextInputsStep` (text2image) is used when both `processed_mask_image` and `image_latents` are not provided.\n"
+ )
+
+
+# controlnet
+class QwenImageOptionalControlNetInputStep(AutoPipelineBlocks):
+ block_classes = [QwenImageControlNetInputsStep]
+ block_names = ["controlnet"]
+ block_trigger_inputs = ["control_image_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Controlnet input step that prepare the control_image_latents input.\n"
+ + "This is an auto pipeline block.\n"
+ + " - `QwenImageControlNetInputsStep` (controlnet) is used when `control_image_latents` is provided.\n"
+ + " - if `control_image_latents` is not provided, step will be skipped."
+ )
+
+
+## 1.7 QwenImage/auto before denoise step
+# compose the steps into a BeforeDenoiseStep for text2image/img2img/inpaint tasks before combine into an auto step
+
+# QwenImage/text2image before denoise
+QwenImageText2ImageBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsStep()),
+ ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
+ ]
+)
+
+
+class QwenImageText2ImageBeforeDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageText2ImageBeforeDenoiseBlocks.values()
+ block_names = QwenImageText2ImageBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for text2image task."
+
+
+# QwenImage/inpaint before denoise
+QwenImageInpaintBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
+ ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
+ ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
+ ]
+)
+
+
+class QwenImageInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageInpaintBeforeDenoiseBlocks.values()
+ block_names = QwenImageInpaintBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
+
+
+# QwenImage/img2img before denoise
+QwenImageImg2ImgBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
+ ("prepare_img2img_latents", QwenImagePrepareLatentsWithStrengthStep()),
+ ("prepare_rope_inputs", QwenImageRoPEInputsStep()),
+ ]
+)
+
+
+class QwenImageImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageImg2ImgBeforeDenoiseBlocks.values()
+ block_names = QwenImageImg2ImgBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
+
+
+# auto before_denoise step for text2image, inpaint, img2img tasks
+class QwenImageAutoBeforeDenoiseStep(AutoPipelineBlocks):
+ block_classes = [
+ QwenImageInpaintBeforeDenoiseStep,
+ QwenImageImg2ImgBeforeDenoiseStep,
+ QwenImageText2ImageBeforeDenoiseStep,
+ ]
+ block_names = ["inpaint", "img2img", "text2image"]
+ block_trigger_inputs = ["processed_mask_image", "image_latents", None]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
+ + "This is an auto pipeline block that works for text2img, inpainting, img2img tasks.\n"
+ + " - `QwenImageInpaintBeforeDenoiseStep` (inpaint) is used when `processed_mask_image` is provided.\n"
+ + " - `QwenImageImg2ImgBeforeDenoiseStep` (img2img) is used when `image_latents` is provided.\n"
+ + " - `QwenImageText2ImageBeforeDenoiseStep` (text2image) is used when both `processed_mask_image` and `image_latents` are not provided.\n"
+ )
+
+
+# auto before_denoise step for controlnet tasks
+class QwenImageOptionalControlNetBeforeDenoiseStep(AutoPipelineBlocks):
+ block_classes = [QwenImageControlNetBeforeDenoiserStep]
+ block_names = ["controlnet"]
+ block_trigger_inputs = ["control_image_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Controlnet before denoise step that prepare the controlnet input.\n"
+ + "This is an auto pipeline block.\n"
+ + " - `QwenImageControlNetBeforeDenoiserStep` (controlnet) is used when `control_image_latents` is provided.\n"
+ + " - if `control_image_latents` is not provided, step will be skipped."
+ )
+
+
+## 1.8 QwenImage/auto denoise
+
+
+# auto denoise step for controlnet tasks: works for all tasks with controlnet
+class QwenImageControlNetAutoDenoiseStep(AutoPipelineBlocks):
+ block_classes = [QwenImageInpaintControlNetDenoiseStep, QwenImageControlNetDenoiseStep]
+ block_names = ["inpaint_denoise", "denoise"]
+ block_trigger_inputs = ["mask", None]
+
+ @property
+ def description(self):
+ return (
+ "Controlnet step during the denoising process. \n"
+ " This is an auto pipeline block that works for inpaint and text2image/img2img tasks with controlnet.\n"
+ + " - `QwenImageInpaintControlNetDenoiseStep` (inpaint) is used when `mask` is provided.\n"
+ + " - `QwenImageControlNetDenoiseStep` (text2image/img2img) is used when `mask` is not provided.\n"
+ )
+
+
+# auto denoise step for everything: works for all tasks with or without controlnet
+class QwenImageAutoDenoiseStep(AutoPipelineBlocks):
+ block_classes = [
+ QwenImageControlNetAutoDenoiseStep,
+ QwenImageInpaintDenoiseStep,
+ QwenImageDenoiseStep,
+ ]
+ block_names = ["controlnet_denoise", "inpaint_denoise", "denoise"]
+ block_trigger_inputs = ["control_image_latents", "mask", None]
+
+ @property
+ def description(self):
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ " This is an auto pipeline block that works for inpaint/text2image/img2img tasks. It also works with controlnet\n"
+ + " - `QwenImageControlNetAutoDenoiseStep` (controlnet) is used when `control_image_latents` is provided.\n"
+ + " - `QwenImageInpaintDenoiseStep` (inpaint) is used when `mask` is provided and `control_image_latents` is not provided.\n"
+ + " - `QwenImageDenoiseStep` (text2image/img2img) is used when `mask` is not provided and `control_image_latents` is not provided.\n"
+ )
+
+
+## 1.9 QwenImage/auto decode
+# auto decode step for inpaint and text2image tasks
+
+
+class QwenImageAutoDecodeStep(AutoPipelineBlocks):
+ block_classes = [QwenImageInpaintDecodeStep, QwenImageDecodeStep]
+ block_names = ["inpaint_decode", "decode"]
+ block_trigger_inputs = ["mask", None]
+
+ @property
+ def description(self):
+ return (
+ "Decode step that decode the latents into images. \n"
+ " This is an auto pipeline block that works for inpaint/text2image/img2img tasks, for both QwenImage and QwenImage-Edit.\n"
+ + " - `QwenImageInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n"
+ + " - `QwenImageDecodeStep` (text2image/img2img) is used when `mask` is not provided.\n"
+ )
+
+
+class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = [
+ QwenImageAutoInputStep,
+ QwenImageOptionalControlNetInputStep,
+ QwenImageAutoBeforeDenoiseStep,
+ QwenImageOptionalControlNetBeforeDenoiseStep,
+ QwenImageAutoDenoiseStep,
+ ]
+ block_names = ["input", "controlnet_input", "before_denoise", "controlnet_before_denoise", "denoise", "decode"]
+
+ @property
+ def description(self):
+ return (
+ "Core step that performs the denoising process. \n"
+ + " - `QwenImageAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ + " - `QwenImageOptionalControlNetInputStep` (controlnet_input) prepares the controlnet input.\n"
+ + " - `QwenImageAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ + " - `QwenImageOptionalControlNetBeforeDenoiseStep` (controlnet_before_denoise) prepares the controlnet input for the denoising step.\n"
+ + " - `QwenImageAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ + " - `QwenImageAutoDecodeStep` (decode) decodes the latents into images.\n\n"
+ + "This step support text-to-image, image-to-image, inpainting, and controlnet tasks for QwenImage:\n"
+ + " - for image-to-image generation, you need to provide `image_latents`\n"
+ + " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n"
+ + " - to run the controlnet workflow, you need to provide `control_image_latents`\n"
+ + " - for text-to-image generation, all you need to provide is prompt embeddings"
+ )
+
+
+## 1.10 QwenImage/auto block & presets
+AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageTextEncoderStep()),
+ ("vae_encoder", QwenImageAutoVaeEncoderStep()),
+ ("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()),
+ ("denoise", QwenImageCoreDenoiseStep()),
+ ("decode", QwenImageAutoDecodeStep()),
+ ]
+)
+
+
+class QwenImageAutoBlocks(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+
+ block_classes = AUTO_BLOCKS.values()
+ block_names = AUTO_BLOCKS.keys()
+
+ @property
+ def description(self):
+ return (
+ "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
+ + "- for image-to-image generation, you need to provide `image`\n"
+ + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
+ + "- to run the controlnet workflow, you need to provide `control_image`\n"
+ + "- for text-to-image generation, all you need to provide is `prompt`"
+ )
+
+
+# 2. QwenImage-Edit
+
+## 2.1 QwenImage-Edit/edit
+
+#### QwenImage-Edit/edit vl encoder: take both image and text prompts
+QwenImageEditVLEncoderBlocks = InsertableDict(
+ [
+ ("resize", QwenImageEditResizeDynamicStep()),
+ ("encode", QwenImageEditTextEncoderStep()),
+ ]
+)
+
+
+class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditVLEncoderBlocks.values()
+ block_names = QwenImageEditVLEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "QwenImage-Edit VL encoder step that encode the image an text prompts together."
+
+
+#### QwenImage-Edit/edit vae encoder
+QwenImageEditVaeEncoderBlocks = InsertableDict(
+ [
+ ("resize", QwenImageEditResizeDynamicStep()), # edit has a different resize step
+ ("preprocess", QwenImageProcessImagesInputStep()), # resized_image -> processed_image
+ ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
+ ]
+)
+
+
+class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditVaeEncoderBlocks.values()
+ block_names = QwenImageEditVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "Vae encoder step that encode the image inputs into their latent representations."
+
+
+#### QwenImage-Edit/edit input
+QwenImageEditInputBlocks = InsertableDict(
+ [
+ ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
+ ("additional_inputs", QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"])),
+ ]
+)
+
+
+class QwenImageEditInputStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditInputBlocks.values()
+ block_names = QwenImageEditInputBlocks.keys()
+
+ @property
+ def description(self):
+ return "Input step that prepares the inputs for the edit denoising step. It:\n"
+ " - make sure the text embeddings have consistent batch size as well as the additional inputs: \n"
+ " - `image_latents`.\n"
+ " - update height/width based `image_latents`, patchify `image_latents`."
+
+
+#### QwenImage/edit presets
+EDIT_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageEditVLEncoderStep()),
+ ("vae_encoder", QwenImageEditVaeEncoderStep()),
+ ("input", QwenImageEditInputStep()),
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsStep()),
+ ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
+ ("denoise", QwenImageEditDenoiseStep()),
+ ("decode", QwenImageDecodeStep()),
+ ]
+)
+
+
+## 2.2 QwenImage-Edit/edit inpaint
+
+#### QwenImage-Edit/edit inpaint vae encoder: the difference from regular inpaint is the resize step
+QwenImageEditInpaintVaeEncoderBlocks = InsertableDict(
+ [
+ ("resize", QwenImageEditResizeDynamicStep()), # image -> resized_image
+ (
+ "preprocess",
+ QwenImageInpaintProcessImagesInputStep,
+ ), # resized_image, mask_image -> processed_image, processed_mask_image, mask_overlay_kwargs
+ (
+ "encode",
+ QwenImageVaeEncoderDynamicStep(input_name="processed_image", output_name="image_latents"),
+ ), # processed_image -> image_latents
+ ]
+)
+
+
+class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditInpaintVaeEncoderBlocks.values()
+ block_names = QwenImageEditInpaintVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return (
+ "This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:\n"
+ " - resize the image for target area (1024 * 1024) while maintaining the aspect ratio.\n"
+ " - process the resized image and mask image.\n"
+ " - create image latents."
+ )
+
+
+#### QwenImage-Edit/edit inpaint presets
+EDIT_INPAINT_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageEditVLEncoderStep()),
+ ("vae_encoder", QwenImageEditInpaintVaeEncoderStep()),
+ ("input", QwenImageInpaintInputStep()),
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
+ ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
+ ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
+ ("denoise", QwenImageEditInpaintDenoiseStep()),
+ ("decode", QwenImageInpaintDecodeStep()),
+ ]
+)
+
+
+## 2.3 QwenImage-Edit/auto encoders
+
+
+class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [
+ QwenImageEditInpaintVaeEncoderStep,
+ QwenImageEditVaeEncoderStep,
+ ]
+ block_names = ["edit_inpaint", "edit"]
+ block_trigger_inputs = ["mask_image", "image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae encoder step that encode the image inputs into their latent representations. \n"
+ " This is an auto pipeline block that works for edit and edit_inpaint tasks.\n"
+ + " - `QwenImageEditInpaintVaeEncoderStep` (edit_inpaint) is used when `mask_image` is provided.\n"
+ + " - `QwenImageEditVaeEncoderStep` (edit) is used when `image` is provided.\n"
+ + " - if `mask_image` or `image` is not provided, step will be skipped."
+ )
+
+
+## 2.4 QwenImage-Edit/auto inputs
+class QwenImageEditAutoInputStep(AutoPipelineBlocks):
+ block_classes = [QwenImageInpaintInputStep, QwenImageEditInputStep]
+ block_names = ["edit_inpaint", "edit"]
+ block_trigger_inputs = ["processed_mask_image", "image_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Input step that prepares the inputs for the edit denoising step.\n"
+ + " It is an auto pipeline block that works for edit and edit_inpaint tasks.\n"
+ + " - `QwenImageInpaintInputStep` (edit_inpaint) is used when `processed_mask_image` is provided.\n"
+ + " - `QwenImageEditInputStep` (edit) is used when `image_latents` is provided.\n"
+ + " - if `processed_mask_image` or `image_latents` is not provided, step will be skipped."
+ )
+
+
+## 2.5 QwenImage-Edit/auto before denoise
+# compose the steps into a BeforeDenoiseStep for edit and edit_inpaint tasks before combine into an auto step
+
+#### QwenImage-Edit/edit before denoise
+QwenImageEditBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsStep()),
+ ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
+ ]
+)
+
+
+class QwenImageEditBeforeDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditBeforeDenoiseBlocks.values()
+ block_names = QwenImageEditBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task."
+
+
+#### QwenImage-Edit/edit inpaint before denoise
+QwenImageEditInpaintBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsWithStrengthStep()),
+ ("prepare_inpaint_latents", QwenImageInpaintPrepareLatentsStep()),
+ ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
+ ]
+)
+
+
+class QwenImageEditInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditInpaintBeforeDenoiseBlocks.values()
+ block_names = QwenImageEditInpaintBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit inpaint task."
+
+
+# auto before_denoise step for edit and edit_inpaint tasks
+class QwenImageEditAutoBeforeDenoiseStep(AutoPipelineBlocks):
+ model_name = "qwenimage-edit"
+ block_classes = [
+ QwenImageEditInpaintBeforeDenoiseStep,
+ QwenImageEditBeforeDenoiseStep,
+ ]
+ block_names = ["edit_inpaint", "edit"]
+ block_trigger_inputs = ["processed_mask_image", "image_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
+ + "This is an auto pipeline block that works for edit (img2img) and edit inpaint tasks.\n"
+ + " - `QwenImageEditInpaintBeforeDenoiseStep` (edit_inpaint) is used when `processed_mask_image` is provided.\n"
+ + " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
+ + " - if `image_latents` or `processed_mask_image` is not provided, step will be skipped."
+ )
+
+
+## 2.6 QwenImage-Edit/auto denoise
+
+
+class QwenImageEditAutoDenoiseStep(AutoPipelineBlocks):
+ model_name = "qwenimage-edit"
+
+ block_classes = [QwenImageEditInpaintDenoiseStep, QwenImageEditDenoiseStep]
+ block_names = ["inpaint_denoise", "denoise"]
+ block_trigger_inputs = ["processed_mask_image", "image_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ + "This block supports edit (img2img) and edit inpaint tasks for QwenImage Edit. \n"
+ + " - `QwenImageEditInpaintDenoiseStep` (inpaint) is used when `processed_mask_image` is provided.\n"
+ + " - `QwenImageEditDenoiseStep` (img2img) is used when `image_latents` is provided.\n"
+ + " - if `processed_mask_image` or `image_latents` is not provided, step will be skipped."
+ )
+
+
+## 2.7 QwenImage-Edit/auto blocks & presets
+
+
+class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit"
+ block_classes = [
+ QwenImageEditAutoInputStep,
+ QwenImageEditAutoBeforeDenoiseStep,
+ QwenImageEditAutoDenoiseStep,
+ ]
+ block_names = ["input", "before_denoise", "denoise"]
+
+ @property
+ def description(self):
+ return (
+ "Core step that performs the denoising process. \n"
+ + " - `QwenImageEditAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ + " - `QwenImageEditAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ + " - `QwenImageEditAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n"
+ + "This step support edit (img2img) and edit inpainting workflow for QwenImage Edit:\n"
+ + " - When `processed_mask_image` is provided, it will be used for edit inpainting task.\n"
+ + " - When `image_latents` is provided, it will be used for edit (img2img) task.\n"
+ )
+
+
+EDIT_AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageEditVLEncoderStep()),
+ ("vae_encoder", QwenImageEditAutoVaeEncoderStep()),
+ ("denoise", QwenImageEditCoreDenoiseStep()),
+ ("decode", QwenImageAutoDecodeStep()),
+ ]
+)
+
+
+class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit"
+ block_classes = EDIT_AUTO_BLOCKS.values()
+ block_names = EDIT_AUTO_BLOCKS.keys()
+
+ @property
+ def description(self):
+ return (
+ "Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.\n"
+ + "- for edit (img2img) generation, you need to provide `image`\n"
+ + "- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
+ )
+
+
+#################### QwenImage Edit Plus #####################
+
+# 3. QwenImage-Edit Plus
+
+## 3.1 QwenImage-Edit Plus / edit
+
+#### QwenImage-Edit Plus vl encoder: take both image and text prompts
+QwenImageEditPlusVLEncoderBlocks = InsertableDict(
+ [
+ ("resize", QwenImageEditPlusResizeDynamicStep()),
+ ("encode", QwenImageEditPlusTextEncoderStep()),
+ ]
+)
+
+
+class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditPlusVLEncoderBlocks.values()
+ block_names = QwenImageEditPlusVLEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "QwenImage-Edit Plus VL encoder step that encode the image an text prompts together."
+
+
+#### QwenImage-Edit Plus vae encoder
+QwenImageEditPlusVaeEncoderBlocks = InsertableDict(
+ [
+ ("resize", QwenImageEditPlusResizeDynamicStep()), # edit plus has a different resize step
+ ("preprocess", QwenImageEditPlusProcessImagesInputStep()), # vae_image -> processed_image
+ ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
+ ]
+)
+
+
+class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditPlusVaeEncoderBlocks.values()
+ block_names = QwenImageEditPlusVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "Vae encoder step that encode the image inputs into their latent representations."
+
+
+#### QwenImage Edit Plus presets
+EDIT_PLUS_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageEditPlusVLEncoderStep()),
+ ("vae_encoder", QwenImageEditPlusVaeEncoderStep()),
+ ("input", QwenImageEditInputStep()),
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsStep()),
+ ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
+ ("denoise", QwenImageEditDenoiseStep()),
+ ("decode", QwenImageDecodeStep()),
+ ]
+)
+
+
+# auto before_denoise step for edit tasks
+class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks):
+ model_name = "qwenimage-edit-plus"
+ block_classes = [QwenImageEditBeforeDenoiseStep]
+ block_names = ["edit"]
+ block_trigger_inputs = ["image_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
+ + "This is an auto pipeline block that works for edit (img2img) task.\n"
+ + " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
+ + " - if `image_latents` is not provided, step will be skipped."
+ )
+
+
+## 3.2 QwenImage-Edit Plus/auto encoders
+
+
+class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [
+ QwenImageEditPlusVaeEncoderStep,
+ ]
+ block_names = ["edit"]
+ block_trigger_inputs = ["image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae encoder step that encode the image inputs into their latent representations. \n"
+ " This is an auto pipeline block that works for edit task.\n"
+ + " - `QwenImageEditPlusVaeEncoderStep` (edit) is used when `image` is provided.\n"
+ + " - if `image` is not provided, step will be skipped."
+ )
+
+
+## 3.3 QwenImage-Edit/auto blocks & presets
+
+
+class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit-plus"
+ block_classes = [
+ QwenImageEditAutoInputStep,
+ QwenImageEditPlusAutoBeforeDenoiseStep,
+ QwenImageEditAutoDenoiseStep,
+ ]
+ block_names = ["input", "before_denoise", "denoise"]
+
+ @property
+ def description(self):
+ return (
+ "Core step that performs the denoising process. \n"
+ + " - `QwenImageEditAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ + " - `QwenImageEditPlusAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ + " - `QwenImageEditAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n"
+ + "This step support edit (img2img) workflow for QwenImage Edit Plus:\n"
+ + " - When `image_latents` is provided, it will be used for edit (img2img) task.\n"
+ )
+
+
+EDIT_PLUS_AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageEditPlusVLEncoderStep()),
+ ("vae_encoder", QwenImageEditPlusAutoVaeEncoderStep()),
+ ("denoise", QwenImageEditPlusCoreDenoiseStep()),
+ ("decode", QwenImageAutoDecodeStep()),
+ ]
+)
+
+
+class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit-plus"
+ block_classes = EDIT_PLUS_AUTO_BLOCKS.values()
+ block_names = EDIT_PLUS_AUTO_BLOCKS.keys()
+
+ @property
+ def description(self):
+ return (
+ "Auto Modular pipeline for edit (img2img) and edit tasks using QwenImage-Edit Plus.\n"
+ + "- for edit (img2img) generation, you need to provide `image`\n"
+ )
+
+
+# 3. all block presets supported in QwenImage, QwenImage-Edit, QwenImage-Edit Plus
+
+
+ALL_BLOCKS = {
+ "text2image": TEXT2IMAGE_BLOCKS,
+ "img2img": IMAGE2IMAGE_BLOCKS,
+ "edit": EDIT_BLOCKS,
+ "edit_inpaint": EDIT_INPAINT_BLOCKS,
+ "edit_plus": EDIT_PLUS_BLOCKS,
+ "inpaint": INPAINT_BLOCKS,
+ "controlnet": CONTROLNET_BLOCKS,
+ "auto": AUTO_BLOCKS,
+ "edit_auto": EDIT_AUTO_BLOCKS,
+ "edit_plus_auto": EDIT_PLUS_AUTO_BLOCKS,
+}
diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py
new file mode 100644
index 0000000000..d9e30864f6
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py
@@ -0,0 +1,208 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import QwenImageLoraLoaderMixin
+from ..modular_pipeline import ModularPipeline
+
+
+class QwenImagePachifier(ConfigMixin):
+ """
+ A class to pack and unpack latents for QwenImage.
+ """
+
+ config_name = "config.json"
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 2,
+ ):
+ super().__init__()
+
+ def pack_latents(self, latents):
+ if latents.ndim != 4 and latents.ndim != 5:
+ raise ValueError(f"Latents must have 4 or 5 dimensions, but got {latents.ndim}")
+
+ if latents.ndim == 4:
+ latents = latents.unsqueeze(2)
+
+ batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width = latents.shape
+ patch_size = self.config.patch_size
+
+ if latent_height % patch_size != 0 or latent_width % patch_size != 0:
+ raise ValueError(
+ f"Latent height and width must be divisible by {patch_size}, but got {latent_height} and {latent_width}"
+ )
+
+ latents = latents.view(
+ batch_size,
+ num_channels_latents,
+ latent_height // patch_size,
+ patch_size,
+ latent_width // patch_size,
+ patch_size,
+ )
+ latents = latents.permute(
+ 0, 2, 4, 1, 3, 5
+ ) # Batch_size, num_patches_height, num_patches_width, num_channels_latents, patch_size, patch_size
+ latents = latents.reshape(
+ batch_size,
+ (latent_height // patch_size) * (latent_width // patch_size),
+ num_channels_latents * patch_size * patch_size,
+ )
+
+ return latents
+
+ def unpack_latents(self, latents, height, width, vae_scale_factor=8):
+ if latents.ndim != 3:
+ raise ValueError(f"Latents must have 3 dimensions, but got {latents.ndim}")
+
+ batch_size, num_patches, channels = latents.shape
+ patch_size = self.config.patch_size
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = patch_size * (int(height) // (vae_scale_factor * patch_size))
+ width = patch_size * (int(width) // (vae_scale_factor * patch_size))
+
+ latents = latents.view(
+ batch_size,
+ height // patch_size,
+ width // patch_size,
+ channels // (patch_size * patch_size),
+ patch_size,
+ patch_size,
+ )
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (patch_size * patch_size), 1, height, width)
+
+ return latents
+
+
+class QwenImageModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
+ """
+ A ModularPipeline for QwenImage.
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+ """
+
+ default_blocks_name = "QwenImageAutoBlocks"
+
+ @property
+ def default_height(self):
+ return self.default_sample_size * self.vae_scale_factor
+
+ @property
+ def default_width(self):
+ return self.default_sample_size * self.vae_scale_factor
+
+ @property
+ def default_sample_size(self):
+ return 128
+
+ @property
+ def vae_scale_factor(self):
+ vae_scale_factor = 8
+ if hasattr(self, "vae") and self.vae is not None:
+ vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
+ return vae_scale_factor
+
+ @property
+ def num_channels_latents(self):
+ num_channels_latents = 16
+ if hasattr(self, "transformer") and self.transformer is not None:
+ num_channels_latents = self.transformer.config.in_channels // 4
+ return num_channels_latents
+
+ @property
+ def is_guidance_distilled(self):
+ is_guidance_distilled = False
+ if hasattr(self, "transformer") and self.transformer is not None:
+ is_guidance_distilled = self.transformer.config.guidance_embeds
+ return is_guidance_distilled
+
+ @property
+ def requires_unconditional_embeds(self):
+ requires_unconditional_embeds = False
+
+ if hasattr(self, "guider") and self.guider is not None:
+ requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
+
+ return requires_unconditional_embeds
+
+
+class QwenImageEditModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
+ """
+ A ModularPipeline for QwenImage-Edit.
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+ """
+
+ default_blocks_name = "QwenImageEditAutoBlocks"
+
+ # YiYi TODO: qwen edit should not provide default height/width, should be derived from the resized input image (after adjustment) produced by the resize step.
+ @property
+ def default_height(self):
+ return self.default_sample_size * self.vae_scale_factor
+
+ @property
+ def default_width(self):
+ return self.default_sample_size * self.vae_scale_factor
+
+ @property
+ def default_sample_size(self):
+ return 128
+
+ @property
+ def vae_scale_factor(self):
+ vae_scale_factor = 8
+ if hasattr(self, "vae") and self.vae is not None:
+ vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
+ return vae_scale_factor
+
+ @property
+ def num_channels_latents(self):
+ num_channels_latents = 16
+ if hasattr(self, "transformer") and self.transformer is not None:
+ num_channels_latents = self.transformer.config.in_channels // 4
+ return num_channels_latents
+
+ @property
+ def is_guidance_distilled(self):
+ is_guidance_distilled = False
+ if hasattr(self, "transformer") and self.transformer is not None:
+ is_guidance_distilled = self.transformer.config.guidance_embeds
+ return is_guidance_distilled
+
+ @property
+ def requires_unconditional_embeds(self):
+ requires_unconditional_embeds = False
+
+ if hasattr(self, "guider") and self.guider is not None:
+ requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
+
+ return requires_unconditional_embeds
+
+
+class QwenImageEditPlusModularPipeline(QwenImageEditModularPipeline):
+ """
+ A ModularPipeline for QwenImage-Edit Plus.
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+ """
+
+ default_blocks_name = "QwenImageEditPlusAutoBlocks"
diff --git a/src/diffusers/modular_pipelines/qwenimage/node_utils.py b/src/diffusers/modular_pipelines/qwenimage/node_utils.py
new file mode 100644
index 0000000000..3230ece68a
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/node_utils.py
@@ -0,0 +1,95 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# mellon nodes
+QwenImage_NODE_TYPES_PARAMS_MAP = {
+ "controlnet": {
+ "inputs": [
+ "control_image",
+ "controlnet_conditioning_scale",
+ "control_guidance_start",
+ "control_guidance_end",
+ "height",
+ "width",
+ ],
+ "model_inputs": [
+ "controlnet",
+ "vae",
+ ],
+ "outputs": [
+ "controlnet_out",
+ ],
+ "block_names": ["controlnet_vae_encoder"],
+ },
+ "denoise": {
+ "inputs": [
+ "embeddings",
+ "width",
+ "height",
+ "seed",
+ "num_inference_steps",
+ "guidance_scale",
+ "image_latents",
+ "strength",
+ "controlnet",
+ ],
+ "model_inputs": [
+ "unet",
+ "guider",
+ "scheduler",
+ ],
+ "outputs": [
+ "latents",
+ "latents_preview",
+ ],
+ "block_names": ["denoise"],
+ },
+ "vae_encoder": {
+ "inputs": [
+ "image",
+ "width",
+ "height",
+ ],
+ "model_inputs": [
+ "vae",
+ ],
+ "outputs": [
+ "image_latents",
+ ],
+ },
+ "text_encoder": {
+ "inputs": [
+ "prompt",
+ "negative_prompt",
+ ],
+ "model_inputs": [
+ "text_encoders",
+ ],
+ "outputs": [
+ "embeddings",
+ ],
+ },
+ "decoder": {
+ "inputs": [
+ "latents",
+ ],
+ "model_inputs": [
+ "vae",
+ ],
+ "outputs": [
+ "images",
+ ],
+ },
+}
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py
index c56f4af1b8..70cbf0c1c7 100644
--- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py
@@ -22,12 +22,12 @@ from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, UNet2DConditionModel
-from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from ...models.controlnets.multicontrolnet import MultiControlNetModel
from ...schedulers import EulerDiscreteScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor, unwrap_module
from ..modular_pipeline import (
- PipelineBlock,
+ ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
@@ -195,7 +195,7 @@ def prepare_latents_img2img(
return latents
-class StableDiffusionXLInputStep(PipelineBlock):
+class StableDiffusionXLInputStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -213,11 +213,6 @@ class StableDiffusionXLInputStep(PipelineBlock):
def inputs(self) -> List[InputParam]:
return [
InputParam("num_images_per_prompt", default=1),
- ]
-
- @property
- def intermediate_inputs(self) -> List[str]:
- return [
InputParam(
"prompt_embeds",
required=True,
@@ -267,37 +262,37 @@ class StableDiffusionXLInputStep(PipelineBlock):
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="negative text embeddings used to guide the image generation",
),
OutputParam(
"pooled_prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="pooled text embeddings used to guide the image generation",
),
OutputParam(
"negative_pooled_prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="negative pooled text embeddings used to guide the image generation",
),
OutputParam(
"ip_adapter_embeds",
type_hint=List[torch.Tensor],
- kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="image embeddings for IP-Adapter",
),
OutputParam(
"negative_ip_adapter_embeds",
type_hint=List[torch.Tensor],
- kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="negative image embeddings for IP-Adapter",
),
]
@@ -394,7 +389,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
return components, state
-class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
+class StableDiffusionXLImg2ImgSetTimestepsStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -421,11 +416,6 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
InputParam("denoising_start"),
# YiYi TODO: do we need num_images_per_prompt here?
InputParam("num_images_per_prompt", default=1),
- ]
-
- @property
- def intermediate_inputs(self) -> List[str]:
- return [
InputParam(
"batch_size",
required=True,
@@ -543,7 +533,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
return components, state
-class StableDiffusionXLSetTimestepsStep(PipelineBlock):
+class StableDiffusionXLSetTimestepsStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -611,7 +601,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
return components, state
-class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
+class StableDiffusionXLInpaintPrepareLatentsStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -640,11 +630,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
"`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of "
"`denoising_start` being declared as an integer, the value of `strength` will be ignored.",
),
- ]
-
- @property
- def intermediate_inputs(self) -> List[str]:
- return [
InputParam("generator"),
InputParam(
"batch_size",
@@ -744,8 +729,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
timestep=None,
is_strength_max=True,
add_noise=True,
- return_noise=False,
- return_image_latents=False,
):
shape = (
batch_size,
@@ -768,7 +751,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
if image.shape[1] == 4:
image_latents = image.to(device=device, dtype=dtype)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
- elif return_image_latents or (latents is None and not is_strength_max):
+ elif latents is None and not is_strength_max:
image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(components, image=image, generator=generator)
image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
@@ -786,13 +769,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = image_latents.to(device)
- outputs = (latents,)
-
- if return_noise:
- outputs += (noise,)
-
- if return_image_latents:
- outputs += (image_latents,)
+ outputs = (latents, noise, image_latents)
return outputs
@@ -864,7 +841,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor
block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor
- block_state.latents, block_state.noise = self.prepare_latents_inpaint(
+ block_state.latents, block_state.noise, block_state.image_latents = self.prepare_latents_inpaint(
components,
block_state.batch_size * block_state.num_images_per_prompt,
components.num_channels_latents,
@@ -878,8 +855,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
timestep=block_state.latent_timestep,
is_strength_max=block_state.is_strength_max,
add_noise=block_state.add_noise,
- return_noise=True,
- return_image_latents=False,
)
# 7. Prepare mask latent variables
@@ -900,7 +875,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
return components, state
-class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
+class StableDiffusionXLImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -920,11 +895,6 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
InputParam("latents"),
InputParam("num_images_per_prompt", default=1),
InputParam("denoising_start"),
- ]
-
- @property
- def intermediate_inputs(self) -> List[InputParam]:
- return [
InputParam("generator"),
InputParam(
"latent_timestep",
@@ -981,7 +951,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
return components, state
-class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
+class StableDiffusionXLPrepareLatentsStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -1002,11 +972,6 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
InputParam("width"),
InputParam("latents"),
InputParam("num_images_per_prompt", default=1),
- ]
-
- @property
- def intermediate_inputs(self) -> List[InputParam]:
- return [
InputParam("generator"),
InputParam(
"batch_size",
@@ -1092,7 +1057,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
return components, state
-class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
+class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -1129,11 +1094,6 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
InputParam("num_images_per_prompt", default=1),
InputParam("aesthetic_score", default=6.0),
InputParam("negative_aesthetic_score", default=2.0),
- ]
-
- @property
- def intermediate_inputs(self) -> List[InputParam]:
- return [
InputParam(
"latents",
required=True,
@@ -1160,13 +1120,13 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
OutputParam(
"add_time_ids",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description="The time ids to condition the denoising process",
),
OutputParam(
"negative_add_time_ids",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description="The negative time ids to condition the denoising process",
),
OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"),
@@ -1316,7 +1276,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
return components, state
-class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
+class StableDiffusionXLPrepareAdditionalConditioningStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -1345,11 +1305,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
InputParam("crops_coords_top_left", default=(0, 0)),
InputParam("negative_crops_coords_top_left", default=(0, 0)),
InputParam("num_images_per_prompt", default=1),
- ]
-
- @property
- def intermediate_inputs(self) -> List[InputParam]:
- return [
InputParam(
"latents",
required=True,
@@ -1376,13 +1331,13 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
OutputParam(
"add_time_ids",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description="The time ids to condition the denoising process",
),
OutputParam(
"negative_add_time_ids",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description="The negative time ids to condition the denoising process",
),
OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"),
@@ -1499,7 +1454,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
return components, state
-class StableDiffusionXLControlNetInputStep(PipelineBlock):
+class StableDiffusionXLControlNetInputStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -1527,11 +1482,6 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
InputParam("controlnet_conditioning_scale", default=1.0),
InputParam("guess_mode", default=False),
InputParam("num_images_per_prompt", default=1),
- ]
-
- @property
- def intermediate_inputs(self) -> List[str]:
- return [
InputParam(
"latents",
required=True,
@@ -1718,7 +1668,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
return components, state
-class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
+class StableDiffusionXLControlNetUnionInputStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -1747,11 +1697,6 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
InputParam("controlnet_conditioning_scale", default=1.0),
InputParam("guess_mode", default=False),
InputParam("num_images_per_prompt", default=1),
- ]
-
- @property
- def intermediate_inputs(self) -> List[InputParam]:
- return [
InputParam(
"latents",
required=True,
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py
index e9f627636e..feb78e1ef1 100644
--- a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py
@@ -24,7 +24,7 @@ from ...models import AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from ...utils import logging
from ..modular_pipeline import (
- PipelineBlock,
+ ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
@@ -33,7 +33,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-class StableDiffusionXLDecodeStep(PipelineBlock):
+class StableDiffusionXLDecodeStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -56,17 +56,12 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
- ]
-
- @property
- def intermediate_inputs(self) -> List[str]:
- return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
- )
+ ),
]
@property
@@ -157,7 +152,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
return components, state
-class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
+class StableDiffusionXLInpaintOverlayMaskStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -184,11 +179,6 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
InputParam("image"),
InputParam("mask_image"),
InputParam("padding_mask_crop"),
- ]
-
- @property
- def intermediate_inputs(self) -> List[str]:
- return [
InputParam(
"images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py
index 7fe4a472ee..8a80257473 100644
--- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py
@@ -25,7 +25,7 @@ from ...utils import logging
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
- PipelineBlock,
+ ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# YiYi experimenting composible denoise loop
# loop step (1): prepare latent input for denoiser
-class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
+class StableDiffusionXLLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -55,7 +55,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
)
@property
- def intermediate_inputs(self) -> List[str]:
+ def inputs(self) -> List[str]:
return [
InputParam(
"latents",
@@ -73,7 +73,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
# loop step (1): prepare latent input for denoiser (with inpainting)
-class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
+class StableDiffusionXLInpaintLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -91,7 +91,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
)
@property
- def intermediate_inputs(self) -> List[str]:
+ def inputs(self) -> List[str]:
return [
InputParam(
"latents",
@@ -144,7 +144,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
# loop step (2): denoise the latents with guidance
-class StableDiffusionXLLoopDenoiser(PipelineBlock):
+class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -171,11 +171,6 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("cross_attention_kwargs"),
- ]
-
- @property
- def intermediate_inputs(self) -> List[str]:
- return [
InputParam(
"num_inference_steps",
required=True,
@@ -188,14 +183,14 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step.",
),
InputParam(
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds, "
"add_time_ids/negative_add_time_ids, "
"pooled_prompt_embeds/negative_pooled_prompt_embeds, "
"and ip_adapter_embeds/negative_ip_adapter_embeds (optional)."
- "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
+ "please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
]
@@ -243,13 +238,13 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
components.guider.cleanup_models(components.unet)
# Perform guidance
- block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
+ block_state.noise_pred = components.guider(guider_state)[0]
return components, block_state
# loop step (2): denoise the latents with guidance (with controlnet)
-class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
+class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -277,11 +272,6 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("cross_attention_kwargs"),
- ]
-
- @property
- def intermediate_inputs(self) -> List[str]:
- return [
InputParam(
"controlnet_cond",
required=True,
@@ -317,14 +307,14 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds, "
"add_time_ids/negative_add_time_ids, "
"pooled_prompt_embeds/negative_pooled_prompt_embeds, "
"and ip_adapter_embeds/negative_ip_adapter_embeds (optional)."
- "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
+ "please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
InputParam(
@@ -443,13 +433,13 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
components.guider.cleanup_models(components.unet)
# Perform guidance
- block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
+ block_state.noise_pred = components.guider(guider_state)[0]
return components, block_state
# loop step (3): scheduler step to update latents
-class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
+class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -470,11 +460,6 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("eta", default=0.0),
- ]
-
- @property
- def intermediate_inputs(self) -> List[str]:
- return [
InputParam("generator"),
]
@@ -507,7 +492,6 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
t,
block_state.latents,
**block_state.extra_step_kwargs,
- **block_state.scheduler_step_kwargs,
return_dict=False,
)[0]
@@ -520,7 +504,7 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
# loop step (3): scheduler step to update latents (with inpainting)
-class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
+class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -542,11 +526,6 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("eta", default=0.0),
- ]
-
- @property
- def intermediate_inputs(self) -> List[str]:
- return [
InputParam("generator"),
InputParam(
"timesteps",
@@ -610,7 +589,6 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
t,
block_state.latents,
**block_state.extra_step_kwargs,
- **block_state.scheduler_step_kwargs,
return_dict=False,
)[0]
@@ -660,7 +638,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
]
@property
- def loop_intermediate_inputs(self) -> List[InputParam]:
+ def loop_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
@@ -717,7 +695,7 @@ class StableDiffusionXLDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
- "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `StableDiffusionXLLoopBeforeDenoiser`\n"
" - `StableDiffusionXLLoopDenoiser`\n"
" - `StableDiffusionXLLoopAfterDenoiser`\n"
@@ -739,7 +717,7 @@ class StableDiffusionXLControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper
return (
"Denoise step that iteratively denoise the latents with controlnet. \n"
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
- "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `StableDiffusionXLLoopBeforeDenoiser`\n"
" - `StableDiffusionXLControlNetLoopDenoiser`\n"
" - `StableDiffusionXLLoopAfterDenoiser`\n"
@@ -761,7 +739,7 @@ class StableDiffusionXLInpaintDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
return (
"Denoise step that iteratively denoise the latents(for inpainting task only). \n"
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
- "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n"
" - `StableDiffusionXLLoopDenoiser`\n"
" - `StableDiffusionXLInpaintLoopAfterDenoiser`\n"
@@ -783,7 +761,7 @@ class StableDiffusionXLInpaintControlNetDenoiseStep(StableDiffusionXLDenoiseLoop
return (
"Denoise step that iteratively denoise the latents(for inpainting task only) with controlnet. \n"
"Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
- "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n"
" - `StableDiffusionXLControlNetLoopDenoiser`\n"
" - `StableDiffusionXLInpaintLoopAfterDenoiser`\n"
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py
index bd0e962140..90b254b6f5 100644
--- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py
@@ -35,7 +35,7 @@ from ...utils import (
scale_lora_layers,
unscale_lora_layers,
)
-from ..modular_pipeline import PipelineBlock, PipelineState
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import StableDiffusionXLModularPipeline
@@ -57,7 +57,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
-class StableDiffusionXLIPAdapterStep(PipelineBlock):
+class StableDiffusionXLIPAdapterStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -215,7 +215,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
return components, state
-class StableDiffusionXLTextEncoderStep(PipelineBlock):
+class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -258,25 +258,25 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description="text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description="negative text embeddings used to guide the image generation",
),
OutputParam(
"pooled_prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description="pooled text embeddings used to guide the image generation",
),
OutputParam(
"negative_pooled_prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description="negative pooled text embeddings used to guide the image generation",
),
]
@@ -576,7 +576,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
return components, state
-class StableDiffusionXLVaeEncoderStep(PipelineBlock):
+class StableDiffusionXLVaeEncoderStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -601,11 +601,6 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
InputParam("image", required=True),
InputParam("height"),
InputParam("width"),
- ]
-
- @property
- def intermediate_inputs(self) -> List[InputParam]:
- return [
InputParam("generator"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
InputParam(
@@ -668,12 +663,11 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
block_state.device = components._execution_device
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
- block_state.image = components.image_processor.preprocess(
+ image = components.image_processor.preprocess(
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
)
- block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
-
- block_state.batch_size = block_state.image.shape[0]
+ image = image.to(device=block_state.device, dtype=block_state.dtype)
+ block_state.batch_size = image.shape[0]
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
@@ -682,16 +676,14 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
)
- block_state.image_latents = self._encode_vae_image(
- components, image=block_state.image, generator=block_state.generator
- )
+ block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
self.set_block_state(state, block_state)
return components, state
-class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
+class StableDiffusionXLInpaintVaeEncoderStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl"
@property
@@ -726,11 +718,6 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
InputParam("image", required=True),
InputParam("mask_image", required=True),
InputParam("padding_mask_crop"),
- ]
-
- @property
- def intermediate_inputs(self) -> List[InputParam]:
- return [
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
InputParam("generator"),
]
@@ -860,34 +847,32 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
block_state.crops_coords = None
block_state.resize_mode = "default"
- block_state.image = components.image_processor.preprocess(
+ image = components.image_processor.preprocess(
block_state.image,
height=block_state.height,
width=block_state.width,
crops_coords=block_state.crops_coords,
resize_mode=block_state.resize_mode,
)
- block_state.image = block_state.image.to(dtype=torch.float32)
+ image = image.to(dtype=torch.float32)
- block_state.mask = components.mask_processor.preprocess(
+ mask = components.mask_processor.preprocess(
block_state.mask_image,
height=block_state.height,
width=block_state.width,
resize_mode=block_state.resize_mode,
crops_coords=block_state.crops_coords,
)
- block_state.masked_image = block_state.image * (block_state.mask < 0.5)
+ block_state.masked_image = image * (mask < 0.5)
- block_state.batch_size = block_state.image.shape[0]
- block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
- block_state.image_latents = self._encode_vae_image(
- components, image=block_state.image, generator=block_state.generator
- )
+ block_state.batch_size = image.shape[0]
+ image = image.to(device=block_state.device, dtype=block_state.dtype)
+ block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
# 7. Prepare mask latent variables
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
components,
- block_state.mask,
+ mask,
block_state.masked_image,
block_state.batch_size,
block_state.height,
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py
index c9033856bc..68b5e33755 100644
--- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py
@@ -82,19 +82,17 @@ class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks):
# before_denoise: text2img
class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
- StableDiffusionXLInputStep,
StableDiffusionXLSetTimestepsStep,
StableDiffusionXLPrepareLatentsStep,
StableDiffusionXLPrepareAdditionalConditioningStep,
]
- block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
+ block_names = ["set_timesteps", "prepare_latents", "prepare_add_cond"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is a sequential pipeline blocks:\n"
- + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n"
+ " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n"
+ " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
@@ -104,19 +102,17 @@ class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks):
# before_denoise: img2img
class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
- StableDiffusionXLInputStep,
StableDiffusionXLImg2ImgSetTimestepsStep,
StableDiffusionXLImg2ImgPrepareLatentsStep,
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
]
- block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
+ block_names = ["set_timesteps", "prepare_latents", "prepare_add_cond"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step for img2img task.\n"
+ "This is a sequential pipeline blocks:\n"
- + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n"
+ " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n"
+ " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
@@ -126,19 +122,17 @@ class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
# before_denoise: inpainting
class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
- StableDiffusionXLInputStep,
StableDiffusionXLImg2ImgSetTimestepsStep,
StableDiffusionXLInpaintPrepareLatentsStep,
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
]
- block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
+ block_names = ["set_timesteps", "prepare_latents", "prepare_add_cond"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step for inpainting task.\n"
+ "This is a sequential pipeline blocks:\n"
- + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n"
+ " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n"
+ " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
@@ -255,25 +249,48 @@ class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks):
)
+class StableDiffusionXLCoreDenoiseStep(SequentialPipelineBlocks):
+ block_classes = [
+ StableDiffusionXLInputStep,
+ StableDiffusionXLAutoBeforeDenoiseStep,
+ StableDiffusionXLAutoControlNetInputStep,
+ StableDiffusionXLAutoDenoiseStep,
+ ]
+ block_names = ["input", "before_denoise", "controlnet_input", "denoise"]
+
+ @property
+ def description(self):
+ return (
+ "Core step that performs the denoising process. \n"
+ + " - `StableDiffusionXLInputStep` (input) standardizes the inputs for the denoising step.\n"
+ + " - `StableDiffusionXLAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ + " - `StableDiffusionXLAutoControlNetInputStep` (controlnet_input) prepares the controlnet input.\n"
+ + " - `StableDiffusionXLAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n"
+ + "This step support text-to-image, image-to-image, inpainting, with or without controlnet/controlnet_union/ip_adapter for Stable Diffusion XL:\n"
+ + "- for image-to-image generation, you need to provide `image_latents`\n"
+ + "- for inpainting, you need to provide `mask_image` and `image_latents`\n"
+ + "- to run the controlnet workflow, you need to provide `control_image`\n"
+ + "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n"
+ + "- to run the ip_adapter workflow, you need to load ip_adapter into your unet and provide `ip_adapter_embeds`\n"
+ + "- for text-to-image generation, all you need to provide is prompt embeddings\n"
+ )
+
+
# ip-adapter, controlnet, text2img, img2img, inpainting
class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks):
block_classes = [
StableDiffusionXLTextEncoderStep,
StableDiffusionXLAutoIPAdapterStep,
StableDiffusionXLAutoVaeEncoderStep,
- StableDiffusionXLAutoBeforeDenoiseStep,
- StableDiffusionXLAutoControlNetInputStep,
- StableDiffusionXLAutoDenoiseStep,
+ StableDiffusionXLCoreDenoiseStep,
StableDiffusionXLAutoDecodeStep,
]
block_names = [
"text_encoder",
"ip_adapter",
- "image_encoder",
- "before_denoise",
- "controlnet_input",
+ "vae_encoder",
"denoise",
- "decoder",
+ "decode",
]
@property
@@ -321,7 +338,7 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
IMAGE2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", StableDiffusionXLTextEncoderStep),
- ("image_encoder", StableDiffusionXLVaeEncoderStep),
+ ("vae_encoder", StableDiffusionXLVaeEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
@@ -334,7 +351,7 @@ IMAGE2IMAGE_BLOCKS = InsertableDict(
INPAINT_BLOCKS = InsertableDict(
[
("text_encoder", StableDiffusionXLTextEncoderStep),
- ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep),
+ ("vae_encoder", StableDiffusionXLInpaintVaeEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep),
@@ -361,10 +378,8 @@ AUTO_BLOCKS = InsertableDict(
[
("text_encoder", StableDiffusionXLTextEncoderStep),
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
- ("image_encoder", StableDiffusionXLAutoVaeEncoderStep),
- ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep),
- ("controlnet_input", StableDiffusionXLAutoControlNetInputStep),
- ("denoise", StableDiffusionXLAutoDenoiseStep),
+ ("vae_encoder", StableDiffusionXLAutoVaeEncoderStep),
+ ("denoise", StableDiffusionXLCoreDenoiseStep),
("decode", StableDiffusionXLAutoDecodeStep),
]
)
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py
index fc030fae56..f2a4c96073 100644
--- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py
@@ -47,13 +47,11 @@ class StableDiffusionXLModularPipeline(
"""
A ModularPipeline for Stable Diffusion XL.
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
+ default_blocks_name = "StableDiffusionXLAutoBlocks"
+
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@@ -76,6 +74,7 @@ class StableDiffusionXLModularPipeline(
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor
+ # YiYi TODO: change to num_channels_latents
@property
def num_channels_unet(self):
num_channels_unet = 4
@@ -247,10 +246,6 @@ SDXL_INPUTS_SCHEMA = {
"control_mode": InputParam(
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
),
-}
-
-
-SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"prompt_embeds": InputParam(
"prompt_embeds",
type_hint=torch.Tensor,
@@ -271,13 +266,6 @@ SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"preprocess_kwargs": InputParam(
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
),
- "latents": InputParam(
- "latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
- ),
- "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
- "num_inference_steps": InputParam(
- "num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
- ),
"latent_timestep": InputParam(
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
),
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py
new file mode 100644
index 0000000000..3e788bf947
--- /dev/null
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py
@@ -0,0 +1,99 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+SDXL_NODE_TYPES_PARAMS_MAP = {
+ "controlnet": {
+ "inputs": [
+ "control_image",
+ "controlnet_conditioning_scale",
+ "control_guidance_start",
+ "control_guidance_end",
+ "height",
+ "width",
+ ],
+ "model_inputs": [
+ "controlnet",
+ ],
+ "outputs": [
+ "controlnet_out",
+ ],
+ "block_names": [None],
+ },
+ "denoise": {
+ "inputs": [
+ "embeddings",
+ "width",
+ "height",
+ "seed",
+ "num_inference_steps",
+ "guidance_scale",
+ "image_latents",
+ "strength",
+ # custom adapters coming in as inputs
+ "controlnet",
+ # ip_adapter is optional and custom; include if available
+ "ip_adapter",
+ ],
+ "model_inputs": [
+ "unet",
+ "guider",
+ "scheduler",
+ ],
+ "outputs": [
+ "latents",
+ "latents_preview",
+ ],
+ "block_names": ["denoise"],
+ },
+ "vae_encoder": {
+ "inputs": [
+ "image",
+ "width",
+ "height",
+ ],
+ "model_inputs": [
+ "vae",
+ ],
+ "outputs": [
+ "image_latents",
+ ],
+ "block_names": ["vae_encoder"],
+ },
+ "text_encoder": {
+ "inputs": [
+ "prompt",
+ "negative_prompt",
+ ],
+ "model_inputs": [
+ "text_encoders",
+ ],
+ "outputs": [
+ "embeddings",
+ ],
+ "block_names": ["text_encoder"],
+ },
+ "decoder": {
+ "inputs": [
+ "latents",
+ ],
+ "model_inputs": [
+ "vae",
+ ],
+ "outputs": [
+ "images",
+ ],
+ "block_names": ["decode"],
+ },
+}
diff --git a/src/diffusers/modular_pipelines/wan/before_denoise.py b/src/diffusers/modular_pipelines/wan/before_denoise.py
index ef65b64537..d48f678edd 100644
--- a/src/diffusers/modular_pipelines/wan/before_denoise.py
+++ b/src/diffusers/modular_pipelines/wan/before_denoise.py
@@ -20,7 +20,7 @@ import torch
from ...schedulers import UniPCMultistepScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
-from ..modular_pipeline import PipelineBlock, PipelineState
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import WanModularPipeline
@@ -94,7 +94,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class WanInputStep(PipelineBlock):
+class WanInputStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -146,13 +146,13 @@ class WanInputStep(PipelineBlock):
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="negative text embeddings used to guide the image generation",
),
]
@@ -194,7 +194,7 @@ class WanInputStep(PipelineBlock):
return components, state
-class WanSetTimestepsStep(PipelineBlock):
+class WanSetTimestepsStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -243,7 +243,7 @@ class WanSetTimestepsStep(PipelineBlock):
return components, state
-class WanPrepareLatentsStep(PipelineBlock):
+class WanPrepareLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
diff --git a/src/diffusers/modular_pipelines/wan/decoders.py b/src/diffusers/modular_pipelines/wan/decoders.py
index 4fadeed4b9..8c751172d8 100644
--- a/src/diffusers/modular_pipelines/wan/decoders.py
+++ b/src/diffusers/modular_pipelines/wan/decoders.py
@@ -22,14 +22,14 @@ from ...configuration_utils import FrozenDict
from ...models import AutoencoderKLWan
from ...utils import logging
from ...video_processor import VideoProcessor
-from ..modular_pipeline import PipelineBlock, PipelineState
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-class WanDecodeStep(PipelineBlock):
+class WanDecodeStep(ModularPipelineBlocks):
model_name = "wan"
@property
diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py
index 76c5cda5f9..66c51493bd 100644
--- a/src/diffusers/modular_pipelines/wan/denoise.py
+++ b/src/diffusers/modular_pipelines/wan/denoise.py
@@ -24,7 +24,7 @@ from ...utils import logging
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
- PipelineBlock,
+ ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
@@ -34,7 +34,7 @@ from .modular_pipeline import WanModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-class WanLoopDenoiser(PipelineBlock):
+class WanLoopDenoiser(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -79,11 +79,11 @@ class WanLoopDenoiser(PipelineBlock):
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds. "
- "Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
+ "Please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
]
@@ -127,12 +127,12 @@ class WanLoopDenoiser(PipelineBlock):
components.guider.cleanup_models(components.transformer)
# Perform guidance
- block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
+ block_state.noise_pred = components.guider(guider_state)[0]
return components, block_state
-class WanLoopAfterDenoiser(PipelineBlock):
+class WanLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -171,7 +171,6 @@ class WanLoopAfterDenoiser(PipelineBlock):
block_state.noise_pred.float(),
t,
block_state.latents.float(),
- **block_state.scheduler_step_kwargs,
return_dict=False,
)[0]
@@ -254,7 +253,7 @@ class WanDenoiseStep(WanDenoiseLoopWrapper):
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
- "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `WanLoopDenoiser`\n"
" - `WanLoopAfterDenoiser`\n"
"This block supports both text2vid tasks."
diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py
index b2ecfd1aa6..cb2fc24238 100644
--- a/src/diffusers/modular_pipelines/wan/encoders.py
+++ b/src/diffusers/modular_pipelines/wan/encoders.py
@@ -22,7 +22,7 @@ from transformers import AutoTokenizer, UMT5EncoderModel
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...utils import is_ftfy_available, logging
-from ..modular_pipeline import PipelineBlock, PipelineState
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import WanModularPipeline
@@ -51,7 +51,7 @@ def prompt_clean(text):
return text
-class WanTextEncoderStep(PipelineBlock):
+class WanTextEncoderStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -89,13 +89,13 @@ class WanTextEncoderStep(PipelineBlock):
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description="text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description="negative text embeddings used to guide the image generation",
),
]
diff --git a/src/diffusers/modular_pipelines/wan/modular_pipeline.py b/src/diffusers/modular_pipelines/wan/modular_pipeline.py
index 4d86e0d08e..e4adf3d151 100644
--- a/src/diffusers/modular_pipelines/wan/modular_pipeline.py
+++ b/src/diffusers/modular_pipelines/wan/modular_pipeline.py
@@ -30,13 +30,11 @@ class WanModularPipeline(
"""
A ModularPipeline for Wan.
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
+ default_blocks_name = "WanAutoBlocks"
+
@property
def default_height(self):
return self.default_sample_height * self.vae_scale_factor_spatial
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index aab7664fd2..190c7871d2 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -127,6 +127,7 @@ else:
"AnimateDiffVideoToVideoPipeline",
"AnimateDiffVideoToVideoControlNetPipeline",
]
+ _import_structure["bria"] = ["BriaPipeline"]
_import_structure["flux"] = [
"FluxControlPipeline",
"FluxControlInpaintPipeline",
@@ -284,6 +285,7 @@ else:
]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
+ _import_structure["lucy"] = ["LucyEditPipeline"]
_import_structure["marigold"].extend(
[
"MarigoldDepthPipeline",
@@ -387,7 +389,16 @@ else:
"SkyReelsV2ImageToVideoPipeline",
"SkyReelsV2Pipeline",
]
- _import_structure["qwenimage"] = ["QwenImagePipeline"]
+ _import_structure["qwenimage"] = [
+ "QwenImagePipeline",
+ "QwenImageImg2ImgPipeline",
+ "QwenImageInpaintPipeline",
+ "QwenImageEditPipeline",
+ "QwenImageEditPlusPipeline",
+ "QwenImageEditInpaintPipeline",
+ "QwenImageControlNetInpaintPipeline",
+ "QwenImageControlNetPipeline",
+ ]
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -547,6 +558,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
+ from .bria import BriaPipeline
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
from .cogvideo import (
CogVideoXFunControlPipeline,
@@ -672,6 +684,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusionXL,
)
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
+ from .lucy import LucyEditPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
from .marigold import (
@@ -704,7 +717,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
- from .qwenimage import QwenImagePipeline
+ from .qwenimage import (
+ QwenImageControlNetInpaintPipeline,
+ QwenImageControlNetPipeline,
+ QwenImageEditInpaintPipeline,
+ QwenImageEditPipeline,
+ QwenImageEditPlusPipeline,
+ QwenImageImg2ImgPipeline,
+ QwenImageInpaintPipeline,
+ QwenImagePipeline,
+ )
from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py
index 0993c8b912..3be0129088 100644
--- a/src/diffusers/pipelines/allegro/pipeline_allegro.py
+++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py
@@ -651,6 +651,12 @@ class AllegroPipeline(DiffusionPipeline):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -658,6 +664,12 @@ class AllegroPipeline(DiffusionPipeline):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -666,6 +678,12 @@ class AllegroPipeline(DiffusionPipeline):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -673,6 +691,12 @@ class AllegroPipeline(DiffusionPipeline):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
@property
@@ -760,7 +784,7 @@ class AllegroPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
index 260669ddaf..56d3190275 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
@@ -971,7 +971,7 @@ class AnimateDiffSDXLPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
index 546ae9239a..b6b40cd6e6 100644
--- a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
+++ b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
@@ -17,7 +17,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
diff --git a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
index 0af2e1fe36..452fc3c01b 100644
--- a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
+++ b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
@@ -34,6 +34,7 @@ from transformers import (
from ...models import AutoencoderKL
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
+ deprecate,
is_accelerate_available,
is_accelerate_version,
is_librosa_available,
@@ -228,6 +229,12 @@ class AudioLDM2Pipeline(DiffusionPipeline):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing
@@ -236,6 +243,12 @@ class AudioLDM2Pipeline(DiffusionPipeline):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
index 7ff9925c45..6251ca4435 100644
--- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
+++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
@@ -497,7 +497,7 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index ebabf17995..8a32d4c367 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -91,6 +91,15 @@ from .pag import (
StableDiffusionXLPAGPipeline,
)
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
+from .qwenimage import (
+ QwenImageControlNetPipeline,
+ QwenImageEditInpaintPipeline,
+ QwenImageEditPipeline,
+ QwenImageEditPlusPipeline,
+ QwenImageImg2ImgPipeline,
+ QwenImageInpaintPipeline,
+ QwenImagePipeline,
+)
from .sana import SanaPipeline
from .stable_cascade import StableCascadeCombinedPipeline, StableCascadeDecoderPipeline
from .stable_diffusion import (
@@ -150,6 +159,8 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("cogview3", CogView3PlusPipeline),
("cogview4", CogView4Pipeline),
("cogview4-control", CogView4ControlPipeline),
+ ("qwenimage", QwenImagePipeline),
+ ("qwenimage-controlnet", QwenImageControlNetPipeline),
]
)
@@ -174,6 +185,9 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("flux-controlnet", FluxControlNetImg2ImgPipeline),
("flux-control", FluxControlImg2ImgPipeline),
("flux-kontext", FluxKontextPipeline),
+ ("qwenimage", QwenImageImg2ImgPipeline),
+ ("qwenimage-edit", QwenImageEditPipeline),
+ ("qwenimage-edit-plus", QwenImageEditPlusPipeline),
]
)
@@ -195,6 +209,8 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
("flux-controlnet", FluxControlNetInpaintPipeline),
("flux-control", FluxControlInpaintPipeline),
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
+ ("qwenimage", QwenImageInpaintPipeline),
+ ("qwenimage-edit", QwenImageEditInpaintPipeline),
]
)
@@ -393,12 +409,8 @@ class AutoPipelineForText2Image(ConfigMixin):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf
- auth login`.
-
-
+ > [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
+ with `hf > auth login`.
Examples:
@@ -688,12 +700,8 @@ class AutoPipelineForImage2Image(ConfigMixin):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf
- auth login`.
-
-
+ > [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
+ with `hf > auth login`.
Examples:
@@ -998,12 +1006,8 @@ class AutoPipelineForInpainting(ConfigMixin):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf
- auth login`.
-
-
+ > [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
+ with `hf > auth login`.
Examples:
diff --git a/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py b/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py
index 928698e442..b061ac2636 100644
--- a/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py
+++ b/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py
@@ -14,7 +14,6 @@
from typing import Optional, Tuple, Union
import torch
-import torch.utils.checkpoint
from torch import nn
from transformers import BertTokenizer
from transformers.activations import QuickGELUActivation as QuickGELU
diff --git a/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py b/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
index 439dc511a0..705d930b59 100644
--- a/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
+++ b/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
@@ -19,11 +19,7 @@ from transformers import CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import PNDMScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
from .blip_image_processing import BlipImageProcessor
@@ -228,7 +224,7 @@ class BlipDiffusionPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by random sampling.
+ tensor will be generated by random sampling.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
diff --git a/src/diffusers/pipelines/bria/__init__.py b/src/diffusers/pipelines/bria/__init__.py
new file mode 100644
index 0000000000..60e319ac79
--- /dev/null
+++ b/src/diffusers/pipelines/bria/__init__.py
@@ -0,0 +1,48 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_bria"] = ["BriaPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_bria import BriaPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/bria/pipeline_bria.py b/src/diffusers/pipelines/bria/pipeline_bria.py
new file mode 100644
index 0000000000..ebddfb0c0e
--- /dev/null
+++ b/src/diffusers/pipelines/bria/pipeline_bria.py
@@ -0,0 +1,729 @@
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import FluxLoraLoaderMixin
+from ...models import AutoencoderKL
+from ...models.transformers.transformer_bria import BriaTransformer2DModel
+from ...pipelines import DiffusionPipeline
+from ...pipelines.bria.pipeline_output import BriaPipelineOutput
+from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
+from ...schedulers import (
+ DDIMScheduler,
+ EulerAncestralDiscreteScheduler,
+ FlowMatchEulerDiscreteScheduler,
+ KarrasDiffusionSchedulers,
+)
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import BriaPipeline
+
+ >>> pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.2", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ # BRIA's T5 text encoder is sensitive to precision. We need to cast it to bfloat16 and keep the final layer in float32.
+
+ >>> pipe.text_encoder = pipe.text_encoder.to(dtype=torch.bfloat16)
+ >>> for block in pipe.text_encoder.encoder.block:
+ ... block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
+ # BRIA's VAE is not supported in mixed precision, so we use float32.
+
+ >>> if pipe.vae.config.shift_factor == 0:
+ ... pipe.vae.to(dtype=torch.float32)
+
+ >>> prompt = "Photorealistic food photography of a stack of fluffy pancakes on a white plate, with maple syrup being poured over them. On top of the pancakes are the words 'BRIA 3.2' in bold, yellow, 3D letters. The background is dark and out of focus."
+ >>> image = pipe(prompt).images[0]
+ >>> image.save("bria.png")
+ ```
+"""
+
+
+def is_ng_none(negative_prompt):
+ return (
+ negative_prompt is None
+ or negative_prompt == ""
+ or (isinstance(negative_prompt, list) and negative_prompt[0] is None)
+ or (type(negative_prompt) == list and negative_prompt[0] == "")
+ )
+
+
+def get_original_sigmas(num_train_timesteps=1000, num_inference_steps=1000):
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
+ sigmas = timesteps / num_train_timesteps
+
+ inds = [int(ind) for ind in np.linspace(0, num_train_timesteps - 1, num_inference_steps)]
+ new_sigmas = sigmas[inds]
+ return new_sigmas
+
+
+class BriaPipeline(DiffusionPipeline):
+ r"""
+ Based on FluxPipeline with several changes:
+ - no pooled embeddings
+ - We use zero padding for prompts
+ - No guidance embedding since this is not a distilled version
+
+ Args:
+ transformer ([`BriaTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. Bria uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`T5TokenizerFast`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ transformer: BriaTransformer2DModel,
+ scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
+ vae: AutoencoderKL,
+ text_encoder: T5EncoderModel,
+ tokenizer: T5TokenizerFast,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ ):
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
+ )
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k
+
+ if self.vae.config.shift_factor is None:
+ self.vae.config.shift_factor = 0
+ self.vae.to(dtype=torch.float32)
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 128,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ if not is_ng_none(negative_prompt):
+ negative_prompt = (
+ batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ )
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ else:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device)
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
+
+ return prompt_embeds, negative_prompt_embeds, text_ids
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @attention_kwargs.setter
+ def attention_kwargs(self, value):
+ self._attention_kwargs = value
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 128,
+ device: Optional[torch.device] = None,
+ ):
+ tokenizer = self.tokenizer
+ text_encoder = self.text_encoder
+ device = device or text_encoder.device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+ prompt_embeds_list = []
+ for p in prompt:
+ text_inputs = tokenizer(
+ p,
+ # padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
+
+ # Concat zeros to max_sequence
+ b, seq_len, dim = prompt_embeds.shape
+ if seq_len < max_sequence_length:
+ padding = torch.zeros(
+ (b, max_sequence_length - seq_len, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
+ )
+ prompt_embeds = torch.concat([prompt_embeds, padding], dim=1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=0)
+ prompt_embeds = prompt_embeds.to(device=device)
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, max_sequence_length, -1)
+ prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
+ return prompt_embeds
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // self.vae_scale_factor)
+ width = 2 * (int(width) // self.vae_scale_factor)
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ return latents, latent_image_ids
+
+ @staticmethod
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ height = height // vae_scale_factor
+ width = width // vae_scale_factor
+
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
+
+ return latents
+
+ @staticmethod
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1)
+ latent_image_ids = latent_image_ids.reshape(
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 30,
+ timesteps: List[int] = None,
+ guidance_scale: float = 5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 128,
+ clip_value: Union[None, float] = None,
+ normalize: bool = False,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.bria.BriaPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.bria.BriaPipelineOutput`] or `tuple`: [`~pipelines.bria.BriaPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ prompt_embeds=prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self.attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
+
+ (prompt_embeds, negative_prompt_embeds, text_ids) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
+ latents, latent_image_ids = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ if (
+ isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler)
+ and self.scheduler.config["use_dynamic_shifting"]
+ ):
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+ image_seq_len = latents.shape[1]
+
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ timesteps,
+ sigmas,
+ mu=mu,
+ )
+ else:
+ # 4. Prepare timesteps
+ # Sample from training sigmas
+ if isinstance(self.scheduler, DDIMScheduler) or isinstance(
+ self.scheduler, EulerAncestralDiscreteScheduler
+ ):
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, None, None
+ )
+ else:
+ sigmas = get_original_sigmas(
+ num_train_timesteps=self.scheduler.config.num_train_timesteps,
+ num_inference_steps=num_inference_steps,
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
+ )
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ if len(latent_image_ids.shape) == 3:
+ latent_image_ids = latent_image_ids[0]
+ if len(text_ids.shape) == 3:
+ text_ids = text_ids[0]
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ if type(self.scheduler) != FlowMatchEulerDiscreteScheduler:
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # This is predicts "v" from flow-matching or eps from diffusion
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ cfg_noise_pred_text = noise_pred_text.std()
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if normalize:
+ noise_pred = noise_pred * (0.7 * (cfg_noise_pred_text / noise_pred.std())) + 0.3 * noise_pred
+
+ if clip_value:
+ assert clip_value > 0
+ noise_pred = noise_pred.clip(-clip_value, clip_value)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents.to(dtype=torch.float32) / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return BriaPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/bria/pipeline_output.py b/src/diffusers/pipelines/bria/pipeline_output.py
new file mode 100644
index 0000000000..54eed06233
--- /dev/null
+++ b/src/diffusers/pipelines/bria/pipeline_output.py
@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class BriaPipelineOutput(BaseOutput):
+ """
+ Output class for Bria pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py
index 3a34ec2a42..5482035b3a 100644
--- a/src/diffusers/pipelines/chroma/pipeline_chroma.py
+++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py
@@ -25,6 +25,7 @@ from ...models import AutoencoderKL, ChromaTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -237,7 +238,7 @@ class ChromaPipeline(
# Chroma requires the attention mask to include one padding token
seq_lengths = attention_mask.sum(dim=1)
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
- attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
+ attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool()
prompt_embeds = self.text_encoder(
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
@@ -245,7 +246,7 @@ class ChromaPipeline(
dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
- attention_mask = attention_mask.to(dtype=dtype, device=device)
+ attention_mask = attention_mask.to(device=device)
_, seq_len, _ = prompt_embeds.shape
@@ -508,6 +509,12 @@ class ChromaPipeline(
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -515,6 +522,12 @@ class ChromaPipeline(
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -523,6 +536,12 @@ class ChromaPipeline(
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -530,6 +549,12 @@ class ChromaPipeline(
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
@@ -580,10 +605,9 @@ class ChromaPipeline(
# Extend the prompt attention mask to account for image tokens in the final sequence
attention_mask = torch.cat(
- [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
+ [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool)],
dim=1,
)
- attention_mask = attention_mask.to(dtype)
return attention_mask
@@ -663,11 +687,11 @@ class ChromaPipeline(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
- Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
- a model to generate images more aligned with `prompt` at the expense of lower image quality.
-
- Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
- the [paper](https://huggingface.co/papers/2210.03142) to learn more.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -676,7 +700,7 @@ class ChromaPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py
index e169db4a4d..9afd4b9e15 100644
--- a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py
+++ b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py
@@ -25,6 +25,7 @@ from ...models import AutoencoderKL, ChromaTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -542,6 +543,12 @@ class ChromaImg2ImgPipeline(
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -549,6 +556,12 @@ class ChromaImg2ImgPipeline(
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -557,6 +570,12 @@ class ChromaImg2ImgPipeline(
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -564,6 +583,12 @@ class ChromaImg2ImgPipeline(
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
@@ -724,12 +749,12 @@ class ChromaImg2ImgPipeline(
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
- guidance_scale (`float`, *optional*, defaults to 5.0):
- Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
- a model to generate images more aligned with `prompt` at the expense of lower image quality.
-
- Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
- the [paper](https://huggingface.co/papers/2210.03142) to learn more.
+ guidance_scale (`float`, *optional*, defaults to 3.5):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
strength (`float, *optional*, defaults to 0.9):
Conceptually, indicates how much to transform the reference image. Must be between 0 and 1. image will
be used as a starting point, adding more noise to it the larger the strength. The number of denoising
@@ -744,7 +769,7 @@ class ChromaImg2ImgPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
index 3c5994172c..4ac33b24bb 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
@@ -571,7 +571,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
index cf6ccebc47..c1335839f8 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
@@ -616,7 +616,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
control_video_latents (`torch.Tensor`, *optional*):
Pre-generated control latents, sampled from a Gaussian distribution, to be used as inputs for
controlled video generation. If not provided, `control_video` must be provided.
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
index d1f02ca9c9..c523c9adec 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
@@ -28,11 +28,7 @@ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import CogVideoXPipelineOutput
@@ -671,7 +667,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
index 230c8ca296..897dc6d1b7 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
@@ -641,7 +641,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
index f2f852c213..304a5c5ad0 100644
--- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
+++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
@@ -466,7 +466,7 @@ class CogView3PlusPipeline(DiffusionPipeline):
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index d8374b694f..22510f5d9d 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -466,7 +466,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py
index ac8d786f04..e26b7ba415 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py
@@ -499,7 +499,7 @@ class CogView4ControlPipeline(DiffusionPipeline):
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/consisid/pipeline_consisid.py b/src/diffusers/pipelines/consisid/pipeline_consisid.py
index 644bd811f6..3e6c149d7f 100644
--- a/src/diffusers/pipelines/consisid/pipeline_consisid.py
+++ b/src/diffusers/pipelines/consisid/pipeline_consisid.py
@@ -733,7 +733,7 @@ class ConsisIDPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
index dec448f3f4..1fbdeb1f27 100644
--- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
+++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
@@ -18,11 +18,7 @@ import torch
from ...models import UNet2DModel
from ...schedulers import CMStochasticIterativeScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
index 598e3b5b6d..e0f1879405 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
@@ -20,11 +20,7 @@ from transformers import CLIPTokenizer
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import PNDMScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..blip_diffusion.blip_image_processing import BlipImageProcessor
from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel
@@ -279,7 +275,7 @@ class BlipDiffusionControlNetPipeline(DeprecatedPipelineMixin, DiffusionPipeline
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by random sampling.
+ tensor will be generated by random sampling.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
index 41303d9c5c..6de8e5747b 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
@@ -146,16 +146,13 @@ class StableDiffusionControlNetInpaintPipeline(
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
-
-
- This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting
+ > [!TIP] > This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting >
([stable-diffusion-v1-5/stable-diffusion-inpainting](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting))
- as well as default text-to-image Stable Diffusion checkpoints
+ > as well as default text-to-image Stable Diffusion checkpoints >
([stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)).
- Default text-to-image Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on
- those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
-
-
+ > Default text-to-image Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned
+ on > those, such as
+ [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
Args:
vae ([`AutoencoderKL`]):
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
index 4aa2a62a53..397ab15715 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
@@ -1326,7 +1326,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
index 526e1ffcb2..4d4845c5a0 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
@@ -1197,7 +1197,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
index 7fa59395a8..fb58b22211 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
@@ -1310,7 +1310,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
index 65e2fe6617..8fedb6d860 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
@@ -1185,7 +1185,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
index 1de1d4bde7..d4c6f336df 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
@@ -394,12 +394,8 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions.
-
-
- This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a
- future release.
-
-
+ > [!WARNING] > This argument exists because `__call__` is not yet end-to-end pmap-able. It will be
+ removed in a > future release.
Examples:
diff --git a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
index 9b9adf4901..2b5684de95 100644
--- a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
+++ b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
@@ -27,11 +27,7 @@ from ...models import AutoencoderKL, HunyuanDiT2DControlNetModel, HunyuanDiT2DMo
from ...models.embeddings import get_2d_rotary_pos_embed
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ...schedulers import DDPMScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
index e31e3a0178..c763411ab5 100644
--- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
+++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
@@ -918,7 +918,7 @@ class StableDiffusion3ControlNetPipeline(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
index 000e080d3a..c33cf979c6 100644
--- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
+++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
@@ -973,7 +973,7 @@ class StableDiffusion3ControlNetInpaintingPipeline(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
index f9034a5844..d000d87e6a 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
@@ -880,7 +880,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -1151,7 +1151,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
index eda950998d..397fbc0d85 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
@@ -1000,11 +1000,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -1021,11 +1017,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
index 2beb0be57b..034a022641 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
@@ -18,7 +18,6 @@ from typing import Callable, List, Optional, Union
import numpy as np
import PIL.Image
import torch
-import torch.utils.checkpoint
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ....image_processor import VaeImageProcessor
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
index adfd899e76..2f54f4fc98 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
@@ -16,7 +16,6 @@ import inspect
from typing import Callable, List, Optional, Union
import torch
-import torch.utils.checkpoint
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer
from ....image_processor import VaeImageProcessor
diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py
index 7211fb5693..5041e352f7 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux.py
@@ -32,6 +32,7 @@ from ...models import AutoencoderKL, FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -310,7 +311,7 @@ class FluxPipeline(
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -545,6 +546,12 @@ class FluxPipeline(
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -552,6 +559,12 @@ class FluxPipeline(
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -560,6 +573,12 @@ class FluxPipeline(
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -567,6 +586,12 @@ class FluxPipeline(
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def prepare_latents(
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py
index 5a057f94cf..848d7bd392 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_control.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py
@@ -26,6 +26,7 @@ from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -324,7 +325,7 @@ class FluxControlPipeline(
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -496,6 +497,12 @@ class FluxControlPipeline(
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -503,6 +510,12 @@ class FluxControlPipeline(
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -511,6 +524,12 @@ class FluxControlPipeline(
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -518,6 +537,12 @@ class FluxControlPipeline(
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
@@ -674,7 +699,7 @@ class FluxControlPipeline(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py
index 8d5439daf6..262345c75a 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py
@@ -335,7 +335,7 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -712,7 +712,7 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py
index 872bcf177c..6915a83a7c 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py
@@ -35,6 +35,7 @@ from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -374,7 +375,7 @@ class FluxControlInpaintPipeline(
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -577,6 +578,12 @@ class FluxControlInpaintPipeline(
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -584,6 +591,12 @@ class FluxControlInpaintPipeline(
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -592,6 +605,12 @@ class FluxControlInpaintPipeline(
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -599,6 +618,12 @@ class FluxControlInpaintPipeline(
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def prepare_latents(
@@ -838,7 +863,7 @@ class FluxControlInpaintPipeline(
1)`, or `(H, W)`.
mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
`Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
- latents tensor will ge generated by `mask_image`.
+ latents tensor will be generated by `mask_image`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -870,7 +895,7 @@ class FluxControlInpaintPipeline(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
index 1438d4a902..507ec68734 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
@@ -341,7 +341,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -764,7 +764,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
index 52e15de53b..582c7bbad8 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
@@ -335,7 +335,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
index d1e874d0b8..f7f34ef231 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
@@ -346,7 +346,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py
index ddfb284eaf..5cb9c82204 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py
@@ -26,6 +26,7 @@ from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -419,7 +420,7 @@ class FluxFillPipeline(
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -633,6 +634,12 @@ class FluxFillPipeline(
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -640,6 +647,12 @@ class FluxFillPipeline(
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -648,6 +661,12 @@ class FluxFillPipeline(
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -655,6 +674,12 @@ class FluxFillPipeline(
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
@@ -775,7 +800,7 @@ class FluxFillPipeline(
1)`, or `(H, W)`.
mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
`Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
- latents tensor will ge generated by `mask_image`.
+ latents tensor will be generated by `mask_image`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -807,7 +832,7 @@ class FluxFillPipeline(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
index 1c4cf3b1cd..ab9140dae9 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
@@ -33,6 +33,7 @@ from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -333,7 +334,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -613,6 +614,12 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
@@ -621,6 +628,12 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
@@ -630,6 +643,12 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
@@ -638,6 +657,12 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def prepare_latents(
@@ -787,7 +812,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
index eeacd9b19b..3bfe82cf43 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
@@ -337,7 +337,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -834,7 +834,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
1)`, or `(H, W)`.
mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
`Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
- latents tensor will ge generated by `mask_image`.
+ latents tensor will be generated by `mask_image`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -873,7 +873,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py
index 3c78aeaf36..94ae460afc 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py
@@ -32,6 +32,7 @@ from ...models import AutoencoderKL, FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -358,7 +359,7 @@ class FluxKontextPipeline(
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -614,6 +615,12 @@ class FluxKontextPipeline(
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
@@ -622,6 +629,12 @@ class FluxKontextPipeline(
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
@@ -631,6 +644,12 @@ class FluxKontextPipeline(
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
@@ -639,6 +658,12 @@ class FluxKontextPipeline(
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def prepare_latents(
@@ -808,7 +833,7 @@ class FluxKontextPipeline(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py
index 6dc621901c..b6f957981e 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py
@@ -22,6 +22,7 @@ from ...models import AutoencoderKL, FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -391,7 +392,7 @@ class FluxKontextInpaintPipeline(
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -688,6 +689,12 @@ class FluxKontextInpaintPipeline(
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
@@ -696,6 +703,12 @@ class FluxKontextInpaintPipeline(
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
@@ -705,6 +718,12 @@ class FluxKontextInpaintPipeline(
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
@@ -713,6 +732,12 @@ class FluxKontextInpaintPipeline(
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def prepare_latents(
@@ -1029,7 +1054,7 @@ class FluxKontextInpaintPipeline(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py
index b5ccfb31a3..e79db337b2 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py
@@ -292,7 +292,7 @@ class FluxPriorReduxPipeline(DiffusionPipeline):
def encode_prompt(
self,
prompt: Union[str, List[str]],
- prompt_2: Union[str, List[str]],
+ prompt_2: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
diff --git a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
index 695f54f3d9..b6af23bca8 100644
--- a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
+++ b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
@@ -522,6 +522,12 @@ class HiDreamImagePipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -529,6 +535,12 @@ class HiDreamImagePipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -537,6 +549,12 @@ class HiDreamImagePipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -544,6 +562,12 @@ class HiDreamImagePipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def check_inputs(
@@ -789,7 +813,7 @@ class HiDreamImagePipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin):
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py
index d8c3548946..b50a6ae3ed 100644
--- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py
+++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py
@@ -24,7 +24,7 @@ from ...image_processor import PipelineImageInput
from ...loaders import HunyuanVideoLoraLoaderMixin
from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
@@ -463,6 +463,12 @@ class HunyuanSkyreelsImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoa
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -470,6 +476,12 @@ class HunyuanSkyreelsImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoa
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -478,6 +490,12 @@ class HunyuanSkyreelsImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoa
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -485,6 +503,12 @@ class HunyuanSkyreelsImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoa
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
@property
diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
index 76b288ed0b..5c8e295eaf 100644
--- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
+++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
@@ -23,7 +23,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import HunyuanVideoLoraLoaderMixin
from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
@@ -420,6 +420,12 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -427,6 +433,12 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -435,6 +447,12 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -442,6 +460,12 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
@property
diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py
index 40d6534655..8006514f47 100644
--- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py
+++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py
@@ -33,7 +33,7 @@ from ...image_processor import PipelineImageInput
from ...loaders import HunyuanVideoLoraLoaderMixin
from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoFramepackTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
@@ -570,6 +570,12 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -577,6 +583,12 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -585,6 +597,12 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -592,6 +610,12 @@ class HunyuanVideoFramepackPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMix
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
@property
diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py
index b9246e2eb2..aa04e65097 100644
--- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py
+++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py
@@ -30,7 +30,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import HunyuanVideoLoraLoaderMixin
from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
@@ -598,6 +598,12 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -605,6 +611,12 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -613,6 +625,12 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -620,6 +638,12 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
@property
diff --git a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
index c7f84866fe..e2f935aaf4 100644
--- a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
+++ b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
@@ -27,11 +27,7 @@ from ...models import AutoencoderKL, HunyuanDiT2DModel
from ...models.embeddings import get_2d_rotary_pos_embed
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ...schedulers import DDPMScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
index 89fea89337..33529f5d09 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
@@ -21,11 +21,7 @@ from transformers import (
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler, DDPMScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP
@@ -291,7 +287,7 @@ class KandinskyPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
index 90d4042ae2..7286bcbee1 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
@@ -271,7 +271,7 @@ class KandinskyCombinedPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
@@ -502,7 +502,7 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
@@ -742,7 +742,7 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
index 998fc777c0..f5e41d499d 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
@@ -23,11 +23,7 @@ from transformers import (
from ...image_processor import VaeImageProcessor
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
index 5645d2a56e..731fce4998 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
@@ -28,11 +28,7 @@ from transformers import (
from ... import __version__
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_encoder import MultilingualCLIP
@@ -469,7 +465,7 @@ class KandinskyInpaintPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
index 8781d706ed..10ea8005c9 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
@@ -212,7 +212,7 @@ class KandinskyPriorPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
negative_prior_prompt (`str`, *optional*):
The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if
`guidance_scale` is less than `1`).
@@ -437,7 +437,7 @@ class KandinskyPriorPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
index 3ecc0ebd5b..429253e998 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
@@ -175,7 +175,7 @@ class KandinskyV22Pipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
index e0b88b41e8..fc2083247b 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
@@ -262,7 +262,7 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
@@ -512,7 +512,7 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
@@ -749,7 +749,7 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
index b9f98f5458..c5faae8279 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
@@ -211,7 +211,7 @@ class KandinskyV22ControlnetPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
index 22171849bb..a61673293e 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
@@ -356,7 +356,7 @@ class KandinskyV22InpaintPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
index 68954c2dc8..bc67847831 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
@@ -6,11 +6,7 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo
from ...models import PriorTransformer
from ...schedulers import UnCLIPScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..kandinsky import KandinskyPriorPipelineOutput
from ..pipeline_utils import DiffusionPipeline
@@ -171,7 +167,7 @@ class KandinskyV22PriorPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
negative_prior_prompt (`str`, *optional*):
The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if
`guidance_scale` is less than `1`).
@@ -412,7 +408,7 @@ class KandinskyV22PriorPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
index 13ea2ad6af..b586d16611 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
@@ -6,11 +6,7 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo
from ...models import PriorTransformer
from ...schedulers import UnCLIPScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..kandinsky import KandinskyPriorPipelineOutput
from ..pipeline_utils import DiffusionPipeline
@@ -195,7 +191,7 @@ class KandinskyV22PriorEmb2EmbPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
negative_prior_prompt (`str`, *optional*):
The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if
`guidance_scale` is less than `1`).
diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors.py b/src/diffusers/pipelines/kolors/pipeline_kolors.py
index 1fa9f6ce1d..948f73ed91 100644
--- a/src/diffusers/pipelines/kolors/pipeline_kolors.py
+++ b/src/diffusers/pipelines/kolors/pipeline_kolors.py
@@ -749,7 +749,7 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionLor
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
index e3cf4f2276..67d49b9a8c 100644
--- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
+++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
@@ -900,7 +900,7 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
index bc50835d19..f1bf4701e3 100644
--- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
@@ -17,7 +17,6 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput
diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
index 273e97f1ec..631539e5c6 100644
--- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
+++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
@@ -4,7 +4,6 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import PIL.Image
import torch
-import torch.utils.checkpoint
from ...models import UNet2DModel, VQModel
from ...schedulers import (
diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py
index 0e60d5c7ac..4d42a7049e 100644
--- a/src/diffusers/pipelines/latte/pipeline_latte.py
+++ b/src/diffusers/pipelines/latte/pipeline_latte.py
@@ -679,7 +679,7 @@ class LattePipeline(DiffusionPipeline):
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
index 341ccabaa1..5b61aaf9b6 100644
--- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
+++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
@@ -722,6 +722,12 @@ class LEditsPPPipelineStableDiffusion(
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -729,6 +735,12 @@ class LEditsPPPipelineStableDiffusion(
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -737,6 +749,12 @@ class LEditsPPPipelineStableDiffusion(
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -744,6 +762,12 @@ class LEditsPPPipelineStableDiffusion(
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
@torch.no_grad()
diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
index ac64844f6f..c1f9a98f06 100644
--- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
@@ -44,6 +44,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDIMScheduler, DPMSolverMultistepScheduler
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_invisible_watermark_available,
is_torch_xla_available,
logging,
@@ -770,6 +771,12 @@ class LEditsPPPipelineStableDiffusionXL(
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -777,6 +784,12 @@ class LEditsPPPipelineStableDiffusionXL(
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -785,6 +798,12 @@ class LEditsPPPipelineStableDiffusionXL(
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -792,6 +811,12 @@ class LEditsPPPipelineStableDiffusionXL(
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.prepare_unet
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py
index 77ba751700..bd23e657c4 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py
@@ -601,7 +601,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
index 217478f418..537588f67c 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
@@ -938,7 +938,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
index 8793d81377..694378b4f0 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
@@ -665,7 +665,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py b/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py
index 0d11f6e76c..9acff105e5 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py
@@ -18,7 +18,7 @@ import torch
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLLTXVideo
-from ...utils import get_logger
+from ...utils import deprecate, get_logger
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
@@ -180,6 +180,12 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -187,6 +193,12 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -195,6 +207,12 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -202,6 +220,12 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def check_inputs(self, video, height, width, latents, tone_map_compression_ratio):
diff --git a/src/diffusers/pipelines/lucy/__init__.py b/src/diffusers/pipelines/lucy/__init__.py
new file mode 100644
index 0000000000..580e1f37f3
--- /dev/null
+++ b/src/diffusers/pipelines/lucy/__init__.py
@@ -0,0 +1,47 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_lucy_edit"] = ["LucyEditPipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_lucy_edit import LucyEditPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/lucy/pipeline_lucy_edit.py b/src/diffusers/pipelines/lucy/pipeline_lucy_edit.py
new file mode 100644
index 0000000000..69f69d5768
--- /dev/null
+++ b/src/diffusers/pipelines/lucy/pipeline_lucy_edit.py
@@ -0,0 +1,735 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 The Decart AI Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Modifications by Decart AI Team:
+# - Based on pipeline_wan.py, but with supports recieving a condition video appended to the channel dimension.
+
+import html
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import regex as re
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import WanLoraLoaderMixin
+from ...models import AutoencoderKLWan, WanTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import LucyPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> from typing import List
+
+ >>> import torch
+ >>> from PIL import Image
+
+ >>> from diffusers import AutoencoderKLWan, LucyEditPipeline
+ >>> from diffusers.utils import export_to_video, load_video
+
+ >>> # Arguments
+ >>> url = "https://d2drjpuinn46lb.cloudfront.net/painter_original_edit.mp4"
+ >>> prompt = "Change the apron and blouse to a classic clown costume: satin polka-dot jumpsuit in bright primary colors, ruffled white collar, oversized pom-pom buttons, white gloves, oversized red shoes, red foam nose; soft window light from left, eye-level medium shot, natural folds and fabric highlights."
+ >>> negative_prompt = ""
+ >>> num_frames = 81
+ >>> height = 480
+ >>> width = 832
+
+
+ >>> # Load video
+ >>> def convert_video(video: List[Image.Image]) -> List[Image.Image]:
+ ... video = load_video(url)[:num_frames]
+ ... video = [video[i].resize((width, height)) for i in range(num_frames)]
+ ... return video
+
+
+ >>> video = load_video(url, convert_method=convert_video)
+
+ >>> # Load model
+ >>> model_id = "decart-ai/Lucy-Edit-Dev"
+ >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+ >>> pipe = LucyEditPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> # Generate video
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... video=video,
+ ... negative_prompt=negative_prompt,
+ ... height=480,
+ ... width=832,
+ ... num_frames=81,
+ ... guidance_scale=5.0,
+ ... ).frames[0]
+
+ >>> # Export video
+ >>> export_to_video(output, "output.mp4", fps=24)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class LucyEditPipeline(DiffusionPipeline, WanLoraLoaderMixin):
+ r"""
+ Pipeline for video-to-video generation using Lucy Edit.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`WanTransformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ transformer_2 ([`WanTransformer3DModel`], *optional*):
+ Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables
+ two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise
+ stages. If not provided, only `transformer` is used.
+ boundary_ratio (`float`, *optional*, defaults to `None`):
+ Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
+ The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
+ `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
+ boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ _optional_components = ["transformer", "transformer_2"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ vae: AutoencoderKLWan,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ transformer: Optional[WanTransformer3DModel] = None,
+ transformer_2: Optional[WanTransformer3DModel] = None,
+ boundary_ratio: Optional[float] = None,
+ expand_timesteps: bool = False, # Wan2.2 ti2v
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ transformer_2=transformer_2,
+ )
+ self.register_to_config(boundary_ratio=boundary_ratio)
+ self.register_to_config(expand_timesteps=expand_timesteps)
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ video,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ guidance_scale_2=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if self.config.boundary_ratio is None and guidance_scale_2 is not None:
+ raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
+
+ if video is None:
+ raise ValueError("`video` is required, received None.")
+
+ def prepare_latents(
+ self,
+ video: Optional[torch.Tensor] = None,
+ batch_size: int = 1,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ num_latent_frames = (
+ (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)
+ )
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+ # Prepare noise latents
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # Prepare condition latents
+ condition_latents = [
+ retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video
+ ]
+
+ condition_latents = torch.cat(condition_latents, dim=0).to(dtype)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ device, dtype
+ )
+
+ condition_latents = (condition_latents - latents_mean) * latents_std
+
+ # Check shapes
+ assert latents.shape == condition_latents.shape, (
+ f"Latents shape {latents.shape} does not match expected shape {condition_latents.shape}. Please check the input."
+ )
+
+ return latents, condition_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ video: List[Image.Image],
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ guidance_scale_2: Optional[float] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ video (`List[Image.Image]`):
+ The video to use as the condition for the video generation.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (`guidance_scale` < `1`).
+ height (`int`, defaults to `480`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `832`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `81`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ guidance_scale_2 (`float`, *optional*, defaults to `None`):
+ Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
+ `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
+ and the pipeline's `boundary_ratio` are not None.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`LucyPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
+ truncated. If the prompt is shorter, it will be padded to this length.
+
+ Examples:
+
+ Returns:
+ [`~LucyPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`LucyPipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ video,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ guidance_scale_2,
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ if self.config.boundary_ratio is not None and guidance_scale_2 is None:
+ guidance_scale_2 = guidance_scale
+
+ self._guidance_scale = guidance_scale
+ self._guidance_scale_2 = guidance_scale_2
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = (
+ self.transformer.config.out_channels
+ if self.transformer is not None
+ else self.transformer_2.config.out_channels
+ )
+ video = self.video_processor.preprocess_video(video, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+ latents, condition_latents = self.prepare_latents(
+ video,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ if self.config.boundary_ratio is not None:
+ boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
+ else:
+ boundary_timestep = None
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ if boundary_timestep is None or t >= boundary_timestep:
+ # wan2.1 or high-noise stage in wan2.2
+ current_model = self.transformer
+ current_guidance_scale = guidance_scale
+ else:
+ # low-noise stage in wan2.2
+ current_model = self.transformer_2
+ current_guidance_scale = guidance_scale_2
+
+ # latent_model_input = latents.to(transformer_dtype)
+ latent_model_input = torch.cat([latents, condition_latents], dim=1).to(transformer_dtype)
+ # latent_model_input = torch.cat([latents, latents], dim=1).to(transformer_dtype)
+ if self.config.expand_timesteps:
+ # seq_len: num_latent_frames * latent_height//2 * latent_width//2
+ temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
+ # batch_size, seq_len
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+ else:
+ timestep = t.expand(latents.shape[0])
+
+ with current_model.cache_context("cond"):
+ noise_pred = current_model(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ with current_model.cache_context("uncond"):
+ noise_uncond = current_model(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return LucyPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/lucy/pipeline_output.py b/src/diffusers/pipelines/lucy/pipeline_output.py
new file mode 100644
index 0000000000..cf9ea91fd1
--- /dev/null
+++ b/src/diffusers/pipelines/lucy/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class LucyPipelineOutput(BaseOutput):
+ r"""
+ Output class for Lucy pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py
index 2067444fa0..b59c265646 100644
--- a/src/diffusers/pipelines/lumina/pipeline_lumina.py
+++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py
@@ -697,7 +697,7 @@ class LuminaPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py
index 0fa0fe9773..937803edbc 100644
--- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py
+++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py
@@ -433,6 +433,12 @@ class Lumina2Pipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -440,6 +446,12 @@ class Lumina2Pipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -448,6 +460,12 @@ class Lumina2Pipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -455,6 +473,12 @@ class Lumina2Pipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
@@ -564,7 +588,7 @@ class Lumina2Pipeline(DiffusionPipeline, Lumina2LoraLoaderMixin):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py b/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py
index da991aefbd..92ec16fd45 100644
--- a/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py
+++ b/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py
@@ -86,15 +86,14 @@ class MarigoldDepthOutput(BaseOutput):
Args:
prediction (`np.ndarray`, `torch.Tensor`):
- Predicted depth maps with values in the range [0, 1]. The shape is $numimages \times 1 \times height \times
- width$ for `torch.Tensor` or $numimages \times height \times width \times 1$ for `np.ndarray`.
+ Predicted depth maps with values in the range [0, 1]. The shape is `numimages × 1 × height × width` for
+ `torch.Tensor` or `numimages × height × width × 1` for `np.ndarray`.
uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
- Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
- \times 1 \times height \times width$ for `torch.Tensor` or $numimages \times height \times width \times 1$
- for `np.ndarray`.
+ Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is `numimages × 1 ×
+ height × width` for `torch.Tensor` or `numimages × height × width × 1` for `np.ndarray`.
latent (`None`, `torch.Tensor`):
Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
- The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
+ The shape is `numimages * numensemble × 4 × latentheight × latentwidth`.
"""
prediction: Union[np.ndarray, torch.Tensor]
diff --git a/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py b/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py
index c809de18f4..bef9ca77c7 100644
--- a/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py
+++ b/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py
@@ -99,17 +99,17 @@ class MarigoldIntrinsicsOutput(BaseOutput):
Args:
prediction (`np.ndarray`, `torch.Tensor`):
- Predicted image intrinsics with values in the range [0, 1]. The shape is $(numimages * numtargets) \times 3
- \times height \times width$ for `torch.Tensor` or $(numimages * numtargets) \times height \times width
- \times 3$ for `np.ndarray`, where `numtargets` corresponds to the number of predicted target modalities of
- the intrinsic image decomposition.
+ Predicted image intrinsics with values in the range [0, 1]. The shape is `(numimages * numtargets) × 3 ×
+ height × width` for `torch.Tensor` or `(numimages * numtargets) × height × width × 3` for `np.ndarray`,
+ where `numtargets` corresponds to the number of predicted target modalities of the intrinsic image
+ decomposition.
uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
- Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $(numimages *
- numtargets) \times 3 \times height \times width$ for `torch.Tensor` or $(numimages * numtargets) \times
- height \times width \times 3$ for `np.ndarray`.
+ Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is `(numimages *
+ numtargets) × 3 × height × width` for `torch.Tensor` or `(numimages * numtargets) × height × width × 3` for
+ `np.ndarray`.
latent (`None`, `torch.Tensor`):
Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
- The shape is $(numimages * numensemble) \times (numtargets * 4) \times latentheight \times latentwidth$.
+ The shape is `(numimages * numensemble) × (numtargets * 4) × latentheight × latentwidth`.
"""
prediction: Union[np.ndarray, torch.Tensor]
diff --git a/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py b/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py
index 192ed590a4..485a39c995 100644
--- a/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py
+++ b/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py
@@ -81,15 +81,14 @@ class MarigoldNormalsOutput(BaseOutput):
Args:
prediction (`np.ndarray`, `torch.Tensor`):
- Predicted normals with values in the range [-1, 1]. The shape is $numimages \times 3 \times height \times
- width$ for `torch.Tensor` or $numimages \times height \times width \times 3$ for `np.ndarray`.
+ Predicted normals with values in the range [-1, 1]. The shape is `numimages × 3 × height × width` for
+ `torch.Tensor` or `numimages × height × width × 3` for `np.ndarray`.
uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
- Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
- \times 1 \times height \times width$ for `torch.Tensor` or $numimages \times height \times width \times 1$
- for `np.ndarray`.
+ Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is `numimages × 1 ×
+ height × width` for `torch.Tensor` or `numimages × height × width × 1` for `np.ndarray`.
latent (`None`, `torch.Tensor`):
Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
- The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
+ The shape is `numimages * numensemble × 4 × latentheight × latentwidth`.
"""
prediction: Union[np.ndarray, torch.Tensor]
diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py
index 3c0f908296..5874a92c6f 100644
--- a/src/diffusers/pipelines/mochi/pipeline_mochi.py
+++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py
@@ -23,11 +23,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import Mochi1LoraLoaderMixin
from ...models import AutoencoderKLMochi, MochiTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
@@ -396,6 +392,12 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -403,6 +405,12 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -411,6 +419,12 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -418,6 +432,12 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def prepare_latents(
@@ -534,7 +554,7 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
index 1254b6725f..090cb46aac 100644
--- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
+++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
@@ -23,7 +23,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import OmniGenTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import is_torch_xla_available, is_torchvision_available, logging, replace_example_docstring
+from ...utils import deprecate, is_torch_xla_available, is_torchvision_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -235,6 +235,12 @@ class OmniGenPipeline(
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -242,6 +248,12 @@ class OmniGenPipeline(
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -250,6 +262,12 @@ class OmniGenPipeline(
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -257,6 +275,12 @@ class OmniGenPipeline(
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
@@ -366,7 +390,7 @@ class OmniGenPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
index 4c02b3dd6d..3daaac328c 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
@@ -150,17 +150,13 @@ class StableDiffusionControlNetPAGInpaintPipeline(
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
-
-
- This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting
- ([runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)) as well as
- default text-to-image Stable Diffusion checkpoints
- ([runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)). Default text-to-image
- Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on those, such as
+ > [!TIP] > This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting >
+ ([runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)) as well as >
+ default text-to-image Stable Diffusion checkpoints >
+ ([runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)). Default text-to-image >
+ Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on those, such as >
[lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
-
-
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py
index 913a647fae..a6df1b22c8 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py
@@ -1199,7 +1199,7 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
index 3a08408662..d156eac8f3 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
@@ -28,11 +28,7 @@ from ...models.attention_processor import PAGCFGHunyuanAttnProcessor2_0, PAGHuny
from ...models.embeddings import get_2d_rotary_pos_embed
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ...schedulers import DDPMScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pag_utils import PAGMixin
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_kolors.py b/src/diffusers/pipelines/pag/pipeline_pag_kolors.py
index ed8e33e2ba..1368358db6 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_kolors.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_kolors.py
@@ -769,7 +769,7 @@ class KolorsPAGPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
index d9d6d14a38..9031877b5b 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
@@ -644,7 +644,7 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py
index 8dbae13a3f..9e91ccbe80 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py
@@ -29,6 +29,7 @@ from ...models.attention_processor import PAGCFGSanaLinearAttnProcessor2_0, PAGI
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
BACKENDS_MAPPING,
+ deprecate,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
@@ -190,6 +191,12 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -197,6 +204,12 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -205,6 +218,12 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -212,6 +231,12 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def encode_prompt(
@@ -703,7 +728,7 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
index 96796f53b0..acb4e52340 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
@@ -761,7 +761,7 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
index 202120dc2c..e1819a79fb 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
@@ -822,7 +822,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py
index 4504684133..6b62ddcc7c 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py
@@ -948,7 +948,7 @@ class StableDiffusionXLPAGPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
index 8c355a5fb1..b6422b2364 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
@@ -1111,7 +1111,7 @@ class StableDiffusionXLPAGImg2ImgPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
index 7d42d1876a..2e12a4a97f 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
@@ -1251,7 +1251,7 @@ class StableDiffusionXLPAGInpaintPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
index 3e22c9a845..61435b80ca 100644
--- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
+++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
@@ -158,11 +158,7 @@ def prepare_mask_and_masked_image(image, mask):
class PaintByExamplePipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
_last_supported_version = "0.33.1"
r"""
-
-
- 🧪 This is an experimental feature!
-
-
+ > [!WARNING] > 🧪 This is an experimental feature!
Pipeline for image-guided image inpainting using Stable Diffusion.
diff --git a/src/diffusers/pipelines/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py
index ea2c0763d9..2724c764c7 100644
--- a/src/diffusers/pipelines/pipeline_flax_utils.py
+++ b/src/diffusers/pipelines/pipeline_flax_utils.py
@@ -276,12 +276,8 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
Can be used to overwrite load and saveable variables (the pipeline components) of the specific pipeline
class. The overwritten components are passed directly to the pipelines `__init__` method.
-
-
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
- auth login`.
-
-
+ > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
+ with `hf > auth login`.
Examples:
@@ -312,6 +308,11 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
>>> dpm_params["scheduler"] = dpmpp_state
```
"""
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
+
cache_dir = kwargs.pop("cache_dir", None)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py
index b5ac6cc301..dd542145d3 100644
--- a/src/diffusers/pipelines/pipeline_loading_utils.py
+++ b/src/diffusers/pipelines/pipeline_loading_utils.py
@@ -19,12 +19,12 @@ import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
+import httpx
import requests
import torch
from huggingface_hub import DDUFEntry, ModelCard, model_info, snapshot_download
-from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
+from huggingface_hub.utils import HfHubHTTPError, OfflineModeIsEnabled, validate_hf_hub_args
from packaging import version
-from requests.exceptions import HTTPError
from .. import __version__
from ..utils import (
@@ -48,10 +48,12 @@ from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transfo
if is_transformers_available():
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizerBase
- from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
+ if is_transformers_version("<=", "4.56.2"):
+ from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
+
if is_accelerate_available():
import accelerate
from accelerate import dispatch_model
@@ -112,7 +114,9 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
]
if is_transformers_available():
- weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
+ weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
+ if is_transformers_version("<=", "4.56.2"):
+ weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
@@ -191,7 +195,9 @@ def filter_model_files(filenames):
]
if is_transformers_available():
- weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
+ weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
+ if is_transformers_version("<=", "4.56.2"):
+ weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
@@ -212,7 +218,9 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
]
if is_transformers_available():
- weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
+ weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
+ if is_transformers_version("<=", "4.56.2"):
+ weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
@@ -613,6 +621,9 @@ def _assign_components_to_devices(
def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs):
+ # TODO: separate out different device_map methods when it gets to it.
+ if device_map != "balanced":
+ return device_map
# To avoid circular import problem.
from diffusers import pipelines
@@ -827,6 +838,9 @@ def load_sub_model(
else:
loading_kwargs["low_cpu_mem_usage"] = False
+ if is_transformers_model and is_transformers_version(">=", "4.57.0"):
+ loading_kwargs.pop("offload_state_dict")
+
if (
quantization_config is not None
and isinstance(quantization_config, PipelineQuantizationConfig)
@@ -1099,7 +1113,7 @@ def _download_dduf_file(
if not local_files_only:
try:
info = model_info(pretrained_model_name, token=token, revision=revision)
- except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
+ except (HfHubHTTPError, OfflineModeIsEnabled, requests.ConnectionError, httpx.NetworkError) as e:
logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
local_files_only = True
model_info_call_error = e # save error to reraise it if model is not cached locally
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index 22efaccec1..392d5fb3fe 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -23,6 +23,7 @@ from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
+import httpx
import numpy as np
import PIL.Image
import requests
@@ -36,9 +37,8 @@ from huggingface_hub import (
read_dduf_file,
snapshot_download,
)
-from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
+from huggingface_hub.utils import HfHubHTTPError, OfflineModeIsEnabled, validate_hf_hub_args
from packaging import version
-from requests.exceptions import HTTPError
from tqdm.auto import tqdm
from typing_extensions import Self
@@ -57,6 +57,7 @@ from ..utils import (
PushToHubMixin,
_get_detailed_type,
_is_valid_type,
+ deprecate,
is_accelerate_available,
is_accelerate_version,
is_hpu_available,
@@ -108,7 +109,7 @@ LIBRARIES = []
for library in LOADABLE_CLASSES:
LIBRARIES.append(library)
-SUPPORTED_DEVICE_MAP = ["balanced"]
+SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()]
logger = logging.get_logger(__name__)
@@ -371,12 +372,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
arguments of `self.to(*args, **kwargs).`
-
-
- If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise,
- the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
-
-
+ > [!TIP] > If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is.
+ Otherwise, > the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
Here are the ways to call `to`:
@@ -504,6 +501,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1"
logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1")
+ if dtype in (torch.bfloat16, None) and kwargs.pop("sdp_on_bf16", True):
+ if hasattr(torch._C, "_set_math_sdp_allow_fp16_bf16_reduction"):
+ torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
+ logger.warning(
+ "Enabled SDP with BF16 precision on HPU. To disable, please use `.to('hpu', sdp_on_bf16=False)`"
+ )
+
module_names, _ = self._get_signature_keys(self)
modules = [getattr(self, n, None) for n in module_names]
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
@@ -619,11 +623,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
`torch.float32` is used.
custom_pipeline (`str`, *optional*):
-
-
- 🧪 This is an experimental feature and may change in the future.
-
-
+ > [!WARNING] > 🧪 This is an experimental feature and may change in the future.
Can be either:
@@ -708,12 +708,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
dduf_file(`str`, *optional*):
Load weights from the specified dduf file.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf
- auth login`.
-
-
+ > [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
+ with `hf > auth login`.
Examples:
@@ -988,12 +984,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
# 7.1 device_map shenanigans
- if final_device_map is not None and len(final_device_map) > 0:
- component_device = final_device_map.get(name, None)
- if component_device is not None:
- current_device_map = {"": component_device}
- else:
- current_device_map = None
+ if final_device_map is not None:
+ if isinstance(final_device_map, dict) and len(final_device_map) > 0:
+ component_device = final_device_map.get(name, None)
+ if component_device is not None:
+ current_device_map = {"": component_device}
+ else:
+ current_device_map = None
+ elif isinstance(final_device_map, str):
+ current_device_map = final_device_map
# 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
@@ -1331,6 +1330,133 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
offload_buffers = len(model._parameters) > 0
cpu_offload(model, device, offload_buffers=offload_buffers)
+ def enable_group_offload(
+ self,
+ onload_device: torch.device,
+ offload_device: torch.device = torch.device("cpu"),
+ offload_type: str = "block_level",
+ num_blocks_per_group: Optional[int] = None,
+ non_blocking: bool = False,
+ use_stream: bool = False,
+ record_stream: bool = False,
+ low_cpu_mem_usage=False,
+ offload_to_disk_path: Optional[str] = None,
+ exclude_modules: Optional[Union[str, List[str]]] = None,
+ ) -> None:
+ r"""
+ Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is,
+ and where it is beneficial, we need to first provide some context on how other supported offloading methods
+ work.
+
+ Typically, offloading is done at two levels:
+ - Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It
+ works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator
+ device when needed for computation. This method is more memory-efficient than keeping all components on the
+ accelerator, but the memory requirements are still quite high. For this method to work, one needs memory
+ equivalent to size of the model in runtime dtype + size of largest intermediate activation tensors to be able
+ to complete the forward pass.
+ - Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method.
+ It
+ works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and
+ onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator
+ memory, but can be slower due to the excessive number of device synchronizations.
+
+ Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers,
+ (either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level
+ offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations
+ is reduced.
+
+ Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability
+ to overlap data transfer and computation to reduce the overall execution time compared to sequential
+ offloading. This is enabled using layer prefetching with streams, i.e., the layer that is to be executed next
+ starts onloading to the accelerator device while the current layer is being executed - this increases the
+ memory requirements slightly. Note that this implementation also supports leaf-level offloading but can be made
+ much faster when using streams.
+
+ Args:
+ onload_device (`torch.device`):
+ The device to which the group of modules are onloaded.
+ offload_device (`torch.device`, defaults to `torch.device("cpu")`):
+ The device to which the group of modules are offloaded. This should typically be the CPU. Default is
+ CPU.
+ offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
+ The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
+ "block_level".
+ offload_to_disk_path (`str`, *optional*, defaults to `None`):
+ The path to the directory where parameters will be offloaded. Setting this option can be useful in
+ limited RAM environment settings where a reasonable speed-memory trade-off is desired.
+ num_blocks_per_group (`int`, *optional*):
+ The number of blocks per group when using offload_type="block_level". This is required when using
+ offload_type="block_level".
+ non_blocking (`bool`, defaults to `False`):
+ If True, offloading and onloading is done with non-blocking data transfer.
+ use_stream (`bool`, defaults to `False`):
+ If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
+ overlapping computation and data transfer.
+ record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
+ as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to
+ the [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html)
+ more details.
+ low_cpu_mem_usage (`bool`, defaults to `False`):
+ If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them.
+ This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be
+ useful when the CPU memory is a bottleneck but may counteract the benefits of using streams.
+ exclude_modules (`Union[str, List[str]]`, defaults to `None`): List of modules to exclude from offloading.
+
+ Example:
+ ```python
+ >>> from diffusers import DiffusionPipeline
+ >>> import torch
+
+ >>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
+
+ >>> pipe.enable_group_offload(
+ ... onload_device=torch.device("cuda"),
+ ... offload_device=torch.device("cpu"),
+ ... offload_type="leaf_level",
+ ... use_stream=True,
+ ... )
+ >>> image = pipe("a beautiful sunset").images[0]
+ ```
+ """
+ from ..hooks import apply_group_offloading
+
+ if isinstance(exclude_modules, str):
+ exclude_modules = [exclude_modules]
+ elif exclude_modules is None:
+ exclude_modules = []
+
+ unknown = set(exclude_modules) - self.components.keys()
+ if unknown:
+ logger.info(
+ f"The following modules are not present in pipeline: {', '.join(unknown)}. Ignore if this is expected."
+ )
+
+ group_offload_kwargs = {
+ "onload_device": onload_device,
+ "offload_device": offload_device,
+ "offload_type": offload_type,
+ "num_blocks_per_group": num_blocks_per_group,
+ "non_blocking": non_blocking,
+ "use_stream": use_stream,
+ "record_stream": record_stream,
+ "low_cpu_mem_usage": low_cpu_mem_usage,
+ "offload_to_disk_path": offload_to_disk_path,
+ }
+ for name, component in self.components.items():
+ if name not in exclude_modules and isinstance(component, torch.nn.Module):
+ if hasattr(component, "enable_group_offload"):
+ component.enable_group_offload(**group_offload_kwargs)
+ else:
+ apply_group_offloading(module=component, **group_offload_kwargs)
+
+ if exclude_modules:
+ for module_name in exclude_modules:
+ module = getattr(self, module_name, None)
+ if module is not None and isinstance(module, torch.nn.Module):
+ module.to(onload_device)
+ logger.debug(f"Placed `{module_name}` on {onload_device} device as it was in `exclude_modules`.")
+
def reset_device_map(self):
r"""
Resets the device maps (if any) to None.
@@ -1370,11 +1496,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
- A path to a *directory* (`./my_pipeline_directory/`) containing a custom pipeline. The directory
must contain a file called `pipeline.py` that defines the custom pipeline.
-
-
- 🧪 This is an experimental feature and may change in the future.
-
-
+ > [!WARNING] > 🧪 This is an experimental feature and may change in the future.
For more information on how to load and create custom pipelines, take a look at [How to contribute a
community pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/contribute_pipeline).
@@ -1428,12 +1550,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
`os.PathLike`:
A path to the downloaded pipeline.
-
-
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
- auth login
-
-
+ > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
+ with `hf > auth login
"""
cache_dir = kwargs.pop("cache_dir", None)
@@ -1478,7 +1596,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if not local_files_only:
try:
info = model_info(pretrained_model_name, token=token, revision=revision)
- except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
+ except (HfHubHTTPError, OfflineModeIsEnabled, requests.ConnectionError, httpx.NetworkError) as e:
logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
local_files_only = True
model_info_call_error = e # save error to reraise it if model is not cached locally
@@ -1706,6 +1824,36 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
return signature_types
+ @property
+ def parameters(self) -> Dict[str, Any]:
+ r"""
+ The `self.parameters` property can be useful to run different pipelines with the same weights and
+ configurations without reallocating additional memory.
+
+ Returns (`dict`):
+ A dictionary containing all the optional parameters needed to initialize the pipeline.
+
+ Examples:
+
+ ```py
+ >>> from diffusers import (
+ ... StableDiffusionPipeline,
+ ... StableDiffusionImg2ImgPipeline,
+ ... StableDiffusionInpaintPipeline,
+ ... )
+
+ >>> text2img = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
+ >>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components, **text2img.parameters)
+ >>> inpaint = StableDiffusionInpaintPipeline(**text2img.components, **text2img.parameters)
+ ```
+ """
+ expected_modules, optional_parameters = self._get_signature_keys(self)
+ pipeline_parameters = {
+ k: self.config[k] for k in self.config.keys() if not k.startswith("_") and k in optional_parameters
+ }
+
+ return pipeline_parameters
+
@property
def components(self) -> Dict[str, Any]:
r"""
@@ -1776,12 +1924,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed
up during training is not guaranteed.
-
-
- ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
- precedent.
-
-
+ > [!WARNING] > ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient
+ attention takes > precedent.
Parameters:
attention_op (`Callable`, *optional*):
@@ -1837,13 +1981,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
in slices to compute attention in several steps. For more than one attention head, the computation is performed
sequentially over each head. This is useful to save some memory in exchange for a small speed decrease.
-
-
- ⚠️ Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch
- 2.0 or xFormers. These attention computations are already very memory efficient so you won't need to enable
- this function. If you enable attention slicing with SDPA or xFormers, it can lead to serious slow downs!
-
-
+ > [!WARNING] > ⚠️ Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA)
+ from PyTorch > 2.0 or xFormers. These attention computations are already very memory efficient so you won't
+ need to enable > this function. If you enable attention slicing with SDPA or xFormers, it can lead to serious
+ slow downs!
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
@@ -2041,6 +2182,12 @@ class StableDiffusionMixin:
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -2048,6 +2195,12 @@ class StableDiffusionMixin:
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -2056,6 +2209,12 @@ class StableDiffusionMixin:
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -2063,6 +2222,12 @@ class StableDiffusionMixin:
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
@@ -2096,11 +2261,7 @@ class StableDiffusionMixin:
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
@@ -2125,11 +2286,7 @@ class StableDiffusionMixin:
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
index bd69746be3..1d718a4852 100644
--- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
+++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
@@ -755,7 +755,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
index c14036cf94..bb169ac5c4 100644
--- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
+++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
@@ -700,7 +700,7 @@ class PixArtSigmaPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/qwenimage/__init__.py b/src/diffusers/pipelines/qwenimage/__init__.py
index 963732ded0..2400632ba2 100644
--- a/src/diffusers/pipelines/qwenimage/__init__.py
+++ b/src/diffusers/pipelines/qwenimage/__init__.py
@@ -24,6 +24,13 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["modeling_qwenimage"] = ["ReduxImageEncoder"]
_import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"]
+ _import_structure["pipeline_qwenimage_controlnet"] = ["QwenImageControlNetPipeline"]
+ _import_structure["pipeline_qwenimage_controlnet_inpaint"] = ["QwenImageControlNetInpaintPipeline"]
+ _import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
+ _import_structure["pipeline_qwenimage_edit_inpaint"] = ["QwenImageEditInpaintPipeline"]
+ _import_structure["pipeline_qwenimage_edit_plus"] = ["QwenImageEditPlusPipeline"]
+ _import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
+ _import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -33,6 +40,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_qwenimage import QwenImagePipeline
+ from .pipeline_qwenimage_controlnet import QwenImageControlNetPipeline
+ from .pipeline_qwenimage_controlnet_inpaint import QwenImageControlNetInpaintPipeline
+ from .pipeline_qwenimage_edit import QwenImageEditPipeline
+ from .pipeline_qwenimage_edit_inpaint import QwenImageEditInpaintPipeline
+ from .pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
+ from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
+ from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline
else:
import sys
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
index 03f6f73b44..33dc2039b9 100644
--- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
@@ -23,7 +23,7 @@ from ...image_processor import VaeImageProcessor
from ...loaders import QwenImageLoraLoaderMixin
from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import QwenImagePipelineOutput
@@ -201,7 +201,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
txt = [template.format(e) for e in prompt]
txt_tokens = self.tokenizer(
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
- ).to(self.device)
+ ).to(device)
encoder_hidden_states = self.text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
@@ -253,6 +253,9 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
+ prompt_embeds = prompt_embeds[:, :max_sequence_length]
+ prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
+
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -316,20 +319,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
if max_sequence_length is not None and max_sequence_length > 1024:
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
- @staticmethod
- def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
- latent_image_ids = torch.zeros(height, width, 3)
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
-
- latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
-
- latent_image_ids = latent_image_ids.reshape(
- latent_image_id_height * latent_image_id_width, latent_image_id_channels
- )
-
- return latent_image_ids.to(device=device, dtype=dtype)
-
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
@@ -359,6 +348,12 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -366,6 +361,12 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -374,6 +375,12 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -381,6 +388,12 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def prepare_latents(
@@ -402,8 +415,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
shape = (batch_size, 1, num_channels_latents, height, width)
if latents is not None:
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
- return latents.to(device=device, dtype=dtype), latent_image_ids
+ return latents.to(device=device, dtype=dtype)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -414,9 +426,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
-
- return latents, latent_image_ids
+ return latents
@property
def guidance_scale(self):
@@ -449,7 +459,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
- guidance_scale: float = 1.0,
+ guidance_scale: Optional[float] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -476,7 +486,12 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
not greater than `1`).
true_cfg_scale (`float`, *optional*, defaults to 1.0):
- When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by
+ setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to
+ generate images that are closely linked to the text `prompt`, usually at the expense of lower image
+ quality.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -488,12 +503,16 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
- guidance_scale (`float`, *optional*, defaults to 3.5):
- Guidance scale as defined in [Classifier-Free Diffusion
- Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
- of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
- `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
- the text `prompt`, usually at the expense of lower image quality.
+ guidance_scale (`float`, *optional*, defaults to None):
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
+ scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
+ that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
+ parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
+ ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
+ please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
+ enable classifier-free guidance computations).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -573,6 +592,16 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
+
+ if true_cfg_scale > 1 and not has_neg_prompt:
+ logger.warning(
+ f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
+ )
+ elif true_cfg_scale <= 1 and has_neg_prompt:
+ logger.warning(
+ " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
+ )
+
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt=prompt,
@@ -594,7 +623,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
- latents, latent_image_ids = self.prepare_latents(
+ latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
@@ -604,7 +633,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
generator,
latents,
)
- img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size
+ img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
@@ -627,10 +656,17 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
self._num_timesteps = len(timesteps)
# handle guidance
- if self.transformer.config.guidance_embeds:
+ if self.transformer.config.guidance_embeds and guidance_scale is None:
+ raise ValueError("guidance_scale is required for guidance-distilled model.")
+ elif self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
- else:
+ elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
+ logger.warning(
+ f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
+ )
+ guidance = None
+ elif not self.transformer.config.guidance_embeds and guidance_scale is None:
guidance = None
if self.attention_kwargs is None:
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py
new file mode 100644
index 0000000000..5111096d93
--- /dev/null
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py
@@ -0,0 +1,998 @@
+# Copyright 2025 Qwen-Image Team, InstantX Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import QwenImageLoraLoaderMixin
+from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
+from ...models.controlnets.controlnet_qwenimage import QwenImageControlNetModel, QwenImageMultiControlNetModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import QwenImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers.utils import load_image
+ >>> from diffusers import QwenImageControlNetModel, QwenImageMultiControlNetModel, QwenImageControlNetPipeline
+
+ >>> # QwenImageControlNetModel
+ >>> controlnet = QwenImageControlNetModel.from_pretrained(
+ ... "InstantX/Qwen-Image-ControlNet-Union", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe = QwenImageControlNetPipeline.from_pretrained(
+ ... "Qwen/Qwen-Image", controlnet=controlnet, torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+ >>> prompt = "Aesthetics art, traditional asian pagoda, elaborate golden accents, sky blue and white color palette, swirling cloud pattern, digital illustration, east asian architecture, ornamental rooftop, intricate detailing on building, cultural representation."
+ >>> negative_prompt = " "
+ >>> control_image = load_image(
+ ... "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Union/resolve/main/conds/canny.png"
+ ... )
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(
+ ... prompt,
+ ... negative_prompt=negative_prompt,
+ ... control_image=control_image,
+ ... controlnet_conditioning_scale=1.0,
+ ... num_inference_steps=30,
+ ... true_cfg_scale=4.0,
+ ... ).images[0]
+ >>> image.save("qwenimage_cn_union.png")
+
+ >>> # QwenImageMultiControlNetModel
+ >>> controlnet = QwenImageControlNetModel.from_pretrained(
+ ... "InstantX/Qwen-Image-ControlNet-Union", torch_dtype=torch.bfloat16
+ ... )
+ >>> controlnet = QwenImageMultiControlNetModel([controlnet])
+ >>> pipe = QwenImageControlNetPipeline.from_pretrained(
+ ... "Qwen/Qwen-Image", controlnet=controlnet, torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+ >>> prompt = "Aesthetics art, traditional asian pagoda, elaborate golden accents, sky blue and white color palette, swirling cloud pattern, digital illustration, east asian architecture, ornamental rooftop, intricate detailing on building, cultural representation."
+ >>> negative_prompt = " "
+ >>> control_image = load_image(
+ ... "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Union/resolve/main/conds/canny.png"
+ ... )
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(
+ ... prompt,
+ ... negative_prompt=negative_prompt,
+ ... control_image=[control_image, control_image],
+ ... controlnet_conditioning_scale=[0.5, 0.5],
+ ... num_inference_steps=30,
+ ... true_cfg_scale=4.0,
+ ... ).images[0]
+ >>> image.save("qwenimage_cn_union_multi.png")
+ ```
+"""
+
+
+# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
+ r"""
+ The QwenImage pipeline for text-to-image generation.
+
+ Args:
+ transformer ([`QwenImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`QwenTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLQwenImage,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ transformer: QwenImageTransformer2DModel,
+ controlnet: Union[QwenImageControlNetModel, QwenImageMultiControlNetModel],
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ controlnet=controlnet,
+ )
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.tokenizer_max_length = 1024
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
+ self.prompt_template_encode_start_idx = 34
+ self.default_sample_size = 128
+
+ # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.extract_masked_hidden
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+
+ return split_result
+
+ # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.get_qwen_prompt_embeds
+ def _get_qwen_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ template = self.prompt_template_encode
+ drop_idx = self.prompt_template_encode_start_idx
+ txt = [template.format(e) for e in prompt]
+ txt_tokens = self.tokenizer(
+ txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
+ ).to(device)
+ encoder_hidden_states = self.text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask,
+ output_hidden_states=True,
+ )
+ hidden_states = encoder_hidden_states.hidden_states[-1]
+ split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
+
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, 1, num_channels_latents, height, width)
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ return latents
+
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if isinstance(image, torch.Tensor):
+ pass
+ else:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ true_cfg_scale: float = 4.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: Optional[float] = None,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ control_image: PipelineImageInput = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by
+ setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to
+ generate images that are closely linked to the text `prompt`, usually at the expense of lower image
+ quality.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to None):
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
+ scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
+ that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
+ parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
+ ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
+ please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
+ enable classifier-free guidance computations).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(control_image) if isinstance(self.controlnet, QwenImageMultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
+ )
+
+ if true_cfg_scale > 1 and not has_neg_prompt:
+ logger.warning(
+ f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
+ )
+ elif true_cfg_scale <= 1 and has_neg_prompt:
+ logger.warning(
+ " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
+ )
+
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_true_cfg:
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 3. Prepare control image
+ num_channels_latents = self.transformer.config.in_channels // 4
+ if isinstance(self.controlnet, QwenImageControlNetModel):
+ control_image = self.prepare_image(
+ image=control_image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+ height, width = control_image.shape[-2:]
+
+ if control_image.ndim == 4:
+ control_image = control_image.unsqueeze(2)
+
+ # vae encode
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
+ latents_mean = (torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1)).to(
+ device
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ device
+ )
+
+ control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
+ control_image = (control_image - latents_mean) * latents_std
+
+ control_image = control_image.permute(0, 2, 1, 3, 4)
+
+ # pack
+ control_image = self._pack_latents(
+ control_image,
+ batch_size=control_image.shape[0],
+ num_channels_latents=num_channels_latents,
+ height=control_image.shape[3],
+ width=control_image.shape[4],
+ ).to(dtype=prompt_embeds.dtype, device=device)
+
+ else:
+ if isinstance(self.controlnet, QwenImageMultiControlNetModel):
+ control_images = []
+ for control_image_ in control_image:
+ control_image_ = self.prepare_image(
+ image=control_image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+
+ height, width = control_image_.shape[-2:]
+
+ if control_image_.ndim == 4:
+ control_image_ = control_image_.unsqueeze(2)
+
+ # vae encode
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1)
+ ).to(device)
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
+ 1, self.vae.config.z_dim, 1, 1, 1
+ ).to(device)
+
+ control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
+ control_image_ = (control_image_ - latents_mean) * latents_std
+
+ control_image_ = control_image_.permute(0, 2, 1, 3, 4)
+
+ # pack
+ control_image_ = self._pack_latents(
+ control_image_,
+ batch_size=control_image_.shape[0],
+ num_channels_latents=num_channels_latents,
+ height=control_image_.shape[3],
+ width=control_image_.shape[4],
+ ).to(dtype=prompt_embeds.dtype, device=device)
+
+ control_images.append(control_image_)
+
+ control_image = control_images
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if isinstance(self.controlnet, QwenImageControlNetModel) else keeps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds and guidance_scale is None:
+ raise ValueError("guidance_scale is required for guidance-distilled model.")
+ elif self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
+ logger.warning(
+ f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
+ )
+ guidance = None
+ elif not self.transformer.config.guidance_embeds and guidance_scale is None:
+ guidance = None
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ # 6. Denoising loop
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ # controlnet
+ controlnet_block_samples = self.controlnet(
+ hidden_states=latents,
+ controlnet_cond=control_image,
+ conditioning_scale=cond_scale,
+ timestep=timestep / 1000,
+ encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ img_shapes=img_shapes,
+ txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
+ return_dict=False,
+ )
+
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ img_shapes=img_shapes,
+ txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
+ controlnet_block_samples=controlnet_block_samples,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
+ encoder_hidden_states=negative_prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
+ controlnet_block_samples=controlnet_block_samples,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
+ noise_pred = comb_pred * (cond_norm / noise_norm)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return QwenImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py
new file mode 100644
index 0000000000..102a813ab5
--- /dev/null
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py
@@ -0,0 +1,941 @@
+# Copyright 2025 Qwen-Image Team, The InstantX Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import QwenImageLoraLoaderMixin
+from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
+from ...models.controlnets.controlnet_qwenimage import QwenImageControlNetModel, QwenImageMultiControlNetModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import QwenImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers.utils import load_image
+ >>> from diffusers import QwenImageControlNetModel, QwenImageControlNetInpaintPipeline
+
+ >>> base_model_path = "Qwen/Qwen-Image"
+ >>> controlnet_model_path = "InstantX/Qwen-Image-ControlNet-Inpainting"
+ >>> controlnet = QwenImageControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.bfloat16)
+ >>> pipe = QwenImageControlNetInpaintPipeline.from_pretrained(
+ ... base_model_path, controlnet=controlnet, torch_dtype=torch.bfloat16
+ ... ).to("cuda")
+ >>> image = load_image(
+ ... "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/images/image1.png"
+ ... )
+ >>> mask_image = load_image(
+ ... "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/masks/mask1.png"
+ ... )
+ >>> prompt = "一辆绿色的出租车行驶在路上"
+ >>> result = pipe(
+ ... prompt=prompt,
+ ... control_image=image,
+ ... control_mask=mask_image,
+ ... controlnet_conditioning_scale=1.0,
+ ... width=mask_image.size[0],
+ ... height=mask_image.size[1],
+ ... true_cfg_scale=4.0,
+ ... ).images[0]
+ >>> image.save("qwenimage_controlnet_inpaint.png")
+ ```
+"""
+
+
+# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class QwenImageControlNetInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
+ r"""
+ The QwenImage pipeline for text-to-image generation.
+
+ Args:
+ transformer ([`QwenImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`QwenTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLQwenImage,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ transformer: QwenImageTransformer2DModel,
+ controlnet: QwenImageControlNetModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ controlnet=controlnet,
+ )
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2,
+ do_resize=True,
+ do_convert_grayscale=True,
+ do_normalize=False,
+ do_binarize=True,
+ )
+
+ self.tokenizer_max_length = 1024
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
+ self.prompt_template_encode_start_idx = 34
+ self.default_sample_size = 128
+
+ # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.extract_masked_hidden
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+
+ return split_result
+
+ # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.get_qwen_prompt_embeds
+ def _get_qwen_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ template = self.prompt_template_encode
+ drop_idx = self.prompt_template_encode_start_idx
+ txt = [template.format(e) for e in prompt]
+ txt_tokens = self.tokenizer(
+ txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
+ ).to(self.device)
+ encoder_hidden_states = self.text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask,
+ output_hidden_states=True,
+ )
+ hidden_states = encoder_hidden_states.hidden_states[-1]
+ split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ # Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
+
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, 1, num_channels_latents, height, width)
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ return latents
+
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if isinstance(image, torch.Tensor):
+ pass
+ else:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ if do_classifier_free_guidance and not guess_mode:
+ image = torch.cat([image] * 2)
+
+ return image
+
+ def prepare_image_with_mask(
+ self,
+ image,
+ mask,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ if isinstance(image, torch.Tensor):
+ pass
+ else:
+ image = self.image_processor.preprocess(image, height=height, width=width)
+
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+ image = image.to(device=device, dtype=dtype) # (bsz, 3, height_ori, width_ori)
+
+ # Prepare mask
+ if isinstance(mask, torch.Tensor):
+ pass
+ else:
+ mask = self.mask_processor.preprocess(mask, height=height, width=width)
+ mask = mask.repeat_interleave(repeat_by, dim=0)
+ mask = mask.to(device=device, dtype=dtype) # (bsz, 1, height_ori, width_ori)
+
+ if image.ndim == 4:
+ image = image.unsqueeze(2)
+
+ if mask.ndim == 4:
+ mask = mask.unsqueeze(2)
+
+ # Get masked image
+ masked_image = image.clone()
+ masked_image[(mask > 0.5).repeat(1, 3, 1, 1, 1)] = -1 # (bsz, 3, 1, height_ori, width_ori)
+
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
+ latents_mean = (torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1)).to(device)
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ device
+ )
+
+ # Encode to latents
+ image_latents = self.vae.encode(masked_image.to(self.vae.dtype)).latent_dist.sample()
+ image_latents = (image_latents - latents_mean) * latents_std
+ image_latents = image_latents.to(dtype) # torch.Size([1, 16, 1, height_ori//8, width_ori//8])
+
+ mask = torch.nn.functional.interpolate(
+ mask, size=(image_latents.shape[-3], image_latents.shape[-2], image_latents.shape[-1])
+ )
+ mask = 1 - mask # torch.Size([1, 1, 1, height_ori//8, width_ori//8])
+
+ control_image = torch.cat(
+ [image_latents, mask], dim=1
+ ) # torch.Size([1, 16+1, 1, height_ori//8, width_ori//8])
+
+ control_image = control_image.permute(0, 2, 1, 3, 4) # torch.Size([1, 1, 16+1, height_ori//8, width_ori//8])
+
+ # pack
+ control_image = self._pack_latents(
+ control_image,
+ batch_size=control_image.shape[0],
+ num_channels_latents=control_image.shape[2],
+ height=control_image.shape[3],
+ width=control_image.shape[4],
+ )
+
+ if do_classifier_free_guidance and not guess_mode:
+ control_image = torch.cat([control_image] * 2)
+
+ return control_image
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ true_cfg_scale: float = 4.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 1.0,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ control_image: PipelineImageInput = None,
+ control_mask: PipelineImageInput = None,
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 3.5):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(control_image) if isinstance(self.controlnet, QwenImageMultiControlNetModel) else 1
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
+ )
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_true_cfg:
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 3. Prepare control image
+ num_channels_latents = self.transformer.config.in_channels // 4
+ if isinstance(self.controlnet, QwenImageControlNetModel):
+ control_image = self.prepare_image_with_mask(
+ image=control_image,
+ mask=control_mask,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=self.vae.dtype,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ controlnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps[0] if isinstance(self.controlnet, QwenImageControlNetModel) else keeps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ # 6. Denoising loop
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ if isinstance(controlnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+ else:
+ controlnet_cond_scale = controlnet_conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+ # controlnet
+ controlnet_block_samples = self.controlnet(
+ hidden_states=latents,
+ controlnet_cond=control_image.to(dtype=latents.dtype, device=device),
+ conditioning_scale=cond_scale,
+ timestep=timestep / 1000,
+ encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ img_shapes=img_shapes,
+ txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
+ return_dict=False,
+ )
+
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ img_shapes=img_shapes,
+ txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
+ controlnet_block_samples=controlnet_block_samples,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
+ encoder_hidden_states=negative_prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
+ controlnet_block_samples=controlnet_block_samples,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
+ noise_pred = comb_pred * (cond_norm / noise_norm)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return QwenImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
new file mode 100644
index 0000000000..ed37b238c8
--- /dev/null
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
@@ -0,0 +1,899 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import QwenImageLoraLoaderMixin
+from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import QwenImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from PIL import Image
+ >>> from diffusers import QwenImageEditPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
+ ... ).convert("RGB")
+ >>> prompt = (
+ ... "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors"
+ ... )
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(image, prompt, num_inference_steps=50).images[0]
+ >>> image.save("qwenimage_edit.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+def calculate_dimensions(target_area, ratio):
+ width = math.sqrt(target_area * ratio)
+ height = width / ratio
+
+ width = round(width / 32) * 32
+ height = round(height / 32) * 32
+
+ return width, height, None
+
+
+class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
+ r"""
+ The Qwen-Image-Edit pipeline for image editing.
+
+ Args:
+ transformer ([`QwenImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`QwenTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLQwenImage,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ processor: Qwen2VLProcessor,
+ transformer: QwenImageTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ processor=processor,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.tokenizer_max_length = 1024
+
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
+ self.prompt_template_encode_start_idx = 64
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+
+ return split_result
+
+ def _get_qwen_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ template = self.prompt_template_encode
+ drop_idx = self.prompt_template_encode_start_idx
+ txt = [template.format(e) for e in prompt]
+
+ model_inputs = self.processor(
+ text=txt,
+ images=image,
+ padding=True,
+ return_tensors="pt",
+ ).to(device)
+
+ outputs = self.text_encoder(
+ input_ids=model_inputs.input_ids,
+ attention_mask=model_inputs.attention_mask,
+ pixel_values=model_inputs.pixel_values,
+ image_grid_thw=model_inputs.image_grid_thw,
+ output_hidden_states=True,
+ )
+
+ hidden_states = outputs.hidden_states[-1]
+ split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ image: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ image (`torch.Tensor`, *optional*):
+ image to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
+
+ return latents
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.latent_channels, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std)
+ .view(1, self.latent_channels, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ image_latents = (image_latents - latents_mean) / latents_std
+
+ return image_latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ image,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, 1, num_channels_latents, height, width)
+
+ image_latents = None
+ if image is not None:
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ image_latent_height, image_latent_width = image_latents.shape[3:]
+ image_latents = self._pack_latents(
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ return latents, image_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Optional[PipelineImageInput] = None,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ true_cfg_scale: float = 4.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: Optional[float] = None,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free
+ Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of
+ equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is
+ enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
+ encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
+ lower image quality.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to None):
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
+ scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
+ that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
+ parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
+ ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
+ please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
+ enable classifier-free guidance computations).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+ image_size = image[0].size if isinstance(image, list) else image.size
+ calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
+ height = height or calculated_height
+ width = width or calculated_width
+
+ multiple_of = self.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # 3. Preprocess image
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
+ image = self.image_processor.resize(image, calculated_height, calculated_width)
+ prompt_image = image
+ image = self.image_processor.preprocess(image, calculated_height, calculated_width)
+ image = image.unsqueeze(2)
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
+ )
+
+ if true_cfg_scale > 1 and not has_neg_prompt:
+ logger.warning(
+ f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
+ )
+ elif true_cfg_scale <= 1 and has_neg_prompt:
+ logger.warning(
+ " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
+ )
+
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
+ image=prompt_image,
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_true_cfg:
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
+ image=prompt_image,
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, image_latents = self.prepare_latents(
+ image,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ img_shapes = [
+ [
+ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
+ (1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2),
+ ]
+ ] * batch_size
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds and guidance_scale is None:
+ raise ValueError("guidance_scale is required for guidance-distilled model.")
+ elif self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
+ logger.warning(
+ f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
+ )
+ guidance = None
+ elif not self.transformer.config.guidance_embeds and guidance_scale is None:
+ guidance = None
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
+ negative_txt_seq_lens = (
+ negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
+ )
+
+ # 6. Denoising loop
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ latent_model_input = latents
+ if image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ encoder_hidden_states=prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred[:, : latents.size(1)]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
+ encoder_hidden_states=negative_prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=negative_txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
+ noise_pred = comb_pred * (cond_norm / noise_norm)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return QwenImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py
new file mode 100644
index 0000000000..d54d1881fa
--- /dev/null
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py
@@ -0,0 +1,1130 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import QwenImageLoraLoaderMixin
+from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import QwenImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from PIL import Image
+ >>> from diffusers import QwenImageEditInpaintPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = QwenImageEditInpaintPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+ >>> source = load_image(img_url)
+ >>> mask = load_image(mask_url)
+ >>> image = pipe(
+ ... prompt=prompt, negative_prompt=" ", image=source, mask_image=mask, strength=1.0, num_inference_steps=50
+ ... ).images[0]
+ >>> image.save("qwenimage_inpainting.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.calculate_dimensions
+def calculate_dimensions(target_area, ratio):
+ width = math.sqrt(target_area * ratio)
+ height = width / ratio
+
+ width = round(width / 32) * 32
+ height = round(height / 32) * 32
+
+ return width, height, None
+
+
+class QwenImageEditInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
+ r"""
+ The Qwen-Image-Edit pipeline for image editing.
+
+ Args:
+ transformer ([`QwenImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`QwenTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLQwenImage,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ processor: Qwen2VLProcessor,
+ transformer: QwenImageTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ processor=processor,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2,
+ vae_latent_channels=self.latent_channels,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
+ self.vl_processor = processor
+ self.tokenizer_max_length = 1024
+
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
+ self.prompt_template_encode_start_idx = 64
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+
+ return split_result
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._get_qwen_prompt_embeds
+ def _get_qwen_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ template = self.prompt_template_encode
+ drop_idx = self.prompt_template_encode_start_idx
+ txt = [template.format(e) for e in prompt]
+
+ model_inputs = self.processor(
+ text=txt,
+ images=image,
+ padding=True,
+ return_tensors="pt",
+ ).to(device)
+
+ outputs = self.text_encoder(
+ input_ids=model_inputs.input_ids,
+ attention_mask=model_inputs.attention_mask,
+ pixel_values=model_inputs.pixel_values,
+ image_grid_thw=model_inputs.image_grid_thw,
+ output_hidden_states=True,
+ )
+
+ hidden_states = outputs.hidden_states[-1]
+ split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ image: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ image (`torch.Tensor`, *optional*):
+ image to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ padding_mask_crop=None,
+ max_sequence_length=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+ if padding_mask_crop is not None:
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ )
+ if not isinstance(mask_image, PIL.Image.Image):
+ raise ValueError(
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+ f" {type(mask_image)}."
+ )
+ if output_type != "pil":
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
+
+ return latents
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_img2img.QwenImageImg2ImgPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ image_latents.device, image_latents.dtype
+ )
+
+ image_latents = (image_latents - latents_mean) * latents_std
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, 1, num_channels_latents, height, width)
+
+ # If image is [B,C,H,W] -> add T=1. If it's already [B,C,T,H,W], leave it.
+ if image.dim() == 4:
+ image = image.unsqueeze(2)
+ elif image.dim() != 5:
+ raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.")
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator) # [B,z,1,H',W']
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ image_latents = image_latents.transpose(1, 2) # [B,1,z,H',W']
+
+ if latents is None:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+ else:
+ noise = latents.to(device)
+ latents = noise
+
+ noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
+ image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ return latents, noise, image_latents
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.prepare_mask_latents
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
+ mask = mask.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if masked_image.dim() == 4:
+ masked_image = masked_image.unsqueeze(2)
+ elif masked_image.dim() != 5:
+ raise ValueError(f"Expected image dims 4 or 5, got {masked_image.dim()}.")
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ if masked_image.shape[1] == self.latent_channels:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = self._encode_vae_image(image=masked_image, generator=generator)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1, 1)
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ masked_image_latents = self._pack_latents(
+ masked_image_latents,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+ mask = self._pack_latents(
+ mask.repeat(1, num_channels_latents, 1, 1),
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+
+ return mask, masked_image_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Optional[PipelineImageInput] = None,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: PipelineImageInput = None,
+ true_cfg_scale: float = 4.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ padding_mask_crop: Optional[int] = None,
+ strength: float = 0.6,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: Optional[float] = None,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free
+ Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of
+ equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is
+ enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
+ encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
+ lower image quality.
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
+ 1)`, or `(H, W)`.
+ mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
+ `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
+ latents tensor will ge generated by `mask_image`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
+ the image is large and contain information irrelevant for inpainting, such as background.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to None):
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
+ scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
+ that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
+ parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
+ ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
+ please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
+ enable classifier-free guidance computations).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+ image_size = image[0].size if isinstance(image, list) else image.size
+ calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
+
+ # height and width are the same as the calculated height and width
+ height = calculated_height
+ width = calculated_width
+
+ multiple_of = self.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type=output_type,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ padding_mask_crop=padding_mask_crop,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # 3. Preprocess image
+ if padding_mask_crop is not None:
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
+ image = self.image_processor.resize(image, calculated_height, calculated_width)
+ original_image = image
+ prompt_image = image
+ image = self.image_processor.preprocess(
+ image,
+ height=calculated_height,
+ width=calculated_width,
+ crops_coords=crops_coords,
+ resize_mode=resize_mode,
+ )
+ image = image.to(dtype=torch.float32)
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
+ )
+
+ if true_cfg_scale > 1 and not has_neg_prompt:
+ logger.warning(
+ f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
+ )
+ elif true_cfg_scale <= 1 and has_neg_prompt:
+ logger.warning(
+ " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
+ )
+
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
+ image=prompt_image,
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_true_cfg:
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
+ image=prompt_image,
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 4. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, noise, image_latents = self.prepare_latents(
+ image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ mask_condition = self.mask_processor.preprocess(
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+
+ if masked_image_latents is None:
+ masked_image = image * (mask_condition < 0.5)
+ else:
+ masked_image = masked_image_latents
+
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask_condition,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ img_shapes = [
+ [
+ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
+ (1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2),
+ ]
+ ] * batch_size
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds and guidance_scale is None:
+ raise ValueError("guidance_scale is required for guidance-distilled model.")
+ elif self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
+ logger.warning(
+ f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
+ )
+ guidance = None
+ elif not self.transformer.config.guidance_embeds and guidance_scale is None:
+ guidance = None
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
+ negative_txt_seq_lens = (
+ negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
+ )
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ latent_model_input = latents
+ if image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ encoder_hidden_states=prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred[:, : latents.size(1)]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
+ encoder_hidden_states=negative_prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=negative_txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
+ noise_pred = comb_pred * (cond_norm / noise_norm)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # for 64 channel transformer only.
+ init_latents_proper = image_latents
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.scale_noise(
+ init_latents_proper, torch.tensor([noise_timestep]), noise
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ if padding_mask_crop is not None:
+ image = [
+ self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image
+ ]
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return QwenImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py
new file mode 100644
index 0000000000..ec203edf16
--- /dev/null
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py
@@ -0,0 +1,883 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import QwenImageLoraLoaderMixin
+from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import QwenImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from PIL import Image
+ >>> from diffusers import QwenImageEditPlusPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = QwenImageEditPlusPipeline.from_pretrained("Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
+ ... ).convert("RGB")
+ >>> prompt = (
+ ... "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors"
+ ... )
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(image, prompt, num_inference_steps=50).images[0]
+ >>> image.save("qwenimage_edit_plus.png")
+ ```
+"""
+
+CONDITION_IMAGE_SIZE = 384 * 384
+VAE_IMAGE_SIZE = 1024 * 1024
+
+
+# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+def calculate_dimensions(target_area, ratio):
+ width = math.sqrt(target_area * ratio)
+ height = width / ratio
+
+ width = round(width / 32) * 32
+ height = round(height / 32) * 32
+
+ return width, height
+
+
+class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
+ r"""
+ The Qwen-Image-Edit pipeline for image editing.
+
+ Args:
+ transformer ([`QwenImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`QwenTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLQwenImage,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ processor: Qwen2VLProcessor,
+ transformer: QwenImageTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ processor=processor,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.tokenizer_max_length = 1024
+
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
+ self.prompt_template_encode_start_idx = 64
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+
+ return split_result
+
+ def _get_qwen_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
+ if isinstance(image, list):
+ base_img_prompt = ""
+ for i, img in enumerate(image):
+ base_img_prompt += img_prompt_template.format(i + 1)
+ elif image is not None:
+ base_img_prompt = img_prompt_template.format(1)
+ else:
+ base_img_prompt = ""
+
+ template = self.prompt_template_encode
+
+ drop_idx = self.prompt_template_encode_start_idx
+ txt = [template.format(base_img_prompt + e) for e in prompt]
+
+ model_inputs = self.processor(
+ text=txt,
+ images=image,
+ padding=True,
+ return_tensors="pt",
+ ).to(device)
+
+ outputs = self.text_encoder(
+ input_ids=model_inputs.input_ids,
+ attention_mask=model_inputs.attention_mask,
+ pixel_values=model_inputs.pixel_values,
+ image_grid_thw=model_inputs.image_grid_thw,
+ output_hidden_states=True,
+ )
+
+ hidden_states = outputs.hidden_states[-1]
+ split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ image: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ image (`torch.Tensor`, *optional*):
+ image to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
+
+ return latents
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.latent_channels, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std)
+ .view(1, self.latent_channels, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ image_latents = (image_latents - latents_mean) / latents_std
+
+ return image_latents
+
+ def prepare_latents(
+ self,
+ images,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, 1, num_channels_latents, height, width)
+
+ image_latents = None
+ if images is not None:
+ if not isinstance(images, list):
+ images = [images]
+ all_image_latents = []
+ for image in images:
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ image_latent_height, image_latent_width = image_latents.shape[3:]
+ image_latents = self._pack_latents(
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
+ )
+ all_image_latents.append(image_latents)
+ image_latents = torch.cat(all_image_latents, dim=1)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ return latents, image_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Optional[PipelineImageInput] = None,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ true_cfg_scale: float = 4.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: Optional[float] = None,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free
+ Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of
+ equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is
+ enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
+ encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
+ lower image quality.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to None):
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
+ scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
+ that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
+ parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
+ ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
+ please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
+ enable classifier-free guidance computations).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+ image_size = image[-1].size if isinstance(image, list) else image.size
+ calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
+ height = height or calculated_height
+ width = width or calculated_width
+
+ multiple_of = self.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # 3. Preprocess image
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
+ if not isinstance(image, list):
+ image = [image]
+ condition_image_sizes = []
+ condition_images = []
+ vae_image_sizes = []
+ vae_images = []
+ for img in image:
+ image_width, image_height = img.size
+ condition_width, condition_height = calculate_dimensions(
+ CONDITION_IMAGE_SIZE, image_width / image_height
+ )
+ vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height)
+ condition_image_sizes.append((condition_width, condition_height))
+ vae_image_sizes.append((vae_width, vae_height))
+ condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
+ vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
+ )
+
+ if true_cfg_scale > 1 and not has_neg_prompt:
+ logger.warning(
+ f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
+ )
+ elif true_cfg_scale <= 1 and has_neg_prompt:
+ logger.warning(
+ " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
+ )
+
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
+ image=condition_images,
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_true_cfg:
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
+ image=condition_images,
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, image_latents = self.prepare_latents(
+ vae_images,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ img_shapes = [
+ [
+ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
+ *[
+ (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
+ for vae_width, vae_height in vae_image_sizes
+ ],
+ ]
+ ] * batch_size
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds and guidance_scale is None:
+ raise ValueError("guidance_scale is required for guidance-distilled model.")
+ elif self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
+ logger.warning(
+ f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
+ )
+ guidance = None
+ elif not self.transformer.config.guidance_embeds and guidance_scale is None:
+ guidance = None
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
+ negative_txt_seq_lens = (
+ negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
+ )
+
+ # 6. Denoising loop
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ latent_model_input = latents
+ if image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ encoder_hidden_states=prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred[:, : latents.size(1)]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
+ encoder_hidden_states=negative_prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=negative_txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
+ noise_pred = comb_pred * (cond_norm / noise_norm)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return QwenImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py
new file mode 100644
index 0000000000..cb4c5d8016
--- /dev/null
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py
@@ -0,0 +1,874 @@
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import QwenImageLoraLoaderMixin
+from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import QwenImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import QwenImageImg2ImgPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = QwenImageImg2ImgPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
+ >>> pipe = pipe.to("cuda")
+ >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
+ >>> init_image = load_image(url).resize((1024, 1024))
+ >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney"
+ >>> images = pipe(prompt=prompt, negative_prompt=" ", image=init_image, strength=0.95).images[0]
+ >>> images.save("qwenimage_img2img.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
+ r"""
+ The QwenImage pipeline for text-to-image generation.
+
+ Args:
+ transformer ([`QwenImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`QwenTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLQwenImage,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ transformer: QwenImageTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
+ self.image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
+ )
+ self.tokenizer_max_length = 1024
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
+ self.prompt_template_encode_start_idx = 34
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+
+ return split_result
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._get_qwen_prompt_embeds
+ def _get_qwen_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ template = self.prompt_template_encode
+ drop_idx = self.prompt_template_encode_start_idx
+ txt = [template.format(e) for e in prompt]
+ txt_tokens = self.tokenizer(
+ txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
+ ).to(device)
+ encoder_hidden_states = self.text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask,
+ output_hidden_states=True,
+ )
+ hidden_states = encoder_hidden_states.hidden_states[-1]
+ split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ image_latents.device, image_latents.dtype
+ )
+
+ image_latents = (image_latents - latents_mean) * latents_std
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied fromCopied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
+
+ prompt_embeds = prompt_embeds[:, :max_sequence_length]
+ prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ strength,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
+
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, 1, num_channels_latents, height, width)
+
+ # If image is [B,C,H,W] -> add T=1. If it's already [B,C,T,H,W], leave it.
+ if image.dim() == 4:
+ image = image.unsqueeze(2)
+ elif image.dim() != 5:
+ raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.")
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator) # [B,z,1,H',W']
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ image_latents = image_latents.transpose(1, 2) # [B,1,z,H',W']
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ true_cfg_scale: float = 4.0,
+ image: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 0.6,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: Optional[float] = None,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by
+ setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to
+ generate images that are closely linked to the text `prompt`, usually at the expense of lower image
+ quality.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to None):
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
+ scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
+ that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
+ parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
+ ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
+ please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
+ enable classifier-free guidance computations).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ strength,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Preprocess image
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
+ init_image = init_image.to(dtype=torch.float32)
+
+ # 3. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
+ )
+
+ if true_cfg_scale > 1 and not has_neg_prompt:
+ logger.warning(
+ f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
+ )
+ elif true_cfg_scale <= 1 and has_neg_prompt:
+ logger.warning(
+ " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
+ )
+
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_true_cfg:
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 4. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents = self.prepare_latents(
+ init_image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds and guidance_scale is None:
+ raise ValueError("guidance_scale is required for guidance-distilled model.")
+ elif self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
+ logger.warning(
+ f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
+ )
+ guidance = None
+ elif not self.transformer.config.guidance_embeds and guidance_scale is None:
+ guidance = None
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
+ negative_txt_seq_lens = (
+ negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
+ )
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ encoder_hidden_states=prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
+ encoder_hidden_states=negative_prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=negative_txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
+ noise_pred = comb_pred * (cond_norm / noise_norm)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+
+ latents = latents / latents_std + latents_mean
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return QwenImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py
new file mode 100644
index 0000000000..1915c27eb2
--- /dev/null
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py
@@ -0,0 +1,1060 @@
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import QwenImageLoraLoaderMixin
+from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import QwenImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import QwenImageInpaintPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = QwenImageInpaintPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+ >>> source = load_image(img_url)
+ >>> mask = load_image(mask_url)
+ >>> image = pipe(prompt=prompt, negative_prompt=" ", image=source, mask_image=mask, strength=0.85).images[0]
+ >>> image.save("qwenimage_inpainting.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
+ r"""
+ The QwenImage pipeline for text-to-image generation.
+
+ Args:
+ transformer ([`QwenImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`QwenTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLQwenImage,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ transformer: QwenImageTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
+ self.image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
+ )
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2,
+ vae_latent_channels=self.latent_channels,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
+ self.tokenizer_max_length = 1024
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
+ self.prompt_template_encode_start_idx = 34
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+
+ return split_result
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._get_qwen_prompt_embeds
+ def _get_qwen_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ template = self.prompt_template_encode
+ drop_idx = self.prompt_template_encode_start_idx
+ txt = [template.format(e) for e in prompt]
+ txt_tokens = self.tokenizer(
+ txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
+ ).to(device)
+ encoder_hidden_states = self.text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask,
+ output_hidden_states=True,
+ )
+ hidden_states = encoder_hidden_states.hidden_states[-1]
+ split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_img2img.QwenImageImg2ImgPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ image_latents.device, image_latents.dtype
+ )
+
+ image_latents = (image_latents - latents_mean) * latents_std
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied fromCopied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
+
+ prompt_embeds = prompt_embeds[:, :max_sequence_length]
+ prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ padding_mask_crop=None,
+ max_sequence_length=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+ if padding_mask_crop is not None:
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ )
+ if not isinstance(mask_image, PIL.Image.Image):
+ raise ValueError(
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+ f" {type(mask_image)}."
+ )
+ if output_type != "pil":
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
+
+ return latents
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ image,
+ timestep,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, 1, num_channels_latents, height, width)
+
+ # If image is [B,C,H,W] -> add T=1. If it's already [B,C,T,H,W], leave it.
+ if image.dim() == 4:
+ image = image.unsqueeze(2)
+ elif image.dim() != 5:
+ raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.")
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator) # [B,z,1,H',W']
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ image_latents = image_latents.transpose(1, 2) # [B,1,z,H',W']
+
+ if latents is None:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+ else:
+ noise = latents.to(device)
+ latents = noise
+
+ noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
+ image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ return latents, noise, image_latents
+
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
+ mask = mask.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if masked_image.dim() == 4:
+ masked_image = masked_image.unsqueeze(2)
+ elif masked_image.dim() != 5:
+ raise ValueError(f"Expected image dims 4 or 5, got {masked_image.dim()}.")
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ if masked_image.shape[1] == self.latent_channels:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = self._encode_vae_image(image=masked_image, generator=generator)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1, 1)
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ masked_image_latents = self._pack_latents(
+ masked_image_latents,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+ mask = self._pack_latents(
+ mask.repeat(1, num_channels_latents, 1, 1),
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+
+ return mask, masked_image_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ true_cfg_scale: float = 4.0,
+ image: PipelineImageInput = None,
+ mask_image: PipelineImageInput = None,
+ masked_image_latents: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ padding_mask_crop: Optional[int] = None,
+ strength: float = 0.6,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: Optional[float] = None,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by
+ setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to
+ generate images that are closely linked to the text `prompt`, usually at the expense of lower image
+ quality.
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
+ 1)`, or `(H, W)`.
+ mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
+ `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
+ latents tensor will be generated by `mask_image`.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
+ the image is large and contain information irrelevant for inpainting, such as background.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to None):
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
+ scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
+ that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
+ parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
+ ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
+ please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
+ enable classifier-free guidance computations).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type=output_type,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ padding_mask_crop=padding_mask_crop,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Preprocess image
+ if padding_mask_crop is not None:
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ original_image = image
+ init_image = self.image_processor.preprocess(
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
+ )
+ init_image = init_image.to(dtype=torch.float32)
+
+ # 3. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
+ )
+
+ if true_cfg_scale > 1 and not has_neg_prompt:
+ logger.warning(
+ f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
+ )
+ elif true_cfg_scale <= 1 and has_neg_prompt:
+ logger.warning(
+ " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
+ )
+
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_true_cfg:
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 4. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+
+ latents, noise, image_latents = self.prepare_latents(
+ init_image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ mask_condition = self.mask_processor.preprocess(
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+
+ if masked_image_latents is None:
+ masked_image = init_image * (mask_condition < 0.5)
+ else:
+ masked_image = masked_image_latents
+
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask_condition,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds and guidance_scale is None:
+ raise ValueError("guidance_scale is required for guidance-distilled model.")
+ elif self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
+ logger.warning(
+ f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
+ )
+ guidance = None
+ elif not self.transformer.config.guidance_embeds and guidance_scale is None:
+ guidance = None
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
+ negative_txt_seq_lens = (
+ negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
+ )
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ encoder_hidden_states=prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
+ encoder_hidden_states=negative_prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=negative_txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
+ noise_pred = comb_pred * (cond_norm / noise_norm)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ # for 64 channel transformer only.
+ init_latents_proper = image_latents
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.scale_noise(
+ init_latents_proper, torch.tensor([noise_timestep]), noise
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+
+ latents = latents / latents_std + latents_mean
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ if padding_mask_crop is not None:
+ image = [
+ self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image
+ ]
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return QwenImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py
index 103f57a236..ac979305ca 100644
--- a/src/diffusers/pipelines/sana/pipeline_sana.py
+++ b/src/diffusers/pipelines/sana/pipeline_sana.py
@@ -30,6 +30,7 @@ from ...schedulers import DPMSolverMultistepScheduler
from ...utils import (
BACKENDS_MAPPING,
USE_PEFT_BACKEND,
+ deprecate,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
@@ -224,6 +225,12 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -231,6 +238,12 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -239,6 +252,12 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -246,6 +265,12 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def _get_gemma_prompt_embeds(
@@ -781,7 +806,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py
index cdc602b964..55ed7b84eb 100644
--- a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py
+++ b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py
@@ -30,6 +30,7 @@ from ...schedulers import DPMSolverMultistepScheduler
from ...utils import (
BACKENDS_MAPPING,
USE_PEFT_BACKEND,
+ deprecate,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
@@ -237,6 +238,12 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -244,6 +251,12 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -252,6 +265,12 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -259,6 +278,12 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
@@ -844,7 +869,7 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
index e8f9d8368f..62b9788292 100644
--- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
+++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
@@ -30,6 +30,7 @@ from ...schedulers import DPMSolverMultistepScheduler
from ...utils import (
BACKENDS_MAPPING,
USE_PEFT_BACKEND,
+ deprecate,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
@@ -175,6 +176,12 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -182,6 +189,12 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -190,6 +203,12 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -197,6 +216,12 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
@@ -663,7 +688,7 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py
index bf290c3ced..8899ed84c4 100644
--- a/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py
+++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py
@@ -31,6 +31,7 @@ from ...schedulers import DPMSolverMultistepScheduler
from ...utils import (
BACKENDS_MAPPING,
USE_PEFT_BACKEND,
+ deprecate,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
@@ -183,6 +184,12 @@ class SanaSprintImg2ImgPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.disable_vae_slicing
@@ -191,6 +198,12 @@ class SanaSprintImg2ImgPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.enable_vae_tiling
@@ -200,6 +213,12 @@ class SanaSprintImg2ImgPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -207,6 +226,12 @@ class SanaSprintImg2ImgPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
@@ -736,7 +761,7 @@ class SanaSprintImg2ImgPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py b/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py
index 89d4d2dca5..07b382dfc4 100644
--- a/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py
+++ b/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py
@@ -18,7 +18,6 @@ from typing import Optional
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
diff --git a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
index 8861ecae7d..b7faf097ab 100644
--- a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
+++ b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
@@ -25,11 +25,7 @@ from transformers import (
from ...models import AutoencoderOobleck, StableAudioDiTModel
from ...models.embeddings import get_1d_rotary_pos_embed
from ...schedulers import EDMDPMSolverMultistepScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
from .modeling_stable_audio import StableAudioProjectionModel
@@ -134,6 +130,12 @@ class StableAudioPipeline(DiffusionPipeline):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing
@@ -142,6 +144,12 @@ class StableAudioPipeline(DiffusionPipeline):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def encode_prompt(
diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
index 6130a9873c..aa39983c4e 100644
--- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
+++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
@@ -362,7 +362,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py
index b705c7e6e5..b3dc23f2e5 100644
--- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py
+++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py
@@ -237,7 +237,7 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
index b3b46af206..9e63b3489c 100644
--- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
+++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
@@ -442,7 +442,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
index 1afa7698da..6befe77aa4 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
@@ -349,12 +349,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions.
-
-
- This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a
- future release.
-
-
+ > [!WARNING] > This argument exists because `__call__` is not yet end-to-end pmap-able. It will be
+ removed in a > future release.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
index 78e3ba239c..81656beba7 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
@@ -389,12 +389,8 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions.
-
-
- This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a
- future release.
-
-
+ > [!WARNING] > This argument exists because `__call__` is not yet end-to-end pmap-able. It will be
+ removed in a > future release.
Examples:
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
index b7e17ba681..5938fe232a 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
@@ -103,11 +103,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
r"""
Flax-based pipeline for text-guided image inpainting using Stable Diffusion.
-
-
- 🧪 This is an experimental feature!
-
-
+ > [!WARNING] > 🧪 This is an experimental feature!
This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
@@ -435,12 +431,8 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions.
-
-
- This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a
- future release.
-
-
+ > [!WARNING] > This argument exists because `__call__` is not yet end-to-end pmap-able. It will be
+ removed in a > future release.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
index 06c2076816..6ebe0986a1 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
@@ -313,7 +313,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
latents (`np.ndarray`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`np.ndarray`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
index 141d849ec3..158bcabbeb 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
@@ -378,7 +378,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
latents (`np.ndarray`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`np.ndarray`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
index 882fa98b07..a765163175 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
@@ -398,7 +398,7 @@ class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`np.ndarray`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
index afee3f61e9..1618f89a49 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
@@ -854,7 +854,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
index fa1e0a4f32..7e97909f42 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
@@ -909,7 +909,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
index 937f7195b2..bed596e57c 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
@@ -984,7 +984,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
1)`, or `(H, W)`.
mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
`Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
- latents tensor will ge generated by `mask_image`.
+ latents tensor will be generated by `mask_image`.
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
@@ -1033,7 +1033,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
index 87bd9f4444..65c25ffbe4 100644
--- a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
+++ b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
@@ -249,11 +249,7 @@ class StableDiffusionDiffEditPipeline(
StableDiffusionLoraLoaderMixin,
):
r"""
-
-
- This is an experimental feature!
-
-
+ > [!WARNING] > This is an experimental feature!
Pipeline for text-guided image inpainting using Stable Diffusion and DiffEdit.
diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
index 350a492826..feebd6adf8 100755
--- a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
@@ -81,11 +81,7 @@ class StableDiffusionKDiffusionPipeline(
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
-
-
- This is an experimental pipeline and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental pipeline and is likely to change in the future.
Args:
vae ([`AutoencoderKL`]):
@@ -539,7 +535,7 @@ class StableDiffusionKDiffusionPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py
index 3b57555071..766ca37d81 100644
--- a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py
@@ -652,7 +652,7 @@ class StableDiffusionXLKDiffusionPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
index 9ac64a0d84..b97cf6f1f6 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
@@ -937,7 +937,7 @@ class StableDiffusionXLPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
index e63c7a55ce..44e8f4fe4b 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
@@ -1097,7 +1097,7 @@ class StableDiffusionXLImg2ImgPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
index f0bc9b9bb3..18f8536a75 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
@@ -1251,7 +1251,7 @@ class StableDiffusionXLInpaintPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
index b1379d1b29..58b0083617 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
@@ -695,7 +695,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
index 5c561721fc..1ce6987114 100644
--- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
+++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
@@ -760,7 +760,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin, Fr
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
index 13183df47d..2802d690f3 100644
--- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
+++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
@@ -971,7 +971,7 @@ class StableDiffusionXLAdapterPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
index a9fa43c1f5..288aae6c0d 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
@@ -1051,7 +1051,7 @@ class TextToVideoZeroSDXLPipeline(
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
motion_field_strength_x (`float`, *optional*, defaults to 12):
Strength of motion in generated video along x-axis. See the
[paper](https://huggingface.co/papers/2303.13439), Sect. 3.3.1.
diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
index 40fd3b3373..f9298d5b86 100644
--- a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
+++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
@@ -232,6 +232,12 @@ class UniDiffuserPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing
@@ -240,6 +246,12 @@ class UniDiffuserPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_tiling
@@ -249,6 +261,12 @@ class UniDiffuserPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_tiling
@@ -257,6 +275,12 @@ class UniDiffuserPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
# Functions to manually set the mode
diff --git a/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py b/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py
index 68130baad7..91a54e1ae8 100644
--- a/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py
+++ b/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py
@@ -22,11 +22,7 @@ from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversio
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import (
- is_torch_xla_available,
- logging,
- replace_example_docstring,
-)
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ..flux.pipeline_flux_fill import FluxFillPipeline as VisualClozeUpsamplingPipeline
from ..flux.pipeline_output import FluxPipelineOutput
from ..pipeline_utils import DiffusionPipeline
@@ -319,7 +315,7 @@ class VisualClozePipeline(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py b/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py
index e7a1d4a4b2..e12995106b 100644
--- a/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py
+++ b/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py
@@ -24,6 +24,7 @@ from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
+ deprecate,
is_torch_xla_available,
logging,
replace_example_docstring,
@@ -524,6 +525,12 @@ class VisualClozeGenerationPipeline(
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
+ depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
+ deprecate(
+ "enable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_slicing()
def disable_vae_slicing(self):
@@ -531,6 +538,12 @@ class VisualClozeGenerationPipeline(
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
+ deprecate(
+ "disable_vae_slicing",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_slicing()
def enable_vae_tiling(self):
@@ -539,6 +552,12 @@ class VisualClozeGenerationPipeline(
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
+ depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
+ deprecate(
+ "enable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.enable_tiling()
def disable_vae_tiling(self):
@@ -546,6 +565,12 @@ class VisualClozeGenerationPipeline(
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
+ depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
+ deprecate(
+ "disable_vae_tiling",
+ "0.40.0",
+ depr_message,
+ )
self.vae.disable_tiling()
def _prepare_latents(self, image, mask, gen, vae_scale_factor, device, dtype):
@@ -736,7 +761,7 @@ class VisualClozeGenerationPipeline(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/visualcloze/visualcloze_utils.py b/src/diffusers/pipelines/visualcloze/visualcloze_utils.py
index 5d221bc1e8..efe5dff476 100644
--- a/src/diffusers/pipelines/visualcloze/visualcloze_utils.py
+++ b/src/diffusers/pipelines/visualcloze/visualcloze_utils.py
@@ -110,7 +110,7 @@ class VisualClozeProcessor(VaeImageProcessor):
new_h = int(processed_images[i][j].height * (new_w / processed_images[i][j].width))
new_w = int(new_w / 16) * 16
new_h = int(new_h / 16) * 16
- processed_images[i][j] = self.height(processed_images[i][j], new_h, new_w)
+ processed_images[i][j] = self._resize_and_crop(processed_images[i][j], new_h, new_w)
# Convert to tensors and normalize
image_sizes = []
diff --git a/src/diffusers/pipelines/wan/pipeline_wan_vace.py b/src/diffusers/pipelines/wan/pipeline_wan_vace.py
index e5f83dd401..2b1890afec 100644
--- a/src/diffusers/pipelines/wan/pipeline_wan_vace.py
+++ b/src/diffusers/pipelines/wan/pipeline_wan_vace.py
@@ -152,16 +152,26 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
- transformer ([`WanTransformer3DModel`]):
+ transformer ([`WanVACETransformer3DModel`]):
Conditional Transformer to denoise the input latents.
+ transformer_2 ([`WanVACETransformer3DModel`], *optional*):
+ Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
+ `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
+ `transformer` is used.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ boundary_ratio (`float`, *optional*, defaults to `None`):
+ Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
+ The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
+ `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
+ boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ _optional_components = ["transformer_2"]
def __init__(
self,
@@ -170,6 +180,8 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
transformer: WanVACETransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
+ transformer_2: WanVACETransformer3DModel = None,
+ boundary_ratio: Optional[float] = None,
):
super().__init__()
@@ -178,9 +190,10 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
+ transformer_2=transformer_2,
scheduler=scheduler,
)
-
+ self.register_to_config(boundary_ratio=boundary_ratio)
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@@ -321,6 +334,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
video=None,
mask=None,
reference_images=None,
+ guidance_scale_2=None,
):
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
if height % base != 0 or width % base != 0:
@@ -332,6 +346,8 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
+ if self.config.boundary_ratio is None and guidance_scale_2 is not None:
+ raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
if prompt is not None and prompt_embeds is not None:
raise ValueError(
@@ -525,8 +541,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latents = retrieve_latents(self.vae.encode(video), generator, sample_mode="argmax").unbind(0)
latents = ((latents.float() - latents_mean) * latents_std).to(vae_dtype)
else:
- mask = mask.to(dtype=vae_dtype)
- mask = torch.where(mask > 0.5, 1.0, 0.0)
+ mask = torch.where(mask > 0.5, 1.0, 0.0).to(dtype=vae_dtype)
inactive = video * (1 - mask)
reactive = video * mask
inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax")
@@ -668,6 +683,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
+ guidance_scale_2: Optional[float] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -729,6 +745,10 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
+ guidance_scale_2 (`float`, *optional*, defaults to `None`):
+ Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
+ `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
+ and the pipeline's `boundary_ratio` are not None.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -775,7 +795,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# Simplification of implementation for now
- if not isinstance(prompt, str):
+ if prompt is not None and not isinstance(prompt, str):
raise ValueError("Passing a list of prompts is not yet supported. This may be supported in the future.")
if num_videos_per_prompt != 1:
raise ValueError(
@@ -794,6 +814,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
video,
mask,
reference_images,
+ guidance_scale_2,
)
if num_frames % self.vae_scale_factor_temporal != 1:
@@ -803,7 +824,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
+ if self.config.boundary_ratio is not None and guidance_scale_2 is None:
+ guidance_scale_2 = guidance_scale
+
self._guidance_scale = guidance_scale
+ self._guidance_scale_2 = guidance_scale_2
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
@@ -897,36 +922,53 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
+ if self.config.boundary_ratio is not None:
+ boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
+ else:
+ boundary_timestep = None
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
+
+ if boundary_timestep is None or t >= boundary_timestep:
+ # wan2.1 or high-noise stage in wan2.2
+ current_model = self.transformer
+ current_guidance_scale = guidance_scale
+ else:
+ # low-noise stage in wan2.2
+ current_model = self.transformer_2
+ current_guidance_scale = guidance_scale_2
+
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- timestep=timestep,
- encoder_hidden_states=prompt_embeds,
- control_hidden_states=conditioning_latents,
- control_hidden_states_scale=conditioning_scale,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
-
- if self.do_classifier_free_guidance:
- noise_uncond = self.transformer(
+ with current_model.cache_context("cond"):
+ noise_pred = current_model(
hidden_states=latent_model_input,
timestep=timestep,
- encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
control_hidden_states=conditioning_latents,
control_hidden_states_scale=conditioning_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
- noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ if self.do_classifier_free_guidance:
+ with current_model.cache_context("uncond"):
+ noise_uncond = current_model(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ control_hidden_states=conditioning_latents,
+ control_hidden_states_scale=conditioning_scale,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py
index 1a2d2e9c22..a976126da7 100644
--- a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py
+++ b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py
@@ -49,7 +49,7 @@ EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
- >>> from diffusers.utils import export_to_video
+ >>> from diffusers.utils import export_to_video, load_video
>>> from diffusers import AutoencoderKLWan, WanVideoToVideoPipeline
>>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
index b9b02a6dd3..bbdb60471f 100644
--- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
+++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
@@ -263,7 +263,7 @@ class WuerstchenDecoderPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
index 00a88ce34e..c54c1fefe8 100644
--- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
+++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
@@ -222,7 +222,7 @@ class WuerstchenCombinedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
index a32f09204d..e138b6e805 100644
--- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
+++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
@@ -348,7 +348,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin)
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
+ tensor will be generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py
index ce214ae7bc..070bcd0b21 100644
--- a/src/diffusers/quantizers/auto.py
+++ b/src/diffusers/quantizers/auto.py
@@ -21,9 +21,11 @@ from typing import Dict, Optional, Union
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
from .gguf import GGUFQuantizer
+from .modelopt import NVIDIAModelOptQuantizer
from .quantization_config import (
BitsAndBytesConfig,
GGUFQuantizationConfig,
+ NVIDIAModelOptConfig,
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
@@ -39,6 +41,7 @@ AUTO_QUANTIZER_MAPPING = {
"gguf": GGUFQuantizer,
"quanto": QuantoQuantizer,
"torchao": TorchAoHfQuantizer,
+ "modelopt": NVIDIAModelOptQuantizer,
}
AUTO_QUANTIZATION_CONFIG_MAPPING = {
@@ -47,6 +50,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"gguf": GGUFQuantizationConfig,
"quanto": QuantoConfig,
"torchao": TorchAoConfig,
+ "modelopt": NVIDIAModelOptConfig,
}
@@ -137,6 +141,9 @@ class DiffusersAutoQuantizer:
if isinstance(quantization_config, dict):
quantization_config = cls.from_dict(quantization_config)
+ if isinstance(quantization_config, NVIDIAModelOptConfig):
+ quantization_config.check_model_patching()
+
if warning_msg != "":
warnings.warn(warning_msg)
diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py
index 357d920d29..24fc724b4c 100644
--- a/src/diffusers/quantizers/base.py
+++ b/src/diffusers/quantizers/base.py
@@ -209,6 +209,17 @@ class DiffusersQuantizer(ABC):
return model
+ def get_cuda_warm_up_factor(self):
+ """
+ The factor to be used in `caching_allocator_warmup` to get the number of bytes to pre-allocate to warm up cuda.
+ A factor of 2 means we allocate all bytes in the empty model (since we allocate in fp16), a factor of 4 means
+ we allocate half the memory of the weights residing in the empty model, etc...
+ """
+ # By default we return 4, i.e. half the model size (this corresponds to the case where the model is not
+ # really pre-processed, i.e. we do not have the info that weights are going to be 8 bits before actual
+ # weight loading)
+ return 4
+
def _dequantize(self, model):
raise NotImplementedError(
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py
index 3dd00b2ce3..2fba9986e8 100644
--- a/src/diffusers/quantizers/gguf/utils.py
+++ b/src/diffusers/quantizers/gguf/utils.py
@@ -429,8 +429,64 @@ def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None):
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
+# this part from calcuis (gguf.org)
+# more info: https://github.com/calcuis/gguf-connector/blob/main/src/gguf_connector/quant2c.py
+
+
+def dequantize_blocks_IQ4_NL(blocks, block_size, type_size, dtype=None):
+ kvalues = torch.tensor(
+ [-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113],
+ dtype=torch.float32,
+ device=blocks.device,
+ )
+ n_blocks = blocks.shape[0]
+ d, qs = split_block_dims(blocks, 2)
+ d = d.view(torch.float16).to(dtype)
+ qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
+ [0, 4], device=blocks.device, dtype=torch.uint8
+ ).reshape((1, 1, 2, 1))
+ qs = (qs & 15).reshape((n_blocks, -1)).to(torch.int64)
+ kvalues = kvalues.view(1, 1, 16)
+ qs = qs.unsqueeze(-1)
+ qs = torch.gather(kvalues.expand(qs.shape[0], qs.shape[1], 16), 2, qs)
+ qs = qs.squeeze(-1).to(dtype)
+ return d * qs
+
+
+def dequantize_blocks_IQ4_XS(blocks, block_size, type_size, dtype=None):
+ kvalues = torch.tensor(
+ [-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113],
+ dtype=torch.float32,
+ device=blocks.device,
+ )
+ n_blocks = blocks.shape[0]
+ d, scales_h, scales_l, qs = split_block_dims(blocks, 2, 2, QK_K // 64)
+ d = d.view(torch.float16).to(dtype)
+ scales_h = scales_h.view(torch.int16)
+ scales_l = scales_l.reshape((n_blocks, -1, 1)) >> torch.tensor(
+ [0, 4], device=blocks.device, dtype=torch.uint8
+ ).reshape((1, 1, 2))
+ scales_h = scales_h.reshape((n_blocks, 1, -1)) >> torch.tensor(
+ [2 * i for i in range(QK_K // 32)], device=blocks.device, dtype=torch.uint8
+ ).reshape((1, -1, 1))
+ scales_l = scales_l.reshape((n_blocks, -1)) & 0x0F
+ scales_h = scales_h.reshape((n_blocks, -1)) & 0x03
+ scales = (scales_l | (scales_h << 4)) - 32
+ dl = (d * scales.to(dtype)).reshape((n_blocks, -1, 1))
+ shifts_q = torch.tensor([0, 4], device=blocks.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
+ qs = qs.reshape((n_blocks, -1, 1, 16)) >> shifts_q
+ qs = (qs & 15).reshape((n_blocks, -1, 32)).to(torch.int64)
+ kvalues = kvalues.view(1, 1, 1, 16)
+ qs = qs.unsqueeze(-1)
+ qs = torch.gather(kvalues.expand(qs.shape[0], qs.shape[1], qs.shape[2], 16), 3, qs)
+ qs = qs.squeeze(-1).to(dtype)
+ return (dl * qs).reshape(n_blocks, -1)
+
+
GGML_QUANT_SIZES = gguf.GGML_QUANT_SIZES
dequantize_functions = {
+ gguf.GGMLQuantizationType.IQ4_NL: dequantize_blocks_IQ4_NL,
+ gguf.GGMLQuantizationType.IQ4_XS: dequantize_blocks_IQ4_XS,
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1,
diff --git a/src/diffusers/quantizers/modelopt/__init__.py b/src/diffusers/quantizers/modelopt/__init__.py
new file mode 100644
index 0000000000..ae0951cb30
--- /dev/null
+++ b/src/diffusers/quantizers/modelopt/__init__.py
@@ -0,0 +1 @@
+from .modelopt_quantizer import NVIDIAModelOptQuantizer
diff --git a/src/diffusers/quantizers/modelopt/modelopt_quantizer.py b/src/diffusers/quantizers/modelopt/modelopt_quantizer.py
new file mode 100644
index 0000000000..534f752321
--- /dev/null
+++ b/src/diffusers/quantizers/modelopt/modelopt_quantizer.py
@@ -0,0 +1,190 @@
+from typing import TYPE_CHECKING, Any, Dict, List, Union
+
+from ...utils import (
+ get_module_from_name,
+ is_accelerate_available,
+ is_nvidia_modelopt_available,
+ is_torch_available,
+ logging,
+)
+from ..base import DiffusersQuantizer
+
+
+if TYPE_CHECKING:
+ from ...models.modeling_utils import ModelMixin
+
+
+if is_torch_available():
+ import torch
+ import torch.nn as nn
+
+if is_accelerate_available():
+ from accelerate.utils import set_module_tensor_to_device
+
+
+logger = logging.get_logger(__name__)
+
+
+class NVIDIAModelOptQuantizer(DiffusersQuantizer):
+ r"""
+ Diffusers Quantizer for TensorRT Model Optimizer
+ """
+
+ use_keep_in_fp32_modules = True
+ requires_calibration = False
+ required_packages = ["nvidia_modelopt"]
+
+ def __init__(self, quantization_config, **kwargs):
+ super().__init__(quantization_config, **kwargs)
+
+ def validate_environment(self, *args, **kwargs):
+ if not is_nvidia_modelopt_available():
+ raise ImportError(
+ "Loading an nvidia-modelopt quantized model requires nvidia-modelopt library (`pip install nvidia-modelopt`)"
+ )
+
+ self.offload = False
+
+ device_map = kwargs.get("device_map", None)
+ if isinstance(device_map, dict):
+ if "cpu" in device_map.values() or "disk" in device_map.values():
+ if self.pre_quantized:
+ raise ValueError(
+ "You are attempting to perform cpu/disk offload with a pre-quantized modelopt model "
+ "This is not supported yet. Please remove the CPU or disk device from the `device_map` argument."
+ )
+ else:
+ self.offload = True
+
+ def check_if_quantized_param(
+ self,
+ model: "ModelMixin",
+ param_value: "torch.Tensor",
+ param_name: str,
+ state_dict: Dict[str, Any],
+ **kwargs,
+ ):
+ # ModelOpt imports diffusers internally. This is here to prevent circular imports
+ from modelopt.torch.quantization.utils import is_quantized
+
+ module, tensor_name = get_module_from_name(model, param_name)
+ if self.pre_quantized:
+ return True
+ elif is_quantized(module) and "weight" in tensor_name:
+ return True
+ return False
+
+ def create_quantized_param(
+ self,
+ model: "ModelMixin",
+ param_value: "torch.Tensor",
+ param_name: str,
+ target_device: "torch.device",
+ *args,
+ **kwargs,
+ ):
+ """
+ Create the quantized parameter by calling .calibrate() after setting it to the module.
+ """
+ # ModelOpt imports diffusers internally. This is here to prevent circular imports
+ import modelopt.torch.quantization as mtq
+
+ dtype = kwargs.get("dtype", torch.float32)
+ module, tensor_name = get_module_from_name(model, param_name)
+ if self.pre_quantized:
+ module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
+ else:
+ set_module_tensor_to_device(model, param_name, target_device, param_value, dtype)
+ mtq.calibrate(
+ module, self.quantization_config.modelopt_config["algorithm"], self.quantization_config.forward_loop
+ )
+ mtq.compress(module)
+ module.weight.requires_grad = False
+
+ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
+ max_memory = {key: val * 0.90 for key, val in max_memory.items()}
+ return max_memory
+
+ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
+ if self.quantization_config.quant_type == "FP8":
+ target_dtype = torch.float8_e4m3fn
+ return target_dtype
+
+ def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype":
+ if torch_dtype is None:
+ logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.")
+ torch_dtype = torch.float32
+ return torch_dtype
+
+ def get_conv_param_names(self, model: "ModelMixin") -> List[str]:
+ """
+ Get parameter names for all convolutional layers in a HuggingFace ModelMixin. Includes Conv1d/2d/3d and
+ ConvTranspose1d/2d/3d.
+ """
+ conv_types = (
+ nn.Conv1d,
+ nn.Conv2d,
+ nn.Conv3d,
+ nn.ConvTranspose1d,
+ nn.ConvTranspose2d,
+ nn.ConvTranspose3d,
+ )
+
+ conv_param_names = []
+ for name, module in model.named_modules():
+ if isinstance(module, conv_types):
+ for param_name, _ in module.named_parameters(recurse=False):
+ conv_param_names.append(f"{name}.{param_name}")
+
+ return conv_param_names
+
+ def _process_model_before_weight_loading(
+ self,
+ model: "ModelMixin",
+ device_map,
+ keep_in_fp32_modules: List[str] = [],
+ **kwargs,
+ ):
+ # ModelOpt imports diffusers internally. This is here to prevent circular imports
+ import modelopt.torch.opt as mto
+
+ if self.pre_quantized:
+ return
+
+ modules_to_not_convert = self.quantization_config.modules_to_not_convert
+
+ if modules_to_not_convert is None:
+ modules_to_not_convert = []
+ if isinstance(modules_to_not_convert, str):
+ modules_to_not_convert = [modules_to_not_convert]
+ modules_to_not_convert.extend(keep_in_fp32_modules)
+ if self.quantization_config.disable_conv_quantization:
+ modules_to_not_convert.extend(self.get_conv_param_names(model))
+
+ for module in modules_to_not_convert:
+ self.quantization_config.modelopt_config["quant_cfg"]["*" + module + "*"] = {"enable": False}
+ self.quantization_config.modules_to_not_convert = modules_to_not_convert
+ mto.apply_mode(model, mode=[("quantize", self.quantization_config.modelopt_config)])
+ model.config.quantization_config = self.quantization_config
+
+ def _process_model_after_weight_loading(self, model, **kwargs):
+ # ModelOpt imports diffusers internally. This is here to prevent circular imports
+ from modelopt.torch.opt import ModeloptStateManager
+
+ if self.pre_quantized:
+ return model
+
+ for _, m in model.named_modules():
+ if hasattr(m, ModeloptStateManager._state_key) and m is not model:
+ ModeloptStateManager.remove_state(m)
+
+ return model
+
+ @property
+ def is_trainable(self):
+ return True
+
+ @property
+ def is_serializable(self):
+ self.quantization_config.check_model_patching(operation="saving")
+ return True
diff --git a/src/diffusers/quantizers/pipe_quant_config.py b/src/diffusers/quantizers/pipe_quant_config.py
index 5d02de16fd..f75a337341 100644
--- a/src/diffusers/quantizers/pipe_quant_config.py
+++ b/src/diffusers/quantizers/pipe_quant_config.py
@@ -48,12 +48,15 @@ class PipelineQuantizationConfig:
self,
quant_backend: str = None,
quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
- components_to_quantize: Optional[List[str]] = None,
+ components_to_quantize: Optional[Union[List[str], str]] = None,
quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
):
self.quant_backend = quant_backend
# Initialize kwargs to be {} to set to the defaults.
self.quant_kwargs = quant_kwargs or {}
+ if components_to_quantize:
+ if isinstance(components_to_quantize, str):
+ components_to_quantize = [components_to_quantize]
self.components_to_quantize = components_to_quantize
self.quant_mapping = quant_mapping
self.config_mapping = {} # book-keeping Example: `{module_name: quant_config}`
diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py
index 871faf076e..5dd8f56717 100644
--- a/src/diffusers/quantizers/quantization_config.py
+++ b/src/diffusers/quantizers/quantization_config.py
@@ -21,18 +21,20 @@ https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e
"""
import copy
+import dataclasses
import importlib.metadata
import inspect
import json
import os
-from dataclasses import dataclass
+import warnings
+from dataclasses import dataclass, is_dataclass
from enum import Enum
from functools import partial
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Union
from packaging import version
-from ..utils import is_torch_available, is_torchao_available, logging
+from ..utils import is_torch_available, is_torchao_available, is_torchao_version, logging
if is_torch_available():
@@ -46,6 +48,7 @@ class QuantizationMethod(str, Enum):
GGUF = "gguf"
TORCHAO = "torchao"
QUANTO = "quanto"
+ MODELOPT = "modelopt"
if is_torchao_available():
@@ -268,7 +271,14 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
if bnb_4bit_quant_storage is None:
self.bnb_4bit_quant_storage = torch.uint8
elif isinstance(bnb_4bit_quant_storage, str):
- if bnb_4bit_quant_storage not in ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]:
+ if bnb_4bit_quant_storage not in [
+ "float16",
+ "float32",
+ "int8",
+ "uint8",
+ "float64",
+ "bfloat16",
+ ]:
raise ValueError(
"`bnb_4bit_quant_storage` must be a valid string (one of 'float16', 'float32', 'int8', 'uint8', 'float64', 'bfloat16') "
)
@@ -434,7 +444,7 @@ class TorchAoConfig(QuantizationConfigMixin):
"""This is a config class for torchao quantization/sparsity techniques.
Args:
- quant_type (`str`):
+ quant_type (Union[`str`, AOBaseConfig]):
The type of quantization we want to use, currently supporting:
- **Integer quantization:**
- Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`,
@@ -456,6 +466,7 @@ class TorchAoConfig(QuantizationConfigMixin):
- **Unsigned Integer quantization:**
- Full function names: `uintx_weight_only`
- Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo`
+ - An AOBaseConfig instance: for more advanced configuration options.
modules_to_not_convert (`List[str]`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
modules left in their original precision.
@@ -469,6 +480,12 @@ class TorchAoConfig(QuantizationConfigMixin):
```python
from diffusers import FluxTransformer2DModel, TorchAoConfig
+ # AOBaseConfig-based configuration
+ from torchao.quantization import Int8WeightOnlyConfig
+
+ quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
+
+ # String-based config
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
@@ -479,7 +496,12 @@ class TorchAoConfig(QuantizationConfigMixin):
```
"""
- def __init__(self, quant_type: str, modules_to_not_convert: Optional[List[str]] = None, **kwargs) -> None:
+ def __init__(
+ self,
+ quant_type: Union[str, "AOBaseConfig"], # noqa: F821
+ modules_to_not_convert: Optional[List[str]] = None,
+ **kwargs,
+ ) -> None:
self.quant_method = QuantizationMethod.TORCHAO
self.quant_type = quant_type
self.modules_to_not_convert = modules_to_not_convert
@@ -490,34 +512,103 @@ class TorchAoConfig(QuantizationConfigMixin):
else:
self.quant_type_kwargs = kwargs
- TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
- if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
- is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
- if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
+ self.post_init()
+
+ def post_init(self):
+ if not isinstance(self.quant_type, str):
+ if is_torchao_version("<=", "0.9.0"):
raise ValueError(
- f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
- f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
+ f"torchao <= 0.9.0 only supports string quant_type, got {type(self.quant_type).__name__}. "
+ f"Upgrade to torchao > 0.9.0 to use AOBaseConfig."
)
- raise ValueError(
- f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
- f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
- )
+ from torchao.quantization.quant_api import AOBaseConfig
- method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
- signature = inspect.signature(method)
- all_kwargs = {
- param.name
- for param in signature.parameters.values()
- if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
- }
- unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
+ if not isinstance(self.quant_type, AOBaseConfig):
+ raise TypeError(f"quant_type must be a AOBaseConfig instance, got {type(self.quant_type).__name__}")
- if len(unsupported_kwargs) > 0:
- raise ValueError(
- f'The quantization method "{quant_type}" does not support the following keyword arguments: '
- f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
- )
+ elif isinstance(self.quant_type, str):
+ TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
+
+ if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
+ is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
+ if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
+ raise ValueError(
+ f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
+ f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
+ )
+
+ raise ValueError(
+ f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
+ f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
+ )
+
+ method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
+ signature = inspect.signature(method)
+ all_kwargs = {
+ param.name
+ for param in signature.parameters.values()
+ if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
+ }
+ unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
+
+ if len(unsupported_kwargs) > 0:
+ raise ValueError(
+ f'The quantization method "{self.quant_type}" does not support the following keyword arguments: '
+ f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
+ )
+
+ def to_dict(self):
+ """Convert configuration to a dictionary."""
+ d = super().to_dict()
+
+ if isinstance(self.quant_type, str):
+ # Handle layout serialization if present
+ if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
+ if is_dataclass(d["quant_type_kwargs"]["layout"]):
+ d["quant_type_kwargs"]["layout"] = [
+ d["quant_type_kwargs"]["layout"].__class__.__name__,
+ dataclasses.asdict(d["quant_type_kwargs"]["layout"]),
+ ]
+ if isinstance(d["quant_type_kwargs"]["layout"], list):
+ assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layout kwargs"
+ assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string"
+ assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict"
+ else:
+ raise ValueError("layout must be a list")
+ else:
+ # Handle AOBaseConfig serialization
+ from torchao.core.config import config_to_dict
+
+ # For now we assume there is 1 config per Transformer, however in the future
+ # We may want to support a config per fqn.
+ d["quant_type"] = {"default": config_to_dict(self.quant_type)}
+
+ return d
+
+ @classmethod
+ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
+ """Create configuration from a dictionary."""
+ if not is_torchao_version(">", "0.9.0"):
+ raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict")
+ config_dict = config_dict.copy()
+ quant_type = config_dict.pop("quant_type")
+
+ if isinstance(quant_type, str):
+ return cls(quant_type=quant_type, **config_dict)
+ # Check if we only have one key which is "default"
+ # In the future we may update this
+ assert len(quant_type) == 1 and "default" in quant_type, (
+ "Expected only one key 'default' in quant_type dictionary"
+ )
+ quant_type = quant_type["default"]
+
+ # Deserialize quant_type if needed
+ from torchao.core.config import config_from_dict
+
+ quant_type = config_from_dict(quant_type)
+
+ return cls(quant_type=quant_type, **config_dict)
@classmethod
def _get_torchao_quant_type_to_method(cls):
@@ -667,8 +758,38 @@ class TorchAoConfig(QuantizationConfigMixin):
raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.")
def get_apply_tensor_subclass(self):
- TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
- return TORCHAO_QUANT_TYPE_METHODS[self.quant_type](**self.quant_type_kwargs)
+ """Create the appropriate quantization method based on configuration."""
+ if not isinstance(self.quant_type, str):
+ return self.quant_type
+ else:
+ methods = self._get_torchao_quant_type_to_method()
+ quant_type_kwargs = self.quant_type_kwargs.copy()
+ if (
+ not torch.cuda.is_available()
+ and is_torchao_available()
+ and self.quant_type == "int4_weight_only"
+ and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
+ and quant_type_kwargs.get("layout", None) is None
+ ):
+ if torch.xpu.is_available():
+ if version.parse(importlib.metadata.version("torchao")) >= version.parse(
+ "0.11.0"
+ ) and version.parse(importlib.metadata.version("torch")) > version.parse("2.7.9"):
+ from torchao.dtypes import Int4XPULayout
+ from torchao.quantization.quant_primitives import ZeroPointDomain
+
+ quant_type_kwargs["layout"] = Int4XPULayout()
+ quant_type_kwargs["zero_point_domain"] = ZeroPointDomain.INT
+ else:
+ raise ValueError(
+ "TorchAoConfig requires torchao >= 0.11.0 and torch >= 2.8.0 for XPU support. Please upgrade the version or use run on CPU with the cpu version pytorch."
+ )
+ else:
+ from torchao.dtypes import Int4CPULayout
+
+ quant_type_kwargs["layout"] = Int4CPULayout()
+
+ return methods[self.quant_type](**quant_type_kwargs)
def __repr__(self):
r"""
@@ -724,3 +845,194 @@ class QuantoConfig(QuantizationConfigMixin):
accepted_weights = ["float8", "int8", "int4", "int2"]
if self.weights_dtype not in accepted_weights:
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
+
+
+@dataclass
+class NVIDIAModelOptConfig(QuantizationConfigMixin):
+ """This is a config class to use nvidia modelopt for quantization.
+
+ Args:
+ quant_type (`str`):
+ The type of quantization we want to use, following is how to use:
+ **weightquant_activationquant ==> FP8_FP8** In the above example we have use FP8 for both weight and
+ activation quantization. Following are the all the options:
+ - FP8
+ - INT8
+ - INT4
+ - NF4
+ - NVFP4
+ modules_to_not_convert (`List[str]`, *optional*, default to `None`):
+ The list of modules to not quantize, useful for quantizing models that explicitly require to have some
+ weight_only (`bool`, *optional*, default to `False`):
+ If set to `True`, the quantization will be applied only to the weights of the model.
+ channel_quantize (`int`, *optional*, default to `None`):
+ The channel quantization axis, useful for quantizing models across different axes.
+ block_quantize (`int`, *optional*, default to `None`):
+ The block size, useful to further quantize each channel/axes into blocks.
+ scale_channel_quantize (`int`, *optional*, default to `None`):
+ The scale channel quantization axis, useful for quantizing calculated scale across different axes.
+ scale_block_quantize (`int`, *optional*, default to `None`):
+ The scale block size, useful for quantizing each scale channel/axes into blocks.
+ algorithm (`str`, *optional*, default to `"max"`):
+ The algorithm to use for quantization, currently only supports `"max"`.
+ forward_loop (`Callable`, *optional*, default to `None`):
+ The forward loop function to use for calibration during quantization.
+ modelopt_config (`dict`, *optional*, default to `None`):
+ The modelopt config, useful for passing custom configs to modelopt.
+ disable_conv_quantization (`bool`, *optional*, default to `False`):
+ If set to `True`, the quantization will be disabled for convolutional layers.
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional parameters which are to be used for calibration.
+ """
+
+ quanttype_to_numbits = {
+ "FP8": (4, 3),
+ "INT8": 8,
+ "INT4": 4,
+ "NF4": 4,
+ "NVFP4": (2, 1),
+ }
+ quanttype_to_scalingbits = {
+ "NF4": 8,
+ "NVFP4": (4, 3),
+ }
+
+ def __init__(
+ self,
+ quant_type: str,
+ modules_to_not_convert: Optional[List[str]] = None,
+ weight_only: bool = True,
+ channel_quantize: Optional[int] = None,
+ block_quantize: Optional[int] = None,
+ scale_channel_quantize: Optional[int] = None,
+ scale_block_quantize: Optional[int] = None,
+ algorithm: str = "max",
+ forward_loop: Optional[Callable] = None,
+ modelopt_config: Optional[dict] = None,
+ disable_conv_quantization: bool = False,
+ **kwargs,
+ ) -> None:
+ self.quant_method = QuantizationMethod.MODELOPT
+ self._normalize_quant_type(quant_type)
+ self.modules_to_not_convert = modules_to_not_convert
+ self.weight_only = weight_only
+ self.channel_quantize = channel_quantize
+ self.block_quantize = block_quantize
+ self.calib_cfg = {
+ "method": algorithm,
+ # add more options here if needed
+ }
+ self.forward_loop = forward_loop
+ self.scale_channel_quantize = scale_channel_quantize
+ self.scale_block_quantize = scale_block_quantize
+ self.modelopt_config = self.get_config_from_quant_type() if not modelopt_config else modelopt_config
+ self.disable_conv_quantization = disable_conv_quantization
+
+ def check_model_patching(self, operation: str = "loading"):
+ # ModelOpt imports diffusers internally. This is here to prevent circular imports
+ from modelopt.torch.opt.plugins.huggingface import _PATCHED_CLASSES
+
+ if len(_PATCHED_CLASSES) == 0:
+ warning_msg = (
+ f"Not {operation} weights in modelopt format. This might cause unreliable behavior."
+ "Please make sure to run the following code before loading/saving model weights:\n\n"
+ " from modelopt.torch.opt import enable_huggingface_checkpointing\n"
+ " enable_huggingface_checkpointing()\n"
+ )
+ warnings.warn(warning_msg)
+
+ def _normalize_quant_type(self, quant_type: str) -> str:
+ """
+ Validates and normalizes the quantization type string.
+
+ Splits the quant_type into weight and activation components, verifies them against supported types, and
+ replaces unsupported values with safe defaults.
+
+ Args:
+ quant_type (str): The input quantization type string (e.g., 'FP8_INT8').
+
+ Returns:
+ str: A valid quantization type string (e.g., 'FP8_INT8' or 'FP8').
+ """
+ parts = quant_type.split("_")
+ w_type = parts[0]
+ act_type = parts[1] if len(parts) > 1 else None
+ if len(parts) > 2:
+ logger.warning(f"Quantization type {quant_type} is not supported. Picking FP8_INT8 as default")
+ w_type = "FP8"
+ act_type = None
+ else:
+ if w_type not in NVIDIAModelOptConfig.quanttype_to_numbits:
+ logger.warning(f"Weight Quantization type {w_type} is not supported. Picking FP8 as default")
+ w_type = "FP8"
+ if act_type is not None and act_type not in NVIDIAModelOptConfig.quanttype_to_numbits:
+ logger.warning(f"Activation Quantization type {act_type} is not supported. Picking INT8 as default")
+ act_type = None
+ self.quant_type = w_type + ("_" + act_type if act_type is not None else "")
+
+ def get_config_from_quant_type(self) -> Dict[str, Any]:
+ """
+ Get the config from the quantization type.
+ """
+ import modelopt.torch.quantization as mtq
+
+ BASE_CONFIG = {
+ "quant_cfg": {
+ "*weight_quantizer": {"fake_quant": False},
+ "*input_quantizer": {},
+ "*output_quantizer": {"enable": False},
+ "*q_bmm_quantizer": {},
+ "*k_bmm_quantizer": {},
+ "*v_bmm_quantizer": {},
+ "*softmax_quantizer": {},
+ **mtq.config._default_disabled_quantizer_cfg,
+ },
+ "algorithm": self.calib_cfg,
+ }
+
+ quant_cfg = BASE_CONFIG["quant_cfg"]
+ if self.weight_only:
+ for k in quant_cfg:
+ if "*weight_quantizer" not in k and not quant_cfg[k]:
+ quant_cfg[k]["enable"] = False
+
+ parts = self.quant_type.split("_")
+ w_type = parts[0]
+ act_type = parts[1].replace("A", "") if len(parts) > 1 else None
+ for k in quant_cfg:
+ if k not in mtq.config._default_disabled_quantizer_cfg and "enable" not in quant_cfg[k]:
+ if k == "*input_quantizer":
+ if act_type is not None:
+ quant_cfg[k]["num_bits"] = NVIDIAModelOptConfig.quanttype_to_numbits[act_type]
+ continue
+ quant_cfg[k]["num_bits"] = NVIDIAModelOptConfig.quanttype_to_numbits[w_type]
+
+ if self.block_quantize is not None and self.channel_quantize is not None:
+ quant_cfg["*weight_quantizer"]["block_sizes"] = {self.channel_quantize: self.block_quantize}
+ quant_cfg["*input_quantizer"]["block_sizes"] = {
+ self.channel_quantize: self.block_quantize,
+ "type": "dynamic",
+ }
+ elif self.channel_quantize is not None:
+ quant_cfg["*weight_quantizer"]["axis"] = self.channel_quantize
+ quant_cfg["*input_quantizer"]["axis"] = self.channel_quantize
+ quant_cfg["*input_quantizer"]["type"] = "dynamic"
+
+ # Only fixed scaling sizes are supported for now in modelopt
+ if self.scale_channel_quantize is not None and self.scale_block_quantize is not None:
+ if w_type in NVIDIAModelOptConfig.quanttype_to_scalingbits:
+ quant_cfg["*weight_quantizer"]["block_sizes"].update(
+ {
+ "scale_bits": NVIDIAModelOptConfig.quanttype_to_scalingbits[w_type],
+ "scale_block_sizes": {self.scale_channel_quantize: self.scale_block_quantize},
+ }
+ )
+ if act_type and act_type in NVIDIAModelOptConfig.quanttype_to_scalingbits:
+ quant_cfg["*input_quantizer"]["block_sizes"].update(
+ {
+ "scale_bits": NVIDIAModelOptConfig.quanttype_to_scalingbits[act_type],
+ "scale_block_sizes": {self.scale_channel_quantize: self.scale_block_quantize},
+ }
+ )
+
+ return BASE_CONFIG
diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py
index c12513f061..2334c7af86 100644
--- a/src/diffusers/quantizers/torchao/torchao_quantizer.py
+++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py
@@ -18,8 +18,10 @@ https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac17
"""
import importlib
+import re
import types
-from typing import TYPE_CHECKING, Any, Dict, List, Union
+from fnmatch import fnmatch
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from packaging import version
@@ -106,6 +108,21 @@ if (
_update_torch_safe_globals()
+def fuzzy_match_size(config_name: str) -> Optional[str]:
+ """
+ Extract the size digit from strings like "4weight", "8weight". Returns the digit as an integer if found, otherwise
+ None.
+ """
+ config_name = config_name.lower()
+
+ str_match = re.search(r"(\d)weight", config_name)
+
+ if str_match:
+ return str_match.group(1)
+
+ return None
+
+
logger = logging.get_logger(__name__)
@@ -175,8 +192,7 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
def update_torch_dtype(self, torch_dtype):
quant_type = self.quantization_config.quant_type
-
- if quant_type.startswith("int") or quant_type.startswith("uint"):
+ if isinstance(quant_type, str) and (quant_type.startswith("int") or quant_type.startswith("uint")):
if torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
@@ -196,24 +212,44 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
quant_type = self.quantization_config.quant_type
+ from accelerate.utils import CustomDtype
- if quant_type.startswith("int8") or quant_type.startswith("int4"):
- # Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
- return torch.int8
- elif quant_type == "uintx_weight_only":
- return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
- elif quant_type.startswith("uint"):
- return {
- 1: torch.uint1,
- 2: torch.uint2,
- 3: torch.uint3,
- 4: torch.uint4,
- 5: torch.uint5,
- 6: torch.uint6,
- 7: torch.uint7,
- }[int(quant_type[4])]
- elif quant_type.startswith("float") or quant_type.startswith("fp"):
- return torch.bfloat16
+ if isinstance(quant_type, str):
+ if quant_type.startswith("int8"):
+ # Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
+ return torch.int8
+ elif quant_type.startswith("int4"):
+ return CustomDtype.INT4
+ elif quant_type == "uintx_weight_only":
+ return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
+ elif quant_type.startswith("uint"):
+ return {
+ 1: torch.uint1,
+ 2: torch.uint2,
+ 3: torch.uint3,
+ 4: torch.uint4,
+ 5: torch.uint5,
+ 6: torch.uint6,
+ 7: torch.uint7,
+ }[int(quant_type[4])]
+ elif quant_type.startswith("float") or quant_type.startswith("fp"):
+ return torch.bfloat16
+
+ elif is_torchao_version(">", "0.9.0"):
+ from torchao.core.config import AOBaseConfig
+
+ quant_type = self.quantization_config.quant_type
+ if isinstance(quant_type, AOBaseConfig):
+ # Extract size digit using fuzzy match on the class name
+ config_name = quant_type.__class__.__name__
+ size_digit = fuzzy_match_size(config_name)
+
+ # Map the extracted digit to appropriate dtype
+ if size_digit == "4":
+ return CustomDtype.INT4
+ else:
+ # Default to int8
+ return torch.int8
if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION):
return target_dtype
@@ -278,6 +314,46 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
+ def get_cuda_warm_up_factor(self):
+ """
+ This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup.
+ - A factor of 2 means we pre-allocate the full memory footprint of the model.
+ - A factor of 4 means we pre-allocate half of that, and so on
+
+ However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give
+ the correct size for quantized weights (like int4 or int8) That's because TorchAO internally represents
+ quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the
+ torch_dtype not the actual bit-width of the quantized data.
+
+ To correct for this:
+ - Use a division factor of 8 for int4 weights
+ - Use a division factor of 4 for int8 weights
+ """
+ # Original mapping for non-AOBaseConfig types
+ # For the uint types, this is a best guess. Once these types become more used
+ # we can look into their nuances.
+ if is_torchao_version(">", "0.9.0"):
+ from torchao.core.config import AOBaseConfig
+
+ quant_type = self.quantization_config.quant_type
+ # For autoquant case, it will be treated in the string implementation below in map_to_target_dtype
+ if isinstance(quant_type, AOBaseConfig):
+ # Extract size digit using fuzzy match on the class name
+ config_name = quant_type.__class__.__name__
+ size_digit = fuzzy_match_size(config_name)
+
+ if size_digit == "4":
+ return 8
+ else:
+ return 4
+
+ map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "uint*": 8, "float8*": 4}
+ quant_type = self.quantization_config.quant_type
+ for pattern, target_dtype in map_to_target_dtype.items():
+ if fnmatch(quant_type, pattern):
+ return target_dtype
+ raise ValueError(f"Unsupported quant_type: {quant_type!r}")
+
def _process_model_before_weight_loading(
self,
model: "ModelMixin",
diff --git a/src/diffusers/schedulers/deprecated/scheduling_karras_ve.py b/src/diffusers/schedulers/deprecated/scheduling_karras_ve.py
index 6b968e7081..9206ee80a6 100644
--- a/src/diffusers/schedulers/deprecated/scheduling_karras_ve.py
+++ b/src/diffusers/schedulers/deprecated/scheduling_karras_ve.py
@@ -53,13 +53,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
-
-
- For more details on the parameters, see [Appendix E](https://huggingface.co/papers/2206.00364). The grid search
- values used to find the optimal `{s_noise, s_churn, s_min, s_max}` for a specific model are described in Table 5 of
- the paper.
-
-
+ > [!TIP] > For more details on the parameters, see [Appendix E](https://huggingface.co/papers/2206.00364). The grid
+ search > values used to find the optimal `{s_noise, s_churn, s_min, s_max}` for a specific model are described in
+ Table 5 of > the paper.
Args:
sigma_min (`float`, defaults to 0.02):
diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py
index 0f50622588..5d81d5eb8a 100644
--- a/src/diffusers/schedulers/scheduling_consistency_models.py
+++ b/src/diffusers/schedulers/scheduling_consistency_models.py
@@ -268,11 +268,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
Gets the scalings used in the consistency model parameterization (from Appendix C of the
[paper](https://huggingface.co/papers/2303.01469)) to enforce boundary condition.
-
-
- `epsilon` in the equations for `c_skip` and `c_out` is set to `sigma_min`.
-
-
+ > [!TIP] > `epsilon` in the equations for `c_skip` and `c_out` is set to `sigma_min`.
Args:
sigma (`torch.Tensor`):
diff --git a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py
index 66ed296da8..b9567f2c47 100644
--- a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py
@@ -304,12 +304,8 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
-
-
- The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
- prediction and data prediction models.
-
-
+ > [!TIP] > The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both
+ noise > prediction and data prediction models.
Args:
model_output (`torch.Tensor`):
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
index d07ff8b200..8b523cd13f 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
@@ -630,12 +630,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
-
-
- The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
- prediction and data prediction models.
-
-
+ > [!TIP] > The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both
+ noise > prediction and data prediction models.
Args:
model_output (`torch.Tensor`):
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
index 9ec9588511..f1a1ac3d82 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
@@ -491,12 +491,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
-
-
- The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
- prediction and data prediction models.
-
-
+ > [!TIP] > The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both
+ noise > prediction and data prediction models.
Args:
model_output (`torch.Tensor`):
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
index 8663210a62..1ae8249730 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
@@ -568,12 +568,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
-
-
- The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
- prediction and data prediction models.
-
-
+ > [!TIP] > The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both
+ noise > prediction and data prediction models.
Args:
model_output (`torch.Tensor`):
diff --git a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
index f1b38aaff5..e9ba695e1f 100644
--- a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
@@ -370,12 +370,8 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
-
-
- The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
- prediction and data prediction models.
-
-
+ > [!TIP] > The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both
+ noise > prediction and data prediction models.
Args:
model_output (`torch.Tensor`):
diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py
index 2df7d560dd..2979ce193a 100644
--- a/src/diffusers/schedulers/scheduling_sasolver.py
+++ b/src/diffusers/schedulers/scheduling_sasolver.py
@@ -500,12 +500,8 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
Noise_prediction is designed to discretize an integral of the noise prediction model, and data_prediction is
designed to discretize an integral of the data prediction model.
-
-
- The algorithm and model type are decoupled. You can use either data_prediction or noise_prediction for both
- noise prediction and data prediction models.
-
-
+ > [!TIP] > The algorithm and model type are decoupled. You can use either data_prediction or noise_prediction
+ for both > noise prediction and data prediction models.
Args:
model_output (`torch.Tensor`):
diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py
index f0e162ea6b..a355c7bb1a 100644
--- a/src/diffusers/schedulers/scheduling_utils.py
+++ b/src/diffusers/schedulers/scheduling_utils.py
@@ -138,15 +138,11 @@ class SchedulerMixin(PushToHubMixin):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
-
-
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
- auth login`. You can also activate the special
- ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
+ > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
+ with `hf > auth login`. You can also activate the special >
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a >
firewalled environment.
-
-
"""
config, kwargs, commit_hash = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path,
diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py
index e6ac78f63e..0534e47d8a 100644
--- a/src/diffusers/schedulers/scheduling_utils_flax.py
+++ b/src/diffusers/schedulers/scheduling_utils_flax.py
@@ -22,9 +22,11 @@ import flax
import jax.numpy as jnp
from huggingface_hub.utils import validate_hf_hub_args
-from ..utils import BaseOutput, PushToHubMixin
+from ..utils import BaseOutput, PushToHubMixin, logging
+logger = logging.get_logger(__name__)
+
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
@@ -118,21 +120,18 @@ class FlaxSchedulerMixin(PushToHubMixin):
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
-
+ > [!TIP] > It is required to be logged in (`hf auth login`) when you want to use private or [gated >
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
- It is required to be logged in (`hf auth login`) when you want to use private or [gated
- models](https://huggingface.co/docs/hub/models-gated#gated-models).
-
-
-
-
-
- Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
- use this method in a firewalled environment.
-
-
+ > [!TIP] > Activate the special
+ ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to > use this method in a
+ firewalled environment.
"""
+ logger.warning(
+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
+ )
config, kwargs = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder,
diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py
index d33b80dba0..7a98fa3da1 100644
--- a/src/diffusers/training_utils.py
+++ b/src/diffusers/training_utils.py
@@ -339,7 +339,8 @@ def offload_models(
original_devices = [next(m.parameters()).device for m in modules]
else:
assert len(modules) == 1
- original_devices = modules[0].device
+ # For DiffusionPipeline, wrap the device in a list to make it iterable
+ original_devices = [modules[0].device]
# move to target device
for m in modules:
m.to(device)
diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py
index 75a2bdd13e..63932221b2 100644
--- a/src/diffusers/utils/__init__.py
+++ b/src/diffusers/utils/__init__.py
@@ -20,10 +20,12 @@ from packaging import version
from .. import __version__
from .constants import (
CONFIG_NAME,
+ DEFAULT_HF_PARALLEL_LOADING_WORKERS,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_DYNAMIC_MODULE_NAME,
FLAX_WEIGHTS_NAME,
GGUF_FILE_EXTENSION,
+ HF_ENABLE_PARALLEL_LOADING,
HF_MODULES_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
MIN_PEFT_VERSION,
@@ -82,10 +84,13 @@ from .import_utils import (
is_k_diffusion_available,
is_k_diffusion_version,
is_kernels_available,
+ is_kornia_available,
is_librosa_available,
is_matplotlib_available,
is_nltk_available,
is_note_seq_available,
+ is_nvidia_modelopt_available,
+ is_nvidia_modelopt_version,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,
diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py
index f8f04cc03a..8b4d76f3cb 100644
--- a/src/diffusers/utils/constants.py
+++ b/src/diffusers/utils/constants.py
@@ -43,6 +43,10 @@ DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
DIFFUSERS_REQUEST_TIMEOUT = 60
DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
+DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
+HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
+DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").lower() in ENV_VARS_TRUE_VALUES
+DIFFUSERS_ENABLE_HUB_KERNELS = os.environ.get("DIFFUSERS_ENABLE_HUB_KERNELS", "").upper() in ENV_VARS_TRUE_VALUES
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
diff --git a/src/diffusers/utils/dummy_nvidia_modelopt_objects.py b/src/diffusers/utils/dummy_nvidia_modelopt_objects.py
new file mode 100644
index 0000000000..046b28223b
--- /dev/null
+++ b/src/diffusers/utils/dummy_nvidia_modelopt_objects.py
@@ -0,0 +1,17 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+from ..utils import DummyObject, requires_backends
+
+
+class NVIDIAModelOptConfig(metaclass=DummyObject):
+ _backends = ["nvidia_modelopt"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["nvidia_modelopt"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["nvidia_modelopt"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["nvidia_modelopt"])
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 35df559ce4..6e7d227979 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -62,6 +62,21 @@ class ClassifierFreeZeroStarGuidance(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class FrequencyDecoupledGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class PerturbedAttentionGuidance(metaclass=DummyObject):
_backends = ["torch"]
@@ -513,6 +528,21 @@ class AutoModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class BriaTransformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class CacheMixin(metaclass=DummyObject):
_backends = ["torch"]
@@ -618,6 +648,21 @@ class ConsistencyDecoderVAE(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class ContextParallelConfig(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class ControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -1023,6 +1068,21 @@ class OmniGenTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class ParallelConfig(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class PixArtTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -1053,6 +1113,36 @@ class PriorTransformer(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class QwenImageControlNetModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class QwenImageMultiControlNetModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class QwenImageTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 293086631f..9ed6250452 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -17,6 +17,36 @@ class FluxAutoBlocks(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
+class FluxKontextAutoBlocks(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class FluxKontextModularPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class FluxModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -32,6 +62,96 @@ class FluxModularPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
+class QwenImageAutoBlocks(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageEditAutoBlocks(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageEditModularPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageEditPlusAutoBlocks(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageEditPlusModularPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageModularPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class StableDiffusionXLAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -362,6 +482,21 @@ class AuraFlowPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
+class BriaPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class ChromaImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1517,6 +1652,21 @@ class LTXPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
+class LucyEditPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class Lumina2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1742,6 +1892,111 @@ class PixArtSigmaPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
+class QwenImageControlNetInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageControlNetPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageEditInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageEditPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageEditPlusPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class QwenImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py
index 74ed240bf0..35bef51229 100644
--- a/src/diffusers/utils/dynamic_modules_utils.py
+++ b/src/diffusers/utils/dynamic_modules_utils.py
@@ -20,7 +20,6 @@ import json
import os
import re
import shutil
-import signal
import sys
import threading
from pathlib import Path
@@ -34,6 +33,7 @@ from packaging import version
from .. import __version__
from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
+from .constants import DIFFUSERS_DISABLE_REMOTE_CODE
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -159,52 +159,25 @@ def check_imports(filename):
return get_relative_imports(filename)
-def _raise_timeout_error(signum, frame):
- raise ValueError(
- "Loading this model requires you to execute custom code contained in the model repository on your local "
- "machine. Please set the option `trust_remote_code=True` to permit loading of this model."
- )
-
-
def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code):
- if trust_remote_code is None:
- if has_remote_code and TIME_OUT_REMOTE_CODE > 0:
- prev_sig_handler = None
- try:
- prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
- signal.alarm(TIME_OUT_REMOTE_CODE)
- while trust_remote_code is None:
- answer = input(
- f"The repository for {model_name} contains custom code which must be executed to correctly "
- f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
- f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
- f"Do you wish to run the custom code? [y/N] "
- )
- if answer.lower() in ["yes", "y", "1"]:
- trust_remote_code = True
- elif answer.lower() in ["no", "n", "0", ""]:
- trust_remote_code = False
- signal.alarm(0)
- except Exception:
- # OS which does not support signal.SIGALRM
- raise ValueError(
- f"The repository for {model_name} contains custom code which must be executed to correctly "
- f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
- f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
- )
- finally:
- if prev_sig_handler is not None:
- signal.signal(signal.SIGALRM, prev_sig_handler)
- signal.alarm(0)
- elif has_remote_code:
- # For the CI which puts the timeout at 0
- _raise_timeout_error(None, None)
+ trust_remote_code = trust_remote_code and not DIFFUSERS_DISABLE_REMOTE_CODE
+ if DIFFUSERS_DISABLE_REMOTE_CODE:
+ logger.warning(
+ "Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable. Ignoring `trust_remote_code`."
+ )
if has_remote_code and not trust_remote_code:
- raise ValueError(
- f"Loading {model_name} requires you to execute the configuration file in that"
- " repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
- " set the option `trust_remote_code=True` to remove this error."
+ error_msg = f"The repository for {model_name} contains custom code. "
+ error_msg += (
+ "Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable."
+ if DIFFUSERS_DISABLE_REMOTE_CODE
+ else "Pass `trust_remote_code=True` to allow loading remote code modules."
+ )
+ raise ValueError(error_msg)
+
+ elif has_remote_code and trust_remote_code:
+ logger.warning(
+ f"`trust_remote_code` is enabled. Downloading code from {model_name}. Please ensure you trust the contents of this repository"
)
return trust_remote_code
@@ -274,6 +247,7 @@ def find_pipeline_class(loaded_module):
def get_cached_module_file(
pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str,
+ subfolder: Optional[str] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
@@ -316,12 +290,8 @@ def get_cached_module_file(
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
-
-
- You may pass a token in `token` if you are not logged in (`hf auth login`) and want to use private or [gated
- models](https://huggingface.co/docs/hub/models-gated#gated-models).
-
-
+ > [!TIP] > You may pass a token in `token` if you are not logged in (`hf auth login`) and want to use private or
+ [gated > models](https://huggingface.co/docs/hub/models-gated#gated-models).
Returns:
`str`: The path to the module inside the cache.
@@ -380,6 +350,7 @@ def get_cached_module_file(
resolved_module_file = hf_hub_download(
pretrained_model_name_or_path,
module_file,
+ subfolder=subfolder,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
@@ -437,6 +408,7 @@ def get_cached_module_file(
get_cached_module_file(
pretrained_model_name_or_path,
f"{module_needed}.py",
+ subfolder=subfolder,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
@@ -451,6 +423,7 @@ def get_cached_module_file(
def get_class_from_dynamic_module(
pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str,
+ subfolder: Optional[str] = None,
class_name: Optional[str] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
@@ -463,12 +436,8 @@ def get_class_from_dynamic_module(
"""
Extracts a class from a module file, present in the local folder or repository of a model.
-
-
- Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
- therefore only be called on trusted repos.
-
-
+ > [!WARNING] > Calling this function will execute the code in the module file found locally or downloaded from the
+ Hub. It should > therefore only be called on trusted repos.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
@@ -503,12 +472,8 @@ def get_class_from_dynamic_module(
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
-
-
- You may pass a token in `token` if you are not logged in (`hf auth login`) and want to use private or [gated
- models](https://huggingface.co/docs/hub/models-gated#gated-models).
-
-
+ > [!TIP] > You may pass a token in `token` if you are not logged in (`hf auth login`) and want to use private or
+ [gated > models](https://huggingface.co/docs/hub/models-gated#gated-models).
Returns:
`type`: The class, dynamically imported from the module.
@@ -524,6 +489,7 @@ def get_class_from_dynamic_module(
final_module = get_cached_module_file(
pretrained_model_name_or_path,
module_file,
+ subfolder=subfolder,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py
index cf85488b7a..b6e99452aa 100644
--- a/src/diffusers/utils/hub_utils.py
+++ b/src/diffusers/utils/hub_utils.py
@@ -38,13 +38,13 @@ from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY, HF_HUB_OFFLINE
from huggingface_hub.file_download import REGEX_COMMIT_HASH
from huggingface_hub.utils import (
EntryNotFoundError,
+ HfHubHTTPError,
RepositoryNotFoundError,
RevisionNotFoundError,
is_jinja_available,
validate_hf_hub_args,
)
from packaging import version
-from requests import HTTPError
from .. import __version__
from .constants import (
@@ -316,7 +316,7 @@ def _get_model_file(
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
) from e
- except HTTPError as e:
+ except HfHubHTTPError as e:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{e}"
) from e
@@ -402,15 +402,17 @@ def _get_checkpoint_shard_files(
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
ignore_patterns = ["*.json", "*.md"]
- # `model_info` call must guarded with the above condition.
- model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
- for shard_file in original_shard_filenames:
- shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
- if not shard_file_present:
- raise EnvironmentError(
- f"{shards_path} does not appear to have a file named {shard_file} which is "
- "required according to the checkpoint index."
- )
+
+ # If the repo doesn't have the required shards, error out early even before downloading anything.
+ if not local_files_only:
+ model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
+ for shard_file in original_shard_filenames:
+ shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
+ if not shard_file_present:
+ raise EnvironmentError(
+ f"{shards_path} does not appear to have a file named {shard_file} which is "
+ "required according to the checkpoint index."
+ )
try:
# Load from URL
@@ -430,13 +432,18 @@ def _get_checkpoint_shard_files(
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
- except HTTPError as e:
+ except HfHubHTTPError as e:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
" again after checking your internet connection."
) from e
cached_filenames = [os.path.join(cached_folder, f) for f in original_shard_filenames]
+ for cached_file in cached_filenames:
+ if not os.path.isfile(cached_file):
+ raise EnvironmentError(
+ f"{cached_folder} does not have a file named {cached_file} which is required according to the checkpoint index."
+ )
return cached_filenames, sharded_metadata
diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py
index d8b26bda46..97065267b0 100644
--- a/src/diffusers/utils/import_utils.py
+++ b/src/diffusers/utils/import_utils.py
@@ -21,6 +21,7 @@ import operator as op
import os
import sys
from collections import OrderedDict, defaultdict
+from functools import lru_cache as cache
from itertools import chain
from types import ModuleType
from typing import Any, Tuple, Union
@@ -70,10 +71,11 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
# Fallback for Python < 3.10
for dist in importlib_metadata.distributions():
_top_level_declared = (dist.read_text("top_level.txt") or "").split()
- _infered_opt_names = {
+ # Infer top-level package names from file structure
+ _inferred_opt_names = {
f.parts[0] if len(f.parts) > 1 else inspect.getmodulename(f) for f in (dist.files or [])
} - {None}
- _top_level_inferred = filter(lambda name: "." not in name, _infered_opt_names)
+ _top_level_inferred = filter(lambda name: "." not in name, _inferred_opt_names)
for pkg in _top_level_declared or _top_level_inferred:
_package_map[pkg].append(dist.metadata["Name"])
except Exception as _:
@@ -119,7 +121,7 @@ if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES:
_safetensors_available, _safetensors_version = _is_package_available("safetensors")
else:
- logger.info("Disabling Safetensors because USE_TF is set")
+ logger.info("Disabling Safetensors because USE_SAFETENSORS is set")
_safetensors_available = False
_onnxruntime_version = "N/A"
@@ -224,6 +226,8 @@ _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
+_kornia_available, _kornia_version = _is_package_available("kornia")
+_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
def is_torch_available():
@@ -362,6 +366,10 @@ def is_optimum_quanto_available():
return _optimum_quanto_available
+def is_nvidia_modelopt_available():
+ return _nvidia_modelopt_available
+
+
def is_timm_available():
return _timm_available
@@ -398,6 +406,10 @@ def is_flash_attn_3_available():
return _flash_attn_3_available
+def is_kornia_available():
+ return _kornia_available
+
+
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -662,6 +674,7 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
+@cache
def is_torch_version(operation: str, version: str):
"""
Compares the current PyTorch version to a given reference with an operation.
@@ -675,6 +688,7 @@ def is_torch_version(operation: str, version: str):
return compare_versions(parse(_torch_version), operation, version)
+@cache
def is_torch_xla_version(operation: str, version: str):
"""
Compares the current torch_xla version to a given reference with an operation.
@@ -690,6 +704,7 @@ def is_torch_xla_version(operation: str, version: str):
return compare_versions(parse(_torch_xla_version), operation, version)
+@cache
def is_transformers_version(operation: str, version: str):
"""
Compares the current Transformers version to a given reference with an operation.
@@ -705,6 +720,7 @@ def is_transformers_version(operation: str, version: str):
return compare_versions(parse(_transformers_version), operation, version)
+@cache
def is_hf_hub_version(operation: str, version: str):
"""
Compares the current Hugging Face Hub version to a given reference with an operation.
@@ -720,6 +736,7 @@ def is_hf_hub_version(operation: str, version: str):
return compare_versions(parse(_hf_hub_version), operation, version)
+@cache
def is_accelerate_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
@@ -735,6 +752,7 @@ def is_accelerate_version(operation: str, version: str):
return compare_versions(parse(_accelerate_version), operation, version)
+@cache
def is_peft_version(operation: str, version: str):
"""
Compares the current PEFT version to a given reference with an operation.
@@ -750,6 +768,7 @@ def is_peft_version(operation: str, version: str):
return compare_versions(parse(_peft_version), operation, version)
+@cache
def is_bitsandbytes_version(operation: str, version: str):
"""
Args:
@@ -764,6 +783,7 @@ def is_bitsandbytes_version(operation: str, version: str):
return compare_versions(parse(_bitsandbytes_version), operation, version)
+@cache
def is_gguf_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
@@ -779,6 +799,7 @@ def is_gguf_version(operation: str, version: str):
return compare_versions(parse(_gguf_version), operation, version)
+@cache
def is_torchao_version(operation: str, version: str):
"""
Compares the current torchao version to a given reference with an operation.
@@ -794,6 +815,7 @@ def is_torchao_version(operation: str, version: str):
return compare_versions(parse(_torchao_version), operation, version)
+@cache
def is_k_diffusion_version(operation: str, version: str):
"""
Compares the current k-diffusion version to a given reference with an operation.
@@ -809,6 +831,7 @@ def is_k_diffusion_version(operation: str, version: str):
return compare_versions(parse(_k_diffusion_version), operation, version)
+@cache
def is_optimum_quanto_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
@@ -824,6 +847,23 @@ def is_optimum_quanto_version(operation: str, version: str):
return compare_versions(parse(_optimum_quanto_version), operation, version)
+@cache
+def is_nvidia_modelopt_version(operation: str, version: str):
+ """
+ Compares the current Nvidia ModelOpt version to a given reference with an operation.
+
+ Args:
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _nvidia_modelopt_available:
+ return False
+ return compare_versions(parse(_nvidia_modelopt_version), operation, version)
+
+
+@cache
def is_xformers_version(operation: str, version: str):
"""
Compares the current xformers version to a given reference with an operation.
@@ -839,6 +879,7 @@ def is_xformers_version(operation: str, version: str):
return compare_versions(parse(_xformers_version), operation, version)
+@cache
def is_sageattention_version(operation: str, version: str):
"""
Compares the current sageattention version to a given reference with an operation.
@@ -854,6 +895,7 @@ def is_sageattention_version(operation: str, version: str):
return compare_versions(parse(_sageattention_version), operation, version)
+@cache
def is_flash_attn_version(operation: str, version: str):
"""
Compares the current flash-attention version to a given reference with an operation.
diff --git a/src/diffusers/utils/kernels_utils.py b/src/diffusers/utils/kernels_utils.py
new file mode 100644
index 0000000000..26d6e3972f
--- /dev/null
+++ b/src/diffusers/utils/kernels_utils.py
@@ -0,0 +1,23 @@
+from ..utils import get_logger
+from .import_utils import is_kernels_available
+
+
+logger = get_logger(__name__)
+
+
+_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
+
+
+def _get_fa3_from_hub():
+ if not is_kernels_available():
+ return None
+ else:
+ from kernels import get_kernel
+
+ try:
+ # TODO: temporary revision for now. Remove when merged upstream into `main`.
+ flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
+ return flash_attn_3_hub
+ except Exception as e:
+ logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
+ raise
diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py
index 35691496a1..2b20f6120c 100644
--- a/src/diffusers/utils/outputs.py
+++ b/src/diffusers/utils/outputs.py
@@ -43,12 +43,8 @@ class BaseOutput(OrderedDict):
tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
Python dictionary.
-
-
- You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
- first.
-
-
+ > [!WARNING] > You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert
+ it to a tuple > first.
"""
def __init_subclass__(cls) -> None:
diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py
index 651fa27294..12066ee3f8 100644
--- a/src/diffusers/utils/peft_utils.py
+++ b/src/diffusers/utils/peft_utils.py
@@ -197,20 +197,6 @@ def get_peft_kwargs(
"lora_bias": lora_bias,
}
- # Example: try load FusionX LoRA into Wan VACE
- exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
- if exclude_modules:
- if not is_peft_version(">=", "0.14.0"):
- msg = """
-It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
-version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
-peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
-https://github.com/huggingface/diffusers/issues/new
- """
- logger.debug(msg)
- else:
- lora_config_kwargs.update({"exclude_modules": exclude_modules})
-
return lora_config_kwargs
@@ -388,27 +374,3 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
if warn_msg:
logger.warning(warn_msg)
-
-
-def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
- """
- Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
- `model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
- doesn't exist in `peft_state_dict`.
- """
- if model_state_dict is None:
- return
- all_modules = set()
- string_to_replace = f"{adapter_name}." if adapter_name else ""
-
- for name in model_state_dict.keys():
- if string_to_replace:
- name = name.replace(string_to_replace, "")
- if "." in name:
- module_name = name.rsplit(".", 1)[0]
- all_modules.add(module_name)
-
- target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
- exclude_modules = list(all_modules - target_modules_set)
-
- return exclude_modules
diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py
index a0307c108a..3297bb5fdc 100644
--- a/src/diffusers/utils/testing_utils.py
+++ b/src/diffusers/utils/testing_utils.py
@@ -38,6 +38,7 @@ from .import_utils import (
is_gguf_available,
is_kernels_available,
is_note_seq_available,
+ is_nvidia_modelopt_available,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,
@@ -66,7 +67,10 @@ else:
global_rng = random.Random()
logger = get_logger(__name__)
-
+logger.warning(
+ "diffusers.utils.testing_utils' is deprecated and will be removed in a future version. "
+ "Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. "
+)
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version
) > version.parse("0.5")
@@ -635,6 +639,18 @@ def require_torchao_version_greater_or_equal(torchao_version):
return decorator
+def require_modelopt_version_greater_or_equal(modelopt_version):
+ def decorator(test_case):
+ correct_nvidia_modelopt_version = is_nvidia_modelopt_available() and version.parse(
+ version.parse(importlib.metadata.version("modelopt")).base_version
+ ) >= version.parse(modelopt_version)
+ return unittest.skipUnless(
+ correct_nvidia_modelopt_version, f"Test requires modelopt with version greater than {modelopt_version}."
+ )(test_case)
+
+ return decorator
+
+
def require_kernels_version_greater_or_equal(kernels_version):
def decorator(test_case):
correct_kernels_version = is_kernels_available() and version.parse(
@@ -801,10 +817,9 @@ def export_to_ply(mesh, output_ply_path: str = None):
f.write(format.pack(*vertex))
if faces is not None:
- format = struct.Struct(" `add_adapter()` works."""
- scheduler_cls = self.scheduler_classes[0]
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
@@ -2342,105 +2219,50 @@ class PeftLoraLoaderMixinTests:
)
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
- @require_peft_version_greater("0.13.2")
- def test_lora_exclude_modules(self):
- """
- Test to check if `exclude_modules` works or not. It works in the following way:
- we first create a pipeline and insert LoRA config into it. We then derive a `set`
- of modules to exclude by investigating its denoiser state dict and denoiser LoRA
- state dict.
-
- We then create a new LoRA config to include the `exclude_modules` and perform tests.
- """
- scheduler_cls = self.scheduler_classes[0]
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components).to(torch_device)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
-
- # only supported for `denoiser` now
- pipe_cp = copy.deepcopy(pipe)
- pipe_cp, _ = self.add_adapters_to_pipeline(
- pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
- )
- denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
- pipe_cp.to("cpu")
- del pipe_cp
-
- denoiser_lora_config.exclude_modules = denoiser_exclude_modules
- pipe, _ = self.add_adapters_to_pipeline(
- pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
- )
- output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
-
- with tempfile.TemporaryDirectory() as tmpdir:
- modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
- lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
- lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
- self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
- pipe.unload_lora_weights()
- pipe.load_lora_weights(tmpdir)
-
- output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
-
- self.assertTrue(
- not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
- "LoRA should change outputs.",
- )
- self.assertTrue(
- np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
- "Lora outputs should match.",
- )
-
def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ output_no_lora = self.get_base_pipe_output()
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config)
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ lora_loadable_components = self.pipeline_class._lora_loadable_modules
+ if "text_encoder_2" in lora_loadable_components:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config)
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- if self.has_two_text_encoders or self.has_three_text_encoders:
- lora_loadable_components = self.pipeline_class._lora_loadable_modules
- if "text_encoder_2" in lora_loadable_components:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
- output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ # First, delete adapter and compare.
+ pipe.delete_adapters(pipe.get_active_adapters()[0])
+ output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertFalse(np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3))
+ self.assertTrue(np.allclose(output_no_lora, output_no_adapter, atol=1e-3, rtol=1e-3))
- with tempfile.TemporaryDirectory() as tmpdirname:
- modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
- lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
- self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
- self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
-
- # First, delete adapter and compare.
- pipe.delete_adapters(pipe.get_active_adapters()[0])
- output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3))
- self.assertTrue(np.allclose(output_no_lora, output_no_adapter, atol=1e-3, rtol=1e-3))
-
- # Then load adapter and compare.
- pipe.load_lora_weights(tmpdirname)
- output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
+ # Then load adapter and compare.
+ pipe.load_lora_weights(tmpdirname)
+ output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
def _test_group_offloading_inference_denoiser(self, offload_type, use_stream):
from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook
@@ -2448,7 +2270,7 @@ class PeftLoraLoaderMixinTests:
onload_device = torch_device
offload_device = torch.device("cpu")
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -2465,9 +2287,8 @@ class PeftLoraLoaderMixinTests:
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
- components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
+ components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
@@ -2483,6 +2304,10 @@ class PeftLoraLoaderMixinTests:
num_blocks_per_group=1,
use_stream=use_stream,
)
+ # Place other model-level components on `torch_device`.
+ for _, component in pipe.components.items():
+ if isinstance(component, torch.nn.Module):
+ component.to(torch_device)
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
self.assertTrue(group_offload_hook_1 is not None)
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -2514,7 +2339,7 @@ class PeftLoraLoaderMixinTests:
@require_torch_accelerator
def test_lora_loading_model_cpu_offload(self):
- components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ components, _, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
@@ -2533,7 +2358,7 @@ class PeftLoraLoaderMixinTests:
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
# reinitialize the pipeline to mimic the inference workflow.
- components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.enable_model_cpu_offload(device=torch_device)
pipe.load_lora_weights(tmpdirname)
diff --git a/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py b/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py
index 4741648359..7eb830cd50 100644
--- a/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py
+++ b/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py
@@ -21,7 +21,8 @@ from parameterized import parameterized
from diffusers import AsymmetricAutoencoderKL
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
Expectations,
backend_empty_cache,
enable_full_determinism,
@@ -34,7 +35,6 @@ from diffusers.utils.testing_utils import (
torch_all_close,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/autoencoders/test_models_autoencoder_cosmos.py b/tests/models/autoencoders/test_models_autoencoder_cosmos.py
index bc0011a2f0..ceccc2364e 100644
--- a/tests/models/autoencoders/test_models_autoencoder_cosmos.py
+++ b/tests/models/autoencoders/test_models_autoencoder_cosmos.py
@@ -15,8 +15,8 @@
import unittest
from diffusers import AutoencoderKLCosmos
-from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
+from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/autoencoders/test_models_autoencoder_dc.py b/tests/models/autoencoders/test_models_autoencoder_dc.py
index 4d2c3dc663..56f172f1c8 100644
--- a/tests/models/autoencoders/test_models_autoencoder_dc.py
+++ b/tests/models/autoencoders/test_models_autoencoder_dc.py
@@ -16,12 +16,12 @@
import unittest
from diffusers import AutoencoderDC
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
index 40479991e9..6f91f8bfa9 100644
--- a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
+++ b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
@@ -19,12 +19,12 @@ import torch
from diffusers import AutoencoderKLHunyuanVideo
from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/autoencoders/test_models_autoencoder_kl.py b/tests/models/autoencoders/test_models_autoencoder_kl.py
index 2c323e4f03..662a3f1b80 100644
--- a/tests/models/autoencoders/test_models_autoencoder_kl.py
+++ b/tests/models/autoencoders/test_models_autoencoder_kl.py
@@ -21,7 +21,8 @@ from parameterized import parameterized
from diffusers import AutoencoderKL
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -34,7 +35,6 @@ from diffusers.utils.testing_utils import (
torch_all_close,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py b/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py
index 7ab9520ce6..739daf2a49 100644
--- a/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py
+++ b/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py
@@ -18,12 +18,12 @@ import unittest
import torch
from diffusers import AutoencoderKLCogVideoX
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py
index 618a448eca..6cb427bff8 100644
--- a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py
+++ b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py
@@ -16,12 +16,12 @@
import unittest
from diffusers import AutoencoderKLTemporalDecoder
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py
index c056930a5e..21ab3896c8 100644
--- a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py
+++ b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py
@@ -18,12 +18,12 @@ import unittest
import torch
from diffusers import AutoencoderKLLTXVideo
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/autoencoders/test_models_autoencoder_magvit.py b/tests/models/autoencoders/test_models_autoencoder_magvit.py
index c171176627..58cbfc05bd 100644
--- a/tests/models/autoencoders/test_models_autoencoder_magvit.py
+++ b/tests/models/autoencoders/test_models_autoencoder_magvit.py
@@ -16,8 +16,8 @@
import unittest
from diffusers import AutoencoderKLMagvit
-from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
+from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/autoencoders/test_models_autoencoder_mochi.py b/tests/models/autoencoders/test_models_autoencoder_mochi.py
index d646693c57..b8c5aaaa1e 100755
--- a/tests/models/autoencoders/test_models_autoencoder_mochi.py
+++ b/tests/models/autoencoders/test_models_autoencoder_mochi.py
@@ -16,12 +16,12 @@
import unittest
from diffusers import AutoencoderKLMochi
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/autoencoders/test_models_autoencoder_oobleck.py b/tests/models/autoencoders/test_models_autoencoder_oobleck.py
index 5c1d7c8b71..eb7bd50f4a 100644
--- a/tests/models/autoencoders/test_models_autoencoder_oobleck.py
+++ b/tests/models/autoencoders/test_models_autoencoder_oobleck.py
@@ -21,7 +21,8 @@ from datasets import load_dataset
from parameterized import parameterized
from diffusers import AutoencoderOobleck
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -29,7 +30,6 @@ from diffusers.utils.testing_utils import (
torch_all_close,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/autoencoders/test_models_autoencoder_tiny.py b/tests/models/autoencoders/test_models_autoencoder_tiny.py
index fba2c9eb1b..4d1dc69cfa 100644
--- a/tests/models/autoencoders/test_models_autoencoder_tiny.py
+++ b/tests/models/autoencoders/test_models_autoencoder_tiny.py
@@ -21,7 +21,8 @@ import torch
from parameterized import parameterized
from diffusers import AutoencoderTiny
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -30,7 +31,6 @@ from diffusers.utils.testing_utils import (
torch_all_close,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/autoencoders/test_models_autoencoder_wan.py b/tests/models/autoencoders/test_models_autoencoder_wan.py
index c0af4f5834..cc9c888681 100644
--- a/tests/models/autoencoders/test_models_autoencoder_wan.py
+++ b/tests/models/autoencoders/test_models_autoencoder_wan.py
@@ -18,8 +18,8 @@ import unittest
import torch
from diffusers import AutoencoderKLWan
-from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
+from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/autoencoders/test_models_consistency_decoder_vae.py b/tests/models/autoencoders/test_models_consistency_decoder_vae.py
index cdce013cfb..7e44edba36 100644
--- a/tests/models/autoencoders/test_models_consistency_decoder_vae.py
+++ b/tests/models/autoencoders/test_models_consistency_decoder_vae.py
@@ -20,7 +20,9 @@ import numpy as np
import torch
from diffusers import ConsistencyDecoderVAE, StableDiffusionPipeline
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_image,
@@ -28,8 +30,6 @@ from diffusers.utils.testing_utils import (
torch_all_close,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/autoencoders/test_models_vae_flax.py b/tests/models/autoencoders/test_models_vae_flax.py
deleted file mode 100644
index 8fedb85ecc..0000000000
--- a/tests/models/autoencoders/test_models_vae_flax.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import unittest
-
-from diffusers import FlaxAutoencoderKL
-from diffusers.utils import is_flax_available
-from diffusers.utils.testing_utils import require_flax
-
-from ..test_modeling_common_flax import FlaxModelTesterMixin
-
-
-if is_flax_available():
- import jax
-
-
-@require_flax
-class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase):
- model_class = FlaxAutoencoderKL
-
- @property
- def dummy_input(self):
- batch_size = 4
- num_channels = 3
- sizes = (32, 32)
-
- prng_key = jax.random.PRNGKey(0)
- image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes))
-
- return {"sample": image, "prng_key": prng_key}
-
- def prepare_init_args_and_inputs_for_common(self):
- init_dict = {
- "block_out_channels": [32, 64],
- "in_channels": 3,
- "out_channels": 3,
- "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
- "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
- "latent_channels": 4,
- }
- inputs_dict = self.dummy_input
- return init_dict, inputs_dict
diff --git a/tests/models/autoencoders/test_models_vq.py b/tests/models/autoencoders/test_models_vq.py
index e8ed98f44a..1c636b0817 100644
--- a/tests/models/autoencoders/test_models_vq.py
+++ b/tests/models/autoencoders/test_models_vq.py
@@ -18,13 +18,13 @@ import unittest
import torch
from diffusers import VQModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_manual_seed,
enable_full_determinism,
floats_tensor,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/test_attention_processor.py b/tests/models/test_attention_processor.py
index d070f6ea33..ccf36b092b 100644
--- a/tests/models/test_attention_processor.py
+++ b/tests/models/test_attention_processor.py
@@ -7,7 +7,8 @@ import torch
from diffusers import DiffusionPipeline
from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor
-from diffusers.utils.testing_utils import torch_device
+
+from ..testing_utils import torch_device
class AttnAddedKVProcessorTests(unittest.TestCase):
diff --git a/tests/models/test_layers_utils.py b/tests/models/test_layers_utils.py
index ec8e01b4b7..eaeffa699d 100644
--- a/tests/models/test_layers_utils.py
+++ b/tests/models/test_layers_utils.py
@@ -24,7 +24,8 @@ from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU
from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from diffusers.models.transformers.transformer_2d import Transformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_manual_seed,
require_torch_accelerator_with_fp64,
require_torch_version_greater_equal,
diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py
index 36b563ba9f..a44ef571c5 100644
--- a/tests/models/test_modeling_common.py
+++ b/tests/models/test_modeling_common.py
@@ -25,7 +25,6 @@ import traceback
import unittest
import unittest.mock as mock
import uuid
-import warnings
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
@@ -36,12 +35,11 @@ import safetensors.torch
import torch
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
-from huggingface_hub import ModelCard, delete_repo, snapshot_download
-from huggingface_hub.utils import is_jinja_available
+from huggingface_hub import ModelCard, delete_repo, snapshot_download, try_to_load_from_cache
+from huggingface_hub.utils import HfHubHTTPError, is_jinja_available
from parameterized import parameterized
-from requests.exceptions import HTTPError
-from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel
+from diffusers.models import FluxTransformer2DModel, SD3Transformer2DModel, UNet2DConditionModel
from diffusers.models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
@@ -59,7 +57,10 @@ from diffusers.utils import (
logging,
)
from diffusers.utils.hub_utils import _add_variant
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import get_torch_cuda_device_capability
+
+from ..others.test_utils import TOKEN, USER, is_staging_test
+from ..testing_utils import (
CaptureLogger,
_check_safetensors_serialization,
backend_empty_cache,
@@ -82,9 +83,6 @@ from diffusers.utils.testing_utils import (
torch_all_close,
torch_device,
)
-from diffusers.utils.torch_utils import get_torch_cuda_device_capability
-
-from ..others.test_utils import TOKEN, USER, is_staging_test
if is_peft_available():
@@ -244,8 +242,8 @@ class ModelUtilsTest(unittest.TestCase):
else:
_ = load_model(repo_id)
- warning_message = str(warning.warnings[0].message)
- self.assertIn("This serialization format is now deprecated to standardize the serialization", warning_message)
+ warning_messages = " ".join(str(w.message) for w in warning.warnings)
+ self.assertIn("This serialization format is now deprecated to standardize the serialization", warning_messages)
# Local tests are already covered down below.
@parameterized.expand(
@@ -272,7 +270,7 @@ class ModelUtilsTest(unittest.TestCase):
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
- response_mock.raise_for_status.side_effect = HTTPError
+ response_mock.raise_for_status.side_effect = HfHubHTTPError("Server down", response=mock.Mock())
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
@@ -291,6 +289,56 @@ class ModelUtilsTest(unittest.TestCase):
if p1.data.ne(p2.data).sum() > 0:
assert False, "Parameters not the same!"
+ def test_local_files_only_with_sharded_checkpoint(self):
+ repo_id = "hf-internal-testing/tiny-flux-sharded"
+ error_response = mock.Mock(
+ status_code=500,
+ headers={},
+ raise_for_status=mock.Mock(side_effect=HfHubHTTPError("Server down", response=mock.Mock())),
+ json=mock.Mock(return_value={}),
+ )
+ client_mock = mock.Mock()
+ client_mock.get.return_value = error_response
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ model = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=tmpdir)
+
+ with mock.patch("huggingface_hub.hf_api.get_session", return_value=client_mock):
+ # Should fail with local_files_only=False (network required)
+ # We would make a network call with model_info
+ with self.assertRaises(OSError):
+ FluxTransformer2DModel.from_pretrained(
+ repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=False
+ )
+
+ # Should succeed with local_files_only=True (uses cache)
+ # model_info call skipped
+ local_model = FluxTransformer2DModel.from_pretrained(
+ repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
+ )
+
+ assert all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), local_model.parameters())), (
+ "Model parameters don't match!"
+ )
+
+ # Remove a shard file
+ cached_shard_file = try_to_load_from_cache(
+ repo_id, filename="transformer/diffusion_pytorch_model-00001-of-00002.safetensors", cache_dir=tmpdir
+ )
+ os.remove(cached_shard_file)
+
+ # Attempting to load from cache should raise an error
+ with self.assertRaises(OSError) as context:
+ FluxTransformer2DModel.from_pretrained(
+ repo_id, subfolder="transformer", cache_dir=tmpdir, local_files_only=True
+ )
+
+ # Verify error mentions the missing shard
+ error_msg = str(context.exception)
+ assert cached_shard_file in error_msg or "required according to the checkpoint index" in error_msg, (
+ f"Expected error about missing shard, got: {error_msg}"
+ )
+
@unittest.skip("Flaky behaviour on CI. Re-enable after migrating to new runners")
@unittest.skipIf(torch_device == "mps", reason="Test not supported for MPS.")
def test_one_request_upon_cached(self):
@@ -1428,6 +1476,41 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
+ @require_torch_accelerator
+ def test_sharded_checkpoints_with_parallel_loading(self):
+ torch.manual_seed(0)
+ config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**config).eval()
+ model = model.to(torch_device)
+
+ base_output = model(**inputs_dict)
+
+ model_size = compute_module_persistent_sizes(model)[""]
+ max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
+ self.assertTrue(os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
+
+ # Now check if the right number of shards exists. First, let's get the number of shards.
+ # Since this number can be dependent on the model being tested, it's important that we calculate it
+ # instead of hardcoding it.
+ expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
+ actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
+ self.assertTrue(actual_num_shards == expected_num_shards)
+
+ # Load with parallel loading
+ os.environ["HF_ENABLE_PARALLEL_LOADING"] = "yes"
+ new_model = self.model_class.from_pretrained(tmp_dir).eval()
+ new_model = new_model.to(torch_device)
+
+ torch.manual_seed(0)
+ if "generator" in inputs_dict:
+ _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ new_output = new_model(**inputs_dict)
+ self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
+ # set to no.
+ os.environ["HF_ENABLE_PARALLEL_LOADING"] = "no"
+
@require_torch_accelerator
def test_sharded_checkpoints_device_map(self):
if self.model_class._no_split_modules is None:
@@ -1971,6 +2054,7 @@ class TorchCompileTesterMixin:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
+ model.eval()
model = torch.compile(model, fullgraph=True)
with (
@@ -1988,6 +2072,7 @@ class TorchCompileTesterMixin:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
+ model.eval()
model.compile_repeated_blocks(fullgraph=True)
recompile_limit = 1
@@ -2010,7 +2095,6 @@ class TorchCompileTesterMixin:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
-
model.eval()
# TODO: Can test for other group offloading kwargs later if needed.
group_offload_kwargs = {
@@ -2023,11 +2107,11 @@ class TorchCompileTesterMixin:
}
model.enable_group_offload(**group_offload_kwargs)
model.compile()
+
with torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
- @require_torch_version_greater("2.7.1")
def test_compile_on_different_shapes(self):
if self.different_shapes_for_compilation is None:
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
@@ -2035,6 +2119,7 @@ class TorchCompileTesterMixin:
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
+ model.eval()
model = torch.compile(model, fullgraph=True, dynamic=True)
for height, width in self.different_shapes_for_compilation:
@@ -2042,6 +2127,26 @@ class TorchCompileTesterMixin:
inputs_dict = self.prepare_dummy_input(height=height, width=width)
_ = model(**inputs_dict)
+ def test_compile_works_with_aot(self):
+ from torch._inductor.package import load_package
+
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ model = self.model_class(**init_dict).to(torch_device)
+ exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2")
+ _ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
+ assert os.path.exists(package_path)
+ loaded_binary = load_package(package_path, run_single_threaded=True)
+
+ model.forward = loaded_binary
+
+ with torch.no_grad():
+ _ = model(**inputs_dict)
+ _ = model(**inputs_dict)
+
@slow
@require_torch_2
@@ -2267,14 +2372,15 @@ class LoraHotSwappingForModelTesterMixin:
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
# check possibility to ignore the error/warning
+ from diffusers.loaders.peft import logger
+
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter("always") # Capture all warnings
- model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
- self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
+ # note: assertNoLogs requires Python 3.10+
+ with self.assertNoLogs(logger, level="WARNING"):
+ model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
diff --git a/tests/models/test_modeling_common_flax.py b/tests/models/test_modeling_common_flax.py
deleted file mode 100644
index 8945aed7c9..0000000000
--- a/tests/models/test_modeling_common_flax.py
+++ /dev/null
@@ -1,66 +0,0 @@
-import inspect
-
-from diffusers.utils import is_flax_available
-from diffusers.utils.testing_utils import require_flax
-
-
-if is_flax_available():
- import jax
-
-
-@require_flax
-class FlaxModelTesterMixin:
- def test_output(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- model = self.model_class(**init_dict)
- variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
- jax.lax.stop_gradient(variables)
-
- output = model.apply(variables, inputs_dict["sample"])
-
- if isinstance(output, dict):
- output = output.sample
-
- self.assertIsNotNone(output)
- expected_shape = inputs_dict["sample"].shape
- self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
-
- def test_forward_with_norm_groups(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- init_dict["norm_num_groups"] = 16
- init_dict["block_out_channels"] = (16, 32)
-
- model = self.model_class(**init_dict)
- variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
- jax.lax.stop_gradient(variables)
-
- output = model.apply(variables, inputs_dict["sample"])
-
- if isinstance(output, dict):
- output = output.sample
-
- self.assertIsNotNone(output)
- expected_shape = inputs_dict["sample"].shape
- self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
-
- def test_deprecated_kwargs(self):
- has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
- has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0
-
- if has_kwarg_in_model_class and not has_deprecated_kwarg:
- raise ValueError(
- f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
- " under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
- " no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
- " []`"
- )
-
- if not has_kwarg_in_model_class and has_deprecated_kwarg:
- raise ValueError(
- f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
- " under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
- f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
- " from `_deprecated_kwargs = []`"
- )
diff --git a/tests/models/transformers/test_models_dit_transformer2d.py b/tests/models/transformers/test_models_dit_transformer2d.py
index 3070321673..473a876375 100644
--- a/tests/models/transformers/test_models_dit_transformer2d.py
+++ b/tests/models/transformers/test_models_dit_transformer2d.py
@@ -18,13 +18,13 @@ import unittest
import torch
from diffusers import DiTTransformer2DModel, Transformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
slow,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_pixart_transformer2d.py b/tests/models/transformers/test_models_pixart_transformer2d.py
index 38fada0b4b..17c400cf19 100644
--- a/tests/models/transformers/test_models_pixart_transformer2d.py
+++ b/tests/models/transformers/test_models_pixart_transformer2d.py
@@ -18,13 +18,13 @@ import unittest
import torch
from diffusers import PixArtTransformer2DModel, Transformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
slow,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_prior.py b/tests/models/transformers/test_models_prior.py
index 5d66aadb1b..af5ac4bbbd 100644
--- a/tests/models/transformers/test_models_prior.py
+++ b/tests/models/transformers/test_models_prior.py
@@ -21,7 +21,8 @@ import torch
from parameterized import parameterized
from diffusers import PriorTransformer
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -29,7 +30,6 @@ from diffusers.utils.testing_utils import (
torch_all_close,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_allegro.py b/tests/models/transformers/test_models_transformer_allegro.py
index 8a0c475583..7c002f8781 100644
--- a/tests/models/transformers/test_models_transformer_allegro.py
+++ b/tests/models/transformers/test_models_transformer_allegro.py
@@ -17,11 +17,11 @@ import unittest
import torch
from diffusers import AllegroTransformer3DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_aura_flow.py b/tests/models/transformers/test_models_transformer_aura_flow.py
index 8dff07373e..ae8c3b7234 100644
--- a/tests/models/transformers/test_models_transformer_aura_flow.py
+++ b/tests/models/transformers/test_models_transformer_aura_flow.py
@@ -18,8 +18,8 @@ import unittest
import torch
from diffusers import AuraFlowTransformer2DModel
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_bria.py b/tests/models/transformers/test_models_transformer_bria.py
new file mode 100644
index 0000000000..9056590edf
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_bria.py
@@ -0,0 +1,181 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import BriaTransformer2DModel
+from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
+from diffusers.models.embeddings import ImageProjection
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
+
+
+enable_full_determinism()
+
+
+def create_bria_ip_adapter_state_dict(model):
+ # "ip_adapter" (cross-attention weights)
+ ip_cross_attn_state_dict = {}
+ key_id = 0
+
+ for name in model.attn_processors.keys():
+ if name.startswith("single_transformer_blocks"):
+ continue
+
+ joint_attention_dim = model.config["joint_attention_dim"]
+ hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
+ sd = FluxIPAdapterJointAttnProcessor2_0(
+ hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
+ ).state_dict()
+ ip_cross_attn_state_dict.update(
+ {
+ f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
+ f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
+ f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
+ f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
+ }
+ )
+
+ key_id += 1
+
+ # "image_proj" (ImageProjection layer weights)
+
+ image_projection = ImageProjection(
+ cross_attention_dim=model.config["joint_attention_dim"],
+ image_embed_dim=model.config["pooled_projection_dim"],
+ num_image_text_embeds=4,
+ )
+
+ ip_image_projection_state_dict = {}
+ sd = image_projection.state_dict()
+ ip_image_projection_state_dict.update(
+ {
+ "proj.weight": sd["image_embeds.weight"],
+ "proj.bias": sd["image_embeds.bias"],
+ "norm.weight": sd["norm.weight"],
+ "norm.bias": sd["norm.bias"],
+ }
+ )
+
+ del sd
+ ip_state_dict = {}
+ ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
+ return ip_state_dict
+
+
+class BriaTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = BriaTransformer2DModel
+ main_input_name = "hidden_states"
+ # We override the items here because the transformer under consideration is small.
+ model_split_percents = [0.8, 0.7, 0.7]
+
+ # Skip setting testing with default: AttnProcessor
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_latent_channels = 4
+ num_image_channels = 3
+ height = width = 4
+ sequence_length = 48
+ embedding_dim = 32
+
+ hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
+ image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
+ timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "img_ids": image_ids,
+ "txt_ids": text_ids,
+ "timestep": timestep,
+ }
+
+ @property
+ def input_shape(self):
+ return (16, 4)
+
+ @property
+ def output_shape(self):
+ return (16, 4)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": 1,
+ "in_channels": 4,
+ "num_layers": 1,
+ "num_single_layers": 1,
+ "attention_head_dim": 8,
+ "num_attention_heads": 2,
+ "joint_attention_dim": 32,
+ "pooled_projection_dim": None,
+ "axes_dims_rope": [0, 4, 4],
+ }
+
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_deprecated_inputs_img_txt_ids_3d(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ output_1 = model(**inputs_dict).to_tuple()[0]
+
+ # update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
+ text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
+ image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
+
+ assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
+ assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
+
+ inputs_dict["txt_ids"] = text_ids_3d
+ inputs_dict["img_ids"] = image_ids_3d
+
+ with torch.no_grad():
+ output_2 = model(**inputs_dict).to_tuple()[0]
+
+ self.assertEqual(output_1.shape, output_2.shape)
+ self.assertTrue(
+ torch.allclose(output_1, output_2, atol=1e-5),
+ msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
+ )
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"BriaTransformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class BriaTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = BriaTransformer2DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return BriaTransformerTests().prepare_init_args_and_inputs_for_common()
+
+
+class BriaTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
+ model_class = BriaTransformer2DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return BriaTransformerTests().prepare_init_args_and_inputs_for_common()
diff --git a/tests/models/transformers/test_models_transformer_chroma.py b/tests/models/transformers/test_models_transformer_chroma.py
index e9fd5a0bfb..92ac8198ed 100644
--- a/tests/models/transformers/test_models_transformer_chroma.py
+++ b/tests/models/transformers/test_models_transformer_chroma.py
@@ -20,8 +20,8 @@ import torch
from diffusers import ChromaTransformer2DModel
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
from diffusers.models.embeddings import ImageProjection
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_cogvideox.py b/tests/models/transformers/test_models_transformer_cogvideox.py
index 54d1242bf7..f632add7e5 100644
--- a/tests/models/transformers/test_models_transformer_cogvideox.py
+++ b/tests/models/transformers/test_models_transformer_cogvideox.py
@@ -18,11 +18,11 @@ import unittest
import torch
from diffusers import CogVideoXTransformer3DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_cogview3plus.py b/tests/models/transformers/test_models_transformer_cogview3plus.py
index 57131dc3f1..d38d77531d 100644
--- a/tests/models/transformers/test_models_transformer_cogview3plus.py
+++ b/tests/models/transformers/test_models_transformer_cogview3plus.py
@@ -18,11 +18,11 @@ import unittest
import torch
from diffusers import CogView3PlusTransformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_cogview4.py b/tests/models/transformers/test_models_transformer_cogview4.py
index 798453e86d..084c3b7cea 100644
--- a/tests/models/transformers/test_models_transformer_cogview4.py
+++ b/tests/models/transformers/test_models_transformer_cogview4.py
@@ -17,8 +17,8 @@ import unittest
import torch
from diffusers import CogView4Transformer2DModel
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_consisid.py b/tests/models/transformers/test_models_transformer_consisid.py
index af2e1e6338..77fc172d07 100644
--- a/tests/models/transformers/test_models_transformer_consisid.py
+++ b/tests/models/transformers/test_models_transformer_consisid.py
@@ -18,11 +18,11 @@ import unittest
import torch
from diffusers import ConsisIDTransformer3DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_cosmos.py b/tests/models/transformers/test_models_transformer_cosmos.py
index 7d26004d75..d7390e105c 100644
--- a/tests/models/transformers/test_models_transformer_cosmos.py
+++ b/tests/models/transformers/test_models_transformer_cosmos.py
@@ -17,8 +17,8 @@ import unittest
import torch
from diffusers import CosmosTransformer3DModel
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_easyanimate.py b/tests/models/transformers/test_models_transformer_easyanimate.py
index 0a255f4d4e..d7b90a47d9 100644
--- a/tests/models/transformers/test_models_transformer_easyanimate.py
+++ b/tests/models/transformers/test_models_transformer_easyanimate.py
@@ -18,8 +18,8 @@ import unittest
import torch
from diffusers import EasyAnimateTransformer3DModel
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py
index 68b5c02bc0..3ab02f797b 100644
--- a/tests/models/transformers/test_models_transformer_flux.py
+++ b/tests/models/transformers/test_models_transformer_flux.py
@@ -20,8 +20,8 @@ import torch
from diffusers import FluxTransformer2DModel
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
from diffusers.models.embeddings import ImageProjection
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, is_peft_available, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
@@ -172,6 +172,35 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
expected_set = {"FluxTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+ # The test exists for cases like
+ # https://github.com/huggingface/diffusers/issues/11874
+ @unittest.skipIf(not is_peft_available(), "Only with PEFT")
+ def test_lora_exclude_modules(self):
+ from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
+
+ lora_rank = 4
+ target_module = "single_transformer_blocks.0.proj_out"
+ adapter_name = "foo"
+ init_dict, _ = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict).to(torch_device)
+
+ state_dict = model.state_dict()
+ target_mod_shape = state_dict[f"{target_module}.weight"].shape
+ lora_state_dict = {
+ f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22,
+ f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33,
+ }
+ # Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
+ config = LoraConfig(
+ r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
+ )
+ inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
+ set_peft_model_state_dict(model, lora_state_dict, adapter_name)
+ retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
+ assert len(retrieved_lora_state_dict) == len(lora_state_dict)
+ assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
+ assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
+
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
diff --git a/tests/models/transformers/test_models_transformer_hidream.py b/tests/models/transformers/test_models_transformer_hidream.py
index fa0fa5123a..fdd5f8c7fd 100644
--- a/tests/models/transformers/test_models_transformer_hidream.py
+++ b/tests/models/transformers/test_models_transformer_hidream.py
@@ -18,11 +18,11 @@ import unittest
import torch
from diffusers import HiDreamImageTransformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_hunyuan_dit.py b/tests/models/transformers/test_models_transformer_hunyuan_dit.py
index 242ce1f283..d82a62d58e 100644
--- a/tests/models/transformers/test_models_transformer_hunyuan_dit.py
+++ b/tests/models/transformers/test_models_transformer_hunyuan_dit.py
@@ -18,11 +18,11 @@ import unittest
import torch
from diffusers import HunyuanDiT2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_hunyuan_video.py b/tests/models/transformers/test_models_transformer_hunyuan_video.py
index b42a3cb5dc..385a5eefd5 100644
--- a/tests/models/transformers/test_models_transformer_hunyuan_video.py
+++ b/tests/models/transformers/test_models_transformer_hunyuan_video.py
@@ -17,11 +17,11 @@ import unittest
import torch
from diffusers import HunyuanVideoTransformer3DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py b/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py
index ddb79925a7..00a2b27e02 100644
--- a/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py
+++ b/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py
@@ -17,11 +17,11 @@ import unittest
import torch
from diffusers import HunyuanVideoFramepackTransformer3DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_latte.py b/tests/models/transformers/test_models_transformer_latte.py
index db93421b44..7bf2c52e62 100644
--- a/tests/models/transformers/test_models_transformer_latte.py
+++ b/tests/models/transformers/test_models_transformer_latte.py
@@ -18,11 +18,11 @@ import unittest
import torch
from diffusers import LatteTransformer3DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_ltx.py b/tests/models/transformers/test_models_transformer_ltx.py
index 2c61658f58..e912463bbf 100644
--- a/tests/models/transformers/test_models_transformer_ltx.py
+++ b/tests/models/transformers/test_models_transformer_ltx.py
@@ -18,8 +18,8 @@ import unittest
import torch
from diffusers import LTXVideoTransformer3DModel
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_lumina.py b/tests/models/transformers/test_models_transformer_lumina.py
index d0103eb473..0024aa106c 100644
--- a/tests/models/transformers/test_models_transformer_lumina.py
+++ b/tests/models/transformers/test_models_transformer_lumina.py
@@ -18,11 +18,11 @@ import unittest
import torch
from diffusers import LuminaNextDiT2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_lumina2.py b/tests/models/transformers/test_models_transformer_lumina2.py
index 731e2ff3d5..4efae3d4b7 100644
--- a/tests/models/transformers/test_models_transformer_lumina2.py
+++ b/tests/models/transformers/test_models_transformer_lumina2.py
@@ -18,11 +18,11 @@ import unittest
import torch
from diffusers import Lumina2Transformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_mochi.py b/tests/models/transformers/test_models_transformer_mochi.py
index db65c03292..931b5874ee 100644
--- a/tests/models/transformers/test_models_transformer_mochi.py
+++ b/tests/models/transformers/test_models_transformer_mochi.py
@@ -18,8 +18,8 @@ import unittest
import torch
from diffusers import MochiTransformer3DModel
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_omnigen.py b/tests/models/transformers/test_models_transformer_omnigen.py
index 25f25a8d63..f1963ddb77 100644
--- a/tests/models/transformers/test_models_transformer_omnigen.py
+++ b/tests/models/transformers/test_models_transformer_omnigen.py
@@ -18,8 +18,8 @@ import unittest
import torch
from diffusers import OmniGenTransformer2DModel
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py
new file mode 100644
index 0000000000..b24fa90503
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_qwenimage.py
@@ -0,0 +1,106 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import pytest
+import torch
+
+from diffusers import QwenImageTransformer2DModel
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
+
+
+enable_full_determinism()
+
+
+class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = QwenImageTransformer2DModel
+ main_input_name = "hidden_states"
+ # We override the items here because the transformer under consideration is small.
+ model_split_percents = [0.7, 0.6, 0.6]
+
+ # Skip setting testing with default: AttnProcessor
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ return self.prepare_dummy_input()
+
+ @property
+ def input_shape(self):
+ return (16, 16)
+
+ @property
+ def output_shape(self):
+ return (16, 16)
+
+ def prepare_dummy_input(self, height=4, width=4):
+ batch_size = 1
+ num_latent_channels = embedding_dim = 16
+ sequence_length = 7
+ vae_scale_factor = 4
+
+ hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
+ timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
+ orig_height = height * 2 * vae_scale_factor
+ orig_width = width * 2 * vae_scale_factor
+ img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "encoder_hidden_states_mask": encoder_hidden_states_mask,
+ "timestep": timestep,
+ "img_shapes": img_shapes,
+ "txt_seq_lens": encoder_hidden_states_mask.sum(dim=1).tolist(),
+ }
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": 2,
+ "in_channels": 16,
+ "out_channels": 4,
+ "num_layers": 2,
+ "attention_head_dim": 16,
+ "num_attention_heads": 3,
+ "joint_attention_dim": 16,
+ "guidance_embeds": False,
+ "axes_dims_rope": (8, 4, 4),
+ }
+
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"QwenImageTransformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = QwenImageTransformer2DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
+
+ def prepare_dummy_input(self, height, width):
+ return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
+
+ @pytest.mark.xfail(condition=True, reason="RoPE needs to be revisited.", strict=True)
+ def test_torch_compile_recompilation_and_graph_break(self):
+ super().test_torch_compile_recompilation_and_graph_break()
diff --git a/tests/models/transformers/test_models_transformer_sana.py b/tests/models/transformers/test_models_transformer_sana.py
index 6586af0e82..2e316c3aed 100644
--- a/tests/models/transformers/test_models_transformer_sana.py
+++ b/tests/models/transformers/test_models_transformer_sana.py
@@ -17,11 +17,11 @@ import unittest
import torch
from diffusers import SanaTransformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_sd3.py b/tests/models/transformers/test_models_transformer_sd3.py
index 10469c0ca9..c4ee7017a3 100644
--- a/tests/models/transformers/test_models_transformer_sd3.py
+++ b/tests/models/transformers/test_models_transformer_sd3.py
@@ -19,11 +19,11 @@ import torch
from diffusers import SD3Transformer2DModel
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_skyreels_v2.py b/tests/models/transformers/test_models_transformer_skyreels_v2.py
index 884f168308..8c36d8256e 100644
--- a/tests/models/transformers/test_models_transformer_skyreels_v2.py
+++ b/tests/models/transformers/test_models_transformer_skyreels_v2.py
@@ -17,11 +17,11 @@ import unittest
import torch
from diffusers import SkyReelsV2Transformer3DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_temporal.py b/tests/models/transformers/test_models_transformer_temporal.py
index 183ef22982..aff83be511 100644
--- a/tests/models/transformers/test_models_transformer_temporal.py
+++ b/tests/models/transformers/test_models_transformer_temporal.py
@@ -18,11 +18,11 @@ import unittest
import torch
from diffusers.models.transformers import TransformerTemporalModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin
diff --git a/tests/models/transformers/test_models_transformer_wan.py b/tests/models/transformers/test_models_transformer_wan.py
index 932f255984..9f248f990c 100644
--- a/tests/models/transformers/test_models_transformer_wan.py
+++ b/tests/models/transformers/test_models_transformer_wan.py
@@ -17,11 +17,11 @@ import unittest
import torch
from diffusers import WanTransformer3DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
diff --git a/tests/models/unets/test_models_unet_1d.py b/tests/models/unets/test_models_unet_1d.py
index e3dd608a25..bac017e7e7 100644
--- a/tests/models/unets/test_models_unet_1d.py
+++ b/tests/models/unets/test_models_unet_1d.py
@@ -19,13 +19,13 @@ import pytest
import torch
from diffusers import UNet1DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_manual_seed,
floats_tensor,
slow,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py
index f6fa82aeb7..e289f44303 100644
--- a/tests/models/unets/test_models_unet_2d.py
+++ b/tests/models/unets/test_models_unet_2d.py
@@ -21,7 +21,8 @@ import torch
from diffusers import UNet2DModel
from diffusers.utils import logging
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -30,7 +31,6 @@ from diffusers.utils.testing_utils import (
torch_all_close,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py
index 123dff16f8..4dbb8ca7c0 100644
--- a/tests/models/unets/test_models_unet_2d_condition.py
+++ b/tests/models/unets/test_models_unet_2d_condition.py
@@ -34,7 +34,8 @@ from diffusers.models.attention_processor import (
from diffusers.models.embeddings import ImageProjection, IPAdapterFaceIDImageProjection, IPAdapterPlusImageProjection
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -51,7 +52,6 @@ from diffusers.utils.testing_utils import (
torch_all_close,
torch_device,
)
-
from ..test_modeling_common import (
LoraHotSwappingForModelTesterMixin,
ModelTesterMixin,
diff --git a/tests/models/unets/test_models_unet_2d_flax.py b/tests/models/unets/test_models_unet_2d_flax.py
deleted file mode 100644
index 69a0704dca..0000000000
--- a/tests/models/unets/test_models_unet_2d_flax.py
+++ /dev/null
@@ -1,104 +0,0 @@
-import gc
-import unittest
-
-from parameterized import parameterized
-
-from diffusers import FlaxUNet2DConditionModel
-from diffusers.utils import is_flax_available
-from diffusers.utils.testing_utils import load_hf_numpy, require_flax, slow
-
-
-if is_flax_available():
- import jax
- import jax.numpy as jnp
-
-
-@slow
-@require_flax
-class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase):
- def get_file_format(self, seed, shape):
- return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
-
- def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
- dtype = jnp.bfloat16 if fp16 else jnp.float32
- image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype)
- return image
-
- def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
- dtype = jnp.bfloat16 if fp16 else jnp.float32
- revision = "bf16" if fp16 else None
-
- model, params = FlaxUNet2DConditionModel.from_pretrained(
- model_id, subfolder="unet", dtype=dtype, revision=revision
- )
- return model, params
-
- def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
- dtype = jnp.bfloat16 if fp16 else jnp.float32
- hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype)
- return hidden_states
-
- @parameterized.expand(
- [
- # fmt: off
- [83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]],
- [17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]],
- [8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]],
- [3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]],
- # fmt: on
- ]
- )
- def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice):
- model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True)
- latents = self.get_latents(seed, fp16=True)
- encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
-
- sample = model.apply(
- {"params": params},
- latents,
- jnp.array(timestep, dtype=jnp.int32),
- encoder_hidden_states=encoder_hidden_states,
- ).sample
-
- assert sample.shape == latents.shape
-
- output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32)
- expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32)
-
- # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware
- assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2)
-
- @parameterized.expand(
- [
- # fmt: off
- [83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]],
- [17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]],
- [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]],
- [3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]],
- # fmt: on
- ]
- )
- def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice):
- model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
- latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
- encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)
-
- sample = model.apply(
- {"params": params},
- latents,
- jnp.array(timestep, dtype=jnp.int32),
- encoder_hidden_states=encoder_hidden_states,
- ).sample
-
- assert sample.shape == latents.shape
-
- output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32)
- expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32)
-
- # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware
- assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2)
diff --git a/tests/models/unets/test_models_unet_3d_condition.py b/tests/models/unets/test_models_unet_3d_condition.py
index 72d692b6e7..f73e3461c3 100644
--- a/tests/models/unets/test_models_unet_3d_condition.py
+++ b/tests/models/unets/test_models_unet_3d_condition.py
@@ -21,8 +21,8 @@ import torch
from diffusers.models import ModelMixin, UNet3DConditionModel
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device
+from ...testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py
index cebd18c10d..40773536df 100644
--- a/tests/models/unets/test_models_unet_controlnetxs.py
+++ b/tests/models/unets/test_models_unet_controlnetxs.py
@@ -21,8 +21,8 @@ from torch import nn
from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
from diffusers.utils import logging
-from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, is_flaky, torch_device
+from ...testing_utils import enable_full_determinism, floats_tensor, is_flaky, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/unets/test_models_unet_motion.py b/tests/models/unets/test_models_unet_motion.py
index bf8d6bd007..d931b345fd 100644
--- a/tests/models/unets/test_models_unet_motion.py
+++ b/tests/models/unets/test_models_unet_motion.py
@@ -24,12 +24,12 @@ import torch
from diffusers import MotionAdapter, UNet2DConditionModel, UNetMotionModel
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/unets/test_models_unet_spatiotemporal.py b/tests/models/unets/test_models_unet_spatiotemporal.py
index 86aa0f6a0e..7df868c9e9 100644
--- a/tests/models/unets/test_models_unet_spatiotemporal.py
+++ b/tests/models/unets/test_models_unet_spatiotemporal.py
@@ -21,13 +21,13 @@ import torch
from diffusers import UNetSpatioTemporalConditionModel
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
skip_mps,
torch_device,
)
-
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
diff --git a/tests/models/unets/test_unet_2d_blocks.py b/tests/models/unets/test_unet_2d_blocks.py
index 21c0c0f08b..5c006963e3 100644
--- a/tests/models/unets/test_unet_2d_blocks.py
+++ b/tests/models/unets/test_unet_2d_blocks.py
@@ -15,8 +15,8 @@
import unittest
from diffusers.models.unets.unet_2d_blocks import * # noqa F403
-from diffusers.utils.testing_utils import torch_device
+from ...testing_utils import torch_device
from .test_unet_blocks_common import UNetBlockTesterMixin
diff --git a/tests/models/unets/test_unet_blocks_common.py b/tests/models/unets/test_unet_blocks_common.py
index ada7c83269..85f9bf8353 100644
--- a/tests/models/unets/test_unet_blocks_common.py
+++ b/tests/models/unets/test_unet_blocks_common.py
@@ -16,14 +16,15 @@ from typing import Tuple
import torch
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
floats_tensor,
require_torch,
require_torch_accelerator_with_training,
torch_all_close,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
@require_torch
diff --git a/tests/modular_pipelines/__init__.py b/tests/modular_pipelines/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/modular_pipelines/stable_diffusion_xl/__init__.py b/tests/modular_pipelines/stable_diffusion_xl/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py
new file mode 100644
index 0000000000..d05f818135
--- /dev/null
+++ b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py
@@ -0,0 +1,462 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import unittest
+from typing import Any, Dict
+
+import numpy as np
+import torch
+from PIL import Image
+
+from diffusers import (
+ ClassifierFreeGuidance,
+ StableDiffusionXLAutoBlocks,
+ StableDiffusionXLModularPipeline,
+)
+from diffusers.loaders import ModularIPAdapterMixin
+
+from ...models.unets.test_models_unet_2d_condition import (
+ create_ip_adapter_state_dict,
+)
+from ...testing_utils import (
+ enable_full_determinism,
+ floats_tensor,
+ torch_device,
+)
+from ..test_modular_pipelines_common import (
+ ModularPipelineTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class SDXLModularTests:
+ """
+ This mixin defines method to create pipeline, base input and base test across all SDXL modular tests.
+ """
+
+ pipeline_class = StableDiffusionXLModularPipeline
+ pipeline_blocks_class = StableDiffusionXLAutoBlocks
+ repo = "hf-internal-testing/tiny-sdxl-modular"
+ params = frozenset(
+ [
+ "prompt",
+ "height",
+ "width",
+ "negative_prompt",
+ "cross_attention_kwargs",
+ "image",
+ "mask_image",
+ ]
+ )
+ batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
+
+ def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
+ pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
+ pipeline.load_components(torch_dtype=torch_dtype)
+ return pipeline
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "output_type": "np",
+ }
+ return inputs
+
+ def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ sd_pipe = self.get_pipeline()
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs, output="images")
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == expected_image_shape
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff, (
+ "Image Slice does not match expected slice"
+ )
+
+
+class SDXLModularIPAdapterTests:
+ """
+ This mixin is designed to test IP Adapter.
+ """
+
+ def test_pipeline_inputs_and_blocks(self):
+ blocks = self.pipeline_blocks_class()
+ parameters = blocks.input_names
+
+ assert issubclass(self.pipeline_class, ModularIPAdapterMixin)
+ assert "ip_adapter_image" in parameters, (
+ "`ip_adapter_image` argument must be supported by the `__call__` method"
+ )
+ assert "ip_adapter" in blocks.sub_blocks, "pipeline must contain an IPAdapter block"
+
+ _ = blocks.sub_blocks.pop("ip_adapter")
+ parameters = blocks.input_names
+ assert "ip_adapter_image" not in parameters, (
+ "`ip_adapter_image` argument must be removed from the `__call__` method"
+ )
+
+ def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
+ return torch.randn((1, 1, cross_attention_dim), device=torch_device)
+
+ def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32):
+ return torch.randn((1, 1, 1, cross_attention_dim), device=torch_device)
+
+ def _get_dummy_masks(self, input_size: int = 64):
+ _masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
+ _masks[0, :, :, : int(input_size / 2)] = 1
+ return _masks
+
+ def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
+ blocks = self.pipeline_blocks_class()
+ _ = blocks.sub_blocks.pop("ip_adapter")
+ parameters = blocks.input_names
+ if "image" in parameters and "strength" in parameters:
+ inputs["num_inference_steps"] = 4
+
+ inputs["output_type"] = "np"
+ return inputs
+
+ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
+ r"""Tests for IP-Adapter.
+
+ The following scenarios are tested:
+ - Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
+ - Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter.
+ - Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
+ - Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
+ """
+ # Raising the tolerance for this test when it's run on a CPU because we
+ # compare against static slices and that can be shaky (with a VVVV low probability).
+ expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
+
+ blocks = self.pipeline_blocks_class()
+ _ = blocks.sub_blocks.pop("ip_adapter")
+ pipe = blocks.init_pipeline(self.repo)
+ pipe.load_components(torch_dtype=torch.float32)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ cross_attention_dim = pipe.unet.config.get("cross_attention_dim")
+
+ # forward pass without ip adapter
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ if expected_pipe_slice is None:
+ output_without_adapter = pipe(**inputs, output="images")
+ else:
+ output_without_adapter = expected_pipe_slice
+
+ # 1. Single IP-Adapter test cases
+ adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
+ pipe.unet._load_ip_adapter_weights(adapter_state_dict)
+
+ # forward pass with single ip adapter, but scale=0 which should have no effect
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
+ inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
+ pipe.set_ip_adapter_scale(0.0)
+ output_without_adapter_scale = pipe(**inputs, output="images")
+ if expected_pipe_slice is not None:
+ output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
+
+ # forward pass with single ip adapter, but with scale of adapter weights
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
+ inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
+ pipe.set_ip_adapter_scale(42.0)
+ output_with_adapter_scale = pipe(**inputs, output="images")
+ if expected_pipe_slice is not None:
+ output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
+
+ max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
+ max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
+
+ assert max_diff_without_adapter_scale < expected_max_diff, (
+ "Output without ip-adapter must be same as normal inference"
+ )
+ assert max_diff_with_adapter_scale > 1e-2, "Output with ip-adapter must be different from normal inference"
+
+ # 2. Multi IP-Adapter test cases
+ adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet)
+ adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet)
+ pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
+
+ # forward pass with multi ip adapter, but scale=0 which should have no effect
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
+ inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
+ pipe.set_ip_adapter_scale([0.0, 0.0])
+ output_without_multi_adapter_scale = pipe(**inputs, output="images")
+ if expected_pipe_slice is not None:
+ output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten()
+
+ # forward pass with multi ip adapter, but with scale of adapter weights
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
+ inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
+ pipe.set_ip_adapter_scale([42.0, 42.0])
+ output_with_multi_adapter_scale = pipe(**inputs, output="images")
+ if expected_pipe_slice is not None:
+ output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten()
+
+ max_diff_without_multi_adapter_scale = np.abs(
+ output_without_multi_adapter_scale - output_without_adapter
+ ).max()
+ max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max()
+ assert max_diff_without_multi_adapter_scale < expected_max_diff, (
+ "Output without multi-ip-adapter must be same as normal inference"
+ )
+ assert max_diff_with_multi_adapter_scale > 1e-2, (
+ "Output with multi-ip-adapter scale must be different from normal inference"
+ )
+
+
+class SDXLModularControlNetTests:
+ """
+ This mixin is designed to test ControlNet.
+ """
+
+ def test_pipeline_inputs(self):
+ blocks = self.pipeline_blocks_class()
+ parameters = blocks.input_names
+
+ assert "control_image" in parameters, "`control_image` argument must be supported by the `__call__` method"
+ assert "controlnet_conditioning_scale" in parameters, (
+ "`controlnet_conditioning_scale` argument must be supported by the `__call__` method"
+ )
+
+ def _modify_inputs_for_controlnet_test(self, inputs: Dict[str, Any]):
+ controlnet_embedder_scale_factor = 2
+ image = torch.randn(
+ (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor),
+ device=torch_device,
+ )
+ inputs["control_image"] = image
+ return inputs
+
+ def test_controlnet(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
+ r"""Tests for ControlNet.
+
+ The following scenarios are tested:
+ - Single ControlNet with scale=0 should produce same output as no ControlNet.
+ - Single ControlNet with scale!=0 should produce different output compared to no ControlNet.
+ """
+ # Raising the tolerance for this test when it's run on a CPU because we
+ # compare against static slices and that can be shaky (with a VVVV low probability).
+ expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
+
+ pipe = self.get_pipeline()
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ # forward pass without controlnet
+ inputs = self.get_dummy_inputs(torch_device)
+ output_without_controlnet = pipe(**inputs, output="images")
+ output_without_controlnet = output_without_controlnet[0, -3:, -3:, -1].flatten()
+
+ # forward pass with single controlnet, but scale=0 which should have no effect
+ inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
+ inputs["controlnet_conditioning_scale"] = 0.0
+ output_without_controlnet_scale = pipe(**inputs, output="images")
+ output_without_controlnet_scale = output_without_controlnet_scale[0, -3:, -3:, -1].flatten()
+
+ # forward pass with single controlnet, but with scale of adapter weights
+ inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
+ inputs["controlnet_conditioning_scale"] = 42.0
+ output_with_controlnet_scale = pipe(**inputs, output="images")
+ output_with_controlnet_scale = output_with_controlnet_scale[0, -3:, -3:, -1].flatten()
+
+ max_diff_without_controlnet_scale = np.abs(output_without_controlnet_scale - output_without_controlnet).max()
+ max_diff_with_controlnet_scale = np.abs(output_with_controlnet_scale - output_without_controlnet).max()
+
+ assert max_diff_without_controlnet_scale < expected_max_diff, (
+ "Output without controlnet must be same as normal inference"
+ )
+ assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference"
+
+ def test_controlnet_cfg(self):
+ pipe = self.get_pipeline()
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ # forward pass with CFG not applied
+ guider = ClassifierFreeGuidance(guidance_scale=1.0)
+ pipe.update_components(guider=guider)
+
+ inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
+ out_no_cfg = pipe(**inputs, output="images")
+
+ # forward pass with CFG applied
+ guider = ClassifierFreeGuidance(guidance_scale=7.5)
+ pipe.update_components(guider=guider)
+ inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
+ out_cfg = pipe(**inputs, output="images")
+
+ assert out_cfg.shape == out_no_cfg.shape
+ max_diff = np.abs(out_cfg - out_no_cfg).max()
+ assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
+
+
+class SDXLModularGuiderTests:
+ def test_guider_cfg(self):
+ pipe = self.get_pipeline()
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ # forward pass with CFG not applied
+ guider = ClassifierFreeGuidance(guidance_scale=1.0)
+ pipe.update_components(guider=guider)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ out_no_cfg = pipe(**inputs, output="images")
+
+ # forward pass with CFG applied
+ guider = ClassifierFreeGuidance(guidance_scale=7.5)
+ pipe.update_components(guider=guider)
+ inputs = self.get_dummy_inputs(torch_device)
+ out_cfg = pipe(**inputs, output="images")
+
+ assert out_cfg.shape == out_no_cfg.shape
+ max_diff = np.abs(out_cfg - out_no_cfg).max()
+ assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
+
+
+class SDXLModularPipelineFastTests(
+ SDXLModularTests,
+ SDXLModularIPAdapterTests,
+ SDXLModularControlNetTests,
+ SDXLModularGuiderTests,
+ ModularPipelineTesterMixin,
+ unittest.TestCase,
+):
+ """Test cases for Stable Diffusion XL modular pipeline fast tests."""
+
+ def test_stable_diffusion_xl_euler(self):
+ self._test_stable_diffusion_xl_euler(
+ expected_image_shape=(1, 64, 64, 3),
+ expected_slice=[
+ 0.5966781,
+ 0.62939394,
+ 0.48465094,
+ 0.51573336,
+ 0.57593524,
+ 0.47035995,
+ 0.53410417,
+ 0.51436996,
+ 0.47313565,
+ ],
+ expected_max_diff=1e-2,
+ )
+
+ def test_inference_batch_single_identical(self):
+ super().test_inference_batch_single_identical(expected_max_diff=3e-3)
+
+
+class SDXLImg2ImgModularPipelineFastTests(
+ SDXLModularTests,
+ SDXLModularIPAdapterTests,
+ SDXLModularControlNetTests,
+ SDXLModularGuiderTests,
+ ModularPipelineTesterMixin,
+ unittest.TestCase,
+):
+ """Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests."""
+
+ def get_dummy_inputs(self, device, seed=0):
+ inputs = super().get_dummy_inputs(device, seed)
+ image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
+ image = image / 2 + 0.5
+ inputs["image"] = image
+ inputs["strength"] = 0.8
+
+ return inputs
+
+ def test_stable_diffusion_xl_euler(self):
+ self._test_stable_diffusion_xl_euler(
+ expected_image_shape=(1, 64, 64, 3),
+ expected_slice=[
+ 0.56943184,
+ 0.4702148,
+ 0.48048905,
+ 0.6235963,
+ 0.551138,
+ 0.49629188,
+ 0.60031277,
+ 0.5688907,
+ 0.43996853,
+ ],
+ expected_max_diff=1e-2,
+ )
+
+ def test_inference_batch_single_identical(self):
+ super().test_inference_batch_single_identical(expected_max_diff=3e-3)
+
+
+class SDXLInpaintingModularPipelineFastTests(
+ SDXLModularTests,
+ SDXLModularIPAdapterTests,
+ SDXLModularControlNetTests,
+ SDXLModularGuiderTests,
+ ModularPipelineTesterMixin,
+ unittest.TestCase,
+):
+ """Test cases for Stable Diffusion XL inpainting modular pipeline fast tests."""
+
+ def get_dummy_inputs(self, device, seed=0):
+ inputs = super().get_dummy_inputs(device, seed)
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
+ image = image.cpu().permute(0, 2, 3, 1)[0]
+ init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
+ # create mask
+ image[8:, 8:, :] = 255
+ mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64))
+
+ inputs["image"] = init_image
+ inputs["mask_image"] = mask_image
+ inputs["strength"] = 1.0
+
+ return inputs
+
+ def test_stable_diffusion_xl_euler(self):
+ self._test_stable_diffusion_xl_euler(
+ expected_image_shape=(1, 64, 64, 3),
+ expected_slice=[
+ 0.40872607,
+ 0.38842705,
+ 0.34893104,
+ 0.47837183,
+ 0.43792963,
+ 0.5332134,
+ 0.3716843,
+ 0.47274873,
+ 0.45000193,
+ ],
+ expected_max_diff=1e-2,
+ )
+
+ def test_inference_batch_single_identical(self):
+ super().test_inference_batch_single_identical(expected_max_diff=3e-3)
diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py
new file mode 100644
index 0000000000..d309fcf353
--- /dev/null
+++ b/tests/modular_pipelines/test_modular_pipelines_common.py
@@ -0,0 +1,359 @@
+import gc
+import tempfile
+import unittest
+from typing import Callable, Union
+
+import numpy as np
+import torch
+
+import diffusers
+from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
+from diffusers.utils import logging
+
+from ..testing_utils import (
+ backend_empty_cache,
+ numpy_cosine_similarity_distance,
+ require_accelerator,
+ require_torch,
+ torch_device,
+)
+
+
+def to_np(tensor):
+ if isinstance(tensor, torch.Tensor):
+ tensor = tensor.detach().cpu().numpy()
+
+ return tensor
+
+
+@require_torch
+class ModularPipelineTesterMixin:
+ """
+ This mixin is designed to be used with unittest.TestCase classes.
+ It provides a set of common tests for each modular pipeline,
+ including:
+ - test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters
+ - test_inference_batch_consistent: check if the pipeline's __call__ method can handle batch inputs
+ - test_inference_batch_single_identical: check if the pipeline's __call__ method can handle single input
+ - test_float16_inference: check if the pipeline's __call__ method can handle float16 inputs
+ - test_to_device: check if the pipeline's __call__ method can handle different devices
+ """
+
+ # Canonical parameters that are passed to `__call__` regardless
+ # of the type of pipeline. They are always optional and have common
+ # sense default values.
+ optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "num_images_per_prompt",
+ "latents",
+ "output_type",
+ ]
+ )
+ # this is modular specific: generator needs to be a intermediate input because it's mutable
+ intermediate_params = frozenset(
+ [
+ "generator",
+ ]
+ )
+
+ def get_generator(self, seed):
+ device = torch_device if torch_device != "mps" else "cpu"
+ generator = torch.Generator(device).manual_seed(seed)
+ return generator
+
+ @property
+ def pipeline_class(self) -> Union[Callable, ModularPipeline]:
+ raise NotImplementedError(
+ "You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. "
+ "See existing pipeline tests for reference."
+ )
+
+ @property
+ def repo(self) -> str:
+ raise NotImplementedError(
+ "You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference."
+ )
+
+ @property
+ def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]:
+ raise NotImplementedError(
+ "You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. "
+ "See existing pipeline tests for reference."
+ )
+
+ def get_pipeline(self):
+ raise NotImplementedError(
+ "You need to implement `get_pipeline(self)` in the child test class. "
+ "See existing pipeline tests for reference."
+ )
+
+ def get_dummy_inputs(self, device, seed=0):
+ raise NotImplementedError(
+ "You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
+ "See existing pipeline tests for reference."
+ )
+
+ @property
+ def params(self) -> frozenset:
+ raise NotImplementedError(
+ "You need to set the attribute `params` in the child test class. "
+ "`params` are checked for if all values are present in `__call__`'s signature."
+ " You can set `params` using one of the common set of parameters defined in `pipeline_params.py`"
+ " e.g., `TEXT_TO_IMAGE_PARAMS` defines the common parameters used in text to "
+ "image pipelines, including prompts and prompt embedding overrides."
+ "If your pipeline's set of arguments has minor changes from one of the common sets of arguments, "
+ "do not make modifications to the existing common sets of arguments. I.e. a text to image pipeline "
+ "with non-configurable height and width arguments should set the attribute as "
+ "`params = TEXT_TO_IMAGE_PARAMS - {'height', 'width'}`. "
+ "See existing pipeline tests for reference."
+ )
+
+ @property
+ def batch_params(self) -> frozenset:
+ raise NotImplementedError(
+ "You need to set the attribute `batch_params` in the child test class. "
+ "`batch_params` are the parameters required to be batched when passed to the pipeline's "
+ "`__call__` method. `pipeline_params.py` provides some common sets of parameters such as "
+ "`TEXT_TO_IMAGE_BATCH_PARAMS`, `IMAGE_VARIATION_BATCH_PARAMS`, etc... If your pipeline's "
+ "set of batch arguments has minor changes from one of the common sets of batch arguments, "
+ "do not make modifications to the existing common sets of batch arguments. I.e. a text to "
+ "image pipeline `negative_prompt` is not batched should set the attribute as "
+ "`batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {'negative_prompt'}`. "
+ "See existing pipeline tests for reference."
+ )
+
+ def setUp(self):
+ # clean up the VRAM before each test
+ super().setUp()
+ torch.compiler.reset()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ # clean up the VRAM after each test in case of CUDA runtime errors
+ super().tearDown()
+ torch.compiler.reset()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_pipeline_call_signature(self):
+ pipe = self.get_pipeline()
+ input_parameters = pipe.blocks.input_names
+ optional_parameters = pipe.default_call_parameters
+
+ def _check_for_parameters(parameters, expected_parameters, param_type):
+ remaining_parameters = {param for param in parameters if param not in expected_parameters}
+ assert len(remaining_parameters) == 0, (
+ f"Required {param_type} parameters not present: {remaining_parameters}"
+ )
+
+ _check_for_parameters(self.params, input_parameters, "input")
+ _check_for_parameters(self.optional_params, optional_parameters, "optional")
+
+ def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
+ pipe = self.get_pipeline()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["generator"] = self.get_generator(0)
+
+ logger = logging.get_logger(pipe.__module__)
+ logger.setLevel(level=diffusers.logging.FATAL)
+
+ # prepare batched inputs
+ batched_inputs = []
+ for batch_size in batch_sizes:
+ batched_input = {}
+ batched_input.update(inputs)
+
+ for name in self.batch_params:
+ if name not in inputs:
+ continue
+
+ value = inputs[name]
+ batched_input[name] = batch_size * [value]
+
+ if batch_generator and "generator" in inputs:
+ batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)]
+
+ if "batch_size" in inputs:
+ batched_input["batch_size"] = batch_size
+
+ batched_inputs.append(batched_input)
+
+ logger.setLevel(level=diffusers.logging.WARNING)
+ for batch_size, batched_input in zip(batch_sizes, batched_inputs):
+ output = pipe(**batched_input, output="images")
+ assert len(output) == batch_size, "Output is different from expected batch size"
+
+ def test_inference_batch_single_identical(
+ self,
+ batch_size=2,
+ expected_max_diff=1e-4,
+ ):
+ pipe = self.get_pipeline()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Reset generator in case it is has been used in self.get_dummy_inputs
+ inputs["generator"] = self.get_generator(0)
+
+ logger = logging.get_logger(pipe.__module__)
+ logger.setLevel(level=diffusers.logging.FATAL)
+
+ # batchify inputs
+ batched_inputs = {}
+ batched_inputs.update(inputs)
+
+ for name in self.batch_params:
+ if name not in inputs:
+ continue
+
+ value = inputs[name]
+ batched_inputs[name] = batch_size * [value]
+
+ if "generator" in inputs:
+ batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
+
+ if "batch_size" in inputs:
+ batched_inputs["batch_size"] = batch_size
+
+ output = pipe(**inputs, output="images")
+ output_batch = pipe(**batched_inputs, output="images")
+
+ assert output_batch.shape[0] == batch_size
+
+ max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max()
+ assert max_diff < expected_max_diff, "Batch inference results different from single inference results"
+
+ @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
+ @require_accelerator
+ def test_float16_inference(self, expected_max_diff=5e-2):
+ pipe = self.get_pipeline()
+ pipe.to(torch_device, torch.float32)
+ pipe.set_progress_bar_config(disable=None)
+
+ pipe_fp16 = self.get_pipeline()
+ pipe_fp16.to(torch_device, torch.float16)
+ pipe_fp16.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ # Reset generator in case it is used inside dummy inputs
+ if "generator" in inputs:
+ inputs["generator"] = self.get_generator(0)
+ output = pipe(**inputs, output="images")
+
+ fp16_inputs = self.get_dummy_inputs(torch_device)
+ # Reset generator in case it is used inside dummy inputs
+ if "generator" in fp16_inputs:
+ fp16_inputs["generator"] = self.get_generator(0)
+ output_fp16 = pipe_fp16(**fp16_inputs, output="images")
+
+ if isinstance(output, torch.Tensor):
+ output = output.cpu()
+ output_fp16 = output_fp16.cpu()
+
+ max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
+ assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"
+
+ @require_accelerator
+ def test_to_device(self):
+ pipe = self.get_pipeline()
+ pipe.set_progress_bar_config(disable=None)
+
+ pipe.to("cpu")
+ model_devices = [
+ component.device.type for component in pipe.components.values() if hasattr(component, "device")
+ ]
+ assert all(device == "cpu" for device in model_devices), "All pipeline components are not on CPU"
+
+ pipe.to(torch_device)
+ model_devices = [
+ component.device.type for component in pipe.components.values() if hasattr(component, "device")
+ ]
+ assert all(device == torch_device for device in model_devices), (
+ "All pipeline components are not on accelerator device"
+ )
+
+ def test_inference_is_not_nan_cpu(self):
+ pipe = self.get_pipeline()
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to("cpu")
+
+ output = pipe(**self.get_dummy_inputs("cpu"), output="images")
+ assert np.isnan(to_np(output)).sum() == 0, "CPU Inference returns NaN"
+
+ @require_accelerator
+ def test_inference_is_not_nan(self):
+ pipe = self.get_pipeline()
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ output = pipe(**self.get_dummy_inputs(torch_device), output="images")
+ assert np.isnan(to_np(output)).sum() == 0, "Accelerator Inference returns NaN"
+
+ def test_num_images_per_prompt(self):
+ pipe = self.get_pipeline()
+
+ if "num_images_per_prompt" not in pipe.blocks.input_names:
+ return
+
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ batch_sizes = [1, 2]
+ num_images_per_prompts = [1, 2]
+
+ for batch_size in batch_sizes:
+ for num_images_per_prompt in num_images_per_prompts:
+ inputs = self.get_dummy_inputs(torch_device)
+
+ for key in inputs.keys():
+ if key in self.batch_params:
+ inputs[key] = batch_size * [inputs[key]]
+
+ images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt, output="images")
+
+ assert images.shape[0] == batch_size * num_images_per_prompt
+
+ @require_accelerator
+ def test_components_auto_cpu_offload_inference_consistent(self):
+ base_pipe = self.get_pipeline().to(torch_device)
+
+ cm = ComponentsManager()
+ cm.enable_auto_cpu_offload(device=torch_device)
+ offload_pipe = self.get_pipeline(components_manager=cm)
+
+ image_slices = []
+ for pipe in [base_pipe, offload_pipe]:
+ inputs = self.get_dummy_inputs(torch_device)
+ image = pipe(**inputs, output="images")
+
+ image_slices.append(image[0, -3:, -3:, -1].flatten())
+
+ assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
+
+ def test_save_from_pretrained(self):
+ pipes = []
+ base_pipe = self.get_pipeline().to(torch_device)
+ pipes.append(base_pipe)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ base_pipe.save_pretrained(tmpdirname)
+ pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
+ pipe.load_components(torch_dtype=torch.float32)
+ pipe.to(torch_device)
+
+ pipes.append(pipe)
+
+ image_slices = []
+ for pipe in pipes:
+ inputs = self.get_dummy_inputs(torch_device)
+ image = pipe(**inputs, output="images")
+
+ image_slices.append(image[0, -3:, -3:, -1].flatten())
+
+ assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
diff --git a/tests/others/__init__.py b/tests/others/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/others/test_config.py b/tests/others/test_config.py
index a8f93024f7..232bf9d473 100644
--- a/tests/others/test_config.py
+++ b/tests/others/test_config.py
@@ -28,7 +28,8 @@ from diffusers import (
logging,
)
from diffusers.configuration_utils import ConfigMixin, register_to_config
-from diffusers.utils.testing_utils import CaptureLogger
+
+from ..testing_utils import CaptureLogger
class SampleObject(ConfigMixin):
diff --git a/tests/others/test_dependencies.py b/tests/others/test_dependencies.py
index a08129a1e9..db22f10c4b 100644
--- a/tests/others/test_dependencies.py
+++ b/tests/others/test_dependencies.py
@@ -39,6 +39,8 @@ class DependencyTester(unittest.TestCase):
backend = "invisible-watermark"
elif backend == "opencv":
backend = "opencv-python"
+ elif backend == "nvidia_modelopt":
+ backend = "nvidia_modelopt[hf]"
assert backend in deps, f"{backend} is not in the deps table!"
def test_pipeline_imports(self):
diff --git a/tests/others/test_ema.py b/tests/others/test_ema.py
index 14808b9e58..436bbe1d53 100644
--- a/tests/others/test_ema.py
+++ b/tests/others/test_ema.py
@@ -20,7 +20,8 @@ import torch
from diffusers import UNet2DConditionModel
from diffusers.training_utils import EMAModel
-from diffusers.utils.testing_utils import enable_full_determinism, skip_mps, torch_device
+
+from ..testing_utils import enable_full_determinism, skip_mps, torch_device
enable_full_determinism()
diff --git a/tests/others/test_outputs.py b/tests/others/test_outputs.py
index cf709d93f7..c8069e6916 100644
--- a/tests/others/test_outputs.py
+++ b/tests/others/test_outputs.py
@@ -7,7 +7,8 @@ import numpy as np
import PIL.Image
from diffusers.utils.outputs import BaseOutput
-from diffusers.utils.testing_utils import require_torch
+
+from ..testing_utils import require_torch
@dataclass
diff --git a/tests/others/test_training.py b/tests/others/test_training.py
index fb64205301..2038a98a81 100644
--- a/tests/others/test_training.py
+++ b/tests/others/test_training.py
@@ -19,7 +19,8 @@ import torch
from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel
from diffusers.training_utils import set_seed
-from diffusers.utils.testing_utils import slow
+
+from ..testing_utils import slow
torch.backends.cuda.matmul.allow_tf32 = False
diff --git a/tests/others/test_utils.py b/tests/others/test_utils.py
index 01b423f556..747b8d5840 100755
--- a/tests/others/test_utils.py
+++ b/tests/others/test_utils.py
@@ -20,7 +20,8 @@ import pytest
from diffusers import __version__
from diffusers.utils import deprecate
-from diffusers.utils.testing_utils import Expectations, str_to_bool
+
+from ..testing_utils import Expectations, str_to_bool
# Used to test the hub
diff --git a/tests/pipelines/allegro/test_allegro.py b/tests/pipelines/allegro/test_allegro.py
index c33b48e7e9..b2e588de06 100644
--- a/tests/pipelines/allegro/test_allegro.py
+++ b/tests/pipelines/allegro/test_allegro.py
@@ -23,7 +23,8 @@ import torch
from transformers import AutoTokenizer, T5Config, T5EncoderModel
from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLAllegro, DDIMScheduler
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -33,7 +34,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py
index 4088d46df5..8d4cd4cf2c 100644
--- a/tests/pipelines/animatediff/test_animatediff.py
+++ b/tests/pipelines/animatediff/test_animatediff.py
@@ -19,7 +19,8 @@ from diffusers import (
)
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import is_xformers_available, logging
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
@@ -27,7 +28,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
IPAdapterTesterMixin,
diff --git a/tests/pipelines/animatediff/test_animatediff_controlnet.py b/tests/pipelines/animatediff/test_animatediff_controlnet.py
index 7bde663b11..4b0eb01d06 100644
--- a/tests/pipelines/animatediff/test_animatediff_controlnet.py
+++ b/tests/pipelines/animatediff/test_animatediff_controlnet.py
@@ -21,8 +21,8 @@ from diffusers import (
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import require_accelerator, torch_device
+from ...testing_utils import require_accelerator, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
IPAdapterTesterMixin,
diff --git a/tests/pipelines/animatediff/test_animatediff_sdxl.py b/tests/pipelines/animatediff/test_animatediff_sdxl.py
index f9686ec005..b5dcd87796 100644
--- a/tests/pipelines/animatediff/test_animatediff_sdxl.py
+++ b/tests/pipelines/animatediff/test_animatediff_sdxl.py
@@ -14,8 +14,8 @@ from diffusers import (
UNetMotionModel,
)
from diffusers.utils import is_xformers_available, logging
-from diffusers.utils.testing_utils import require_accelerator, torch_device
+from ...testing_utils import require_accelerator, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
IPAdapterTesterMixin,
diff --git a/tests/pipelines/animatediff/test_animatediff_sparsectrl.py b/tests/pipelines/animatediff/test_animatediff_sparsectrl.py
index 3e33326c8a..6b9f672cc4 100644
--- a/tests/pipelines/animatediff/test_animatediff_sparsectrl.py
+++ b/tests/pipelines/animatediff/test_animatediff_sparsectrl.py
@@ -20,8 +20,8 @@ from diffusers import (
)
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import require_accelerator, torch_device
+from ...testing_utils import require_accelerator, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
IPAdapterTesterMixin,
diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py
index bc771e148e..1adb13dc4c 100644
--- a/tests/pipelines/animatediff/test_animatediff_video2video.py
+++ b/tests/pipelines/animatediff/test_animatediff_video2video.py
@@ -19,8 +19,8 @@ from diffusers import (
)
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import is_xformers_available, logging
-from diffusers.utils.testing_utils import require_accelerator, torch_device
+from ...testing_utils import require_accelerator, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS, VIDEO_TO_VIDEO_BATCH_PARAMS
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineFromPipeTesterMixin, PipelineTesterMixin
diff --git a/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py b/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py
index 3babbbe4ba..c71c8c8817 100644
--- a/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py
+++ b/tests/pipelines/animatediff/test_animatediff_video2video_controlnet.py
@@ -20,8 +20,8 @@ from diffusers import (
)
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import is_xformers_available, logging
-from diffusers.utils.testing_utils import require_accelerator, torch_device
+from ...testing_utils import require_accelerator, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS, VIDEO_TO_VIDEO_BATCH_PARAMS
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineFromPipeTesterMixin, PipelineTesterMixin
diff --git a/tests/pipelines/audioldm2/test_audioldm2.py b/tests/pipelines/audioldm2/test_audioldm2.py
index 12b9694567..14ff1272a2 100644
--- a/tests/pipelines/audioldm2/test_audioldm2.py
+++ b/tests/pipelines/audioldm2/test_audioldm2.py
@@ -46,14 +46,14 @@ from diffusers import (
PNDMScheduler,
)
from diffusers.utils import is_transformers_version
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
is_torch_version,
nightly,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@@ -138,10 +138,8 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
patch_stride=2,
patch_embed_input_channels=4,
)
- text_encoder_config = ClapConfig.from_text_audio_configs(
- text_config=text_branch_config,
- audio_config=audio_branch_config,
- projection_dim=16,
+ text_encoder_config = ClapConfig(
+ text_config=text_branch_config, audio_config=audio_branch_config, projection_dim=16
)
text_encoder = ClapModel(text_encoder_config)
tokenizer = RobertaTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta", model_max_length=77)
diff --git a/tests/pipelines/bria/__init__.py b/tests/pipelines/bria/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/pipelines/bria/test_pipeline_bria.py b/tests/pipelines/bria/test_pipeline_bria.py
new file mode 100644
index 0000000000..844488e76f
--- /dev/null
+++ b/tests/pipelines/bria/test_pipeline_bria.py
@@ -0,0 +1,319 @@
+# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from huggingface_hub import hf_hub_download
+from transformers import T5EncoderModel, T5TokenizerFast
+
+from diffusers import (
+ AutoencoderKL,
+ BriaTransformer2DModel,
+ FlowMatchEulerDiscreteScheduler,
+)
+from diffusers.pipelines.bria import BriaPipeline
+
+# from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
+from tests.pipelines.test_pipelines_common import PipelineTesterMixin, to_np
+
+from ...testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ numpy_cosine_similarity_distance,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
+
+
+enable_full_determinism()
+
+
+class BriaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = BriaPipeline
+ params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"])
+ batch_params = frozenset(["prompt"])
+ test_xformers_attention = False
+
+ # there is no xformers processor for Flux
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = BriaTransformer2DModel(
+ patch_size=1,
+ in_channels=16,
+ num_layers=1,
+ num_single_layers=1,
+ attention_head_dim=8,
+ num_attention_heads=2,
+ joint_attention_dim=32,
+ pooled_projection_dim=None,
+ axes_dims_rope=[0, 4, 4],
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ act_fn="silu",
+ block_out_channels=(32,),
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D"],
+ latent_channels=4,
+ sample_size=32,
+ shift_factor=0,
+ scaling_factor=0.13025,
+ use_post_quant_conv=True,
+ use_quant_conv=True,
+ force_upcast=False,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = T5TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ components = {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "transformer": transformer,
+ "vae": vae,
+ "image_encoder": None,
+ "feature_extractor": None,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "negative_prompt": "bad, ugly",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 16,
+ "width": 16,
+ "max_sequence_length": 48,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
+ def test_bria_different_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt"] = "a different prompt"
+ output_different_prompts = pipe(**inputs).images[0]
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+ assert max_diff > 1e-6
+
+ def test_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 57)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
+
+ @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
+ @require_torch_accelerator
+ def test_save_load_float16(self, expected_max_diff=1e-2):
+ components = self.get_dummy_components()
+ for name, module in components.items():
+ if hasattr(module, "half"):
+ components[name] = module.to(torch_device).half()
+
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ for name, component in pipe_loaded.components.items():
+ if name == "vae":
+ continue
+ if hasattr(component, "dtype"):
+ self.assertTrue(
+ component.dtype == torch.float16,
+ f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_loaded = pipe_loaded(**inputs)[0]
+ max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
+ self.assertLess(
+ max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
+ )
+
+ def test_bria_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(16, 16), (32, 32), (64, 64)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
+
+ def test_to_dtype(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
+ self.assertTrue([dtype == torch.float32 for dtype in model_dtypes] == [True, True, True])
+
+ def test_torch_dtype_dict(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ pipe.save_pretrained(tmpdirname)
+ torch_dtype_dict = {"transformer": torch.bfloat16, "default": torch.float16}
+ loaded_pipe = self.pipeline_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype_dict)
+
+ self.assertEqual(loaded_pipe.transformer.dtype, torch.bfloat16)
+ self.assertEqual(loaded_pipe.text_encoder.dtype, torch.float16)
+ self.assertEqual(loaded_pipe.vae.dtype, torch.float16)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ pipe.save_pretrained(tmpdirname)
+ torch_dtype_dict = {"default": torch.float16}
+ loaded_pipe = self.pipeline_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype_dict)
+
+ self.assertEqual(loaded_pipe.transformer.dtype, torch.float16)
+ self.assertEqual(loaded_pipe.text_encoder.dtype, torch.float16)
+ self.assertEqual(loaded_pipe.vae.dtype, torch.float16)
+
+
+@slow
+@require_torch_accelerator
+class BriaPipelineSlowTests(unittest.TestCase):
+ pipeline_class = BriaPipeline
+ repo_id = "briaai/BRIA-3.2"
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def get_inputs(self, device, seed=0):
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ prompt_embeds = torch.load(
+ hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt")
+ ).to(torch_device)
+
+ return {
+ "prompt_embeds": prompt_embeds,
+ "num_inference_steps": 2,
+ "guidance_scale": 0.0,
+ "max_sequence_length": 256,
+ "output_type": "np",
+ "generator": generator,
+ }
+
+ def test_bria_inference_bf16(self):
+ pipe = self.pipeline_class.from_pretrained(
+ self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, tokenizer=None
+ )
+ pipe.to(torch_device)
+
+ inputs = self.get_inputs(torch_device)
+
+ image = pipe(**inputs).images[0]
+ image_slice = image[0, :10, :10].flatten()
+
+ expected_slice = np.array(
+ [
+ 0.59729785,
+ 0.6153719,
+ 0.595112,
+ 0.5884763,
+ 0.59366125,
+ 0.5795311,
+ 0.58325,
+ 0.58449626,
+ 0.57737637,
+ 0.58432233,
+ 0.5867875,
+ 0.57824117,
+ 0.5819089,
+ 0.5830988,
+ 0.57730293,
+ 0.57647324,
+ 0.5769151,
+ 0.57312685,
+ 0.57926565,
+ 0.5823928,
+ 0.57783926,
+ 0.57162863,
+ 0.575649,
+ 0.5745547,
+ 0.5740556,
+ 0.5799735,
+ 0.57799566,
+ 0.5715559,
+ 0.5771242,
+ 0.5773058,
+ ],
+ dtype=np.float32,
+ )
+ max_diff = numpy_cosine_similarity_distance(expected_slice, image_slice)
+ self.assertLess(max_diff, 1e-4, f"Image slice is different from expected slice: {max_diff:.4f}")
diff --git a/tests/pipelines/chroma/test_pipeline_chroma.py b/tests/pipelines/chroma/test_pipeline_chroma.py
index 5121a2b52d..3edd58b75f 100644
--- a/tests/pipelines/chroma/test_pipeline_chroma.py
+++ b/tests/pipelines/chroma/test_pipeline_chroma.py
@@ -5,8 +5,8 @@ import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
-from diffusers.utils.testing_utils import torch_device
+from ...testing_utils import torch_device
from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist
diff --git a/tests/pipelines/chroma/test_pipeline_chroma_img2img.py b/tests/pipelines/chroma/test_pipeline_chroma_img2img.py
index d518e1b7b8..4ed1393037 100644
--- a/tests/pipelines/chroma/test_pipeline_chroma_img2img.py
+++ b/tests/pipelines/chroma/test_pipeline_chroma_img2img.py
@@ -6,8 +6,8 @@ import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
-from diffusers.utils.testing_utils import floats_tensor, torch_device
+from ...testing_utils import floats_tensor, torch_device
from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist
diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py
index a6cb558513..dca1725d8a 100644
--- a/tests/pipelines/cogvideo/test_cogvideox.py
+++ b/tests/pipelines/cogvideo/test_cogvideox.py
@@ -21,7 +21,8 @@ import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -29,7 +30,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
FasterCacheTesterMixin,
diff --git a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
index 685823dc06..097e8df7b3 100644
--- a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
+++ b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
@@ -21,11 +21,11 @@ from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCogVideoX, CogVideoXFunControlPipeline, CogVideoXTransformer3DModel, DDIMScheduler
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
diff --git a/tests/pipelines/cogvideo/test_cogvideox_image2video.py b/tests/pipelines/cogvideo/test_cogvideox_image2video.py
index 90f767f9a7..1dd5e2ae14 100644
--- a/tests/pipelines/cogvideo/test_cogvideox_image2video.py
+++ b/tests/pipelines/cogvideo/test_cogvideox_image2video.py
@@ -23,7 +23,8 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, DDIMScheduler
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -31,7 +32,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
diff --git a/tests/pipelines/cogvideo/test_cogvideox_video2video.py b/tests/pipelines/cogvideo/test_cogvideox_video2video.py
index ba48079e74..3a1da7c4e7 100644
--- a/tests/pipelines/cogvideo/test_cogvideox_video2video.py
+++ b/tests/pipelines/cogvideo/test_cogvideox_video2video.py
@@ -21,8 +21,8 @@ from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXVideoToVideoPipeline, DDIMScheduler
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
diff --git a/tests/pipelines/cogview3/test_cogview3plus.py b/tests/pipelines/cogview3/test_cogview3plus.py
index d868beffba..819d4b952f 100644
--- a/tests/pipelines/cogview3/test_cogview3plus.py
+++ b/tests/pipelines/cogview3/test_cogview3plus.py
@@ -21,7 +21,8 @@ import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -29,7 +30,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
diff --git a/tests/pipelines/cogview4/test_cogview4.py b/tests/pipelines/cogview4/test_cogview4.py
index 20d2afaea9..a1f0fc7a71 100644
--- a/tests/pipelines/cogview4/test_cogview4.py
+++ b/tests/pipelines/cogview4/test_cogview4.py
@@ -20,8 +20,8 @@ import torch
from transformers import AutoTokenizer, GlmConfig, GlmForCausalLM
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
diff --git a/tests/pipelines/consisid/test_consisid.py b/tests/pipelines/consisid/test_consisid.py
index 66bb7acf25..4fd9e536cd 100644
--- a/tests/pipelines/consisid/test_consisid.py
+++ b/tests/pipelines/consisid/test_consisid.py
@@ -23,7 +23,8 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCogVideoX, ConsisIDPipeline, ConsisIDTransformer3DModel, DDIMScheduler
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -31,7 +32,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
diff --git a/tests/pipelines/consistency_models/test_consistency_models.py b/tests/pipelines/consistency_models/test_consistency_models.py
index 7c7cecdfb0..0ab0c0af25 100644
--- a/tests/pipelines/consistency_models/test_consistency_models.py
+++ b/tests/pipelines/consistency_models/test_consistency_models.py
@@ -10,7 +10,9 @@ from diffusers import (
ConsistencyModelPipeline,
UNet2DModel,
)
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
Expectations,
backend_empty_cache,
enable_full_determinism,
@@ -19,8 +21,6 @@ from diffusers.utils.testing_utils import (
require_torch_accelerator,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py
index bd558a50cf..b142c2baf9 100644
--- a/tests/pipelines/controlnet/test_controlnet.py
+++ b/tests/pipelines/controlnet/test_controlnet.py
@@ -32,7 +32,9 @@ from diffusers import (
)
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -44,8 +46,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_BATCH_PARAMS,
diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py
index dd7bb002a1..c5d438e934 100644
--- a/tests/pipelines/controlnet/test_controlnet_img2img.py
+++ b/tests/pipelines/controlnet/test_controlnet_img2img.py
@@ -35,7 +35,9 @@ from diffusers import (
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -44,8 +46,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py
index c457c324c5..ebbe869e9e 100644
--- a/tests/pipelines/controlnet/test_controlnet_inpaint.py
+++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py
@@ -35,7 +35,9 @@ from diffusers import (
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -45,8 +47,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py b/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py
index ee12ce1723..c91f2c700c 100644
--- a/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py
+++ b/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py
@@ -37,13 +37,13 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
require_torch_accelerator,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_BATCH_PARAMS,
diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py
index 47d0920b74..42ec446dbf 100644
--- a/tests/pipelines/controlnet/test_controlnet_sdxl.py
+++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py
@@ -34,7 +34,9 @@ from diffusers import (
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_image,
@@ -42,8 +44,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_BATCH_PARAMS,
diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py b/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py
index 5a8dd10ad5..bd4a233741 100644
--- a/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py
+++ b/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py
@@ -28,13 +28,13 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
require_torch_accelerator,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
diff --git a/tests/pipelines/controlnet/test_flax_controlnet.py b/tests/pipelines/controlnet/test_flax_controlnet.py
deleted file mode 100644
index 07d3a09e5d..0000000000
--- a/tests/pipelines/controlnet/test_flax_controlnet.py
+++ /dev/null
@@ -1,127 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline
-from diffusers.utils import is_flax_available, load_image
-from diffusers.utils.testing_utils import require_flax, slow
-
-
-if is_flax_available():
- import jax
- import jax.numpy as jnp
- from flax.jax_utils import replicate
- from flax.training.common_utils import shard
-
-
-@slow
-@require_flax
-class FlaxControlNetPipelineIntegrationTests(unittest.TestCase):
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
-
- def test_canny(self):
- controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
- "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.bfloat16
- )
- pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16
- )
- params["controlnet"] = controlnet_params
-
- prompts = "bird"
- num_samples = jax.device_count()
- prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
-
- canny_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
- )
- processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
-
- rng = jax.random.PRNGKey(0)
- rng = jax.random.split(rng, jax.device_count())
-
- p_params = replicate(params)
- prompt_ids = shard(prompt_ids)
- processed_image = shard(processed_image)
-
- images = pipe(
- prompt_ids=prompt_ids,
- image=processed_image,
- params=p_params,
- prng_seed=rng,
- num_inference_steps=50,
- jit=True,
- ).images
- assert images.shape == (jax.device_count(), 1, 768, 512, 3)
-
- images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
- image_slice = images[0, 253:256, 253:256, -1]
-
- output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
- expected_slice = jnp.array(
- [0.167969, 0.116699, 0.081543, 0.154297, 0.132812, 0.108887, 0.169922, 0.169922, 0.205078]
- )
-
- assert jnp.abs(output_slice - expected_slice).max() < 1e-2
-
- def test_pose(self):
- controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
- "lllyasviel/sd-controlnet-openpose", from_pt=True, dtype=jnp.bfloat16
- )
- pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, from_pt=True, dtype=jnp.bfloat16
- )
- params["controlnet"] = controlnet_params
-
- prompts = "Chef in the kitchen"
- num_samples = jax.device_count()
- prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
-
- pose_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/pose.png"
- )
- processed_image = pipe.prepare_image_inputs([pose_image] * num_samples)
-
- rng = jax.random.PRNGKey(0)
- rng = jax.random.split(rng, jax.device_count())
-
- p_params = replicate(params)
- prompt_ids = shard(prompt_ids)
- processed_image = shard(processed_image)
-
- images = pipe(
- prompt_ids=prompt_ids,
- image=processed_image,
- params=p_params,
- prng_seed=rng,
- num_inference_steps=50,
- jit=True,
- ).images
- assert images.shape == (jax.device_count(), 1, 768, 512, 3)
-
- images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
- image_slice = images[0, 253:256, 253:256, -1]
-
- output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
- expected_slice = jnp.array(
- [[0.271484, 0.261719, 0.275391, 0.277344, 0.279297, 0.291016, 0.294922, 0.302734, 0.302734]]
- )
-
- assert jnp.abs(output_slice - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py
index 5b336edc7a..0895d9de35 100644
--- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py
+++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py
@@ -29,7 +29,9 @@ from diffusers import (
)
from diffusers.models import FluxControlNetModel
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
nightly,
@@ -37,8 +39,6 @@ from diffusers.utils.testing_utils import (
require_big_accelerator,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin
diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
index ab4cf32734..3d8378a578 100644
--- a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
+++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
@@ -11,11 +11,11 @@ from diffusers import (
FluxControlNetModel,
FluxTransformer2DModel,
)
-from diffusers.utils.testing_utils import (
- torch_device,
-)
from diffusers.utils.torch_utils import randn_tensor
+from ...testing_utils import (
+ torch_device,
+)
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py
index 94d97e9962..3ba475deb8 100644
--- a/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py
+++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py
@@ -20,13 +20,13 @@ from diffusers import (
FluxControlNetModel,
FluxTransformer2DModel,
)
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py b/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
index d6e7af34bd..bf31f2abcf 100644
--- a/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
+++ b/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
@@ -28,15 +28,15 @@ from diffusers import (
)
from diffusers.models import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..test_pipelines_common import PipelineTesterMixin
@@ -155,7 +155,7 @@ class HunyuanDiTControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMix
if torch_device == "xpu":
expected_slice = np.array(
- [0.6376953, 0.84375, 0.58691406, 0.48046875, 0.43652344, 0.5517578, 0.54248047, 0.5644531, 0.48217773]
+ [0.6948242, 0.89160156, 0.59375, 0.5078125, 0.57910156, 0.6035156, 0.58447266, 0.53564453, 0.52246094]
)
else:
expected_slice = np.array(
diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
index fcf8cade67..34c34b7a2c 100644
--- a/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
+++ b/tests/pipelines/controlnet_sd3/test_controlnet_inpaint_sd3.py
@@ -26,12 +26,12 @@ from diffusers import (
StableDiffusion3ControlNetInpaintingPipeline,
)
from diffusers.models import SD3ControlNetModel
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
index 1f1f800bcf..2b6cf8d1e8 100644
--- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
+++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
@@ -29,7 +29,9 @@ from diffusers import (
)
from diffusers.models import SD3ControlNetModel, SD3MultiControlNetModel
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -37,8 +39,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/cosmos/test_cosmos.py b/tests/pipelines/cosmos/test_cosmos.py
index 4d3202f785..32eea9c98c 100644
--- a/tests/pipelines/cosmos/test_cosmos.py
+++ b/tests/pipelines/cosmos/test_cosmos.py
@@ -23,8 +23,8 @@ import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCosmos, CosmosTextToWorldPipeline, CosmosTransformer3DModel, EDMEulerScheduler
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
from .cosmos_guardrail import DummyCosmosSafetyChecker
diff --git a/tests/pipelines/cosmos/test_cosmos2_text2image.py b/tests/pipelines/cosmos/test_cosmos2_text2image.py
index cc2fcec641..8e3c5e4c29 100644
--- a/tests/pipelines/cosmos/test_cosmos2_text2image.py
+++ b/tests/pipelines/cosmos/test_cosmos2_text2image.py
@@ -28,8 +28,8 @@ from diffusers import (
CosmosTransformer3DModel,
FlowMatchEulerDiscreteScheduler,
)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
from .cosmos_guardrail import DummyCosmosSafetyChecker
diff --git a/tests/pipelines/cosmos/test_cosmos2_video2world.py b/tests/pipelines/cosmos/test_cosmos2_video2world.py
index b23c8aed17..b0ca0e160d 100644
--- a/tests/pipelines/cosmos/test_cosmos2_video2world.py
+++ b/tests/pipelines/cosmos/test_cosmos2_video2world.py
@@ -29,8 +29,8 @@ from diffusers import (
CosmosTransformer3DModel,
FlowMatchEulerDiscreteScheduler,
)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
from .cosmos_guardrail import DummyCosmosSafetyChecker
diff --git a/tests/pipelines/cosmos/test_cosmos_video2world.py b/tests/pipelines/cosmos/test_cosmos_video2world.py
index d0dba5575b..2633c2007a 100644
--- a/tests/pipelines/cosmos/test_cosmos_video2world.py
+++ b/tests/pipelines/cosmos/test_cosmos_video2world.py
@@ -24,8 +24,8 @@ import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCosmos, CosmosTransformer3DModel, CosmosVideoToWorldPipeline, EDMEulerScheduler
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
from .cosmos_guardrail import DummyCosmosSafetyChecker
diff --git a/tests/pipelines/ddim/test_ddim.py b/tests/pipelines/ddim/test_ddim.py
index 57b97b4649..731635bea6 100644
--- a/tests/pipelines/ddim/test_ddim.py
+++ b/tests/pipelines/ddim/test_ddim.py
@@ -19,8 +19,8 @@ import numpy as np
import torch
from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, slow, torch_device
+from ...testing_utils import enable_full_determinism, require_torch_accelerator, slow, torch_device
from ..pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py
index 97bb53128d..04ee741d8e 100644
--- a/tests/pipelines/ddpm/test_ddpm.py
+++ b/tests/pipelines/ddpm/test_ddpm.py
@@ -19,7 +19,8 @@ import numpy as np
import torch
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, slow, torch_device
+
+from ...testing_utils import enable_full_determinism, require_torch_accelerator, slow, torch_device
enable_full_determinism()
diff --git a/tests/pipelines/deepfloyd_if/__init__.py b/tests/pipelines/deepfloyd_if/__init__.py
index 094254a618..d47374b07e 100644
--- a/tests/pipelines/deepfloyd_if/__init__.py
+++ b/tests/pipelines/deepfloyd_if/__init__.py
@@ -7,8 +7,8 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import DDPMScheduler, UNet2DConditionModel
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.pipelines.deepfloyd_if import IFWatermarker
-from diffusers.utils.testing_utils import torch_device
+from ...testing_utils import torch_device
from ..test_pipelines_common import to_np
diff --git a/tests/pipelines/deepfloyd_if/test_if.py b/tests/pipelines/deepfloyd_if/test_if.py
index 633d802ab9..e1870ddcba 100644
--- a/tests/pipelines/deepfloyd_if/test_if.py
+++ b/tests/pipelines/deepfloyd_if/test_if.py
@@ -23,7 +23,8 @@ from diffusers import (
)
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -37,7 +38,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
from . import IFPipelineTesterMixin
diff --git a/tests/pipelines/deepfloyd_if/test_if_img2img.py b/tests/pipelines/deepfloyd_if/test_if_img2img.py
index 739d2a0e16..9d3c96052b 100644
--- a/tests/pipelines/deepfloyd_if/test_if_img2img.py
+++ b/tests/pipelines/deepfloyd_if/test_if_img2img.py
@@ -22,7 +22,8 @@ import torch
from diffusers import IFImg2ImgPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -37,7 +38,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
diff --git a/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py
index fb89aab8e2..e2114910ed 100644
--- a/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py
+++ b/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py
@@ -22,7 +22,8 @@ import torch
from diffusers import IFImg2ImgSuperResolutionPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -37,7 +38,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
diff --git a/tests/pipelines/deepfloyd_if/test_if_inpainting.py b/tests/pipelines/deepfloyd_if/test_if_inpainting.py
index 127ae19aa6..2679e0b776 100644
--- a/tests/pipelines/deepfloyd_if/test_if_inpainting.py
+++ b/tests/pipelines/deepfloyd_if/test_if_inpainting.py
@@ -22,7 +22,8 @@ import torch
from diffusers import IFInpaintingPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -37,7 +38,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
diff --git a/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py
index 8b5210194a..3d64556c6e 100644
--- a/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py
+++ b/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py
@@ -22,7 +22,8 @@ import torch
from diffusers import IFInpaintingSuperResolutionPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -37,7 +38,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
diff --git a/tests/pipelines/deepfloyd_if/test_if_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_superresolution.py
index c16b3d6a56..fa7c0fb2e0 100644
--- a/tests/pipelines/deepfloyd_if/test_if_superresolution.py
+++ b/tests/pipelines/deepfloyd_if/test_if_superresolution.py
@@ -22,7 +22,8 @@ import torch
from diffusers import IFSuperResolutionPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -37,7 +38,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
from . import IFPipelineTesterMixin
diff --git a/tests/pipelines/dit/test_dit.py b/tests/pipelines/dit/test_dit.py
index 46e28a28e1..cd5c08ced3 100644
--- a/tests/pipelines/dit/test_dit.py
+++ b/tests/pipelines/dit/test_dit.py
@@ -21,7 +21,8 @@ import torch
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DiTTransformer2DModel, DPMSolverMultistepScheduler
from diffusers.utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_numpy,
@@ -30,7 +31,6 @@ from diffusers.utils.testing_utils import (
require_torch_accelerator,
torch_device,
)
-
from ..pipeline_params import (
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS,
CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS,
diff --git a/tests/pipelines/easyanimate/test_easyanimate.py b/tests/pipelines/easyanimate/test_easyanimate.py
index 161734a166..5cb2a232bb 100644
--- a/tests/pipelines/easyanimate/test_easyanimate.py
+++ b/tests/pipelines/easyanimate/test_easyanimate.py
@@ -26,7 +26,8 @@ from diffusers import (
EasyAnimateTransformer3DModel,
FlowMatchEulerDiscreteScheduler,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -34,7 +35,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -48,6 +48,7 @@ class EasyAnimatePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ test_xformers_attention = False
required_optional_params = frozenset(
[
"num_inference_steps",
diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py
index cc8266e1a5..1ddbd4ba3d 100644
--- a/tests/pipelines/flux/test_pipeline_flux.py
+++ b/tests/pipelines/flux/test_pipeline_flux.py
@@ -13,7 +13,9 @@ from diffusers import (
FluxPipeline,
FluxTransformer2DModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ Expectations,
backend_empty_cache,
nightly,
numpy_cosine_similarity_distance,
@@ -21,7 +23,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..test_pipelines_common import (
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
@@ -276,10 +277,14 @@ class FluxPipelineSlowTests(unittest.TestCase):
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
# fmt: off
- expected_slice = np.array(
- [0.3242, 0.3203, 0.3164, 0.3164, 0.3125, 0.3125, 0.3281, 0.3242, 0.3203, 0.3301, 0.3262, 0.3242, 0.3281, 0.3242, 0.3203, 0.3262, 0.3262, 0.3164, 0.3262, 0.3281, 0.3184, 0.3281, 0.3281, 0.3203, 0.3281, 0.3281, 0.3164, 0.3320, 0.3320, 0.3203],
- dtype=np.float32,
+
+ expected_slices = Expectations(
+ {
+ ("cuda", None): np.array([0.3242, 0.3203, 0.3164, 0.3164, 0.3125, 0.3125, 0.3281, 0.3242, 0.3203, 0.3301, 0.3262, 0.3242, 0.3281, 0.3242, 0.3203, 0.3262, 0.3262, 0.3164, 0.3262, 0.3281, 0.3184, 0.3281, 0.3281, 0.3203, 0.3281, 0.3281, 0.3164, 0.3320, 0.3320, 0.3203], dtype=np.float32,),
+ ("xpu", 3): np.array([0.3301, 0.3281, 0.3359, 0.3203, 0.3203, 0.3281, 0.3281, 0.3301, 0.3340, 0.3281, 0.3320, 0.3359, 0.3281, 0.3301, 0.3320, 0.3242, 0.3301, 0.3281, 0.3242, 0.3320, 0.3320, 0.3281, 0.3320, 0.3320, 0.3262, 0.3320, 0.3301, 0.3301, 0.3359, 0.3320], dtype=np.float32,),
+ }
)
+ expected_slice = expected_slices.get_expectation()
# fmt: on
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
diff --git a/tests/pipelines/flux/test_pipeline_flux_control.py b/tests/pipelines/flux/test_pipeline_flux_control.py
index 42283da6fd..7e966470a3 100644
--- a/tests/pipelines/flux/test_pipeline_flux_control.py
+++ b/tests/pipelines/flux/test_pipeline_flux_control.py
@@ -6,8 +6,8 @@ from PIL import Image
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel
-from diffusers.utils.testing_utils import torch_device
+from ...testing_utils import torch_device
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
diff --git a/tests/pipelines/flux/test_pipeline_flux_control_img2img.py b/tests/pipelines/flux/test_pipeline_flux_control_img2img.py
index 966543f63a..e56136f2e9 100644
--- a/tests/pipelines/flux/test_pipeline_flux_control_img2img.py
+++ b/tests/pipelines/flux/test_pipeline_flux_control_img2img.py
@@ -11,8 +11,8 @@ from diffusers import (
FluxControlImg2ImgPipeline,
FluxTransformer2DModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
index 0abd08e373..e42c5fc2aa 100644
--- a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
+++ b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
@@ -11,10 +11,10 @@ from diffusers import (
FluxControlInpaintPipeline,
FluxTransformer2DModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
diff --git a/tests/pipelines/flux/test_pipeline_flux_fill.py b/tests/pipelines/flux/test_pipeline_flux_fill.py
index 04d4c68db8..25a4a33548 100644
--- a/tests/pipelines/flux/test_pipeline_flux_fill.py
+++ b/tests/pipelines/flux/test_pipeline_flux_fill.py
@@ -6,12 +6,12 @@ import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxFillPipeline, FluxTransformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/flux/test_pipeline_flux_img2img.py b/tests/pipelines/flux/test_pipeline_flux_img2img.py
index 6d33ca721b..6f435760ae 100644
--- a/tests/pipelines/flux/test_pipeline_flux_img2img.py
+++ b/tests/pipelines/flux/test_pipeline_flux_img2img.py
@@ -6,12 +6,12 @@ import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxImg2ImgPipeline, FluxTransformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin
diff --git a/tests/pipelines/flux/test_pipeline_flux_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_inpaint.py
index 161348455c..6324ff236e 100644
--- a/tests/pipelines/flux/test_pipeline_flux_inpaint.py
+++ b/tests/pipelines/flux/test_pipeline_flux_inpaint.py
@@ -6,12 +6,12 @@ import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxInpaintPipeline, FluxTransformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin
diff --git a/tests/pipelines/flux/test_pipeline_flux_kontext.py b/tests/pipelines/flux/test_pipeline_flux_kontext.py
index 7471d78ad5..5c78964ea5 100644
--- a/tests/pipelines/flux/test_pipeline_flux_kontext.py
+++ b/tests/pipelines/flux/test_pipeline_flux_kontext.py
@@ -12,8 +12,8 @@ from diffusers import (
FluxKontextPipeline,
FluxTransformer2DModel,
)
-from diffusers.utils.testing_utils import torch_device
+from ...testing_utils import torch_device
from ..test_pipelines_common import (
FasterCacheTesterMixin,
FluxIPAdapterTesterMixin,
diff --git a/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py
index 615209264d..9a2e32056d 100644
--- a/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py
+++ b/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py
@@ -12,8 +12,8 @@ from diffusers import (
FluxKontextInpaintPipeline,
FluxTransformer2DModel,
)
-from diffusers.utils.testing_utils import floats_tensor, torch_device
+from ...testing_utils import floats_tensor, torch_device
from ..test_pipelines_common import (
FasterCacheTesterMixin,
FluxIPAdapterTesterMixin,
diff --git a/tests/pipelines/flux/test_pipeline_flux_redux.py b/tests/pipelines/flux/test_pipeline_flux_redux.py
index b73050a64d..bbeee28e6a 100644
--- a/tests/pipelines/flux/test_pipeline_flux_redux.py
+++ b/tests/pipelines/flux/test_pipeline_flux_redux.py
@@ -6,7 +6,8 @@ import torch
from diffusers import FluxPipeline, FluxPriorReduxPipeline
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
Expectations,
backend_empty_cache,
numpy_cosine_similarity_distance,
diff --git a/tests/pipelines/hidream_image/test_pipeline_hidream.py b/tests/pipelines/hidream_image/test_pipeline_hidream.py
index 1c5f30e870..ddf39ba4c1 100644
--- a/tests/pipelines/hidream_image/test_pipeline_hidream.py
+++ b/tests/pipelines/hidream_image/test_pipeline_hidream.py
@@ -32,8 +32,8 @@ from diffusers import (
HiDreamImagePipeline,
HiDreamImageTransformer2DModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism
+from ...testing_utils import enable_full_determinism
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@@ -47,8 +47,8 @@ class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
required_optional_params = PipelineTesterMixin.required_optional_params
+ test_xformers_attention = False
test_layerwise_casting = True
supports_dduf = False
diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py b/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py
index 82281f28bc..27b5bde310 100644
--- a/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py
+++ b/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py
@@ -36,8 +36,8 @@ from diffusers import (
HunyuanVideoImageToVideoPipeline,
HunyuanVideoTransformer3DModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_skyreels_image2video.py b/tests/pipelines/hunyuan_video/test_hunyuan_skyreels_image2video.py
index fad159c06b..7ebe797feb 100644
--- a/tests/pipelines/hunyuan_video/test_hunyuan_skyreels_image2video.py
+++ b/tests/pipelines/hunyuan_video/test_hunyuan_skyreels_image2video.py
@@ -26,8 +26,8 @@ from diffusers import (
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoTransformer3DModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_pipelines_common import PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np
diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py
index 26ec861522..4bdf3ee20e 100644
--- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py
+++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py
@@ -26,8 +26,8 @@ from diffusers import (
HunyuanVideoPipeline,
HunyuanVideoTransformer3DModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..test_pipelines_common import (
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py b/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py
index 297c3df45a..51c258b15c 100644
--- a/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py
+++ b/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py
@@ -36,11 +36,11 @@ from diffusers import (
HunyuanVideoFramepackPipeline,
HunyuanVideoFramepackTransformer3DModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..test_pipelines_common import (
FasterCacheTesterMixin,
PipelineTesterMixin,
diff --git a/tests/pipelines/hunyuandit/test_hunyuan_dit.py b/tests/pipelines/hunyuandit/test_hunyuan_dit.py
index 7a5f807213..2a329f10bc 100644
--- a/tests/pipelines/hunyuandit/test_hunyuan_dit.py
+++ b/tests/pipelines/hunyuandit/test_hunyuan_dit.py
@@ -22,7 +22,8 @@ import torch
from transformers import AutoTokenizer, BertModel, T5EncoderModel
from diffusers import AutoencoderKL, DDPMScheduler, HunyuanDiT2DModel, HunyuanDiTPipeline
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -30,7 +31,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
diff --git a/tests/pipelines/ip_adapters/__init__.py b/tests/pipelines/ip_adapters/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
index f5980f218a..32590111cd 100644
--- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
+++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
@@ -33,7 +33,8 @@ from diffusers import (
)
from diffusers.image_processor import IPAdapterMaskProcessor
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
Expectations,
backend_empty_cache,
enable_full_determinism,
diff --git a/tests/pipelines/kandinsky/test_kandinsky.py b/tests/pipelines/kandinsky/test_kandinsky.py
index 65a5195a8b..6207e71df8 100644
--- a/tests/pipelines/kandinsky/test_kandinsky.py
+++ b/tests/pipelines/kandinsky/test_kandinsky.py
@@ -18,12 +18,15 @@ import random
import unittest
import numpy as np
+import pytest
import torch
from transformers import XLMRobertaTokenizerFast
from diffusers import DDIMScheduler, KandinskyPipeline, KandinskyPriorPipeline, UNet2DConditionModel, VQModel
from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP
-from diffusers.utils.testing_utils import (
+from diffusers.utils import is_transformers_version
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -32,7 +35,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
@@ -215,6 +217,11 @@ class KandinskyPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
dummy = Dummies()
return dummy.get_dummy_inputs(device=device, seed=seed)
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.56.2"),
+ reason="Latest transformers changes the slices",
+ strict=False,
+ )
def test_kandinsky(self):
device = "cpu"
diff --git a/tests/pipelines/kandinsky/test_kandinsky_combined.py b/tests/pipelines/kandinsky/test_kandinsky_combined.py
index 6dd8895952..eba8976597 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_combined.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_combined.py
@@ -16,10 +16,12 @@
import unittest
import numpy as np
+import pytest
from diffusers import KandinskyCombinedPipeline, KandinskyImg2ImgCombinedPipeline, KandinskyInpaintCombinedPipeline
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
+from diffusers.utils import is_transformers_version
+from ...testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
from ..test_pipelines_common import PipelineTesterMixin
from .test_kandinsky import Dummies
from .test_kandinsky_img2img import Dummies as Img2ImgDummies
@@ -73,6 +75,11 @@ class KandinskyPipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase)
)
return inputs
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.56.2"),
+ reason="Latest transformers changes the slices",
+ strict=False,
+ )
def test_kandinsky(self):
device = "cpu"
@@ -181,6 +188,11 @@ class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.Te
inputs.pop("negative_image_embeds")
return inputs
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.56.2"),
+ reason="Latest transformers changes the slices",
+ strict=False,
+ )
def test_kandinsky(self):
device = "cpu"
@@ -292,6 +304,11 @@ class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.Te
inputs.pop("negative_image_embeds")
return inputs
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.56.2"),
+ reason="Latest transformers changes the slices",
+ strict=False,
+ )
def test_kandinsky(self):
device = "cpu"
diff --git a/tests/pipelines/kandinsky/test_kandinsky_img2img.py b/tests/pipelines/kandinsky/test_kandinsky_img2img.py
index 5a0107838a..6d1b43a24f 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_img2img.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_img2img.py
@@ -18,6 +18,7 @@ import random
import unittest
import numpy as np
+import pytest
import torch
from PIL import Image
from transformers import XLMRobertaTokenizerFast
@@ -31,7 +32,9 @@ from diffusers import (
VQModel,
)
from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP
-from diffusers.utils.testing_utils import (
+from diffusers.utils import is_transformers_version
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -42,7 +45,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
@@ -237,6 +239,11 @@ class KandinskyImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
dummies = Dummies()
return dummies.get_dummy_inputs(device=device, seed=seed)
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.56.2"),
+ reason="Latest transformers changes the slices",
+ strict=False,
+ )
def test_kandinsky_img2img(self):
device = "cpu"
diff --git a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
index 03b555b2f0..e2f4aa2a4f 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
@@ -18,13 +18,16 @@ import random
import unittest
import numpy as np
+import pytest
import torch
from PIL import Image
from transformers import XLMRobertaTokenizerFast
from diffusers import DDIMScheduler, KandinskyInpaintPipeline, KandinskyPriorPipeline, UNet2DConditionModel, VQModel
from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP
-from diffusers.utils.testing_utils import (
+from diffusers.utils import is_transformers_version
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -34,7 +37,6 @@ from diffusers.utils.testing_utils import (
require_torch_accelerator,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
@@ -231,6 +233,11 @@ class KandinskyInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
dummies = Dummies()
return dummies.get_dummy_inputs(device=device, seed=seed)
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.56.2"),
+ reason="Latest transformers changes the slices",
+ strict=False,
+ )
def test_kandinsky_inpaint(self):
device = "cpu"
diff --git a/tests/pipelines/kandinsky/test_kandinsky_prior.py b/tests/pipelines/kandinsky/test_kandinsky_prior.py
index 8ecf2d855f..903a1e5dec 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_prior.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_prior.py
@@ -28,8 +28,8 @@ from transformers import (
)
from diffusers import KandinskyPriorPipeline, PriorTransformer, UnCLIPScheduler
-from diffusers.utils.testing_utils import enable_full_determinism, skip_mps, torch_device
+from ...testing_utils import enable_full_determinism, skip_mps, torch_device
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky.py b/tests/pipelines/kandinsky2_2/test_kandinsky.py
index 0ad5620eee..38294aa4c1 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky.py
@@ -21,7 +21,8 @@ import numpy as np
import torch
from diffusers import DDIMScheduler, KandinskyV22Pipeline, KandinskyV22PriorPipeline, UNet2DConditionModel, VQModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -31,7 +32,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
index 1e064d3368..476fc584cc 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py
@@ -22,8 +22,8 @@ from diffusers import (
KandinskyV22Img2ImgCombinedPipeline,
KandinskyV22InpaintCombinedPipeline,
)
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
+from ...testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
from ..test_pipelines_common import PipelineTesterMixin
from .test_kandinsky import Dummies
from .test_kandinsky_img2img import Dummies as Img2ImgDummies
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
index b2d6f0fc05..4054e38c56 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
@@ -27,7 +27,8 @@ from diffusers import (
UNet2DConditionModel,
VQModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -38,7 +39,6 @@ from diffusers.utils.testing_utils import (
require_torch_accelerator,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py
index 4f50f51819..a434660592 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet_img2img.py
@@ -28,7 +28,8 @@ from diffusers import (
UNet2DConditionModel,
VQModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -39,7 +40,6 @@ from diffusers.utils.testing_utils import (
require_torch_accelerator,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py b/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
index bc1477b97e..99f3fe0f40 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_img2img.py
@@ -28,7 +28,8 @@ from diffusers import (
UNet2DConditionModel,
VQModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -39,7 +40,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
index 8b3d8f74ec..d4eb650263 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py
@@ -28,7 +28,8 @@ from diffusers import (
UNet2DConditionModel,
VQModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -40,7 +41,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py b/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py
index f5c2d6037b..adcc6cc216 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py
@@ -29,8 +29,8 @@ from transformers import (
)
from diffusers import KandinskyV22PriorPipeline, PriorTransformer, UnCLIPScheduler
-from diffusers.utils.testing_utils import enable_full_determinism, skip_mps, torch_device
+from ...testing_utils import enable_full_determinism, skip_mps, torch_device
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py b/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py
index 54a9cf6d60..5377d91779 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py
@@ -30,13 +30,13 @@ from transformers import (
)
from diffusers import KandinskyV22PriorEmb2EmbPipeline, PriorTransformer, UnCLIPScheduler
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
skip_mps,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/kandinsky3/test_kandinsky3.py b/tests/pipelines/kandinsky3/test_kandinsky3.py
index 1acf076b3d..55500f729b 100644
--- a/tests/pipelines/kandinsky3/test_kandinsky3.py
+++ b/tests/pipelines/kandinsky3/test_kandinsky3.py
@@ -30,7 +30,8 @@ from diffusers import (
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_image,
@@ -38,7 +39,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
diff --git a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
index edad5b7d37..503fdb242d 100644
--- a/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
+++ b/tests/pipelines/kandinsky3/test_kandinsky3_img2img.py
@@ -30,7 +30,8 @@ from diffusers import (
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -39,7 +40,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
diff --git a/tests/pipelines/kolors/test_kolors.py b/tests/pipelines/kolors/test_kolors.py
index 839d06ab93..f1d4982d4d 100644
--- a/tests/pipelines/kolors/test_kolors.py
+++ b/tests/pipelines/kolors/test_kolors.py
@@ -25,8 +25,8 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer
-from diffusers.utils.testing_utils import enable_full_determinism
+from ...testing_utils import enable_full_determinism
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
diff --git a/tests/pipelines/kolors/test_kolors_img2img.py b/tests/pipelines/kolors/test_kolors_img2img.py
index c8429322ca..5a5d31a464 100644
--- a/tests/pipelines/kolors/test_kolors_img2img.py
+++ b/tests/pipelines/kolors/test_kolors_img2img.py
@@ -26,11 +26,11 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
)
-
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
diff --git a/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py b/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py
index 570fa8fadf..c7666244b3 100644
--- a/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py
+++ b/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py
@@ -12,14 +12,14 @@ from diffusers import (
LCMScheduler,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
diff --git a/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py b/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py
index 88e31a97aa..d8e7745b78 100644
--- a/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py
+++ b/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py
@@ -13,7 +13,8 @@ from diffusers import (
LCMScheduler,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -22,7 +23,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion.py b/tests/pipelines/latent_diffusion/test_latent_diffusion.py
index ec52f5aebf..21c5bcf5a5 100644
--- a/tests/pipelines/latent_diffusion/test_latent_diffusion.py
+++ b/tests/pipelines/latent_diffusion/test_latent_diffusion.py
@@ -21,7 +21,8 @@ import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler, LDMTextToImagePipeline, UNet2DConditionModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_numpy,
@@ -29,7 +30,6 @@ from diffusers.utils.testing_utils import (
require_torch_accelerator,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py b/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py
index 2884dd3508..b2cbdb9f5b 100644
--- a/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py
+++ b/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py
@@ -21,7 +21,8 @@ import torch
from diffusers import DDIMScheduler, LDMSuperResolutionPipeline, UNet2DModel, VQModel
from diffusers.utils import PIL_INTERPOLATION
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
load_image,
diff --git a/tests/pipelines/latte/test_latte.py b/tests/pipelines/latte/test_latte.py
index 97b7edeb6c..a40d4bf8ee 100644
--- a/tests/pipelines/latte/test_latte.py
+++ b/tests/pipelines/latte/test_latte.py
@@ -31,7 +31,8 @@ from diffusers import (
PyramidAttentionBroadcastConfig,
)
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -39,7 +40,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
FasterCacheTesterMixin,
diff --git a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py
index ab0221dc81..6db20a464f 100644
--- a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py
+++ b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py
@@ -28,7 +28,8 @@ from diffusers import (
LEditsPPPipelineStableDiffusion,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
Expectations,
backend_empty_cache,
enable_full_determinism,
diff --git a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py
index 75795a3342..06c1ceb0cf 100644
--- a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py
+++ b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py
@@ -37,7 +37,7 @@ from diffusers import (
)
# from diffusers.image_processor import VaeImageProcessor
-from diffusers.utils.testing_utils import (
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
load_image,
diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py
index bf0c7fde59..aaf4161b51 100644
--- a/tests/pipelines/ltx/test_ltx.py
+++ b/tests/pipelines/ltx/test_ltx.py
@@ -20,8 +20,8 @@ import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
diff --git a/tests/pipelines/ltx/test_ltx_condition.py b/tests/pipelines/ltx/test_ltx_condition.py
index a586fadaa7..f5dfb01862 100644
--- a/tests/pipelines/ltx/test_ltx_condition.py
+++ b/tests/pipelines/ltx/test_ltx_condition.py
@@ -26,8 +26,8 @@ from diffusers import (
LTXVideoTransformer3DModel,
)
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
diff --git a/tests/pipelines/ltx/test_ltx_image2video.py b/tests/pipelines/ltx/test_ltx_image2video.py
index f43f66df53..2702993d4a 100644
--- a/tests/pipelines/ltx/test_ltx_image2video.py
+++ b/tests/pipelines/ltx/test_ltx_image2video.py
@@ -25,8 +25,8 @@ from diffusers import (
LTXImageToVideoPipeline,
LTXVideoTransformer3DModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
diff --git a/tests/pipelines/ltx/test_ltx_latent_upsample.py b/tests/pipelines/ltx/test_ltx_latent_upsample.py
index f9ddb12186..0044a85c64 100644
--- a/tests/pipelines/ltx/test_ltx_latent_upsample.py
+++ b/tests/pipelines/ltx/test_ltx_latent_upsample.py
@@ -19,8 +19,8 @@ import torch
from diffusers import AutoencoderKLLTXVideo, LTXLatentUpsamplePipeline
from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel
-from diffusers.utils.testing_utils import enable_full_determinism
+from ...testing_utils import enable_full_determinism
from ..test_pipelines_common import PipelineTesterMixin, to_np
diff --git a/tests/pipelines/lumina/test_lumina_nextdit.py b/tests/pipelines/lumina/test_lumina_nextdit.py
index c270a83841..d2c114825d 100644
--- a/tests/pipelines/lumina/test_lumina_nextdit.py
+++ b/tests/pipelines/lumina/test_lumina_nextdit.py
@@ -11,14 +11,14 @@ from diffusers import (
LuminaNextDiT2DModel,
LuminaPipeline,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
numpy_cosine_similarity_distance,
require_torch_accelerator,
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/marigold/test_marigold_depth.py b/tests/pipelines/marigold/test_marigold_depth.py
index 13f9a42186..3c85305992 100644
--- a/tests/pipelines/marigold/test_marigold_depth.py
+++ b/tests/pipelines/marigold/test_marigold_depth.py
@@ -31,7 +31,9 @@ from diffusers import (
MarigoldDepthPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ Expectations,
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -41,7 +43,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
@@ -356,7 +357,7 @@ class MarigoldDepthPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f32_cuda_G0_S1_P768_E1_B1_M1(self):
+ def test_marigold_depth_einstein_f32_accelerator_G0_S1_P768_E1_B1_M1(self):
self._test_marigold_depth(
is_fp16=False,
device=torch_device,
@@ -369,7 +370,7 @@ class MarigoldDepthPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E1_B1_M1(self):
+ def test_marigold_depth_einstein_f16_accelerator_G0_S1_P768_E1_B1_M1(self):
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
@@ -382,7 +383,7 @@ class MarigoldDepthPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1(self):
+ def test_marigold_depth_einstein_f16_accelerator_G2024_S1_P768_E1_B1_M1(self):
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
@@ -395,12 +396,23 @@ class MarigoldDepthPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G0_S2_P768_E1_B1_M1(self):
+ def test_marigold_depth_einstein_f16_accelerator_G0_S2_P768_E1_B1_M1(self):
+ # fmt: off
+ expected_slices = Expectations(
+ {
+ ("cuda", 7): np.array([0.1085, 0.1098, 0.1110, 0.1081, 0.1085, 0.1082, 0.1085, 0.1057, 0.0996]),
+ ("xpu", 3): np.array([0.1084, 0.1096, 0.1108, 0.1080, 0.1083, 0.1080,
+ 0.1085, 0.1057, 0.0996]),
+ }
+ )
+ expected_slice = expected_slices.get_expectation()
+ # fmt: on
+
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
generator_seed=0,
- expected_slice=np.array([0.1085, 0.1098, 0.1110, 0.1081, 0.1085, 0.1082, 0.1085, 0.1057, 0.0996]),
+ expected_slice=expected_slice,
num_inference_steps=2,
processing_resolution=768,
ensemble_size=1,
@@ -408,7 +420,7 @@ class MarigoldDepthPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M1(self):
+ def test_marigold_depth_einstein_f16_accelerator_G0_S1_P512_E1_B1_M1(self):
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
@@ -421,7 +433,7 @@ class MarigoldDepthPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E3_B1_M1(self):
+ def test_marigold_depth_einstein_f16_accelerator_G0_S1_P768_E3_B1_M1(self):
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
@@ -435,7 +447,7 @@ class MarigoldDepthPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E4_B2_M1(self):
+ def test_marigold_depth_einstein_f16_accelerator_G0_S1_P768_E4_B2_M1(self):
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
@@ -449,7 +461,7 @@ class MarigoldDepthPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M0(self):
+ def test_marigold_depth_einstein_f16_accelerator_G0_S1_P512_E1_B1_M0(self):
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
diff --git a/tests/pipelines/marigold/test_marigold_intrinsics.py b/tests/pipelines/marigold/test_marigold_intrinsics.py
index f00650634a..7db14b67ce 100644
--- a/tests/pipelines/marigold/test_marigold_intrinsics.py
+++ b/tests/pipelines/marigold/test_marigold_intrinsics.py
@@ -32,7 +32,9 @@ from diffusers import (
MarigoldIntrinsicsPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
+ Expectations,
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -41,7 +43,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -416,7 +417,7 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
expected_slice: np.ndarray = None,
model_id: str = "prs-eth/marigold-iid-appearance-v1-1",
image_url: str = "https://marigoldmonodepth.github.io/images/einstein.jpg",
- atol: float = 1e-4,
+ atol: float = 1e-3,
**pipe_kwargs,
):
from_pretrained_kwargs = {}
@@ -531,11 +532,41 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
)
def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P768_E3_B1_M1(self):
+ expected_slices = Expectations(
+ {
+ ("xpu", 3): np.array(
+ [
+ 0.62655,
+ 0.62477,
+ 0.62161,
+ 0.62452,
+ 0.62454,
+ 0.62454,
+ 0.62255,
+ 0.62647,
+ 0.63379,
+ ]
+ ),
+ ("cuda", 7): np.array(
+ [
+ 0.61572,
+ 0.1377,
+ 0.61182,
+ 0.61426,
+ 0.61377,
+ 0.61426,
+ 0.61279,
+ 0.61572,
+ 0.62354,
+ ]
+ ),
+ }
+ )
self._test_marigold_intrinsics(
is_fp16=True,
device=torch_device,
generator_seed=0,
- expected_slice=np.array([0.61572, 0.61377, 0.61182, 0.61426, 0.61377, 0.61426, 0.61279, 0.61572, 0.62354]),
+ expected_slice=expected_slices.get_expectation(),
num_inference_steps=1,
processing_resolution=768,
ensemble_size=3,
@@ -545,11 +576,41 @@ class MarigoldIntrinsicsPipelineIntegrationTests(unittest.TestCase):
)
def test_marigold_intrinsics_einstein_f16_accelerator_G0_S1_P768_E4_B2_M1(self):
+ expected_slices = Expectations(
+ {
+ ("xpu", 3): np.array(
+ [
+ 0.62988,
+ 0.62792,
+ 0.62548,
+ 0.62841,
+ 0.62792,
+ 0.62792,
+ 0.62646,
+ 0.62939,
+ 0.63721,
+ ]
+ ),
+ ("cuda", 7): np.array(
+ [
+ 0.61914,
+ 0.6167,
+ 0.61475,
+ 0.61719,
+ 0.61719,
+ 0.61768,
+ 0.61572,
+ 0.61914,
+ 0.62695,
+ ]
+ ),
+ }
+ )
self._test_marigold_intrinsics(
is_fp16=True,
device=torch_device,
generator_seed=0,
- expected_slice=np.array([0.61914, 0.6167, 0.61475, 0.61719, 0.61719, 0.61768, 0.61572, 0.61914, 0.62695]),
+ expected_slice=expected_slices.get_expectation(),
num_inference_steps=1,
processing_resolution=768,
ensemble_size=4,
diff --git a/tests/pipelines/marigold/test_marigold_normals.py b/tests/pipelines/marigold/test_marigold_normals.py
index 1797f99b21..108163bf22 100644
--- a/tests/pipelines/marigold/test_marigold_normals.py
+++ b/tests/pipelines/marigold/test_marigold_normals.py
@@ -31,7 +31,8 @@ from diffusers import (
MarigoldNormalsPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -40,7 +41,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py
index f1684cce72..5615720a93 100644
--- a/tests/pipelines/mochi/test_mochi.py
+++ b/tests/pipelines/mochi/test_mochi.py
@@ -21,7 +21,8 @@ import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
nightly,
@@ -30,7 +31,6 @@ from diffusers.utils.testing_utils import (
require_torch_accelerator,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import FasterCacheTesterMixin, FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
diff --git a/tests/pipelines/omnigen/test_pipeline_omnigen.py b/tests/pipelines/omnigen/test_pipeline_omnigen.py
index e8f84eb913..1a758b7050 100644
--- a/tests/pipelines/omnigen/test_pipeline_omnigen.py
+++ b/tests/pipelines/omnigen/test_pipeline_omnigen.py
@@ -6,7 +6,8 @@ import torch
from transformers import AutoTokenizer
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
Expectations,
backend_empty_cache,
numpy_cosine_similarity_distance,
@@ -14,7 +15,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
@@ -22,7 +22,7 @@ class OmniGenPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = OmniGenPipeline
params = frozenset(["prompt", "guidance_scale"])
batch_params = frozenset(["prompt"])
-
+ test_xformers_attention = False
test_layerwise_casting = True
def get_dummy_components(self):
diff --git a/tests/pipelines/pag/test_pag_animatediff.py b/tests/pipelines/pag/test_pag_animatediff.py
index b9ce29c70b..b1cbd82d76 100644
--- a/tests/pipelines/pag/test_pag_animatediff.py
+++ b/tests/pipelines/pag/test_pag_animatediff.py
@@ -19,8 +19,8 @@ from diffusers import (
)
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import is_xformers_available
-from diffusers.utils.testing_utils import require_accelerator, torch_device
+from ...testing_utils import require_accelerator, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
IPAdapterTesterMixin,
diff --git a/tests/pipelines/pag/test_pag_controlnet_sd.py b/tests/pipelines/pag/test_pag_controlnet_sd.py
index 378f0a130c..36d5ae100a 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sd.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sd.py
@@ -28,9 +28,9 @@ from diffusers import (
StableDiffusionControlNetPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from diffusers.utils.torch_utils import randn_tensor
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
diff --git a/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py b/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py
index 5eff71ed64..948381f976 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sd_inpaint.py
@@ -32,9 +32,9 @@ from diffusers import (
StableDiffusionControlNetPAGInpaintPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
from diffusers.utils.torch_utils import randn_tensor
+from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
diff --git a/tests/pipelines/pag/test_pag_controlnet_sdxl.py b/tests/pipelines/pag/test_pag_controlnet_sdxl.py
index 4d7e4f072e..51b00f6932 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sdxl.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sdxl.py
@@ -28,9 +28,9 @@ from diffusers import (
StableDiffusionXLControlNetPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism
from diffusers.utils.torch_utils import randn_tensor
+from ...testing_utils import enable_full_determinism
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
diff --git a/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py b/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py
index dec029a499..3c1088adbc 100644
--- a/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py
+++ b/tests/pipelines/pag/test_pag_controlnet_sdxl_img2img.py
@@ -29,8 +29,8 @@ from diffusers import (
StableDiffusionXLControlNetPAGImg2ImgPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor
+from ...testing_utils import enable_full_determinism, floats_tensor
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
diff --git a/tests/pipelines/pag/test_pag_hunyuan_dit.py b/tests/pipelines/pag/test_pag_hunyuan_dit.py
index 65f39f585d..f268a614f8 100644
--- a/tests/pipelines/pag/test_pag_hunyuan_dit.py
+++ b/tests/pipelines/pag/test_pag_hunyuan_dit.py
@@ -28,8 +28,8 @@ from diffusers import (
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
diff --git a/tests/pipelines/pag/test_pag_kolors.py b/tests/pipelines/pag/test_pag_kolors.py
index b504b77801..1bbb4e79e4 100644
--- a/tests/pipelines/pag/test_pag_kolors.py
+++ b/tests/pipelines/pag/test_pag_kolors.py
@@ -27,8 +27,8 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer
-from diffusers.utils.testing_utils import enable_full_determinism
+from ...testing_utils import enable_full_determinism
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
diff --git a/tests/pipelines/pag/test_pag_pixart_sigma.py b/tests/pipelines/pag/test_pag_pixart_sigma.py
index eb9399c9b3..c04ebad08f 100644
--- a/tests/pipelines/pag/test_pag_pixart_sigma.py
+++ b/tests/pipelines/pag/test_pag_pixart_sigma.py
@@ -30,12 +30,12 @@ from diffusers import (
PixArtTransformer2DModel,
)
from diffusers.utils import logging
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
CaptureLogger,
enable_full_determinism,
torch_device,
)
-
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
diff --git a/tests/pipelines/pag/test_pag_sana.py b/tests/pipelines/pag/test_pag_sana.py
index 31b384f3eb..5408595c72 100644
--- a/tests/pipelines/pag/test_pag_sana.py
+++ b/tests/pipelines/pag/test_pag_sana.py
@@ -26,8 +26,8 @@ from diffusers import (
SanaPipeline,
SanaTransformer2DModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
diff --git a/tests/pipelines/pag/test_pag_sd.py b/tests/pipelines/pag/test_pag_sd.py
index ee9a74ed03..064815d136 100644
--- a/tests/pipelines/pag/test_pag_sd.py
+++ b/tests/pipelines/pag/test_pag_sd.py
@@ -29,14 +29,14 @@ from diffusers import (
StableDiffusionPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
diff --git a/tests/pipelines/pag/test_pag_sd3.py b/tests/pipelines/pag/test_pag_sd3.py
index 737e238e5f..26e6ca0992 100644
--- a/tests/pipelines/pag/test_pag_sd3.py
+++ b/tests/pipelines/pag/test_pag_sd3.py
@@ -12,10 +12,10 @@ from diffusers import (
StableDiffusion3PAGPipeline,
StableDiffusion3Pipeline,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
torch_device,
)
-
from ..test_pipelines_common import (
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
diff --git a/tests/pipelines/pag/test_pag_sd3_img2img.py b/tests/pipelines/pag/test_pag_sd3_img2img.py
index fe593d47dc..19a36e283d 100644
--- a/tests/pipelines/pag/test_pag_sd3_img2img.py
+++ b/tests/pipelines/pag/test_pag_sd3_img2img.py
@@ -15,7 +15,8 @@ from diffusers import (
StableDiffusion3Img2ImgPipeline,
StableDiffusion3PAGImg2ImgPipeline,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -24,7 +25,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
diff --git a/tests/pipelines/pag/test_pag_sd_img2img.py b/tests/pipelines/pag/test_pag_sd_img2img.py
index 668e798463..0b440d5ec9 100644
--- a/tests/pipelines/pag/test_pag_sd_img2img.py
+++ b/tests/pipelines/pag/test_pag_sd_img2img.py
@@ -31,7 +31,8 @@ from diffusers import (
StableDiffusionPAGImg2ImgPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -40,7 +41,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
diff --git a/tests/pipelines/pag/test_pag_sd_inpaint.py b/tests/pipelines/pag/test_pag_sd_inpaint.py
index f856041422..709df68370 100644
--- a/tests/pipelines/pag/test_pag_sd_inpaint.py
+++ b/tests/pipelines/pag/test_pag_sd_inpaint.py
@@ -29,7 +29,8 @@ from diffusers import (
StableDiffusionPAGInpaintPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -38,7 +39,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
diff --git a/tests/pipelines/pag/test_pag_sdxl.py b/tests/pipelines/pag/test_pag_sdxl.py
index 5c1608d210..cca5c61651 100644
--- a/tests/pipelines/pag/test_pag_sdxl.py
+++ b/tests/pipelines/pag/test_pag_sdxl.py
@@ -29,14 +29,14 @@ from diffusers import (
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
diff --git a/tests/pipelines/pag/test_pag_sdxl_img2img.py b/tests/pipelines/pag/test_pag_sdxl_img2img.py
index 2e18fdcebb..d311500d3c 100644
--- a/tests/pipelines/pag/test_pag_sdxl_img2img.py
+++ b/tests/pipelines/pag/test_pag_sdxl_img2img.py
@@ -38,7 +38,8 @@ from diffusers import (
StableDiffusionXLPAGImg2ImgPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -47,7 +48,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
diff --git a/tests/pipelines/pag/test_pag_sdxl_inpaint.py b/tests/pipelines/pag/test_pag_sdxl_inpaint.py
index e36716b603..00a07582e2 100644
--- a/tests/pipelines/pag/test_pag_sdxl_inpaint.py
+++ b/tests/pipelines/pag/test_pag_sdxl_inpaint.py
@@ -39,7 +39,8 @@ from diffusers import (
StableDiffusionXLPAGInpaintPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -48,7 +49,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
diff --git a/tests/pipelines/pipeline_params.py b/tests/pipelines/pipeline_params.py
index 4e2c4dcdd9..3db7c9fa1b 100644
--- a/tests/pipelines/pipeline_params.py
+++ b/tests/pipelines/pipeline_params.py
@@ -20,12 +20,6 @@ TEXT_TO_IMAGE_PARAMS = frozenset(
]
)
-TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
-
-TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
-
-IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
-
IMAGE_VARIATION_PARAMS = frozenset(
[
"image",
@@ -35,8 +29,6 @@ IMAGE_VARIATION_PARAMS = frozenset(
]
)
-IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
-
TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
[
"prompt",
@@ -50,8 +42,6 @@ TEXT_GUIDED_IMAGE_VARIATION_PARAMS = frozenset(
]
)
-TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
-
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
[
# Text guided image variation with an image mask
@@ -67,8 +57,6 @@ TEXT_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
]
)
-TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
-
IMAGE_INPAINTING_PARAMS = frozenset(
[
# image variation with an image mask
@@ -80,8 +68,6 @@ IMAGE_INPAINTING_PARAMS = frozenset(
]
)
-IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
-
IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
[
"example_image",
@@ -93,20 +79,12 @@ IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS = frozenset(
]
)
-IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
+UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
CLASS_CONDITIONED_IMAGE_GENERATION_PARAMS = frozenset(["class_labels"])
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS = frozenset(["class_labels"])
-UNCONDITIONAL_IMAGE_GENERATION_PARAMS = frozenset(["batch_size"])
-
-UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
-
-UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
-
-UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
-
TEXT_TO_AUDIO_PARAMS = frozenset(
[
"prompt",
@@ -119,11 +97,38 @@ TEXT_TO_AUDIO_PARAMS = frozenset(
]
)
-TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
TOKENS_TO_AUDIO_GENERATION_PARAMS = frozenset(["input_tokens"])
+UNCONDITIONAL_AUDIO_GENERATION_PARAMS = frozenset(["batch_size"])
+
+# image params
+TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
+
+IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
+
+
+# batch params
+TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
+
+IMAGE_VARIATION_BATCH_PARAMS = frozenset(["image"])
+
+TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS = frozenset(["prompt", "image", "negative_prompt"])
+
+TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["prompt", "image", "mask_image", "negative_prompt"])
+
+IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["image", "mask_image"])
+
+IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS = frozenset(["example_image", "image", "mask_image"])
+
+UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS = frozenset([])
+
+UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS = frozenset([])
+
+TEXT_TO_AUDIO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
+
TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS = frozenset(["input_tokens"])
-TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
-
VIDEO_TO_VIDEO_BATCH_PARAMS = frozenset(["prompt", "negative_prompt", "video"])
+
+# callback params
+TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS = frozenset(["prompt_embeds"])
diff --git a/tests/pipelines/pixart_alpha/test_pixart.py b/tests/pipelines/pixart_alpha/test_pixart.py
index 933a005c4a..fd41c9887d 100644
--- a/tests/pipelines/pixart_alpha/test_pixart.py
+++ b/tests/pipelines/pixart_alpha/test_pixart.py
@@ -27,7 +27,8 @@ from diffusers import (
PixArtAlphaPipeline,
PixArtTransformer2DModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -35,7 +36,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
diff --git a/tests/pipelines/pixart_sigma/test_pixart.py b/tests/pipelines/pixart_sigma/test_pixart.py
index cda7b442d7..2cb80df81a 100644
--- a/tests/pipelines/pixart_sigma/test_pixart.py
+++ b/tests/pipelines/pixart_sigma/test_pixart.py
@@ -27,7 +27,8 @@ from diffusers import (
PixArtSigmaPipeline,
PixArtTransformer2DModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -35,7 +36,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
diff --git a/tests/pipelines/pndm/test_pndm.py b/tests/pipelines/pndm/test_pndm.py
index 2c12690ad1..61d6efe88c 100644
--- a/tests/pipelines/pndm/test_pndm.py
+++ b/tests/pipelines/pndm/test_pndm.py
@@ -19,7 +19,8 @@ import numpy as np
import torch
from diffusers import PNDMPipeline, PNDMScheduler, UNet2DModel
-from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch, torch_device
+
+from ...testing_utils import enable_full_determinism, nightly, require_torch, torch_device
enable_full_determinism()
diff --git a/tests/pipelines/qwenimage/test_qwenimage.py b/tests/pipelines/qwenimage/test_qwenimage.py
index a312d0658f..8ebfe7d08b 100644
--- a/tests/pipelines/qwenimage/test_qwenimage.py
+++ b/tests/pipelines/qwenimage/test_qwenimage.py
@@ -24,8 +24,8 @@ from diffusers import (
QwenImagePipeline,
QwenImageTransformer2DModel,
)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
diff --git a/tests/pipelines/qwenimage/test_qwenimage_controlnet.py b/tests/pipelines/qwenimage/test_qwenimage_controlnet.py
new file mode 100644
index 0000000000..188106b49b
--- /dev/null
+++ b/tests/pipelines/qwenimage/test_qwenimage_controlnet.py
@@ -0,0 +1,338 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
+
+from diffusers import (
+ AutoencoderKLQwenImage,
+ FlowMatchEulerDiscreteScheduler,
+ QwenImageControlNetModel,
+ QwenImageControlNetPipeline,
+ QwenImageMultiControlNetModel,
+ QwenImageTransformer2DModel,
+)
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from diffusers.utils.torch_utils import randn_tensor
+
+from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class QwenControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = QwenImageControlNetPipeline
+ params = (TEXT_TO_IMAGE_PARAMS | frozenset(["control_image", "controlnet_conditioning_scale"])) - {
+ "cross_attention_kwargs"
+ }
+ batch_params = frozenset(["prompt", "negative_prompt", "control_image"])
+ image_params = frozenset(["control_image"])
+ image_latents_params = frozenset(["latents"])
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "control_image",
+ "controlnet_conditioning_scale",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = QwenImageTransformer2DModel(
+ patch_size=2,
+ in_channels=16,
+ out_channels=4,
+ num_layers=2,
+ attention_head_dim=16,
+ num_attention_heads=3,
+ joint_attention_dim=16,
+ guidance_embeds=False,
+ axes_dims_rope=(8, 4, 4),
+ )
+
+ torch.manual_seed(0)
+ controlnet = QwenImageControlNetModel(
+ patch_size=2,
+ in_channels=16,
+ out_channels=4,
+ num_layers=2,
+ attention_head_dim=16,
+ num_attention_heads=3,
+ joint_attention_dim=16,
+ axes_dims_rope=(8, 4, 4),
+ )
+
+ torch.manual_seed(0)
+ z_dim = 4
+ vae = AutoencoderKLQwenImage(
+ base_dim=z_dim * 6,
+ z_dim=z_dim,
+ dim_mult=[1, 2, 4],
+ num_res_blocks=1,
+ temperal_downsample=[False, True],
+ latents_mean=[0.0] * z_dim,
+ latents_std=[1.0] * z_dim,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ config = Qwen2_5_VLConfig(
+ text_config={
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "rope_scaling": {
+ "mrope_section": [1, 1, 2],
+ "rope_type": "default",
+ "type": "default",
+ },
+ "rope_theta": 1_000_000.0,
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_heads": 2,
+ "out_hidden_size": 16,
+ },
+ hidden_size=16,
+ vocab_size=152064,
+ vision_end_token_id=151653,
+ vision_start_token_id=151652,
+ vision_token_id=151654,
+ )
+
+ text_encoder = Qwen2_5_VLForConditionalGeneration(config)
+ tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "controlnet": controlnet,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ control_image = randn_tensor(
+ (1, 3, 32, 32),
+ generator=generator,
+ device=torch.device(device),
+ dtype=torch.float32,
+ )
+
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 3.0,
+ "true_cfg_scale": 1.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "control_image": control_image,
+ "controlnet_conditioning_scale": 0.5,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_qwen_controlnet(self):
+ device = "cpu"
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+
+ # Expected slice from the generated image
+ expected_slice = torch.tensor(
+ [
+ 0.4726,
+ 0.5549,
+ 0.6324,
+ 0.6548,
+ 0.4968,
+ 0.4639,
+ 0.4749,
+ 0.4898,
+ 0.4725,
+ 0.4645,
+ 0.4435,
+ 0.3339,
+ 0.3400,
+ 0.4630,
+ 0.3879,
+ 0.4406,
+ ]
+ )
+
+ generated_slice = generated_image.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ def test_qwen_controlnet_multicondition(self):
+ device = "cpu"
+ components = self.get_dummy_components()
+
+ components["controlnet"] = QwenImageMultiControlNetModel([components["controlnet"]])
+
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ control_image = inputs["control_image"]
+ inputs["control_image"] = [control_image, control_image]
+ inputs["controlnet_conditioning_scale"] = [0.5, 0.5]
+
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+ # Expected slice from the generated image
+ expected_slice = torch.tensor(
+ [
+ 0.6239,
+ 0.6642,
+ 0.5768,
+ 0.6039,
+ 0.5270,
+ 0.5070,
+ 0.5006,
+ 0.5271,
+ 0.4506,
+ 0.3085,
+ 0.3435,
+ 0.5152,
+ 0.5096,
+ 0.5422,
+ 0.4286,
+ 0.5752,
+ ]
+ )
+
+ generated_slice = generated_image.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ inputs["control_image"] = randn_tensor(
+ (1, 3, 128, 128),
+ generator=inputs["generator"],
+ device=torch.device(generator_device),
+ dtype=torch.float32,
+ )
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ inputs["control_image"] = randn_tensor(
+ (1, 3, 128, 128),
+ generator=inputs["generator"],
+ device=torch.device(generator_device),
+ dtype=torch.float32,
+ )
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
diff --git a/tests/pipelines/qwenimage/test_qwenimage_edit.py b/tests/pipelines/qwenimage/test_qwenimage_edit.py
new file mode 100644
index 0000000000..058548cf5f
--- /dev/null
+++ b/tests/pipelines/qwenimage/test_qwenimage_edit.py
@@ -0,0 +1,243 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import pytest
+import torch
+from PIL import Image
+from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
+
+from diffusers import (
+ AutoencoderKLQwenImage,
+ FlowMatchEulerDiscreteScheduler,
+ QwenImageEditPipeline,
+ QwenImageTransformer2DModel,
+)
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class QwenImageEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = QwenImageEditPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = frozenset(["prompt", "image"])
+ image_params = frozenset(["image"])
+ image_latents_params = frozenset(["latents"])
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ tiny_ckpt_id = "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
+
+ torch.manual_seed(0)
+ transformer = QwenImageTransformer2DModel(
+ patch_size=2,
+ in_channels=16,
+ out_channels=4,
+ num_layers=2,
+ attention_head_dim=16,
+ num_attention_heads=3,
+ joint_attention_dim=16,
+ guidance_embeds=False,
+ axes_dims_rope=(8, 4, 4),
+ )
+
+ torch.manual_seed(0)
+ z_dim = 4
+ vae = AutoencoderKLQwenImage(
+ base_dim=z_dim * 6,
+ z_dim=z_dim,
+ dim_mult=[1, 2, 4],
+ num_res_blocks=1,
+ temperal_downsample=[False, True],
+ latents_mean=[0.0] * z_dim,
+ latents_std=[1.0] * z_dim,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ config = Qwen2_5_VLConfig(
+ text_config={
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "rope_scaling": {
+ "mrope_section": [1, 1, 2],
+ "rope_type": "default",
+ "type": "default",
+ },
+ "rope_theta": 1000000.0,
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_heads": 2,
+ "out_hidden_size": 16,
+ },
+ hidden_size=16,
+ vocab_size=152064,
+ vision_end_token_id=151653,
+ vision_start_token_id=151652,
+ vision_token_id=151654,
+ )
+ text_encoder = Qwen2_5_VLForConditionalGeneration(config)
+ tokenizer = Qwen2Tokenizer.from_pretrained(tiny_ckpt_id)
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "processor": Qwen2VLProcessor.from_pretrained(tiny_ckpt_id),
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ inputs = {
+ "prompt": "dance monkey",
+ "image": Image.new("RGB", (32, 32)),
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "true_cfg_scale": 1.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+
+ # fmt: off
+ expected_slice = torch.tensor([[0.5637, 0.6341, 0.6001, 0.5620, 0.5794, 0.5498, 0.5757, 0.6389, 0.4174, 0.3597, 0.5649, 0.4894, 0.4969, 0.5255, 0.4083, 0.4986]])
+ # fmt: on
+
+ generated_slice = generated_image.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ @pytest.mark.xfail(condition=True, reason="Preconfigured embeddings need to be revisited.", strict=True)
+ def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
+ super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol)
diff --git a/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py b/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py
new file mode 100644
index 0000000000..6faf347282
--- /dev/null
+++ b/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py
@@ -0,0 +1,253 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import pytest
+import torch
+from PIL import Image
+from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
+
+from diffusers import (
+ AutoencoderKLQwenImage,
+ FlowMatchEulerDiscreteScheduler,
+ QwenImageEditPlusPipeline,
+ QwenImageTransformer2DModel,
+)
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class QwenImageEditPlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = QwenImageEditPlusPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = frozenset(["prompt", "image"])
+ image_params = frozenset(["image"])
+ image_latents_params = frozenset(["latents"])
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ tiny_ckpt_id = "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
+
+ torch.manual_seed(0)
+ transformer = QwenImageTransformer2DModel(
+ patch_size=2,
+ in_channels=16,
+ out_channels=4,
+ num_layers=2,
+ attention_head_dim=16,
+ num_attention_heads=3,
+ joint_attention_dim=16,
+ guidance_embeds=False,
+ axes_dims_rope=(8, 4, 4),
+ )
+
+ torch.manual_seed(0)
+ z_dim = 4
+ vae = AutoencoderKLQwenImage(
+ base_dim=z_dim * 6,
+ z_dim=z_dim,
+ dim_mult=[1, 2, 4],
+ num_res_blocks=1,
+ temperal_downsample=[False, True],
+ latents_mean=[0.0] * z_dim,
+ latents_std=[1.0] * z_dim,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ config = Qwen2_5_VLConfig(
+ text_config={
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "rope_scaling": {
+ "mrope_section": [1, 1, 2],
+ "rope_type": "default",
+ "type": "default",
+ },
+ "rope_theta": 1000000.0,
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_heads": 2,
+ "out_hidden_size": 16,
+ },
+ hidden_size=16,
+ vocab_size=152064,
+ vision_end_token_id=151653,
+ vision_start_token_id=151652,
+ vision_token_id=151654,
+ )
+ text_encoder = Qwen2_5_VLForConditionalGeneration(config)
+ tokenizer = Qwen2Tokenizer.from_pretrained(tiny_ckpt_id)
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "processor": Qwen2VLProcessor.from_pretrained(tiny_ckpt_id),
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ image = Image.new("RGB", (32, 32))
+ inputs = {
+ "prompt": "dance monkey",
+ "image": [image, image],
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "true_cfg_scale": 1.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+
+ # fmt: off
+ expected_slice = torch.tensor([[0.5637, 0.6341, 0.6001, 0.5620, 0.5794, 0.5498, 0.5757, 0.6389, 0.4174, 0.3597, 0.5649, 0.4894, 0.4969, 0.5255, 0.4083, 0.4986]])
+ # fmt: on
+
+ generated_slice = generated_image.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ @pytest.mark.xfail(condition=True, reason="Preconfigured embeddings need to be revisited.", strict=True)
+ def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
+ super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol)
+
+ @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
+ def test_num_images_per_prompt():
+ super().test_num_images_per_prompt()
+
+ @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
+ def test_inference_batch_consistent():
+ super().test_inference_batch_consistent()
+
+ @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
+ def test_inference_batch_single_identical():
+ super().test_inference_batch_single_identical()
diff --git a/tests/pipelines/qwenimage/test_qwenimage_img2img.py b/tests/pipelines/qwenimage/test_qwenimage_img2img.py
new file mode 100644
index 0000000000..07e683ec7f
--- /dev/null
+++ b/tests/pipelines/qwenimage/test_qwenimage_img2img.py
@@ -0,0 +1,218 @@
+import random
+import unittest
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
+
+from diffusers import (
+ AutoencoderKLQwenImage,
+ FlowMatchEulerDiscreteScheduler,
+ QwenImageImg2ImgPipeline,
+ QwenImageTransformer2DModel,
+)
+
+from ...testing_utils import (
+ enable_full_determinism,
+ floats_tensor,
+ torch_device,
+)
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class QwenImageImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+ pipeline_class = QwenImageImg2ImgPipeline
+ params = frozenset(["prompt", "image", "height", "width", "guidance_scale", "true_cfg_scale", "strength"])
+ batch_params = frozenset(["prompt", "image"])
+ image_params = frozenset(["image"])
+ image_latents_params = frozenset(["latents"])
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ supports_dduf = False
+ test_xformers_attention = False
+ test_attention_slicing = True
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = QwenImageTransformer2DModel(
+ patch_size=2,
+ in_channels=16,
+ out_channels=4,
+ num_layers=2,
+ attention_head_dim=16,
+ num_attention_heads=3,
+ joint_attention_dim=16,
+ guidance_embeds=False,
+ axes_dims_rope=(8, 4, 4),
+ )
+
+ torch.manual_seed(0)
+ z_dim = 4
+ vae = AutoencoderKLQwenImage(
+ base_dim=z_dim * 6,
+ z_dim=z_dim,
+ dim_mult=[1, 2, 4],
+ num_res_blocks=1,
+ temperal_downsample=[False, True],
+ latents_mean=[0.0] * 4,
+ latents_std=[1.0] * 4,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ config = Qwen2_5_VLConfig(
+ text_config={
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "rope_scaling": {
+ "mrope_section": [1, 1, 2],
+ "rope_type": "default",
+ "type": "default",
+ },
+ "rope_theta": 1000000.0,
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_heads": 2,
+ "out_hidden_size": 16,
+ },
+ hidden_size=16,
+ vocab_size=152064,
+ vision_end_token_id=151653,
+ vision_start_token_id=151652,
+ vision_token_id=151654,
+ )
+ text_encoder = Qwen2_5_VLForConditionalGeneration(config)
+ tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
+
+ return {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "image": image,
+ "prompt": "dance monkey",
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 3.0,
+ "true_cfg_scale": 1.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs).images[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs).images[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs).images[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
diff --git a/tests/pipelines/qwenimage/test_qwenimage_inpaint.py b/tests/pipelines/qwenimage/test_qwenimage_inpaint.py
new file mode 100644
index 0000000000..b564624540
--- /dev/null
+++ b/tests/pipelines/qwenimage/test_qwenimage_inpaint.py
@@ -0,0 +1,233 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import unittest
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
+
+from diffusers import (
+ AutoencoderKLQwenImage,
+ FlowMatchEulerDiscreteScheduler,
+ QwenImageInpaintPipeline,
+ QwenImageTransformer2DModel,
+)
+
+from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class QwenImageInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = QwenImageInpaintPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = QwenImageTransformer2DModel(
+ patch_size=2,
+ in_channels=16,
+ out_channels=4,
+ num_layers=2,
+ attention_head_dim=16,
+ num_attention_heads=3,
+ joint_attention_dim=16,
+ guidance_embeds=False,
+ axes_dims_rope=(8, 4, 4),
+ )
+
+ torch.manual_seed(0)
+ z_dim = 4
+ vae = AutoencoderKLQwenImage(
+ base_dim=z_dim * 6,
+ z_dim=z_dim,
+ dim_mult=[1, 2, 4],
+ num_res_blocks=1,
+ temperal_downsample=[False, True],
+ # fmt: off
+ latents_mean=[0.0] * 4,
+ latents_std=[1.0] * 4,
+ # fmt: on
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ config = Qwen2_5_VLConfig(
+ text_config={
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "rope_scaling": {
+ "mrope_section": [1, 1, 2],
+ "rope_type": "default",
+ "type": "default",
+ },
+ "rope_theta": 1000000.0,
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_heads": 2,
+ "out_hidden_size": 16,
+ },
+ hidden_size=16,
+ vocab_size=152064,
+ vision_end_token_id=151653,
+ vision_start_token_id=151652,
+ vision_token_id=151654,
+ )
+ text_encoder = Qwen2_5_VLForConditionalGeneration(config)
+ tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
+ mask_image = torch.ones((1, 1, 32, 32)).to(device)
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "bad quality",
+ "image": image,
+ "mask_image": mask_image,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 3.0,
+ "true_cfg_scale": 1.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
diff --git a/tests/pipelines/sana/test_sana.py b/tests/pipelines/sana/test_sana.py
index 26c06c1c9e..34ea3079b1 100644
--- a/tests/pipelines/sana/test_sana.py
+++ b/tests/pipelines/sana/test_sana.py
@@ -21,14 +21,14 @@ import torch
from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
diff --git a/tests/pipelines/sana/test_sana_controlnet.py b/tests/pipelines/sana/test_sana_controlnet.py
index 9b5c9e439e..043e276fcb 100644
--- a/tests/pipelines/sana/test_sana_controlnet.py
+++ b/tests/pipelines/sana/test_sana_controlnet.py
@@ -26,12 +26,12 @@ from diffusers import (
SanaControlNetPipeline,
SanaTransformer2DModel,
)
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
diff --git a/tests/pipelines/sana/test_sana_sprint.py b/tests/pipelines/sana/test_sana_sprint.py
index 021e559637..fee2304dce 100644
--- a/tests/pipelines/sana/test_sana_sprint.py
+++ b/tests/pipelines/sana/test_sana_sprint.py
@@ -20,11 +20,11 @@ import torch
from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, SanaSprintPipeline, SanaTransformer2DModel, SCMScheduler
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
diff --git a/tests/pipelines/sana/test_sana_sprint_img2img.py b/tests/pipelines/sana/test_sana_sprint_img2img.py
index c0e4bf8e35..c218abb8e9 100644
--- a/tests/pipelines/sana/test_sana_sprint_img2img.py
+++ b/tests/pipelines/sana/test_sana_sprint_img2img.py
@@ -20,12 +20,12 @@ import torch
from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, SanaSprintImg2ImgPipeline, SanaTransformer2DModel, SCMScheduler
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
diff --git a/tests/pipelines/shap_e/test_shap_e.py b/tests/pipelines/shap_e/test_shap_e.py
index 47cc97844e..99fd286929 100644
--- a/tests/pipelines/shap_e/test_shap_e.py
+++ b/tests/pipelines/shap_e/test_shap_e.py
@@ -21,14 +21,14 @@ from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokeni
from diffusers import HeunDiscreteScheduler, PriorTransformer, ShapEPipeline
from diffusers.pipelines.shap_e import ShapERenderer
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
load_numpy,
nightly,
require_torch_accelerator,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
diff --git a/tests/pipelines/shap_e/test_shap_e_img2img.py b/tests/pipelines/shap_e/test_shap_e_img2img.py
index ba9f9fe521..b1867db249 100644
--- a/tests/pipelines/shap_e/test_shap_e_img2img.py
+++ b/tests/pipelines/shap_e/test_shap_e_img2img.py
@@ -22,7 +22,8 @@ from transformers import CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel
from diffusers import HeunDiscreteScheduler, PriorTransformer, ShapEImg2ImgPipeline
from diffusers.pipelines.shap_e import ShapERenderer
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
floats_tensor,
load_image,
@@ -31,7 +32,6 @@ from diffusers.utils.testing_utils import (
require_torch_accelerator,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2.py b/tests/pipelines/skyreels_v2/test_skyreels_v2.py
index adbbf05325..1bcec877c3 100644
--- a/tests/pipelines/skyreels_v2/test_skyreels_v2.py
+++ b/tests/pipelines/skyreels_v2/test_skyreels_v2.py
@@ -24,10 +24,10 @@ from diffusers import (
SkyReelsV2Transformer3DModel,
UniPCMultistepScheduler,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py
index cf9070bb95..74235d59ef 100644
--- a/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py
+++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py
@@ -24,10 +24,10 @@ from diffusers import (
SkyReelsV2Transformer3DModel,
UniPCMultistepScheduler,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py
index 7b8a299281..f0cbc710df 100644
--- a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py
+++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py
@@ -28,8 +28,8 @@ from diffusers import (
SkyReelsV2Transformer3DModel,
UniPCMultistepScheduler,
)
-from diffusers.utils.testing_utils import enable_full_determinism
+from ...testing_utils import enable_full_determinism
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py
index bc6a9acbf7..1b0b23318e 100644
--- a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py
+++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py
@@ -26,11 +26,11 @@ from diffusers import (
SkyReelsV2Transformer3DModel,
UniPCMultistepScheduler,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py
index 3ca5862072..784f701a29 100644
--- a/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py
+++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py
@@ -31,8 +31,8 @@ from diffusers import (
SkyReelsV2Transformer3DModel,
UniPCMultistepScheduler,
)
-from diffusers.utils.testing_utils import enable_full_determinism
+from ...testing_utils import enable_full_determinism
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/stable_audio/test_stable_audio.py b/tests/pipelines/stable_audio/test_stable_audio.py
index 5167dfdf0c..dd03f4d07f 100644
--- a/tests/pipelines/stable_audio/test_stable_audio.py
+++ b/tests/pipelines/stable_audio/test_stable_audio.py
@@ -32,7 +32,8 @@ from diffusers import (
StableAudioProjectionModel,
)
from diffusers.utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
Expectations,
backend_empty_cache,
enable_full_determinism,
@@ -40,7 +41,6 @@ from diffusers.utils.testing_utils import (
require_torch_accelerator,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py
index 0a75b1e8b9..afa0db39f3 100644
--- a/tests/pipelines/stable_cascade/test_stable_cascade_combined.py
+++ b/tests/pipelines/stable_cascade/test_stable_cascade_combined.py
@@ -22,8 +22,8 @@ from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokeni
from diffusers import DDPMWuerstchenScheduler, StableCascadeCombinedPipeline
from diffusers.models import StableCascadeUNet
from diffusers.pipelines.wuerstchen import PaellaVQModel
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
+from ...testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py b/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py
index d0c9fc891f..5b3acb8705 100644
--- a/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py
+++ b/tests/pipelines/stable_cascade/test_stable_cascade_decoder.py
@@ -23,7 +23,9 @@ from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokeni
from diffusers import DDPMWuerstchenScheduler, StableCascadeDecoderPipeline
from diffusers.models import StableCascadeUNet
from diffusers.pipelines.wuerstchen import PaellaVQModel
-from diffusers.utils.testing_utils import (
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_numpy,
@@ -34,8 +36,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-from diffusers.utils.torch_utils import randn_tensor
-
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_prior.py b/tests/pipelines/stable_cascade/test_stable_cascade_prior.py
index 90633adea9..f8267186db 100644
--- a/tests/pipelines/stable_cascade/test_stable_cascade_prior.py
+++ b/tests/pipelines/stable_cascade/test_stable_cascade_prior.py
@@ -23,7 +23,8 @@ from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokeni
from diffusers import DDPMWuerstchenScheduler, StableCascadePriorPipeline
from diffusers.models import StableCascadeUNet
from diffusers.utils.import_utils import is_peft_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_numpy,
diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py
index 69c105743b..62414f3f19 100644
--- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py
@@ -27,8 +27,8 @@ from diffusers import (
OnnxStableDiffusionPipeline,
PNDMScheduler,
)
-from diffusers.utils.testing_utils import is_onnx_available, nightly, require_onnxruntime, require_torch_gpu
+from ...testing_utils import is_onnx_available, nightly, require_onnxruntime, require_torch_gpu
from ..test_pipelines_onnx_common import OnnxPipelineTesterMixin
diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py
index 8a470fc668..28d1d0f37f 100644
--- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py
+++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py
@@ -26,7 +26,8 @@ from diffusers import (
OnnxStableDiffusionImg2ImgPipeline,
PNDMScheduler,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
floats_tensor,
is_onnx_available,
load_image,
@@ -34,7 +35,6 @@ from diffusers.utils.testing_utils import (
require_onnxruntime,
require_torch_gpu,
)
-
from ..test_pipelines_onnx_common import OnnxPipelineTesterMixin
diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py
index 6bca7b288c..1d46ff9a2f 100644
--- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py
+++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py
@@ -18,14 +18,14 @@ import unittest
import numpy as np
from diffusers import LMSDiscreteScheduler, OnnxStableDiffusionInpaintPipeline
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
is_onnx_available,
load_image,
nightly,
require_onnxruntime,
require_torch_gpu,
)
-
from ..test_pipelines_onnx_common import OnnxPipelineTesterMixin
diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_upscale.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_upscale.py
index e25118575f..55d9d38d64 100644
--- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_upscale.py
+++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_upscale.py
@@ -26,7 +26,8 @@ from diffusers import (
OnnxStableDiffusionUpscalePipeline,
PNDMScheduler,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
floats_tensor,
is_onnx_available,
load_image,
@@ -34,7 +35,6 @@ from diffusers.utils.testing_utils import (
require_onnxruntime,
require_torch_gpu,
)
-
from ..test_pipelines_onnx_common import OnnxPipelineTesterMixin
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
index bcad693501..c9d9525b2e 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
@@ -41,7 +41,8 @@ from diffusers import (
UNet2DConditionModel,
logging,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
CaptureLogger,
backend_empty_cache,
backend_max_memory_allocated,
@@ -58,7 +59,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
index c80667656e..a0b7268b9d 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
@@ -33,7 +33,8 @@ from diffusers import (
StableDiffusionImg2ImgPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -48,7 +49,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
index 20a9848118..259806a947 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
@@ -35,7 +35,8 @@ from diffusers import (
StableDiffusionInpaintPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
Expectations,
backend_empty_cache,
backend_max_memory_allocated,
@@ -50,7 +51,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py
index 1654831a99..4758c5dab4 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py
@@ -32,7 +32,8 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.image_processor import VaeImageProcessor
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -44,7 +45,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
index b3b5ba3de4..3b2552b432 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
@@ -31,7 +31,8 @@ from diffusers import (
UNet2DConditionModel,
logging,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
CaptureLogger,
backend_empty_cache,
backend_max_memory_allocated,
@@ -45,7 +46,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py
index 6f772e5df1..bea7c09904 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py
@@ -36,7 +36,8 @@ from diffusers import (
StableDiffusionDepth2ImgPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -50,7 +51,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py
deleted file mode 100644
index 77014bd7a5..0000000000
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py
+++ /dev/null
@@ -1,108 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-from diffusers import FlaxDPMSolverMultistepScheduler, FlaxStableDiffusionPipeline
-from diffusers.utils import is_flax_available
-from diffusers.utils.testing_utils import nightly, require_flax
-
-
-if is_flax_available():
- import jax
- import jax.numpy as jnp
- from flax.jax_utils import replicate
- from flax.training.common_utils import shard
-
-
-@nightly
-@require_flax
-class FlaxStableDiffusion2PipelineIntegrationTests(unittest.TestCase):
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
-
- def test_stable_diffusion_flax(self):
- sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-2",
- variant="bf16",
- dtype=jnp.bfloat16,
- )
-
- prompt = "A painting of a squirrel eating a burger"
- num_samples = jax.device_count()
- prompt = num_samples * [prompt]
- prompt_ids = sd_pipe.prepare_inputs(prompt)
-
- params = replicate(params)
- prompt_ids = shard(prompt_ids)
-
- prng_seed = jax.random.PRNGKey(0)
- prng_seed = jax.random.split(prng_seed, jax.device_count())
-
- images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0]
- assert images.shape == (jax.device_count(), 1, 768, 768, 3)
-
- images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
- image_slice = images[0, 253:256, 253:256, -1]
-
- output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
- expected_slice = jnp.array([0.4238, 0.4414, 0.4395, 0.4453, 0.4629, 0.4590, 0.4531, 0.45508, 0.4512])
-
- assert jnp.abs(output_slice - expected_slice).max() < 1e-2
-
-
-@nightly
-@require_flax
-class FlaxStableDiffusion2PipelineNightlyTests(unittest.TestCase):
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
-
- def test_stable_diffusion_dpm_flax(self):
- model_id = "stabilityai/stable-diffusion-2"
- scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
- sd_pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
- model_id,
- scheduler=scheduler,
- variant="bf16",
- dtype=jnp.bfloat16,
- )
- params["scheduler"] = scheduler_params
-
- prompt = "A painting of a squirrel eating a burger"
- num_samples = jax.device_count()
- prompt = num_samples * [prompt]
- prompt_ids = sd_pipe.prepare_inputs(prompt)
-
- params = replicate(params)
- prompt_ids = shard(prompt_ids)
-
- prng_seed = jax.random.PRNGKey(0)
- prng_seed = jax.random.split(prng_seed, jax.device_count())
-
- images = sd_pipe(prompt_ids, params, prng_seed, num_inference_steps=25, jit=True)[0]
- assert images.shape == (jax.device_count(), 1, 768, 768, 3)
-
- images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
- image_slice = images[0, 253:256, 253:256, -1]
-
- output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
- expected_slice = jnp.array([0.4336, 0.42969, 0.4453, 0.4199, 0.4297, 0.4531, 0.4434, 0.4434, 0.4297])
-
- assert jnp.abs(output_slice - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py
deleted file mode 100644
index d83c696736..0000000000
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py
+++ /dev/null
@@ -1,82 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-from diffusers import FlaxStableDiffusionInpaintPipeline
-from diffusers.utils import is_flax_available, load_image
-from diffusers.utils.testing_utils import require_flax, slow
-
-
-if is_flax_available():
- import jax
- import jax.numpy as jnp
- from flax.jax_utils import replicate
- from flax.training.common_utils import shard
-
-
-@slow
-@require_flax
-class FlaxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
-
- def test_stable_diffusion_inpaint_pipeline(self):
- init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/sd2-inpaint/init_image.png"
- )
- mask_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-inpaint/mask.png"
- )
-
- model_id = "xvjiarui/stable-diffusion-2-inpainting"
- pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None)
-
- prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
-
- prng_seed = jax.random.PRNGKey(0)
- num_inference_steps = 50
-
- num_samples = jax.device_count()
- prompt = num_samples * [prompt]
- init_image = num_samples * [init_image]
- mask_image = num_samples * [mask_image]
- prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs(prompt, init_image, mask_image)
-
- # shard inputs and rng
- params = replicate(params)
- prng_seed = jax.random.split(prng_seed, jax.device_count())
- prompt_ids = shard(prompt_ids)
- processed_masked_images = shard(processed_masked_images)
- processed_masks = shard(processed_masks)
-
- output = pipeline(
- prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True
- )
-
- images = output.images.reshape(num_samples, 512, 512, 3)
-
- image_slice = images[0, 253:256, 253:256, -1]
-
- output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
- expected_slice = jnp.array(
- [0.3611307, 0.37649736, 0.3757408, 0.38213953, 0.39295167, 0.3841631, 0.41554978, 0.4137475, 0.4217084]
- )
-
- assert jnp.abs(output_slice - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py
index 238874c7f8..f010c1b03f 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py
@@ -23,7 +23,8 @@ from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -36,7 +37,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py
index 50cb9aa4b7..2e4b428dfe 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py
@@ -30,7 +30,8 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.schedulers import KarrasDiffusionSchedulers
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -40,7 +41,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py
index a0949db7ee..481ac7f2d1 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py
@@ -24,7 +24,8 @@ from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionUpscalePipeline, UNet2DConditionModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py
index 55d801fd6c..37b309c4ca 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py
@@ -30,7 +30,8 @@ from diffusers import (
StableDiffusionPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
index 2179ec8e22..3ccefe3de3 100644
--- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
+++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
@@ -6,14 +6,14 @@ import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
numpy_cosine_similarity_distance,
require_big_accelerator,
slow,
torch_device,
)
-
from ..test_pipelines_common import (
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
@@ -124,37 +124,22 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
}
return inputs
- def test_stable_diffusion_3_different_prompts(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ def test_inference(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
inputs = self.get_dummy_inputs(torch_device)
- output_same_prompt = pipe(**inputs).images[0]
+ image = pipe(**inputs).images[0]
+ generated_slice = image.flatten()
+ generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt_2"] = "a different prompt"
- inputs["prompt_3"] = "another different prompt"
- output_different_prompts = pipe(**inputs).images[0]
+ # fmt: off
+ expected_slice = np.array([0.5112, 0.5228, 0.5235, 0.5524, 0.3188, 0.5017, 0.5574, 0.4899, 0.6812, 0.5991, 0.3908, 0.5213, 0.5582, 0.4457, 0.4204, 0.5616])
+ # fmt: on
- max_diff = np.abs(output_same_prompt - output_different_prompts).max()
-
- # Outputs should be different here
- assert max_diff > 1e-2
-
- def test_stable_diffusion_3_different_negative_prompts(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
-
- inputs = self.get_dummy_inputs(torch_device)
- output_same_prompt = pipe(**inputs).images[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["negative_prompt_2"] = "deformed"
- inputs["negative_prompt_3"] = "blurry"
- output_different_prompts = pipe(**inputs).images[0]
-
- max_diff = np.abs(output_same_prompt - output_different_prompts).max()
-
- # Outputs should be different here
- assert max_diff > 1e-2
+ self.assertTrue(
+ np.allclose(generated_slice, expected_slice, atol=1e-3), "Output does not match expected slice."
+ )
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -268,40 +253,9 @@ class StableDiffusion3PipelineSlowTests(unittest.TestCase):
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
- expected_slice = np.array(
- [
- 0.4648,
- 0.4404,
- 0.4177,
- 0.5063,
- 0.4800,
- 0.4287,
- 0.5425,
- 0.5190,
- 0.4717,
- 0.5430,
- 0.5195,
- 0.4766,
- 0.5361,
- 0.5122,
- 0.4612,
- 0.4871,
- 0.4749,
- 0.4058,
- 0.4756,
- 0.4678,
- 0.3804,
- 0.4832,
- 0.4822,
- 0.3799,
- 0.5103,
- 0.5034,
- 0.3953,
- 0.5073,
- 0.4839,
- 0.3884,
- ]
- )
+ # fmt: off
+ expected_slice = np.array([0.4648, 0.4404, 0.4177, 0.5063, 0.4800, 0.4287, 0.5425, 0.5190, 0.4717, 0.5430, 0.5195, 0.4766, 0.5361, 0.5122, 0.4612, 0.4871, 0.4749, 0.4058, 0.4756, 0.4678, 0.3804, 0.4832, 0.4822, 0.3799, 0.5103, 0.5034, 0.3953, 0.5073, 0.4839, 0.3884])
+ # fmt: on
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
index 7f913cb63d..9025b1060c 100644
--- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
+++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
@@ -13,7 +13,8 @@ from diffusers import (
StableDiffusion3Img2ImgPipeline,
)
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
Expectations,
backend_empty_cache,
floats_tensor,
@@ -22,7 +23,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
@@ -128,37 +128,22 @@ class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unitte
}
return inputs
- def test_stable_diffusion_3_img2img_different_prompts(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ def test_inference(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
inputs = self.get_dummy_inputs(torch_device)
- output_same_prompt = pipe(**inputs).images[0]
+ image = pipe(**inputs).images[0]
+ generated_slice = image.flatten()
+ generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt_2"] = "a different prompt"
- inputs["prompt_3"] = "another different prompt"
- output_different_prompts = pipe(**inputs).images[0]
+ # fmt: off
+ expected_slice = np.array([0.4564, 0.5486, 0.4868, 0.5923, 0.3775, 0.5543, 0.4807, 0.4177, 0.3778, 0.5957, 0.5726, 0.4333, 0.6312, 0.5062, 0.4838, 0.5984])
+ # fmt: on
- max_diff = np.abs(output_same_prompt - output_different_prompts).max()
-
- # Outputs should be different here
- assert max_diff > 1e-2
-
- def test_stable_diffusion_3_img2img_different_negative_prompts(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
-
- inputs = self.get_dummy_inputs(torch_device)
- output_same_prompt = pipe(**inputs).images[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["negative_prompt_2"] = "deformed"
- inputs["negative_prompt_3"] = "blurry"
- output_different_prompts = pipe(**inputs).images[0]
-
- max_diff = np.abs(output_same_prompt - output_different_prompts).max()
-
- # Outputs should be different here
- assert max_diff > 1e-2
+ self.assertTrue(
+ np.allclose(generated_slice, expected_slice, atol=1e-3), "Output does not match expected slice."
+ )
@unittest.skip("Skip for now.")
def test_multi_vae(self):
@@ -207,112 +192,16 @@ class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
+
+ # fmt: off
expected_slices = Expectations(
{
- ("xpu", 3): np.array(
- [
- 0.5117,
- 0.4421,
- 0.3852,
- 0.5044,
- 0.4219,
- 0.3262,
- 0.5024,
- 0.4329,
- 0.3276,
- 0.4978,
- 0.4412,
- 0.3355,
- 0.4983,
- 0.4338,
- 0.3279,
- 0.4893,
- 0.4241,
- 0.3129,
- 0.4875,
- 0.4253,
- 0.3030,
- 0.4961,
- 0.4267,
- 0.2988,
- 0.5029,
- 0.4255,
- 0.3054,
- 0.5132,
- 0.4248,
- 0.3222,
- ]
- ),
- ("cuda", 7): np.array(
- [
- 0.5435,
- 0.4673,
- 0.5732,
- 0.4438,
- 0.3557,
- 0.4912,
- 0.4331,
- 0.3491,
- 0.4915,
- 0.4287,
- 0.347,
- 0.4849,
- 0.4355,
- 0.3469,
- 0.4871,
- 0.4431,
- 0.3538,
- 0.4912,
- 0.4521,
- 0.3643,
- 0.5059,
- 0.4587,
- 0.373,
- 0.5166,
- 0.4685,
- 0.3845,
- 0.5264,
- 0.4746,
- 0.3914,
- 0.5342,
- ]
- ),
- ("cuda", 8): np.array(
- [
- 0.5146,
- 0.4385,
- 0.3826,
- 0.5098,
- 0.4150,
- 0.3218,
- 0.5142,
- 0.4312,
- 0.3298,
- 0.5127,
- 0.4431,
- 0.3411,
- 0.5171,
- 0.4424,
- 0.3374,
- 0.5088,
- 0.4348,
- 0.3242,
- 0.5073,
- 0.4380,
- 0.3174,
- 0.5132,
- 0.4397,
- 0.3115,
- 0.5132,
- 0.4343,
- 0.3118,
- 0.5219,
- 0.4328,
- 0.3256,
- ]
- ),
+ ("xpu", 3): np.array([0.5117, 0.4421, 0.3852, 0.5044, 0.4219, 0.3262, 0.5024, 0.4329, 0.3276, 0.4978, 0.4412, 0.3355, 0.4983, 0.4338, 0.3279, 0.4893, 0.4241, 0.3129, 0.4875, 0.4253, 0.3030, 0.4961, 0.4267, 0.2988, 0.5029, 0.4255, 0.3054, 0.5132, 0.4248, 0.3222]),
+ ("cuda", 7): np.array([0.5435, 0.4673, 0.5732, 0.4438, 0.3557, 0.4912, 0.4331, 0.3491, 0.4915, 0.4287, 0.347, 0.4849, 0.4355, 0.3469, 0.4871, 0.4431, 0.3538, 0.4912, 0.4521, 0.3643, 0.5059, 0.4587, 0.373, 0.5166, 0.4685, 0.3845, 0.5264, 0.4746, 0.3914, 0.5342]),
+ ("cuda", 8): np.array([0.5146, 0.4385, 0.3826, 0.5098, 0.4150, 0.3218, 0.5142, 0.4312, 0.3298, 0.5127, 0.4431, 0.3411, 0.5171, 0.4424, 0.3374, 0.5088, 0.4348, 0.3242, 0.5073, 0.4380, 0.3174, 0.5132, 0.4397, 0.3115, 0.5132, 0.4343, 0.3118, 0.5219, 0.4328, 0.3256]),
}
)
+ # fmt: on
expected_slice = expected_slices.get_expectation()
diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py
index 4090306dec..6289303402 100644
--- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py
+++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_inpaint.py
@@ -11,12 +11,12 @@ from diffusers import (
SD3Transformer2DModel,
StableDiffusion3InpaintPipeline,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
@@ -132,37 +132,23 @@ class StableDiffusion3InpaintPipelineFastTests(PipelineLatentTesterMixin, unitte
}
return inputs
- def test_stable_diffusion_3_inpaint_different_prompts(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ def test_inference(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
inputs = self.get_dummy_inputs(torch_device)
- output_same_prompt = pipe(**inputs).images[0]
+ image = pipe(**inputs).images[0]
+ generated_slice = image.flatten()
+ generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt_2"] = "a different prompt"
- inputs["prompt_3"] = "another different prompt"
- output_different_prompts = pipe(**inputs).images[0]
+ # fmt: off
+ expected_slice = np.array([0.5035, 0.6661, 0.5859, 0.413, 0.4224, 0.4234, 0.7181, 0.5062, 0.5183, 0.6877, 0.5074, 0.585, 0.6111, 0.5422, 0.5306, 0.5891])
+ # fmt: on
- max_diff = np.abs(output_same_prompt - output_different_prompts).max()
-
- # Outputs should be different here
- assert max_diff > 1e-2
-
- def test_stable_diffusion_3_inpaint_different_negative_prompts(self):
- pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
-
- inputs = self.get_dummy_inputs(torch_device)
- output_same_prompt = pipe(**inputs).images[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["negative_prompt_2"] = "deformed"
- inputs["negative_prompt_3"] = "blurry"
- output_different_prompts = pipe(**inputs).images[0]
-
- max_diff = np.abs(output_same_prompt - output_different_prompts).max()
-
- # Outputs should be different here
- assert max_diff > 1e-2
+ self.assertTrue(
+ np.allclose(generated_slice, expected_slice, atol=1e-3), "Output does not match expected slice."
+ )
+ @unittest.skip("Skip for now.")
def test_multi_vae(self):
pass
diff --git a/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py b/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py
index 009c75df42..79b38d1cad 100644
--- a/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py
+++ b/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py
@@ -34,7 +34,8 @@ from diffusers import (
)
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -45,7 +46,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineFromPipeTesterMixin, PipelineTesterMixin, assert_mean_pixel_difference
diff --git a/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py b/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py
index 5eca6c2380..dbf5a7b68e 100644
--- a/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py
+++ b/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py
@@ -29,7 +29,8 @@ from diffusers import (
StableDiffusionImageVariationPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -44,7 +45,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
index 966d864843..b318a505e9 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
@@ -34,7 +34,8 @@ from diffusers import (
UNet2DConditionModel,
UniPCMultistepScheduler,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_image,
@@ -43,7 +44,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py
index c39c9bedaf..3d72270dda 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py
@@ -32,12 +32,12 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.utils import logging
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
-
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import (
IPAdapterTesterMixin,
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py
index 450891b257..c549984706 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py
@@ -38,7 +38,8 @@ from diffusers import (
StableDiffusionXLImg2ImgPipeline,
UNet2DConditionModel,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -47,7 +48,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
index 6ac820547d..d3f5779c76 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
@@ -41,14 +41,14 @@ from diffusers import (
UNet2DConditionModel,
UniPCMultistepScheduler,
)
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
floats_tensor,
require_torch_accelerator,
slow,
torch_device,
)
-
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py
index 932a249689..20a03583e7 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py
@@ -29,8 +29,8 @@ from diffusers.image_processor import VaeImageProcessor
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instruct_pix2pix import (
StableDiffusionXLInstructPix2PixPipeline,
)
-from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
+from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
diff --git a/tests/pipelines/stable_unclip/test_stable_unclip.py b/tests/pipelines/stable_unclip/test_stable_unclip.py
index e3cbb1891b..8923c2f63c 100644
--- a/tests/pipelines/stable_unclip/test_stable_unclip.py
+++ b/tests/pipelines/stable_unclip/test_stable_unclip.py
@@ -13,7 +13,8 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -24,7 +25,6 @@ from diffusers.utils.testing_utils import (
require_torch_accelerator,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineKarrasSchedulerTesterMixin,
diff --git a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py
index 8ca5723ce6..e7a0fbccef 100644
--- a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py
+++ b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py
@@ -17,7 +17,8 @@ from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableUnCLIPI
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
@@ -31,7 +32,6 @@ from diffusers.utils.testing_utils import (
skip_mps,
torch_device,
)
-
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import (
PipelineKarrasSchedulerTesterMixin,
diff --git a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py
index f77a5b1620..52595f7a8c 100644
--- a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py
+++ b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py
@@ -20,7 +20,8 @@ from diffusers import (
)
from diffusers.utils import load_image, logging
from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
CaptureLogger,
backend_empty_cache,
enable_full_determinism,
@@ -32,7 +33,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py
index f49ad282f3..6d9e681979 100644
--- a/tests/pipelines/test_pipeline_utils.py
+++ b/tests/pipelines/test_pipeline_utils.py
@@ -19,7 +19,8 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
-from diffusers.utils.testing_utils import require_torch_accelerator, torch_device
+
+from ..testing_utils import require_torch_accelerator, torch_device
class IsSafetensorsCompatibleTests(unittest.TestCase):
diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py
index 6c342bcbe6..5b86423553 100644
--- a/tests/pipelines/test_pipelines.py
+++ b/tests/pipelines/test_pipelines.py
@@ -28,14 +28,15 @@ import warnings
import numpy as np
import PIL.Image
+import pytest
import requests_mock
import safetensors.torch
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
+from huggingface_hub.utils import HfHubHTTPError
from parameterized import parameterized
from PIL import Image
-from requests.exceptions import HTTPError
from transformers import CLIPImageProcessor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
@@ -62,11 +63,10 @@ from diffusers import (
)
from diffusers.pipelines.pipeline_utils import _get_pipeline_class
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
-from diffusers.utils import (
- CONFIG_NAME,
- WEIGHTS_NAME,
-)
-from diffusers.utils.testing_utils import (
+from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, is_transformers_version
+from diffusers.utils.torch_utils import is_compiled_module
+
+from ..testing_utils import (
CaptureLogger,
backend_empty_cache,
enable_full_determinism,
@@ -89,7 +89,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-from diffusers.utils.torch_utils import is_compiled_module
enable_full_determinism()
@@ -429,7 +428,7 @@ class DownloadTests(unittest.TestCase):
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
- response_mock.raise_for_status.side_effect = HTTPError
+ response_mock.raise_for_status.side_effect = HfHubHTTPError("Server down", response=mock.Mock())
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
@@ -456,7 +455,7 @@ class DownloadTests(unittest.TestCase):
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
- response_mock.raise_for_status.side_effect = HTTPError
+ response_mock.raise_for_status.side_effect = HfHubHTTPError("Server down", response=mock.Mock())
response_mock.json.return_value = {}
# first check that with local files only the pipeline can only be used if cached
@@ -583,6 +582,7 @@ class DownloadTests(unittest.TestCase):
assert not any(f.endswith(unexpected_ext) for f in files)
assert all(variant in f for f in model_files if f.endswith(model_ext) and variant is not None)
+ @pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=True)
def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self):
repo_id = "hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds"
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
@@ -689,6 +689,7 @@ class DownloadTests(unittest.TestCase):
)
assert "Error no file name" in str(error_context.exception)
+ @pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=True)
def test_local_save_load_index(self):
prompt = "hello"
for variant in [None, "fp16"]:
@@ -1583,6 +1584,7 @@ class PipelineFastTests(unittest.TestCase):
assert pipeline.scheduler is not None
assert pipeline.feature_extractor is not None
+ @pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=True)
def test_no_pytorch_download_when_doing_safetensors(self):
# by default we don't download
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -1602,6 +1604,7 @@ class PipelineFastTests(unittest.TestCase):
# pytorch does not
assert not os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin"))
+ @pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=True)
def test_no_safetensors_download_when_doing_pytorch(self):
use_safetensors = False
diff --git a/tests/pipelines/test_pipelines_auto.py b/tests/pipelines/test_pipelines_auto.py
index de4b447f66..f3c639c367 100644
--- a/tests/pipelines/test_pipelines_auto.py
+++ b/tests/pipelines/test_pipelines_auto.py
@@ -35,7 +35,8 @@ from diffusers.pipelines.auto_pipeline import (
AUTO_INPAINT_PIPELINES_MAPPING,
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
)
-from diffusers.utils.testing_utils import slow
+
+from ..testing_utils import slow
PRETRAINED_MODEL_REPO_MAPPING = OrderedDict(
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index 387eb6a614..db8209835b 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, Union
import numpy as np
import PIL.Image
+import pytest
import torch
import torch.nn as nn
from huggingface_hub import ModelCard, delete_repo
@@ -48,19 +49,6 @@ from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.source_code_parsing_utils import ReturnNameVisitor
-from diffusers.utils.testing_utils import (
- CaptureLogger,
- backend_empty_cache,
- numpy_cosine_similarity_distance,
- require_accelerate_version_greater,
- require_accelerator,
- require_hf_hub_version_greater,
- require_torch,
- require_torch_accelerator,
- require_transformers_version_greater,
- skip_mps,
- torch_device,
-)
from ..models.autoencoders.vae import (
get_asym_autoencoder_kl_config,
@@ -74,6 +62,19 @@ from ..models.unets.test_models_unet_2d_condition import (
create_ip_adapter_state_dict,
)
from ..others.test_utils import TOKEN, USER, is_staging_test
+from ..testing_utils import (
+ CaptureLogger,
+ backend_empty_cache,
+ numpy_cosine_similarity_distance,
+ require_accelerate_version_greater,
+ require_accelerator,
+ require_hf_hub_version_greater,
+ require_torch,
+ require_torch_accelerator,
+ require_transformers_version_greater,
+ skip_mps,
+ torch_device,
+)
def to_np(tensor):
@@ -2339,6 +2340,96 @@ class PipelineTesterMixin:
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
)
+ @require_torch_accelerator
+ def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ torch.manual_seed(0)
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["generator"] = torch.manual_seed(0)
+ out = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir)
+ loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map=torch_device)
+ for component in loaded_pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ inputs["generator"] = torch.manual_seed(0)
+ loaded_out = loaded_pipe(**inputs)[0]
+ max_diff = np.abs(to_np(out) - to_np(loaded_out)).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+ @require_torch_accelerator
+ def test_pipeline_level_group_offloading_sanity_checks(self):
+ components = self.get_dummy_components()
+ pipe: DiffusionPipeline = self.pipeline_class(**components)
+
+ for name, component in pipe.components.items():
+ if hasattr(component, "_supports_group_offloading"):
+ if not component._supports_group_offloading:
+ pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")
+
+ module_names = sorted(
+ [name for name, component in pipe.components.items() if isinstance(component, torch.nn.Module)]
+ )
+ exclude_module_name = module_names[0]
+ offload_device = "cpu"
+ pipe.enable_group_offload(
+ onload_device=torch_device,
+ offload_device=offload_device,
+ offload_type="leaf_level",
+ exclude_modules=exclude_module_name,
+ )
+ excluded_module = getattr(pipe, exclude_module_name)
+ self.assertTrue(torch.device(excluded_module.device).type == torch.device(torch_device).type)
+
+ for name, component in pipe.components.items():
+ if name not in [exclude_module_name] and isinstance(component, torch.nn.Module):
+ # `component.device` prints the `onload_device` type. We should probably override the
+ # `device` property in `ModelMixin`.
+ component_device = next(component.parameters())[0].device
+ self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type)
+
+ @require_torch_accelerator
+ def test_pipeline_level_group_offloading_inference(self, expected_max_difference=1e-4):
+ components = self.get_dummy_components()
+ pipe: DiffusionPipeline = self.pipeline_class(**components)
+
+ for name, component in pipe.components.items():
+ if hasattr(component, "_supports_group_offloading"):
+ if not component._supports_group_offloading:
+ pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")
+
+ # Regular inference.
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ torch.manual_seed(0)
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["generator"] = torch.manual_seed(0)
+ out = pipe(**inputs)[0]
+
+ pipe.to("cpu")
+ del pipe
+
+ # Inference with offloading
+ pipe: DiffusionPipeline = self.pipeline_class(**components)
+ offload_device = "cpu"
+ pipe.enable_group_offload(
+ onload_device=torch_device,
+ offload_device=offload_device,
+ offload_type="leaf_level",
+ )
+ pipe.set_progress_bar_config(disable=None)
+ inputs["generator"] = torch.manual_seed(0)
+ out_offload = pipe(**inputs)[0]
+
+ max_diff = np.abs(to_np(out) - to_np(out_offload)).max()
+ self.assertLess(max_diff, expected_max_difference)
+
@is_staging_test
class PipelinePushToHubTester(unittest.TestCase):
diff --git a/tests/pipelines/test_pipelines_flax.py b/tests/pipelines/test_pipelines_flax.py
deleted file mode 100644
index ffe43ac9d7..0000000000
--- a/tests/pipelines/test_pipelines_flax.py
+++ /dev/null
@@ -1,260 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import tempfile
-import unittest
-
-import numpy as np
-
-from diffusers.utils import is_flax_available
-from diffusers.utils.testing_utils import require_flax, slow
-
-
-if is_flax_available():
- import jax
- import jax.numpy as jnp
- from flax.jax_utils import replicate
- from flax.training.common_utils import shard
-
- from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline
-
-
-@require_flax
-class DownloadTests(unittest.TestCase):
- def test_download_only_pytorch(self):
- with tempfile.TemporaryDirectory() as tmpdirname:
- # pipeline has Flax weights
- _ = FlaxDiffusionPipeline.from_pretrained(
- "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
- )
-
- all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
- files = [item for sublist in all_root_files for item in sublist]
-
- # None of the downloaded files should be a PyTorch file even if we have some here:
- # https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_pytorch_model.bin
- assert not any(f.endswith(".bin") for f in files)
-
-
-@slow
-@require_flax
-class FlaxPipelineTests(unittest.TestCase):
- def test_dummy_all_tpus(self):
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
- "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
- )
-
- prompt = (
- "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
- " field, close up, split lighting, cinematic"
- )
-
- prng_seed = jax.random.PRNGKey(0)
- num_inference_steps = 4
-
- num_samples = jax.device_count()
- prompt = num_samples * [prompt]
- prompt_ids = pipeline.prepare_inputs(prompt)
-
- # shard inputs and rng
- params = replicate(params)
- prng_seed = jax.random.split(prng_seed, num_samples)
- prompt_ids = shard(prompt_ids)
-
- images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
-
- assert images.shape == (num_samples, 1, 64, 64, 3)
- if jax.device_count() == 8:
- assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.1514745) < 1e-3
- assert np.abs(np.abs(images, dtype=np.float32).sum() - 49947.875) < 5e-1
-
- images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
- assert len(images_pil) == num_samples
-
- def test_stable_diffusion_v1_4(self):
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4", revision="flax", safety_checker=None
- )
-
- prompt = (
- "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
- " field, close up, split lighting, cinematic"
- )
-
- prng_seed = jax.random.PRNGKey(0)
- num_inference_steps = 50
-
- num_samples = jax.device_count()
- prompt = num_samples * [prompt]
- prompt_ids = pipeline.prepare_inputs(prompt)
-
- # shard inputs and rng
- params = replicate(params)
- prng_seed = jax.random.split(prng_seed, num_samples)
- prompt_ids = shard(prompt_ids)
-
- images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
-
- assert images.shape == (num_samples, 1, 512, 512, 3)
- if jax.device_count() == 8:
- assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-2
- assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 5e-1
-
- def test_stable_diffusion_v1_4_bfloat_16(self):
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4", variant="bf16", dtype=jnp.bfloat16, safety_checker=None
- )
-
- prompt = (
- "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
- " field, close up, split lighting, cinematic"
- )
-
- prng_seed = jax.random.PRNGKey(0)
- num_inference_steps = 50
-
- num_samples = jax.device_count()
- prompt = num_samples * [prompt]
- prompt_ids = pipeline.prepare_inputs(prompt)
-
- # shard inputs and rng
- params = replicate(params)
- prng_seed = jax.random.split(prng_seed, num_samples)
- prompt_ids = shard(prompt_ids)
-
- images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
-
- assert images.shape == (num_samples, 1, 512, 512, 3)
- if jax.device_count() == 8:
- assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 5e-2
- assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1
-
- def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4", variant="bf16", dtype=jnp.bfloat16
- )
-
- prompt = (
- "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
- " field, close up, split lighting, cinematic"
- )
-
- prng_seed = jax.random.PRNGKey(0)
- num_inference_steps = 50
-
- num_samples = jax.device_count()
- prompt = num_samples * [prompt]
- prompt_ids = pipeline.prepare_inputs(prompt)
-
- # shard inputs and rng
- params = replicate(params)
- prng_seed = jax.random.split(prng_seed, num_samples)
- prompt_ids = shard(prompt_ids)
-
- images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
-
- assert images.shape == (num_samples, 1, 512, 512, 3)
- if jax.device_count() == 8:
- assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 5e-2
- assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1
-
- def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
- scheduler = FlaxDDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- set_alpha_to_one=False,
- steps_offset=1,
- )
-
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4",
- variant="bf16",
- dtype=jnp.bfloat16,
- scheduler=scheduler,
- safety_checker=None,
- )
- scheduler_state = scheduler.create_state()
-
- params["scheduler"] = scheduler_state
-
- prompt = (
- "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
- " field, close up, split lighting, cinematic"
- )
-
- prng_seed = jax.random.PRNGKey(0)
- num_inference_steps = 50
-
- num_samples = jax.device_count()
- prompt = num_samples * [prompt]
- prompt_ids = pipeline.prepare_inputs(prompt)
-
- # shard inputs and rng
- params = replicate(params)
- prng_seed = jax.random.split(prng_seed, num_samples)
- prompt_ids = shard(prompt_ids)
-
- images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
-
- assert images.shape == (num_samples, 1, 512, 512, 3)
- if jax.device_count() == 8:
- assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.045043945)) < 5e-2
- assert np.abs((np.abs(images, dtype=np.float32).sum() - 2347693.5)) < 5e-1
-
- def test_jax_memory_efficient_attention(self):
- prompt = (
- "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of"
- " field, close up, split lighting, cinematic"
- )
-
- num_samples = jax.device_count()
- prompt = num_samples * [prompt]
- prng_seed = jax.random.split(jax.random.PRNGKey(0), num_samples)
-
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4",
- variant="bf16",
- dtype=jnp.bfloat16,
- safety_checker=None,
- )
-
- params = replicate(params)
- prompt_ids = pipeline.prepare_inputs(prompt)
- prompt_ids = shard(prompt_ids)
- images = pipeline(prompt_ids, params, prng_seed, jit=True).images
- assert images.shape == (num_samples, 1, 512, 512, 3)
- slice = images[2, 0, 256, 10:17, 1]
-
- # With memory efficient attention
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4",
- variant="bf16",
- dtype=jnp.bfloat16,
- safety_checker=None,
- use_memory_efficient_attention=True,
- )
-
- params = replicate(params)
- prompt_ids = pipeline.prepare_inputs(prompt)
- prompt_ids = shard(prompt_ids)
- images_eff = pipeline(prompt_ids, params, prng_seed, jit=True).images
- assert images_eff.shape == (num_samples, 1, 512, 512, 3)
- slice_eff = images[2, 0, 256, 10:17, 1]
-
- # I checked the results visually and they are very similar. However, I saw that the max diff is `1` and the `sum`
- # over the 8 images is exactly `256`, which is very suspicious. Testing a random slice for now.
- assert abs(slice_eff - slice).max() < 1e-2
diff --git a/tests/pipelines/test_pipelines_onnx_common.py b/tests/pipelines/test_pipelines_onnx_common.py
index 575ecd0075..fa077efb8a 100644
--- a/tests/pipelines/test_pipelines_onnx_common.py
+++ b/tests/pipelines/test_pipelines_onnx_common.py
@@ -1,4 +1,4 @@
-from diffusers.utils.testing_utils import require_onnxruntime
+from ..testing_utils import require_onnxruntime
@require_onnxruntime
diff --git a/tests/pipelines/visualcloze/test_pipeline_visualcloze_combined.py b/tests/pipelines/visualcloze/test_pipeline_visualcloze_combined.py
index 7e2aa25709..00ae0441fe 100644
--- a/tests/pipelines/visualcloze/test_pipeline_visualcloze_combined.py
+++ b/tests/pipelines/visualcloze/test_pipeline_visualcloze_combined.py
@@ -10,14 +10,14 @@ from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPToken
import diffusers
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxTransformer2DModel, VisualClozePipeline
from diffusers.utils import logging
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
CaptureLogger,
enable_full_determinism,
floats_tensor,
require_accelerator,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin, to_np
diff --git a/tests/pipelines/visualcloze/test_pipeline_visualcloze_generation.py b/tests/pipelines/visualcloze/test_pipeline_visualcloze_generation.py
index 0cd714af17..ab6b3ca5c5 100644
--- a/tests/pipelines/visualcloze/test_pipeline_visualcloze_generation.py
+++ b/tests/pipelines/visualcloze/test_pipeline_visualcloze_generation.py
@@ -15,14 +15,14 @@ from diffusers import (
VisualClozeGenerationPipeline,
)
from diffusers.utils import logging
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
CaptureLogger,
enable_full_determinism,
floats_tensor,
require_accelerator,
torch_device,
)
-
from ..test_pipelines_common import PipelineTesterMixin, to_np
diff --git a/tests/pipelines/wan/test_wan.py b/tests/pipelines/wan/test_wan.py
index 90b7978ec7..106a7b2946 100644
--- a/tests/pipelines/wan/test_wan.py
+++ b/tests/pipelines/wan/test_wan.py
@@ -21,14 +21,14 @@ import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanPipeline, WanTransformer3DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/wan/test_wan_22.py b/tests/pipelines/wan/test_wan_22.py
index 9fdae66980..56ef5ceb97 100644
--- a/tests/pipelines/wan/test_wan_22.py
+++ b/tests/pipelines/wan/test_wan_22.py
@@ -20,11 +20,11 @@ import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanPipeline, WanTransformer3DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/wan/test_wan_22_image_to_video.py b/tests/pipelines/wan/test_wan_22_image_to_video.py
index 3f72a74e44..6294d62044 100644
--- a/tests/pipelines/wan/test_wan_22_image_to_video.py
+++ b/tests/pipelines/wan/test_wan_22_image_to_video.py
@@ -21,11 +21,11 @@ from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanImageToVideoPipeline, WanTransformer3DModel
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
torch_device,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/wan/test_wan_image_to_video.py b/tests/pipelines/wan/test_wan_image_to_video.py
index 1c938ce2de..07a9142f25 100644
--- a/tests/pipelines/wan/test_wan_image_to_video.py
+++ b/tests/pipelines/wan/test_wan_image_to_video.py
@@ -27,8 +27,8 @@ from transformers import (
)
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanImageToVideoPipeline, WanTransformer3DModel
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
diff --git a/tests/pipelines/wan/test_wan_vace.py b/tests/pipelines/wan/test_wan_vace.py
index 885defcfb4..f99863c880 100644
--- a/tests/pipelines/wan/test_wan_vace.py
+++ b/tests/pipelines/wan/test_wan_vace.py
@@ -20,8 +20,8 @@ from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel
-from diffusers.utils.testing_utils import enable_full_determinism
+from ...testing_utils import enable_full_determinism
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@@ -87,6 +87,7 @@ class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
+ "transformer_2": None,
}
return components
diff --git a/tests/pipelines/wan/test_wan_video_to_video.py b/tests/pipelines/wan/test_wan_video_to_video.py
index f4bb0960ac..27ada121ca 100644
--- a/tests/pipelines/wan/test_wan_video_to_video.py
+++ b/tests/pipelines/wan/test_wan_video_to_video.py
@@ -19,10 +19,10 @@ from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLWan, UniPCMultistepScheduler, WanTransformer3DModel, WanVideoToVideoPipeline
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
enable_full_determinism,
)
-
from ..pipeline_params import TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py
index 8e2a8515c6..c1da8f1ece 100644
--- a/tests/quantization/bnb/test_4bit.py
+++ b/tests/quantization/bnb/test_4bit.py
@@ -32,7 +32,8 @@ from diffusers import (
)
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils import is_accelerate_version, logging
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
CaptureLogger,
backend_empty_cache,
is_bitsandbytes_available,
@@ -50,7 +51,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..test_torch_compile_utils import QuantCompileTests
@@ -886,6 +886,7 @@ class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase):
components_to_quantize=["transformer", "text_encoder_2"],
)
+ @require_bitsandbytes_version_greater("0.46.1")
def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
super().test_torch_compile()
diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py
index 64f56b02b0..fde3966dec 100644
--- a/tests/quantization/bnb/test_mixed_int8.py
+++ b/tests/quantization/bnb/test_mixed_int8.py
@@ -32,7 +32,8 @@ from diffusers import (
)
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils import is_accelerate_version
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
CaptureLogger,
backend_empty_cache,
is_bitsandbytes_available,
@@ -51,7 +52,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..test_torch_compile_utils import QuantCompileTests
@@ -847,6 +847,10 @@ class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
components_to_quantize=["transformer", "text_encoder_2"],
)
+ @pytest.mark.xfail(
+ reason="Test fails because of an offloading problem from Accelerate with confusion in hooks."
+ " Test passes without recompilation context manager. Refer to https://github.com/huggingface/diffusers/pull/12002/files#r2240462757 for details."
+ )
def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
super()._test_torch_compile(torch_dtype=torch.float16)
diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py
index 9c79daf791..0f4fd408a7 100644
--- a/tests/quantization/gguf/test_gguf.py
+++ b/tests/quantization/gguf/test_gguf.py
@@ -20,7 +20,8 @@ from diffusers import (
WanVACETransformer3DModel,
)
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
Expectations,
backend_empty_cache,
backend_max_memory_allocated,
@@ -38,7 +39,6 @@ from diffusers.utils.testing_utils import (
require_torch_version_greater,
torch_device,
)
-
from ..test_torch_compile_utils import QuantCompileTests
@@ -212,6 +212,7 @@ class GGUFSingleFileTesterMixin:
class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
+ diffusers_ckpt_path = "https://huggingface.co/sayakpaul/flux-diffusers-gguf/blob/main/model-Q4_0.gguf"
torch_dtype = torch.bfloat16
model_cls = FluxTransformer2DModel
expected_memory_use_in_gb = 5
@@ -296,6 +297,16 @@ class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4
+ def test_loading_gguf_diffusers_format(self):
+ model = self.model_cls.from_single_file(
+ self.diffusers_ckpt_path,
+ subfolder="transformer",
+ quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
+ config="black-forest-labs/FLUX.1-dev",
+ )
+ model.to(torch_device)
+ model(**self.get_dummy_inputs())
+
class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-large-gguf/blob/main/sd3.5_large-Q4_0.gguf"
@@ -349,33 +360,33 @@ class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase)
{
("xpu", 3): np.array(
[
- 0.16210938,
- 0.2734375,
- 0.27734375,
- 0.109375,
- 0.27148438,
- 0.2578125,
- 0.1015625,
- 0.2578125,
- 0.2578125,
- 0.14453125,
- 0.26953125,
- 0.29492188,
- 0.12890625,
- 0.28710938,
- 0.30078125,
- 0.11132812,
- 0.27734375,
+ 0.16796875,
0.27929688,
- 0.15625,
- 0.31054688,
+ 0.28320312,
+ 0.11328125,
+ 0.27539062,
+ 0.26171875,
+ 0.10742188,
+ 0.26367188,
+ 0.26171875,
+ 0.1484375,
+ 0.2734375,
0.296875,
- 0.15234375,
- 0.3203125,
- 0.29492188,
- 0.140625,
- 0.3046875,
- 0.28515625,
+ 0.13476562,
+ 0.2890625,
+ 0.30078125,
+ 0.1171875,
+ 0.28125,
+ 0.28125,
+ 0.16015625,
+ 0.31445312,
+ 0.30078125,
+ 0.15625,
+ 0.32421875,
+ 0.296875,
+ 0.14453125,
+ 0.30859375,
+ 0.2890625,
]
),
("cuda", 7): np.array(
diff --git a/tests/quantization/modelopt/__init__.py b/tests/quantization/modelopt/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/quantization/modelopt/test_modelopt.py b/tests/quantization/modelopt/test_modelopt.py
new file mode 100644
index 0000000000..6b0624a280
--- /dev/null
+++ b/tests/quantization/modelopt/test_modelopt.py
@@ -0,0 +1,306 @@
+import gc
+import tempfile
+import unittest
+
+from diffusers import NVIDIAModelOptConfig, SD3Transformer2DModel, StableDiffusion3Pipeline
+from diffusers.utils import is_nvidia_modelopt_available, is_torch_available
+from diffusers.utils.testing_utils import (
+ backend_empty_cache,
+ backend_reset_peak_memory_stats,
+ enable_full_determinism,
+ nightly,
+ numpy_cosine_similarity_distance,
+ require_accelerate,
+ require_big_accelerator,
+ require_modelopt_version_greater_or_equal,
+ require_torch_cuda_compatibility,
+ torch_device,
+)
+
+
+if is_nvidia_modelopt_available():
+ import modelopt.torch.quantization as mtq
+
+if is_torch_available():
+ import torch
+
+ from ..utils import LoRALayer, get_memory_consumption_stat
+
+enable_full_determinism()
+
+
+@nightly
+@require_big_accelerator
+@require_accelerate
+@require_modelopt_version_greater_or_equal("0.33.1")
+class ModelOptBaseTesterMixin:
+ model_id = "hf-internal-testing/tiny-sd3-pipe"
+ model_cls = SD3Transformer2DModel
+ pipeline_cls = StableDiffusion3Pipeline
+ torch_dtype = torch.bfloat16
+ expected_memory_reduction = 0.0
+ keep_in_fp32_module = ""
+ modules_to_not_convert = ""
+ _test_torch_compile = False
+
+ def setUp(self):
+ backend_reset_peak_memory_stats(torch_device)
+ backend_empty_cache(torch_device)
+ gc.collect()
+
+ def tearDown(self):
+ backend_reset_peak_memory_stats(torch_device)
+ backend_empty_cache(torch_device)
+ gc.collect()
+
+ def get_dummy_init_kwargs(self):
+ return {"quant_type": "FP8"}
+
+ def get_dummy_model_init_kwargs(self):
+ return {
+ "pretrained_model_name_or_path": self.model_id,
+ "torch_dtype": self.torch_dtype,
+ "quantization_config": NVIDIAModelOptConfig(**self.get_dummy_init_kwargs()),
+ "subfolder": "transformer",
+ }
+
+ def test_modelopt_layers(self):
+ model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ assert mtq.utils.is_quantized(module)
+
+ def test_modelopt_memory_usage(self):
+ inputs = self.get_dummy_inputs()
+ inputs = {
+ k: v.to(device=torch_device, dtype=torch.bfloat16) for k, v in inputs.items() if not isinstance(v, bool)
+ }
+
+ unquantized_model = self.model_cls.from_pretrained(
+ self.model_id, torch_dtype=self.torch_dtype, subfolder="transformer"
+ )
+ unquantized_model.to(torch_device)
+ unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs)
+
+ quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
+ quantized_model.to(torch_device)
+ quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs)
+
+ assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction
+
+ def test_keep_modules_in_fp32(self):
+ _keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules
+ self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module
+
+ model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
+ model.to(torch_device)
+
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ if name in model._keep_in_fp32_modules:
+ assert module.weight.dtype == torch.float32
+ self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules
+
+ def test_modules_to_not_convert(self):
+ init_kwargs = self.get_dummy_model_init_kwargs()
+ quantization_config_kwargs = self.get_dummy_init_kwargs()
+ quantization_config_kwargs.update({"modules_to_not_convert": self.modules_to_not_convert})
+ quantization_config = NVIDIAModelOptConfig(**quantization_config_kwargs)
+ init_kwargs.update({"quantization_config": quantization_config})
+
+ model = self.model_cls.from_pretrained(**init_kwargs)
+ model.to(torch_device)
+
+ for name, module in model.named_modules():
+ if name in self.modules_to_not_convert:
+ assert not mtq.utils.is_quantized(module)
+
+ def test_dtype_assignment(self):
+ model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
+
+ with self.assertRaises(ValueError):
+ model.to(torch.float16)
+
+ with self.assertRaises(ValueError):
+ device_0 = f"{torch_device}:0"
+ model.to(device=device_0, dtype=torch.float16)
+
+ with self.assertRaises(ValueError):
+ model.float()
+
+ with self.assertRaises(ValueError):
+ model.half()
+
+ model.to(torch_device)
+
+ def test_serialization(self):
+ model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
+ inputs = self.get_dummy_inputs()
+
+ model.to(torch_device)
+ with torch.no_grad():
+ model_output = model(**inputs)
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model.save_pretrained(tmp_dir)
+ saved_model = self.model_cls.from_pretrained(
+ tmp_dir,
+ torch_dtype=torch.bfloat16,
+ )
+
+ saved_model.to(torch_device)
+ with torch.no_grad():
+ saved_model_output = saved_model(**inputs)
+
+ assert torch.allclose(model_output.sample, saved_model_output.sample, rtol=1e-5, atol=1e-5)
+
+ def test_torch_compile(self):
+ if not self._test_torch_compile:
+ return
+
+ model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
+ compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True, dynamic=False)
+
+ model.to(torch_device)
+ with torch.no_grad():
+ model_output = model(**self.get_dummy_inputs()).sample
+
+ compiled_model.to(torch_device)
+ with torch.no_grad():
+ compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample
+
+ model_output = model_output.detach().float().cpu().numpy()
+ compiled_model_output = compiled_model_output.detach().float().cpu().numpy()
+
+ max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten())
+ assert max_diff < 1e-3
+
+ def test_device_map_error(self):
+ with self.assertRaises(ValueError):
+ _ = self.model_cls.from_pretrained(
+ **self.get_dummy_model_init_kwargs(),
+ device_map={0: "8GB", "cpu": "16GB"},
+ )
+
+ def get_dummy_inputs(self):
+ batch_size = 1
+ seq_len = 16
+ height = width = 32
+ num_latent_channels = 4
+ caption_channels = 8
+
+ torch.manual_seed(0)
+ hidden_states = torch.randn((batch_size, num_latent_channels, height, width)).to(
+ torch_device, dtype=torch.bfloat16
+ )
+ encoder_hidden_states = torch.randn((batch_size, seq_len, caption_channels)).to(
+ torch_device, dtype=torch.bfloat16
+ )
+ timestep = torch.tensor([1.0]).to(torch_device, dtype=torch.bfloat16).expand(batch_size)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ }
+
+ def test_model_cpu_offload(self):
+ init_kwargs = self.get_dummy_init_kwargs()
+ transformer = self.model_cls.from_pretrained(
+ self.model_id,
+ quantization_config=NVIDIAModelOptConfig(**init_kwargs),
+ subfolder="transformer",
+ torch_dtype=torch.bfloat16,
+ )
+ pipe = self.pipeline_cls.from_pretrained(self.model_id, transformer=transformer, torch_dtype=torch.bfloat16)
+ pipe.enable_model_cpu_offload(device=torch_device)
+ _ = pipe("a cat holding a sign that says hello", num_inference_steps=2)
+
+ def test_training(self):
+ quantization_config = NVIDIAModelOptConfig(**self.get_dummy_init_kwargs())
+ quantized_model = self.model_cls.from_pretrained(
+ self.model_id,
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.bfloat16,
+ ).to(torch_device)
+
+ for param in quantized_model.parameters():
+ param.requires_grad = False
+ if param.ndim == 1:
+ param.data = param.data.to(torch.float32)
+
+ for _, module in quantized_model.named_modules():
+ if hasattr(module, "to_q"):
+ module.to_q = LoRALayer(module.to_q, rank=4)
+ if hasattr(module, "to_k"):
+ module.to_k = LoRALayer(module.to_k, rank=4)
+ if hasattr(module, "to_v"):
+ module.to_v = LoRALayer(module.to_v, rank=4)
+
+ with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16):
+ inputs = self.get_dummy_inputs()
+ output = quantized_model(**inputs)[0]
+ output.norm().backward()
+
+ for module in quantized_model.modules():
+ if isinstance(module, LoRALayer):
+ self.assertTrue(module.adapter[1].weight.grad is not None)
+
+
+class SanaTransformerFP8WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
+ expected_memory_reduction = 0.6
+
+ def get_dummy_init_kwargs(self):
+ return {"quant_type": "FP8"}
+
+
+class SanaTransformerINT8WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
+ expected_memory_reduction = 0.6
+ _test_torch_compile = True
+
+ def get_dummy_init_kwargs(self):
+ return {"quant_type": "INT8"}
+
+
+@require_torch_cuda_compatibility(8.0)
+class SanaTransformerINT4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
+ expected_memory_reduction = 0.55
+
+ def get_dummy_init_kwargs(self):
+ return {
+ "quant_type": "INT4",
+ "block_quantize": 128,
+ "channel_quantize": -1,
+ "disable_conv_quantization": True,
+ }
+
+
+@require_torch_cuda_compatibility(8.0)
+class SanaTransformerNF4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
+ expected_memory_reduction = 0.65
+
+ def get_dummy_init_kwargs(self):
+ return {
+ "quant_type": "NF4",
+ "block_quantize": 128,
+ "channel_quantize": -1,
+ "scale_block_quantize": 8,
+ "scale_channel_quantize": -1,
+ "modules_to_not_convert": ["conv"],
+ }
+
+
+@require_torch_cuda_compatibility(8.0)
+class SanaTransformerNVFP4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase):
+ expected_memory_reduction = 0.65
+
+ def get_dummy_init_kwargs(self):
+ return {
+ "quant_type": "NVFP4",
+ "block_quantize": 128,
+ "channel_quantize": -1,
+ "scale_block_quantize": 8,
+ "scale_channel_quantize": -1,
+ "modules_to_not_convert": ["conv"],
+ }
diff --git a/tests/quantization/quanto/test_quanto.py b/tests/quantization/quanto/test_quanto.py
index d7bde6591d..e3463f136f 100644
--- a/tests/quantization/quanto/test_quanto.py
+++ b/tests/quantization/quanto/test_quanto.py
@@ -5,14 +5,15 @@ import unittest
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
from diffusers.models.attention_processor import Attention
from diffusers.utils import is_optimum_quanto_available, is_torch_available
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_reset_peak_memory_stats,
enable_full_determinism,
nightly,
numpy_cosine_similarity_distance,
require_accelerate,
- require_big_accelerator,
+ require_accelerator,
require_torch_cuda_compatibility,
torch_device,
)
@@ -30,7 +31,7 @@ enable_full_determinism()
@nightly
-@require_big_accelerator
+@require_accelerator
@require_accelerate
class QuantoBaseTesterMixin:
model_id = None
diff --git a/tests/quantization/test_pipeline_level_quantization.py b/tests/quantization/test_pipeline_level_quantization.py
index e91fe6d4cb..5f1a3de2e5 100644
--- a/tests/quantization/test_pipeline_level_quantization.py
+++ b/tests/quantization/test_pipeline_level_quantization.py
@@ -22,7 +22,8 @@ from parameterized import parameterized
from diffusers import BitsAndBytesConfig, DiffusionPipeline, QuantoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils import logging
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
CaptureLogger,
is_transformers_available,
require_accelerate,
@@ -298,3 +299,19 @@ transformer BitsAndBytesConfig {
data = json.loads(json_part)
return data
+
+ def test_single_component_to_quantize(self):
+ component_to_quantize = "transformer"
+ quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_8bit",
+ quant_kwargs={"load_in_8bit": True},
+ components_to_quantize=component_to_quantize,
+ )
+ pipe = DiffusionPipeline.from_pretrained(
+ self.model_name,
+ quantization_config=quant_config,
+ torch_dtype=torch.bfloat16,
+ )
+ for name, component in pipe.components.items():
+ if name == component_to_quantize:
+ self.assertTrue(hasattr(component.config, "quantization_config"))
diff --git a/tests/quantization/test_torch_compile_utils.py b/tests/quantization/test_torch_compile_utils.py
index c742927646..29758cbdd7 100644
--- a/tests/quantization/test_torch_compile_utils.py
+++ b/tests/quantization/test_torch_compile_utils.py
@@ -18,7 +18,8 @@ import inspect
import torch
from diffusers import DiffusionPipeline
-from diffusers.utils.testing_utils import backend_empty_cache, require_torch_accelerator, slow, torch_device
+
+from ..testing_utils import backend_empty_cache, require_torch_accelerator, slow, torch_device
@require_torch_accelerator
@@ -56,12 +57,18 @@ class QuantCompileTests:
pipe.transformer.compile(fullgraph=True)
# small resolutions to ensure speedy execution.
- pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
+ with torch._dynamo.config.patch(error_on_recompile=True):
+ pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
pipe.enable_model_cpu_offload()
- pipe.transformer.compile()
+ # regional compilation is better for offloading.
+ # see: https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/
+ if getattr(pipe.transformer, "_repeated_blocks"):
+ pipe.transformer.compile_repeated_blocks(fullgraph=True)
+ else:
+ pipe.transformer.compile()
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py
index 5dcc207e65..38997de17b 100644
--- a/tests/quantization/torchao/test_torchao.py
+++ b/tests/quantization/torchao/test_torchao.py
@@ -14,11 +14,13 @@
# limitations under the License.
import gc
+import importlib.metadata
import tempfile
import unittest
from typing import List
import numpy as np
+from packaging import version
from parameterized import parameterized
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
@@ -31,7 +33,8 @@ from diffusers import (
)
from diffusers.models.attention_processor import Attention
from diffusers.quantizers import PipelineQuantizationConfig
-from diffusers.utils.testing_utils import (
+
+from ...testing_utils import (
backend_empty_cache,
backend_synchronize,
enable_full_determinism,
@@ -45,7 +48,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from ..test_torch_compile_utils import QuantCompileTests
@@ -65,6 +67,9 @@ if is_torchao_available():
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import get_model_size_in_bytes
+ if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.9.0"):
+ from torchao.quantization import Int8WeightOnlyConfig
+
@require_torch
@require_torch_accelerator
@@ -522,6 +527,15 @@ class TorchAoTest(unittest.TestCase):
inputs = self.get_dummy_inputs(torch_device)
_ = pipe(**inputs)
+ @require_torchao_version_greater_or_equal("0.9.0")
+ def test_aobase_config(self):
+ quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
+ components = self.get_dummy_components(quantization_config)
+ pipe = FluxPipeline(**components).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ _ = pipe(**inputs)
+
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@@ -628,6 +642,14 @@ class TorchAoSerializationTest(unittest.TestCase):
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
+ @require_torchao_version_greater_or_equal("0.9.0")
+ def test_aobase_config(self):
+ quant_method, quant_method_kwargs = Int8WeightOnlyConfig(), {}
+ expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
+ device = torch_device
+ self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
+ self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
+
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py
index d458a3e6d5..a74ece5a3a 100644
--- a/tests/quantization/utils.py
+++ b/tests/quantization/utils.py
@@ -1,5 +1,6 @@
from diffusers.utils import is_torch_available
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_peak_memory_stats,
diff --git a/tests/remote/test_remote_decode.py b/tests/remote/test_remote_decode.py
index cec96e729a..27170cba08 100644
--- a/tests/remote/test_remote_decode.py
+++ b/tests/remote/test_remote_decode.py
@@ -30,13 +30,14 @@ from diffusers.utils.constants import (
from diffusers.utils.remote_utils import (
remote_decode,
)
-from diffusers.utils.testing_utils import (
+from diffusers.video_processor import VideoProcessor
+
+from ..testing_utils import (
enable_full_determinism,
slow,
torch_all_close,
torch_device,
)
-from diffusers.video_processor import VideoProcessor
enable_full_determinism()
diff --git a/tests/remote/test_remote_encode.py b/tests/remote/test_remote_encode.py
index 62ed97ee8f..4c0daf08fd 100644
--- a/tests/remote/test_remote_encode.py
+++ b/tests/remote/test_remote_encode.py
@@ -31,7 +31,8 @@ from diffusers.utils.remote_utils import (
remote_decode,
remote_encode,
)
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
enable_full_determinism,
slow,
)
diff --git a/tests/schedulers/test_scheduler_dpm_sde.py b/tests/schedulers/test_scheduler_dpm_sde.py
index 69b6111734..e4dde67344 100644
--- a/tests/schedulers/test_scheduler_dpm_sde.py
+++ b/tests/schedulers/test_scheduler_dpm_sde.py
@@ -1,8 +1,8 @@
import torch
from diffusers import DPMSolverSDEScheduler
-from diffusers.utils.testing_utils import require_torchsde, torch_device
+from ..testing_utils import require_torchsde, torch_device
from .test_schedulers import SchedulerCommonTest
diff --git a/tests/schedulers/test_scheduler_euler.py b/tests/schedulers/test_scheduler_euler.py
index 01e173a631..ee99465abf 100644
--- a/tests/schedulers/test_scheduler_euler.py
+++ b/tests/schedulers/test_scheduler_euler.py
@@ -1,8 +1,8 @@
import torch
from diffusers import EulerDiscreteScheduler
-from diffusers.utils.testing_utils import torch_device
+from ..testing_utils import torch_device
from .test_schedulers import SchedulerCommonTest
diff --git a/tests/schedulers/test_scheduler_euler_ancestral.py b/tests/schedulers/test_scheduler_euler_ancestral.py
index 9f22ab38dd..c4fe61bfc3 100644
--- a/tests/schedulers/test_scheduler_euler_ancestral.py
+++ b/tests/schedulers/test_scheduler_euler_ancestral.py
@@ -1,8 +1,8 @@
import torch
from diffusers import EulerAncestralDiscreteScheduler
-from diffusers.utils.testing_utils import torch_device
+from ..testing_utils import torch_device
from .test_schedulers import SchedulerCommonTest
diff --git a/tests/schedulers/test_scheduler_flax.py b/tests/schedulers/test_scheduler_flax.py
deleted file mode 100644
index c8121d3341..0000000000
--- a/tests/schedulers/test_scheduler_flax.py
+++ /dev/null
@@ -1,920 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import inspect
-import tempfile
-import unittest
-from typing import Dict, List, Tuple
-
-from diffusers import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler
-from diffusers.utils import is_flax_available
-from diffusers.utils.testing_utils import require_flax
-
-
-if is_flax_available():
- import jax
- import jax.numpy as jnp
- from jax import random
-
- jax_device = jax.default_backend()
-
-
-@require_flax
-class FlaxSchedulerCommonTest(unittest.TestCase):
- scheduler_classes = ()
- forward_default_kwargs = ()
-
- @property
- def dummy_sample(self):
- batch_size = 4
- num_channels = 3
- height = 8
- width = 8
-
- key1, key2 = random.split(random.PRNGKey(0))
- sample = random.uniform(key1, (batch_size, num_channels, height, width))
-
- return sample, key2
-
- @property
- def dummy_sample_deter(self):
- batch_size = 4
- num_channels = 3
- height = 8
- width = 8
-
- num_elems = batch_size * num_channels * height * width
- sample = jnp.arange(num_elems)
- sample = sample.reshape(num_channels, height, width, batch_size)
- sample = sample / num_elems
- return jnp.transpose(sample, (3, 0, 1, 2))
-
- def get_scheduler_config(self):
- raise NotImplementedError
-
- def dummy_model(self):
- def model(sample, t, *args):
- return sample * t / (t + 1)
-
- return model
-
- def check_over_configs(self, time_step=0, **config):
- kwargs = dict(self.forward_default_kwargs)
-
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- sample, key = self.dummy_sample
- residual = 0.1 * sample
-
- scheduler_config = self.get_scheduler_config(**config)
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample
- new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample
-
- assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
-
- def check_over_forward(self, time_step=0, **forward_kwargs):
- kwargs = dict(self.forward_default_kwargs)
- kwargs.update(forward_kwargs)
-
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- sample, key = self.dummy_sample
- residual = 0.1 * sample
-
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample
- new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample
-
- assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
-
- def test_from_save_pretrained(self):
- kwargs = dict(self.forward_default_kwargs)
-
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- sample, key = self.dummy_sample
- residual = 0.1 * sample
-
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- output = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample
- new_output = new_scheduler.step(new_state, residual, 1, sample, key, **kwargs).prev_sample
-
- assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
-
- def test_step_shape(self):
- kwargs = dict(self.forward_default_kwargs)
-
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- sample, key = self.dummy_sample
- residual = 0.1 * sample
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- output_0 = scheduler.step(state, residual, 0, sample, key, **kwargs).prev_sample
- output_1 = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample
-
- self.assertEqual(output_0.shape, sample.shape)
- self.assertEqual(output_0.shape, output_1.shape)
-
- def test_scheduler_outputs_equivalence(self):
- def set_nan_tensor_to_zero(t):
- return t.at[t != t].set(0)
-
- def recursive_check(tuple_object, dict_object):
- if isinstance(tuple_object, (List, Tuple)):
- for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
- recursive_check(tuple_iterable_value, dict_iterable_value)
- elif isinstance(tuple_object, Dict):
- for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
- recursive_check(tuple_iterable_value, dict_iterable_value)
- elif tuple_object is None:
- return
- else:
- self.assertTrue(
- jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
- msg=(
- "Tuple and dict output are not equal. Difference:"
- f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
- f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
- f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
- ),
- )
-
- kwargs = dict(self.forward_default_kwargs)
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- sample, key = self.dummy_sample
- residual = 0.1 * sample
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- outputs_dict = scheduler.step(state, residual, 0, sample, key, **kwargs)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- outputs_tuple = scheduler.step(state, residual, 0, sample, key, return_dict=False, **kwargs)
-
- recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
-
- def test_deprecated_kwargs(self):
- for scheduler_class in self.scheduler_classes:
- has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters
- has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0
-
- if has_kwarg_in_model_class and not has_deprecated_kwarg:
- raise ValueError(
- f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated"
- " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if"
- " there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
- " []`"
- )
-
- if not has_kwarg_in_model_class and has_deprecated_kwarg:
- raise ValueError(
- f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated"
- " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`"
- f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the"
- " deprecated argument from `_deprecated_kwargs = []`"
- )
-
-
-@require_flax
-class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest):
- scheduler_classes = (FlaxDDPMScheduler,)
-
- def get_scheduler_config(self, **kwargs):
- config = {
- "num_train_timesteps": 1000,
- "beta_start": 0.0001,
- "beta_end": 0.02,
- "beta_schedule": "linear",
- "variance_type": "fixed_small",
- "clip_sample": True,
- }
-
- config.update(**kwargs)
- return config
-
- def test_timesteps(self):
- for timesteps in [1, 5, 100, 1000]:
- self.check_over_configs(num_train_timesteps=timesteps)
-
- def test_betas(self):
- for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
- self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
-
- def test_schedules(self):
- for schedule in ["linear", "squaredcos_cap_v2"]:
- self.check_over_configs(beta_schedule=schedule)
-
- def test_variance_type(self):
- for variance in ["fixed_small", "fixed_large", "other"]:
- self.check_over_configs(variance_type=variance)
-
- def test_clip_sample(self):
- for clip_sample in [True, False]:
- self.check_over_configs(clip_sample=clip_sample)
-
- def test_time_indices(self):
- for t in [0, 500, 999]:
- self.check_over_forward(time_step=t)
-
- def test_variance(self):
- scheduler_class = self.scheduler_classes[0]
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0) - 0.0)) < 1e-5
- assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487) - 0.00979)) < 1e-5
- assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999) - 0.02)) < 1e-5
-
- def test_full_loop_no_noise(self):
- scheduler_class = self.scheduler_classes[0]
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- num_trained_timesteps = len(scheduler)
-
- model = self.dummy_model()
- sample = self.dummy_sample_deter
- key1, key2 = random.split(random.PRNGKey(0))
-
- for t in reversed(range(num_trained_timesteps)):
- # 1. predict noise residual
- residual = model(sample, t)
-
- # 2. predict previous mean of sample x_t-1
- output = scheduler.step(state, residual, t, sample, key1)
- pred_prev_sample = output.prev_sample
- state = output.state
- key1, key2 = random.split(key2)
-
- # if t > 0:
- # noise = self.dummy_sample_deter
- # variance = scheduler.get_variance(t) ** (0.5) * noise
- #
- # sample = pred_prev_sample + variance
- sample = pred_prev_sample
-
- result_sum = jnp.sum(jnp.abs(sample))
- result_mean = jnp.mean(jnp.abs(sample))
-
- if jax_device == "tpu":
- assert abs(result_sum - 255.0714) < 1e-2
- assert abs(result_mean - 0.332124) < 1e-3
- else:
- assert abs(result_sum - 270.2) < 1e-1
- assert abs(result_mean - 0.3519494) < 1e-3
-
-
-@require_flax
-class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest):
- scheduler_classes = (FlaxDDIMScheduler,)
- forward_default_kwargs = (("num_inference_steps", 50),)
-
- def get_scheduler_config(self, **kwargs):
- config = {
- "num_train_timesteps": 1000,
- "beta_start": 0.0001,
- "beta_end": 0.02,
- "beta_schedule": "linear",
- }
-
- config.update(**kwargs)
- return config
-
- def full_loop(self, **config):
- scheduler_class = self.scheduler_classes[0]
- scheduler_config = self.get_scheduler_config(**config)
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
- key1, key2 = random.split(random.PRNGKey(0))
-
- num_inference_steps = 10
-
- model = self.dummy_model()
- sample = self.dummy_sample_deter
-
- state = scheduler.set_timesteps(state, num_inference_steps)
-
- for t in state.timesteps:
- residual = model(sample, t)
- output = scheduler.step(state, residual, t, sample)
- sample = output.prev_sample
- state = output.state
- key1, key2 = random.split(key2)
-
- return sample
-
- def check_over_configs(self, time_step=0, **config):
- kwargs = dict(self.forward_default_kwargs)
-
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- sample, _ = self.dummy_sample
- residual = 0.1 * sample
-
- scheduler_config = self.get_scheduler_config(**config)
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample
- new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample
-
- assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
-
- def test_from_save_pretrained(self):
- kwargs = dict(self.forward_default_kwargs)
-
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- sample, _ = self.dummy_sample
- residual = 0.1 * sample
-
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample
- new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample
-
- assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
-
- def check_over_forward(self, time_step=0, **forward_kwargs):
- kwargs = dict(self.forward_default_kwargs)
- kwargs.update(forward_kwargs)
-
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- sample, _ = self.dummy_sample
- residual = 0.1 * sample
-
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- new_state = new_scheduler.set_timesteps(new_state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample
- new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample
-
- assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
-
- def test_scheduler_outputs_equivalence(self):
- def set_nan_tensor_to_zero(t):
- return t.at[t != t].set(0)
-
- def recursive_check(tuple_object, dict_object):
- if isinstance(tuple_object, (List, Tuple)):
- for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
- recursive_check(tuple_iterable_value, dict_iterable_value)
- elif isinstance(tuple_object, Dict):
- for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
- recursive_check(tuple_iterable_value, dict_iterable_value)
- elif tuple_object is None:
- return
- else:
- self.assertTrue(
- jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
- msg=(
- "Tuple and dict output are not equal. Difference:"
- f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
- f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
- f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
- ),
- )
-
- kwargs = dict(self.forward_default_kwargs)
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- sample, _ = self.dummy_sample
- residual = 0.1 * sample
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs)
-
- recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
-
- def test_step_shape(self):
- kwargs = dict(self.forward_default_kwargs)
-
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- sample, _ = self.dummy_sample
- residual = 0.1 * sample
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- output_0 = scheduler.step(state, residual, 0, sample, **kwargs).prev_sample
- output_1 = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample
-
- self.assertEqual(output_0.shape, sample.shape)
- self.assertEqual(output_0.shape, output_1.shape)
-
- def test_timesteps(self):
- for timesteps in [100, 500, 1000]:
- self.check_over_configs(num_train_timesteps=timesteps)
-
- def test_steps_offset(self):
- for steps_offset in [0, 1]:
- self.check_over_configs(steps_offset=steps_offset)
-
- scheduler_class = self.scheduler_classes[0]
- scheduler_config = self.get_scheduler_config(steps_offset=1)
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
- state = scheduler.set_timesteps(state, 5)
- assert jnp.equal(state.timesteps, jnp.array([801, 601, 401, 201, 1])).all()
-
- def test_betas(self):
- for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
- self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
-
- def test_schedules(self):
- for schedule in ["linear", "squaredcos_cap_v2"]:
- self.check_over_configs(beta_schedule=schedule)
-
- def test_time_indices(self):
- for t in [1, 10, 49]:
- self.check_over_forward(time_step=t)
-
- def test_inference_steps(self):
- for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
- self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
-
- def test_variance(self):
- scheduler_class = self.scheduler_classes[0]
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5
- assert jnp.sum(jnp.abs(scheduler._get_variance(state, 420, 400) - 0.14771)) < 1e-5
- assert jnp.sum(jnp.abs(scheduler._get_variance(state, 980, 960) - 0.32460)) < 1e-5
- assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5
- assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487, 486) - 0.00979)) < 1e-5
- assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999, 998) - 0.02)) < 1e-5
-
- def test_full_loop_no_noise(self):
- sample = self.full_loop()
-
- result_sum = jnp.sum(jnp.abs(sample))
- result_mean = jnp.mean(jnp.abs(sample))
-
- assert abs(result_sum - 172.0067) < 1e-2
- assert abs(result_mean - 0.223967) < 1e-3
-
- def test_full_loop_with_set_alpha_to_one(self):
- # We specify different beta, so that the first alpha is 0.99
- sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
- result_sum = jnp.sum(jnp.abs(sample))
- result_mean = jnp.mean(jnp.abs(sample))
-
- if jax_device == "tpu":
- assert abs(result_sum - 149.8409) < 1e-2
- assert abs(result_mean - 0.1951) < 1e-3
- else:
- assert abs(result_sum - 149.8295) < 1e-2
- assert abs(result_mean - 0.1951) < 1e-3
-
- def test_full_loop_with_no_set_alpha_to_one(self):
- # We specify different beta, so that the first alpha is 0.99
- sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
- result_sum = jnp.sum(jnp.abs(sample))
- result_mean = jnp.mean(jnp.abs(sample))
-
- if jax_device == "tpu":
- pass
- # FIXME: both result_sum and result_mean are nan on TPU
- # assert jnp.isnan(result_sum)
- # assert jnp.isnan(result_mean)
- else:
- assert abs(result_sum - 149.0784) < 1e-2
- assert abs(result_mean - 0.1941) < 1e-3
-
- def test_prediction_type(self):
- for prediction_type in ["epsilon", "sample", "v_prediction"]:
- self.check_over_configs(prediction_type=prediction_type)
-
-
-@require_flax
-class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest):
- scheduler_classes = (FlaxPNDMScheduler,)
- forward_default_kwargs = (("num_inference_steps", 50),)
-
- def get_scheduler_config(self, **kwargs):
- config = {
- "num_train_timesteps": 1000,
- "beta_start": 0.0001,
- "beta_end": 0.02,
- "beta_schedule": "linear",
- }
-
- config.update(**kwargs)
- return config
-
- def check_over_configs(self, time_step=0, **config):
- kwargs = dict(self.forward_default_kwargs)
- num_inference_steps = kwargs.pop("num_inference_steps", None)
- sample, _ = self.dummy_sample
- residual = 0.1 * sample
- dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
-
- for scheduler_class in self.scheduler_classes:
- scheduler_config = self.get_scheduler_config(**config)
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
- state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
- # copy over dummy past residuals
- state = state.replace(ets=dummy_past_residuals[:])
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
- new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
- # copy over dummy past residuals
- new_state = new_state.replace(ets=dummy_past_residuals[:])
-
- (prev_sample, state) = scheduler.step_prk(state, residual, time_step, sample, **kwargs)
- (new_prev_sample, new_state) = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs)
-
- assert jnp.sum(jnp.abs(prev_sample - new_prev_sample)) < 1e-5, "Scheduler outputs are not identical"
-
- output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs)
- new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs)
-
- assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
-
- @unittest.skip("Test not supported.")
- def test_from_save_pretrained(self):
- pass
-
- def test_scheduler_outputs_equivalence(self):
- def set_nan_tensor_to_zero(t):
- return t.at[t != t].set(0)
-
- def recursive_check(tuple_object, dict_object):
- if isinstance(tuple_object, (List, Tuple)):
- for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
- recursive_check(tuple_iterable_value, dict_iterable_value)
- elif isinstance(tuple_object, Dict):
- for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
- recursive_check(tuple_iterable_value, dict_iterable_value)
- elif tuple_object is None:
- return
- else:
- self.assertTrue(
- jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5),
- msg=(
- "Tuple and dict output are not equal. Difference:"
- f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:"
- f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has"
- f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}."
- ),
- )
-
- kwargs = dict(self.forward_default_kwargs)
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- sample, _ = self.dummy_sample
- residual = 0.1 * sample
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs)
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs)
-
- recursive_check(outputs_tuple[0], outputs_dict.prev_sample)
-
- def check_over_forward(self, time_step=0, **forward_kwargs):
- kwargs = dict(self.forward_default_kwargs)
- num_inference_steps = kwargs.pop("num_inference_steps", None)
- sample, _ = self.dummy_sample
- residual = 0.1 * sample
- dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
-
- for scheduler_class in self.scheduler_classes:
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
- state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
-
- # copy over dummy past residuals (must be after setting timesteps)
- scheduler.ets = dummy_past_residuals[:]
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- scheduler.save_config(tmpdirname)
- new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname)
- # copy over dummy past residuals
- new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)
-
- # copy over dummy past residual (must be after setting timesteps)
- new_state.replace(ets=dummy_past_residuals[:])
-
- output, state = scheduler.step_prk(state, residual, time_step, sample, **kwargs)
- new_output, new_state = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs)
-
- assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
-
- output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs)
- new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs)
-
- assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
-
- def full_loop(self, **config):
- scheduler_class = self.scheduler_classes[0]
- scheduler_config = self.get_scheduler_config(**config)
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- num_inference_steps = 10
- model = self.dummy_model()
- sample = self.dummy_sample_deter
- state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
-
- for i, t in enumerate(state.prk_timesteps):
- residual = model(sample, t)
- sample, state = scheduler.step_prk(state, residual, t, sample)
-
- for i, t in enumerate(state.plms_timesteps):
- residual = model(sample, t)
- sample, state = scheduler.step_plms(state, residual, t, sample)
-
- return sample
-
- def test_step_shape(self):
- kwargs = dict(self.forward_default_kwargs)
-
- num_inference_steps = kwargs.pop("num_inference_steps", None)
-
- for scheduler_class in self.scheduler_classes:
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- sample, _ = self.dummy_sample
- residual = 0.1 * sample
-
- if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
- state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
- elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
- kwargs["num_inference_steps"] = num_inference_steps
-
- # copy over dummy past residuals (must be done after set_timesteps)
- dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05])
- state = state.replace(ets=dummy_past_residuals[:])
-
- output_0, state = scheduler.step_prk(state, residual, 0, sample, **kwargs)
- output_1, state = scheduler.step_prk(state, residual, 1, sample, **kwargs)
-
- self.assertEqual(output_0.shape, sample.shape)
- self.assertEqual(output_0.shape, output_1.shape)
-
- output_0, state = scheduler.step_plms(state, residual, 0, sample, **kwargs)
- output_1, state = scheduler.step_plms(state, residual, 1, sample, **kwargs)
-
- self.assertEqual(output_0.shape, sample.shape)
- self.assertEqual(output_0.shape, output_1.shape)
-
- def test_timesteps(self):
- for timesteps in [100, 1000]:
- self.check_over_configs(num_train_timesteps=timesteps)
-
- def test_steps_offset(self):
- for steps_offset in [0, 1]:
- self.check_over_configs(steps_offset=steps_offset)
-
- scheduler_class = self.scheduler_classes[0]
- scheduler_config = self.get_scheduler_config(steps_offset=1)
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
- state = scheduler.set_timesteps(state, 10, shape=())
- assert jnp.equal(
- state.timesteps,
- jnp.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]),
- ).all()
-
- def test_betas(self):
- for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):
- self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
-
- def test_schedules(self):
- for schedule in ["linear", "squaredcos_cap_v2"]:
- self.check_over_configs(beta_schedule=schedule)
-
- def test_time_indices(self):
- for t in [1, 5, 10]:
- self.check_over_forward(time_step=t)
-
- def test_inference_steps(self):
- for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
- self.check_over_forward(num_inference_steps=num_inference_steps)
-
- def test_pow_of_3_inference_steps(self):
- # earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3
- num_inference_steps = 27
-
- for scheduler_class in self.scheduler_classes:
- sample, _ = self.dummy_sample
- residual = 0.1 * sample
-
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape)
-
- # before power of 3 fix, would error on first step, so we only need to do two
- for i, t in enumerate(state.prk_timesteps[:2]):
- sample, state = scheduler.step_prk(state, residual, t, sample)
-
- def test_inference_plms_no_past_residuals(self):
- with self.assertRaises(ValueError):
- scheduler_class = self.scheduler_classes[0]
- scheduler_config = self.get_scheduler_config()
- scheduler = scheduler_class(**scheduler_config)
- state = scheduler.create_state()
-
- scheduler.step_plms(state, self.dummy_sample, 1, self.dummy_sample).prev_sample
-
- def test_full_loop_no_noise(self):
- sample = self.full_loop()
- result_sum = jnp.sum(jnp.abs(sample))
- result_mean = jnp.mean(jnp.abs(sample))
-
- if jax_device == "tpu":
- assert abs(result_sum - 198.1275) < 1e-2
- assert abs(result_mean - 0.2580) < 1e-3
- else:
- assert abs(result_sum - 198.1318) < 1e-2
- assert abs(result_mean - 0.2580) < 1e-3
-
- def test_full_loop_with_set_alpha_to_one(self):
- # We specify different beta, so that the first alpha is 0.99
- sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
- result_sum = jnp.sum(jnp.abs(sample))
- result_mean = jnp.mean(jnp.abs(sample))
-
- if jax_device == "tpu":
- assert abs(result_sum - 186.83226) < 1e-2
- assert abs(result_mean - 0.24327) < 1e-3
- else:
- assert abs(result_sum - 186.9466) < 1e-2
- assert abs(result_mean - 0.24342) < 1e-3
-
- def test_full_loop_with_no_set_alpha_to_one(self):
- # We specify different beta, so that the first alpha is 0.99
- sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
- result_sum = jnp.sum(jnp.abs(sample))
- result_mean = jnp.mean(jnp.abs(sample))
-
- if jax_device == "tpu":
- assert abs(result_sum - 186.83226) < 1e-2
- assert abs(result_mean - 0.24327) < 1e-3
- else:
- assert abs(result_sum - 186.9482) < 1e-2
- assert abs(result_mean - 0.2434) < 1e-3
diff --git a/tests/schedulers/test_scheduler_heun.py b/tests/schedulers/test_scheduler_heun.py
index 90012f5525..97bef50048 100644
--- a/tests/schedulers/test_scheduler_heun.py
+++ b/tests/schedulers/test_scheduler_heun.py
@@ -1,8 +1,8 @@
import torch
from diffusers import HeunDiscreteScheduler
-from diffusers.utils.testing_utils import torch_device
+from ..testing_utils import torch_device
from .test_schedulers import SchedulerCommonTest
diff --git a/tests/schedulers/test_scheduler_kdpm2_ancestral.py b/tests/schedulers/test_scheduler_kdpm2_ancestral.py
index fa85c2be45..135534db45 100644
--- a/tests/schedulers/test_scheduler_kdpm2_ancestral.py
+++ b/tests/schedulers/test_scheduler_kdpm2_ancestral.py
@@ -1,8 +1,8 @@
import torch
from diffusers import KDPM2AncestralDiscreteScheduler
-from diffusers.utils.testing_utils import torch_device
+from ..testing_utils import torch_device
from .test_schedulers import SchedulerCommonTest
diff --git a/tests/schedulers/test_scheduler_kdpm2_discrete.py b/tests/schedulers/test_scheduler_kdpm2_discrete.py
index 4d8923b694..370ba2253e 100644
--- a/tests/schedulers/test_scheduler_kdpm2_discrete.py
+++ b/tests/schedulers/test_scheduler_kdpm2_discrete.py
@@ -1,8 +1,8 @@
import torch
from diffusers import KDPM2DiscreteScheduler
-from diffusers.utils.testing_utils import torch_device
+from ..testing_utils import torch_device
from .test_schedulers import SchedulerCommonTest
diff --git a/tests/schedulers/test_scheduler_lcm.py b/tests/schedulers/test_scheduler_lcm.py
index f3f6e9ba58..f54970e0eb 100644
--- a/tests/schedulers/test_scheduler_lcm.py
+++ b/tests/schedulers/test_scheduler_lcm.py
@@ -4,8 +4,8 @@ from typing import Dict, List, Tuple
import torch
from diffusers import LCMScheduler
-from diffusers.utils.testing_utils import torch_device
+from ..testing_utils import torch_device
from .test_schedulers import SchedulerCommonTest
diff --git a/tests/schedulers/test_scheduler_lms.py b/tests/schedulers/test_scheduler_lms.py
index 3bfcd57c1b..c4abca3ac9 100644
--- a/tests/schedulers/test_scheduler_lms.py
+++ b/tests/schedulers/test_scheduler_lms.py
@@ -1,8 +1,8 @@
import torch
from diffusers import LMSDiscreteScheduler
-from diffusers.utils.testing_utils import torch_device
+from ..testing_utils import torch_device
from .test_schedulers import SchedulerCommonTest
diff --git a/tests/schedulers/test_scheduler_sasolver.py b/tests/schedulers/test_scheduler_sasolver.py
index baa2736b2f..2c2d2c0397 100644
--- a/tests/schedulers/test_scheduler_sasolver.py
+++ b/tests/schedulers/test_scheduler_sasolver.py
@@ -1,8 +1,8 @@
import torch
from diffusers import SASolverScheduler
-from diffusers.utils.testing_utils import require_torchsde, torch_device
+from ..testing_utils import require_torchsde, torch_device
from .test_schedulers import SchedulerCommonTest
diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py
index cd8dc5ccf1..5a8380e659 100755
--- a/tests/schedulers/test_schedulers.py
+++ b/tests/schedulers/test_schedulers.py
@@ -41,9 +41,9 @@ from diffusers import (
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import logging
-from diffusers.utils.testing_utils import CaptureLogger, torch_device
from ..others.test_utils import TOKEN, USER, is_staging_test
+from ..testing_utils import CaptureLogger, torch_device
torch.backends.cuda.matmul.allow_tf32 = False
diff --git a/tests/single_file/single_file_testing_utils.py b/tests/single_file/single_file_testing_utils.py
index 4e1713c9ce..52fd2f5bfc 100644
--- a/tests/single_file/single_file_testing_utils.py
+++ b/tests/single_file/single_file_testing_utils.py
@@ -1,3 +1,4 @@
+import gc
import tempfile
from io import BytesIO
@@ -7,8 +8,12 @@ from huggingface_hub import hf_hub_download, snapshot_download
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.models.attention_processor import AttnProcessor
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
+ backend_empty_cache,
+ nightly,
numpy_cosine_similarity_distance,
+ require_torch_accelerator,
torch_device,
)
@@ -46,6 +51,93 @@ def download_diffusers_config(repo_id, tmpdir):
return path
+@nightly
+@require_torch_accelerator
+class SingleFileModelTesterMixin:
+ def setup_method(self):
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def teardown_method(self):
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_single_file_model_config(self):
+ pretrained_kwargs = {}
+ single_file_kwargs = {}
+
+ if hasattr(self, "subfolder") and self.subfolder:
+ pretrained_kwargs["subfolder"] = self.subfolder
+
+ if hasattr(self, "torch_dtype") and self.torch_dtype:
+ pretrained_kwargs["torch_dtype"] = self.torch_dtype
+ single_file_kwargs["torch_dtype"] = self.torch_dtype
+
+ model = self.model_class.from_pretrained(self.repo_id, **pretrained_kwargs)
+ model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
+
+ PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
+ for param_name, param_value in model_single_file.config.items():
+ if param_name in PARAMS_TO_IGNORE:
+ continue
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
+
+ def test_single_file_model_parameters(self):
+ pretrained_kwargs = {}
+ single_file_kwargs = {}
+
+ if hasattr(self, "subfolder") and self.subfolder:
+ pretrained_kwargs["subfolder"] = self.subfolder
+
+ if hasattr(self, "torch_dtype") and self.torch_dtype:
+ pretrained_kwargs["torch_dtype"] = self.torch_dtype
+ single_file_kwargs["torch_dtype"] = self.torch_dtype
+
+ model = self.model_class.from_pretrained(self.repo_id, **pretrained_kwargs)
+ model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
+
+ state_dict = model.state_dict()
+ state_dict_single_file = model_single_file.state_dict()
+
+ assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
+ "Model parameters keys differ between pretrained and single file loading"
+ )
+
+ for key in state_dict.keys():
+ param = state_dict[key]
+ param_single_file = state_dict_single_file[key]
+
+ assert param.shape == param_single_file.shape, (
+ f"Parameter shape mismatch for {key}: "
+ f"pretrained {param.shape} vs single file {param_single_file.shape}"
+ )
+
+ assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), (
+ f"Parameter values differ for {key}: "
+ f"max difference {torch.max(torch.abs(param - param_single_file)).item()}"
+ )
+
+ def test_checkpoint_altered_keys_loading(self):
+ # Test loading with checkpoints that have altered keys
+ if not hasattr(self, "alternate_keys_ckpt_paths") or not self.alternate_keys_ckpt_paths:
+ return
+
+ for ckpt_path in self.alternate_keys_ckpt_paths:
+ backend_empty_cache(torch_device)
+
+ single_file_kwargs = {}
+ if hasattr(self, "torch_dtype") and self.torch_dtype:
+ single_file_kwargs["torch_dtype"] = self.torch_dtype
+
+ model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
+
+ del model
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+
class SDSingleFileTesterMixin:
single_file_kwargs = {}
diff --git a/tests/single_file/test_lumina2_transformer.py b/tests/single_file/test_lumina2_transformer.py
index 2ac681897d..bb5a0bf473 100644
--- a/tests/single_file/test_lumina2_transformer.py
+++ b/tests/single_file/test_lumina2_transformer.py
@@ -13,25 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
-import unittest
from diffusers import (
Lumina2Transformer2DModel,
)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
+
+from ..testing_utils import (
enable_full_determinism,
- require_torch_accelerator,
- torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@require_torch_accelerator
-class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
+class TestLumina2Transformer2DModelSingleFile(SingleFileModelTesterMixin):
model_class = Lumina2Transformer2DModel
ckpt_path = "https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors"
alternate_keys_ckpt_paths = [
@@ -39,34 +35,4 @@ class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
]
repo_id = "Alpha-VLLM/Lumina-Image-2.0"
-
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between single file loading and pretrained loading"
- )
-
- def test_checkpoint_loading(self):
- for ckpt_path in self.alternate_keys_ckpt_paths:
- backend_empty_cache(torch_device)
- model = self.model_class.from_single_file(ckpt_path)
-
- del model
- gc.collect()
- backend_empty_cache(torch_device)
+ subfolder = "transformer"
diff --git a/tests/single_file/test_model_autoencoder_dc_single_file.py b/tests/single_file/test_model_autoencoder_dc_single_file.py
index 184498ca2f..444ca40469 100644
--- a/tests/single_file/test_model_autoencoder_dc_single_file.py
+++ b/tests/single_file/test_model_autoencoder_dc_single_file.py
@@ -13,47 +13,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
-import unittest
import torch
from diffusers import (
AutoencoderDC,
)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
+
+from ..testing_utils import (
enable_full_determinism,
load_hf_numpy,
numpy_cosine_similarity_distance,
- require_torch_accelerator,
- slow,
torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@slow
-@require_torch_accelerator
-class AutoencoderDCSingleFileTests(unittest.TestCase):
+class TestAutoencoderDCSingleFile(SingleFileModelTesterMixin):
model_class = AutoencoderDC
ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0/blob/main/model.safetensors"
repo_id = "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"
main_input_name = "sample"
base_precision = 1e-2
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
@@ -79,18 +64,6 @@ class AutoencoderDCSingleFileTests(unittest.TestCase):
assert numpy_cosine_similarity_distance(output_slice_1, output_slice_2) < 1e-4
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id)
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between pretrained loading and single file loading"
- )
-
def test_single_file_in_type_variant_components(self):
# `in` variant checkpoints require passing in a `config` parameter
# in order to set the scaling factor correctly.
diff --git a/tests/single_file/test_model_controlnet_single_file.py b/tests/single_file/test_model_controlnet_single_file.py
index ade6f63a50..2fa81fe3ae 100644
--- a/tests/single_file/test_model_controlnet_single_file.py
+++ b/tests/single_file/test_model_controlnet_single_file.py
@@ -13,55 +13,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
-import unittest
import torch
from diffusers import (
ControlNetModel,
)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
+
+from ..testing_utils import (
enable_full_determinism,
- require_torch_accelerator,
- slow,
- torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@slow
-@require_torch_accelerator
-class ControlNetModelSingleFileTests(unittest.TestCase):
+class TestControlNetModelSingleFile(SingleFileModelTesterMixin):
model_class = ControlNetModel
ckpt_path = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
repo_id = "lllyasviel/control_v11p_sd15_canny"
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id)
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between single file loading and pretrained loading"
- )
-
def test_single_file_arguments(self):
model_default = self.model_class.from_single_file(self.ckpt_path)
diff --git a/tests/single_file/test_model_flux_transformer_single_file.py b/tests/single_file/test_model_flux_transformer_single_file.py
index 2f837bd18e..0642a71c57 100644
--- a/tests/single_file/test_model_flux_transformer_single_file.py
+++ b/tests/single_file/test_model_flux_transformer_single_file.py
@@ -14,57 +14,34 @@
# limitations under the License.
import gc
-import unittest
from diffusers import (
FluxTransformer2DModel,
)
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
- require_torch_accelerator,
torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@require_torch_accelerator
-class FluxTransformer2DModelSingleFileTests(unittest.TestCase):
+class TestFluxTransformer2DModelSingleFile(SingleFileModelTesterMixin):
model_class = FluxTransformer2DModel
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
repo_id = "black-forest-labs/FLUX.1-dev"
+ subfolder = "transformer"
- def setUp(self):
- super().setUp()
+ def test_device_map_cuda(self):
+ backend_empty_cache(torch_device)
+ model = self.model_class.from_single_file(self.ckpt_path, device_map="cuda")
+
+ del model
gc.collect()
backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between single file loading and pretrained loading"
- )
-
- def test_checkpoint_loading(self):
- for ckpt_path in self.alternate_keys_ckpt_paths:
- backend_empty_cache(torch_device)
- model = self.model_class.from_single_file(ckpt_path)
-
- del model
- gc.collect()
- backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_model_motion_adapter_single_file.py b/tests/single_file/test_model_motion_adapter_single_file.py
index dc08a95b84..a047c81b47 100644
--- a/tests/single_file/test_model_motion_adapter_single_file.py
+++ b/tests/single_file/test_model_motion_adapter_single_file.py
@@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import unittest
from diffusers import (
MotionAdapter,
)
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
enable_full_determinism,
)
@@ -26,7 +26,7 @@ from diffusers.utils.testing_utils import (
enable_full_determinism()
-class MotionAdapterSingleFileTests(unittest.TestCase):
+class MotionAdapterSingleFileTests:
model_class = MotionAdapter
def test_single_file_components_version_v1_5(self):
diff --git a/tests/single_file/test_model_sd_cascade_unet_single_file.py b/tests/single_file/test_model_sd_cascade_unet_single_file.py
index a16278c6b0..7472122710 100644
--- a/tests/single_file/test_model_sd_cascade_unet_single_file.py
+++ b/tests/single_file/test_model_sd_cascade_unet_single_file.py
@@ -14,13 +14,13 @@
# limitations under the License.
import gc
-import unittest
import torch
from diffusers import StableCascadeUNet
from diffusers.utils import logging
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
@@ -36,14 +36,12 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableCascadeUNetSingleFileTest(unittest.TestCase):
- def setUp(self):
- super().setUp()
+class StableCascadeUNetSingleFileTest:
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_model_vae_single_file.py b/tests/single_file/test_model_vae_single_file.py
index 9d994b5b49..9198d9b163 100644
--- a/tests/single_file/test_model_vae_single_file.py
+++ b/tests/single_file/test_model_vae_single_file.py
@@ -13,31 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
-import unittest
import torch
from diffusers import (
AutoencoderKL,
)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
+
+from ..testing_utils import (
enable_full_determinism,
load_hf_numpy,
numpy_cosine_similarity_distance,
- require_torch_accelerator,
- slow,
torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@slow
-@require_torch_accelerator
-class AutoencoderKLSingleFileTests(unittest.TestCase):
+class TestAutoencoderKLSingleFile(SingleFileModelTesterMixin):
model_class = AutoencoderKL
ckpt_path = (
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
@@ -46,16 +41,6 @@ class AutoencoderKLSingleFileTests(unittest.TestCase):
main_input_name = "sample"
base_precision = 1e-2
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
@@ -83,18 +68,6 @@ class AutoencoderKLSingleFileTests(unittest.TestCase):
assert numpy_cosine_similarity_distance(output_slice_1, output_slice_2) < 1e-4
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id)
- model_single_file = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between pretrained loading and single file loading"
- )
-
def test_single_file_arguments(self):
model_default = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id)
diff --git a/tests/single_file/test_model_wan_autoencoder_single_file.py b/tests/single_file/test_model_wan_autoencoder_single_file.py
index 7f0e1c1a4b..0babf30234 100644
--- a/tests/single_file/test_model_wan_autoencoder_single_file.py
+++ b/tests/single_file/test_model_wan_autoencoder_single_file.py
@@ -13,49 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
-import unittest
from diffusers import (
AutoencoderKLWan,
)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
+
+from ..testing_utils import (
enable_full_determinism,
- require_torch_accelerator,
- torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@require_torch_accelerator
-class AutoencoderKLWanSingleFileTests(unittest.TestCase):
+class TestAutoencoderKLWanSingleFile(SingleFileModelTesterMixin):
model_class = AutoencoderKLWan
ckpt_path = (
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors"
)
repo_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
-
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id, subfolder="vae")
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between single file loading and pretrained loading"
- )
+ subfolder = "vae"
diff --git a/tests/single_file/test_model_wan_transformer3d_single_file.py b/tests/single_file/test_model_wan_transformer3d_single_file.py
index 72b4b3a58a..b769092060 100644
--- a/tests/single_file/test_model_wan_transformer3d_single_file.py
+++ b/tests/single_file/test_model_wan_transformer3d_single_file.py
@@ -13,81 +13,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
-import unittest
import torch
from diffusers import (
WanTransformer3DModel,
)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
+
+from ..testing_utils import (
enable_full_determinism,
require_big_accelerator,
- require_torch_accelerator,
- torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@require_torch_accelerator
-class WanTransformer3DModelText2VideoSingleFileTest(unittest.TestCase):
+class TestWanTransformer3DModelText2VideoSingleFile(SingleFileModelTesterMixin):
model_class = WanTransformer3DModel
ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors"
repo_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
-
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between single file loading and pretrained loading"
- )
+ subfolder = "transformer"
@require_big_accelerator
-@require_torch_accelerator
-class WanTransformer3DModelImage2VideoSingleFileTest(unittest.TestCase):
+class TestWanTransformer3DModelImage2VideoSingleFile(SingleFileModelTesterMixin):
model_class = WanTransformer3DModel
ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_i2v_480p_14B_fp8_e4m3fn.safetensors"
repo_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
torch_dtype = torch.float8_e4m3fn
-
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer", torch_dtype=self.torch_dtype)
- model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=self.torch_dtype)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between single file loading and pretrained loading"
- )
+ subfolder = "transformer"
diff --git a/tests/single_file/test_sana_transformer.py b/tests/single_file/test_sana_transformer.py
index e74c5be6ff..9e2adb93bf 100644
--- a/tests/single_file/test_sana_transformer.py
+++ b/tests/single_file/test_sana_transformer.py
@@ -1,22 +1,17 @@
-import gc
-import unittest
-
from diffusers import (
SanaTransformer2DModel,
)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
+
+from ..testing_utils import (
enable_full_determinism,
- require_torch_accelerator,
- torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@require_torch_accelerator
-class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
+class TestSanaTransformer2DModelSingleFile(SingleFileModelTesterMixin):
model_class = SanaTransformer2DModel
ckpt_path = (
"https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth"
@@ -26,34 +21,4 @@ class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
]
repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers"
-
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between single file loading and pretrained loading"
- )
-
- def test_checkpoint_loading(self):
- for ckpt_path in self.alternate_keys_ckpt_paths:
- backend_empty_cache(torch_device)
- model = self.model_class.from_single_file(ckpt_path)
-
- del model
- gc.collect()
- backend_empty_cache(torch_device)
+ subfolder = "transformer"
diff --git a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py
index 7589b48028..141748b084 100644
--- a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py
+++ b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py
@@ -1,13 +1,13 @@
import gc
import tempfile
-import unittest
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -15,7 +15,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from .single_file_testing_utils import (
SDSingleFileTesterMixin,
download_diffusers_config,
@@ -29,7 +28,7 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionControlNetPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetPipeline
ckpt_path = (
"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
@@ -39,13 +38,11 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
)
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py
index 1555831db6..8238866cbf 100644
--- a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py
+++ b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py
@@ -1,13 +1,14 @@
import gc
import tempfile
-import unittest
+import pytest
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -15,7 +16,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from .single_file_testing_utils import (
SDSingleFileTesterMixin,
download_diffusers_config,
@@ -29,19 +29,17 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionControlNetInpaintPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetInpaintPipeline
ckpt_path = "https://huggingface.co/botp/stable-diffusion-v1-5-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml"
repo_id = "stable-diffusion-v1-5/stable-diffusion-inpainting"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -115,7 +113,7 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC
super()._compare_component_configs(pipe, pipe_single_file)
- @unittest.skip("runwayml original config repo does not exist")
+ @pytest.mark.skip(reason="runwayml original config repo does not exist")
def test_single_file_components_with_original_config(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
@@ -125,7 +123,7 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC
super()._compare_component_configs(pipe, pipe_single_file)
- @unittest.skip("runwayml original config repo does not exist")
+ @pytest.mark.skip(reason="runwayml original config repo does not exist")
def test_single_file_components_with_original_config_local_files_only(self):
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
diff --git a/tests/single_file/test_stable_diffusion_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_single_file.py
index 2c1e414e5e..80ef6c2574 100644
--- a/tests/single_file/test_stable_diffusion_controlnet_single_file.py
+++ b/tests/single_file/test_stable_diffusion_controlnet_single_file.py
@@ -1,13 +1,13 @@
import gc
import tempfile
-import unittest
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -15,7 +15,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from .single_file_testing_utils import (
SDSingleFileTesterMixin,
download_diffusers_config,
@@ -29,7 +28,7 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionControlNetPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetPipeline
ckpt_path = (
"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
@@ -39,13 +38,11 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
)
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_img2img_single_file.py b/tests/single_file/test_stable_diffusion_img2img_single_file.py
index 9ad9355824..e76846c800 100644
--- a/tests/single_file/test_stable_diffusion_img2img_single_file.py
+++ b/tests/single_file/test_stable_diffusion_img2img_single_file.py
@@ -1,5 +1,4 @@
import gc
-import unittest
import torch
@@ -7,14 +6,14 @@ from diffusers import (
StableDiffusionImg2ImgPipeline,
)
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
-
from .single_file_testing_utils import SDSingleFileTesterMixin
@@ -23,7 +22,7 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionImg2ImgPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionImg2ImgPipeline
ckpt_path = (
"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
@@ -33,13 +32,11 @@ class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSin
)
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -66,19 +63,17 @@ class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSin
@slow
@require_torch_accelerator
-class StableDiffusion21Img2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusion21Img2ImgPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionImg2ImgPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors"
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
repo_id = "stabilityai/stable-diffusion-2-1"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_inpaint_single_file.py
index b05a098c0b..6e5d27cdff 100644
--- a/tests/single_file/test_stable_diffusion_inpaint_single_file.py
+++ b/tests/single_file/test_stable_diffusion_inpaint_single_file.py
@@ -1,20 +1,20 @@
import gc
-import unittest
+import pytest
import torch
from diffusers import (
StableDiffusionInpaintPipeline,
)
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
-
from .single_file_testing_utils import SDSingleFileTesterMixin
@@ -23,19 +23,17 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionInpaintPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionInpaintPipeline
ckpt_path = "https://huggingface.co/botp/stable-diffusion-v1-5-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml"
repo_id = "botp/stable-diffusion-v1-5-inpainting"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -70,18 +68,18 @@ class StableDiffusionInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSin
assert pipe.unet.config.in_channels == 4
- @unittest.skip("runwayml original config has been removed")
+ @pytest.mark.skip(reason="runwayml original config has been removed")
def test_single_file_components_with_original_config(self):
return
- @unittest.skip("runwayml original config has been removed")
+ @pytest.mark.skip(reason="runwayml original config has been removed")
def test_single_file_components_with_original_config_local_files_only(self):
return
@slow
@require_torch_accelerator
-class StableDiffusion21InpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusion21InpaintPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionInpaintPipeline
ckpt_path = (
"https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/512-inpainting-ema.safetensors"
@@ -89,13 +87,11 @@ class StableDiffusion21InpaintPipelineSingleFileSlowTests(unittest.TestCase, SDS
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inpainting-inference.yaml"
repo_id = "stabilityai/stable-diffusion-2-inpainting"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_single_file.py b/tests/single_file/test_stable_diffusion_single_file.py
index 78baeb9492..377dedbc57 100644
--- a/tests/single_file/test_stable_diffusion_single_file.py
+++ b/tests/single_file/test_stable_diffusion_single_file.py
@@ -1,13 +1,13 @@
import gc
import tempfile
-import unittest
import torch
from diffusers import EulerDiscreteScheduler, StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
nightly,
@@ -15,7 +15,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from .single_file_testing_utils import (
SDSingleFileTesterMixin,
download_original_config,
@@ -28,7 +27,7 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionPipeline
ckpt_path = (
"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
@@ -38,13 +37,11 @@ class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFile
)
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -90,19 +87,17 @@ class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFile
@slow
-class StableDiffusion21PipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusion21PipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors"
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
repo_id = "stabilityai/stable-diffusion-2-1"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -125,7 +120,7 @@ class StableDiffusion21PipelineSingleFileSlowTests(unittest.TestCase, SDSingleFi
@nightly
@slow
@require_torch_accelerator
-class StableDiffusionInstructPix2PixPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionInstructPix2PixPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionInstructPix2PixPipeline
ckpt_path = "https://huggingface.co/timbrooks/instruct-pix2pix/blob/main/instruct-pix2pix-00-22000.safetensors"
original_config = (
@@ -134,13 +129,11 @@ class StableDiffusionInstructPix2PixPipelineSingleFileSlowTests(unittest.TestCas
repo_id = "timbrooks/instruct-pix2pix"
single_file_kwargs = {"extract_ema": True}
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_upscale_single_file.py b/tests/single_file/test_stable_diffusion_upscale_single_file.py
index 398fc9ece3..ba4819fadf 100644
--- a/tests/single_file/test_stable_diffusion_upscale_single_file.py
+++ b/tests/single_file/test_stable_diffusion_upscale_single_file.py
@@ -1,5 +1,4 @@
import gc
-import unittest
import pytest
import torch
@@ -8,7 +7,8 @@ from diffusers import (
StableDiffusionUpscalePipeline,
)
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -16,7 +16,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from .single_file_testing_utils import SDSingleFileTesterMixin
@@ -25,19 +24,17 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionUpscalePipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionUpscalePipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionUpscalePipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/blob/main/x4-upscaler-ema.safetensors"
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
repo_id = "stabilityai/stable-diffusion-x4-upscaler"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py
index fb5f8725b8..3d124fa8c2 100644
--- a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py
+++ b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py
@@ -1,6 +1,5 @@
import gc
import tempfile
-import unittest
import torch
@@ -10,7 +9,8 @@ from diffusers import (
)
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -18,7 +18,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from .single_file_testing_utils import (
SDXLSingleFileTesterMixin,
download_diffusers_config,
@@ -32,7 +31,7 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
+class TestStableDiffusionXLAdapterPipelineSingleFileSlow(SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLAdapterPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -40,13 +39,11 @@ class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDX
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
)
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py
index 6d8c4369e1..6f50370261 100644
--- a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py
+++ b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py
@@ -1,13 +1,13 @@
import gc
import tempfile
-import unittest
import torch
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -15,7 +15,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from .single_file_testing_utils import (
SDXLSingleFileTesterMixin,
download_diffusers_config,
@@ -28,7 +27,7 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
+class TestStableDiffusionXLControlNetPipelineSingleFileSlow(SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLControlNetPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -36,13 +35,11 @@ class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase,
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
)
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_xl_img2img_single_file.py b/tests/single_file/test_stable_diffusion_xl_img2img_single_file.py
index 7df8b84bc2..56657f37d9 100644
--- a/tests/single_file/test_stable_diffusion_xl_img2img_single_file.py
+++ b/tests/single_file/test_stable_diffusion_xl_img2img_single_file.py
@@ -1,5 +1,4 @@
import gc
-import unittest
import torch
@@ -8,7 +7,8 @@ from diffusers import (
StableDiffusionXLImg2ImgPipeline,
)
from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -16,7 +16,6 @@ from diffusers.utils.testing_utils import (
slow,
torch_device,
)
-
from .single_file_testing_utils import SDXLSingleFileTesterMixin
@@ -25,7 +24,7 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionXLImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
+class TestStableDiffusionXLImg2ImgPipelineSingleFileSlow(SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLImg2ImgPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -33,13 +32,11 @@ class StableDiffusionXLImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDX
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
)
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -66,7 +63,7 @@ class StableDiffusionXLImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDX
@slow
@require_torch_accelerator
-class StableDiffusionXLImg2ImgRefinerPipelineSingleFileSlowTests(unittest.TestCase):
+class StableDiffusionXLImg2ImgRefinerPipelineSingleFileSlowTests:
pipeline_class = StableDiffusionXLImg2ImgPipeline
ckpt_path = (
"https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/sd_xl_refiner_1.0.safetensors"
diff --git a/tests/single_file/test_stable_diffusion_xl_instruct_pix2pix.py b/tests/single_file/test_stable_diffusion_xl_instruct_pix2pix.py
index 5a01463863..d755b70105 100644
--- a/tests/single_file/test_stable_diffusion_xl_instruct_pix2pix.py
+++ b/tests/single_file/test_stable_diffusion_xl_instruct_pix2pix.py
@@ -1,10 +1,10 @@
import gc
-import unittest
import torch
from diffusers import StableDiffusionXLInstructPix2PixPipeline
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
@@ -18,19 +18,17 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionXLInstructPix2PixPipeline(unittest.TestCase):
+class StableDiffusionXLInstructPix2PixPipeline:
pipeline_class = StableDiffusionXLInstructPix2PixPipeline
ckpt_path = "https://huggingface.co/stabilityai/cosxl/blob/main/cosxl_edit.safetensors"
original_config = None
repo_id = "diffusers/sdxl-instructpix2pix-768"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_xl_single_file.py b/tests/single_file/test_stable_diffusion_xl_single_file.py
index 77f58d8592..4e5319ca25 100644
--- a/tests/single_file/test_stable_diffusion_xl_single_file.py
+++ b/tests/single_file/test_stable_diffusion_xl_single_file.py
@@ -1,19 +1,18 @@
import gc
-import unittest
import torch
from diffusers import (
StableDiffusionXLPipeline,
)
-from diffusers.utils.testing_utils import (
+
+from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
-
from .single_file_testing_utils import SDXLSingleFileTesterMixin
@@ -22,7 +21,7 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionXLPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
+class TestStableDiffusionXLPipelineSingleFileSlow(SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -30,13 +29,11 @@ class StableDiffusionXLPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingle
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
)
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/testing_utils.py b/tests/testing_utils.py
new file mode 100644
index 0000000000..7f849219c1
--- /dev/null
+++ b/tests/testing_utils.py
@@ -0,0 +1,1557 @@
+import functools
+import glob
+import importlib
+import importlib.metadata
+import inspect
+import io
+import logging
+import multiprocessing
+import os
+import random
+import re
+import struct
+import sys
+import tempfile
+import time
+import unittest
+import urllib.parse
+from collections import UserDict
+from contextlib import contextmanager
+from io import BytesIO, StringIO
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import PIL.ImageOps
+import requests
+from numpy.linalg import norm
+from packaging import version
+
+from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
+from diffusers.utils.import_utils import (
+ BACKENDS_MAPPING,
+ is_accelerate_available,
+ is_bitsandbytes_available,
+ is_compel_available,
+ is_flax_available,
+ is_gguf_available,
+ is_kernels_available,
+ is_note_seq_available,
+ is_onnx_available,
+ is_opencv_available,
+ is_optimum_quanto_available,
+ is_peft_available,
+ is_timm_available,
+ is_torch_available,
+ is_torch_version,
+ is_torchao_available,
+ is_torchsde_available,
+ is_transformers_available,
+)
+from diffusers.utils.logging import get_logger
+
+
+if is_torch_available():
+ import torch
+
+ IS_ROCM_SYSTEM = torch.version.hip is not None
+ IS_CUDA_SYSTEM = torch.version.cuda is not None
+ IS_XPU_SYSTEM = getattr(torch.version, "xpu", None) is not None
+else:
+ IS_ROCM_SYSTEM = False
+ IS_CUDA_SYSTEM = False
+ IS_XPU_SYSTEM = False
+
+global_rng = random.Random()
+
+logger = get_logger(__name__)
+
+_required_peft_version = is_peft_available() and version.parse(
+ version.parse(importlib.metadata.version("peft")).base_version
+) > version.parse("0.5")
+_required_transformers_version = is_transformers_available() and version.parse(
+ version.parse(importlib.metadata.version("transformers")).base_version
+) > version.parse("4.33")
+
+USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
+BIG_GPU_MEMORY = int(os.getenv("BIG_GPU_MEMORY", 40))
+
+if is_torch_available():
+ import torch
+
+ # Set a backend environment variable for any extra module import required for a custom accelerator
+ if "DIFFUSERS_TEST_BACKEND" in os.environ:
+ backend = os.environ["DIFFUSERS_TEST_BACKEND"]
+ try:
+ _ = importlib.import_module(backend)
+ except ModuleNotFoundError as e:
+ raise ModuleNotFoundError(
+ f"Failed to import `DIFFUSERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module \
+ to enable a specified backend.):\n{e}"
+ ) from e
+
+ if "DIFFUSERS_TEST_DEVICE" in os.environ:
+ torch_device = os.environ["DIFFUSERS_TEST_DEVICE"]
+ try:
+ # try creating device to see if provided device is valid
+ _ = torch.device(torch_device)
+ except RuntimeError as e:
+ raise RuntimeError(
+ f"Unknown testing device specified by environment variable `DIFFUSERS_TEST_DEVICE`: {torch_device}"
+ ) from e
+ logger.info(f"torch_device overrode to {torch_device}")
+ else:
+ if torch.cuda.is_available():
+ torch_device = "cuda"
+ elif torch.xpu.is_available():
+ torch_device = "xpu"
+ else:
+ torch_device = "cpu"
+ is_torch_higher_equal_than_1_12 = version.parse(
+ version.parse(torch.__version__).base_version
+ ) >= version.parse("1.12")
+
+ if is_torch_higher_equal_than_1_12:
+ # Some builds of torch 1.12 don't have the mps backend registered. See #892 for more details
+ mps_backend_registered = hasattr(torch.backends, "mps")
+ torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device
+
+ from diffusers.utils.torch_utils import get_torch_cuda_device_capability
+
+
+def torch_all_close(a, b, *args, **kwargs):
+ if not is_torch_available():
+ raise ValueError("PyTorch needs to be installed to use this function.")
+ if not torch.allclose(a, b, *args, **kwargs):
+ assert False, f"Max diff is absolute {(a - b).abs().max()}. Diff tensor is {(a - b).abs()}."
+ return True
+
+
+def numpy_cosine_similarity_distance(a, b):
+ similarity = np.dot(a, b) / (norm(a) * norm(b))
+ distance = 1.0 - similarity.mean()
+
+ return distance
+
+
+def check_if_dicts_are_equal(dict1, dict2):
+ dict1, dict2 = dict1.copy(), dict2.copy()
+
+ for key, value in dict1.items():
+ if isinstance(value, set):
+ dict1[key] = sorted(value)
+ for key, value in dict2.items():
+ if isinstance(value, set):
+ dict2[key] = sorted(value)
+
+ for key in dict1:
+ if key not in dict2:
+ return False
+ if dict1[key] != dict2[key]:
+ return False
+
+ for key in dict2:
+ if key not in dict1:
+ return False
+
+ return True
+
+
+def print_tensor_test(
+ tensor,
+ limit_to_slices=None,
+ max_torch_print=None,
+ filename="test_corrections.txt",
+ expected_tensor_name="expected_slice",
+):
+ if max_torch_print:
+ torch.set_printoptions(threshold=10_000)
+
+ test_name = os.environ.get("PYTEST_CURRENT_TEST")
+ if not torch.is_tensor(tensor):
+ tensor = torch.from_numpy(tensor)
+ if limit_to_slices:
+ tensor = tensor[0, -3:, -3:, -1]
+
+ tensor_str = str(tensor.detach().cpu().flatten().to(torch.float32)).replace("\n", "")
+ # format is usually:
+ # expected_slice = np.array([-0.5713, -0.3018, -0.9814, 0.04663, -0.879, 0.76, -1.734, 0.1044, 1.161])
+ output_str = tensor_str.replace("tensor", f"{expected_tensor_name} = np.array")
+ test_file, test_class, test_fn = test_name.split("::")
+ test_fn = test_fn.split()[0]
+ with open(filename, "a") as f:
+ print("::".join([test_file, test_class, test_fn, output_str]), file=f)
+
+
+def get_tests_dir(append_path=None):
+ """
+ Args:
+ append_path: optional path to append to the tests dir path
+ Return:
+ The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is
+ joined after the `tests` dir the former is provided.
+ """
+ # this function caller's __file__
+ caller__file__ = inspect.stack()[1][1]
+ tests_dir = os.path.abspath(os.path.dirname(caller__file__))
+
+ while not tests_dir.endswith("tests"):
+ tests_dir = os.path.dirname(tests_dir)
+
+ if append_path:
+ return Path(tests_dir, append_path).as_posix()
+ else:
+ return tests_dir
+
+
+# Taken from the following PR:
+# https://github.com/huggingface/accelerate/pull/1964
+def str_to_bool(value) -> int:
+ """
+ Converts a string representation of truth to `True` (1) or `False` (0). True values are `y`, `yes`, `t`, `true`,
+ `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
+ """
+ value = value.lower()
+ if value in ("y", "yes", "t", "true", "on", "1"):
+ return 1
+ elif value in ("n", "no", "f", "false", "off", "0"):
+ return 0
+ else:
+ raise ValueError(f"invalid truth value {value}")
+
+
+def parse_flag_from_env(key, default=False):
+ try:
+ value = os.environ[key]
+ except KeyError:
+ # KEY isn't set, default to `default`.
+ _value = default
+ else:
+ # KEY is set, convert it to True or False.
+ try:
+ _value = str_to_bool(value)
+ except ValueError:
+ # More values are supported, but let's keep the message simple.
+ raise ValueError(f"If set, {key} must be yes or no.")
+ return _value
+
+
+_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
+_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False)
+_run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False)
+
+
+def floats_tensor(shape, scale=1.0, rng=None, name=None):
+ """Creates a random float32 tensor"""
+ if rng is None:
+ rng = global_rng
+
+ total_dims = 1
+ for dim in shape:
+ total_dims *= dim
+
+ values = []
+ for _ in range(total_dims):
+ values.append(rng.random() * scale)
+
+ return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
+
+
+def slow(test_case):
+ """
+ Decorator marking a test as slow.
+
+ Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
+
+ """
+ return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
+
+
+def nightly(test_case):
+ """
+ Decorator marking a test that runs nightly in the diffusers CI.
+
+ Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them.
+
+ """
+ return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case)
+
+
+def is_torch_compile(test_case):
+ """
+ Decorator marking a test that runs compile tests in the diffusers CI.
+
+ Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them.
+
+ """
+ return unittest.skipUnless(_run_compile_tests, "test is torch compile")(test_case)
+
+
+def require_torch(test_case):
+ """
+ Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed.
+ """
+ return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
+
+
+def require_torch_2(test_case):
+ """
+ Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed.
+ """
+ return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")(
+ test_case
+ )
+
+
+def require_torch_version_greater_equal(torch_version):
+ """Decorator marking a test that requires torch with a specific version or greater."""
+
+ def decorator(test_case):
+ correct_torch_version = is_torch_available() and is_torch_version(">=", torch_version)
+ return unittest.skipUnless(
+ correct_torch_version, f"test requires torch with the version greater than or equal to {torch_version}"
+ )(test_case)
+
+ return decorator
+
+
+def require_torch_version_greater(torch_version):
+ """Decorator marking a test that requires torch with a specific version greater."""
+
+ def decorator(test_case):
+ correct_torch_version = is_torch_available() and is_torch_version(">", torch_version)
+ return unittest.skipUnless(
+ correct_torch_version, f"test requires torch with the version greater than {torch_version}"
+ )(test_case)
+
+ return decorator
+
+
+def require_torch_gpu(test_case):
+ """Decorator marking a test that requires CUDA and PyTorch."""
+ return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
+ test_case
+ )
+
+
+def require_torch_cuda_compatibility(expected_compute_capability):
+ def decorator(test_case):
+ if torch.cuda.is_available():
+ current_compute_capability = get_torch_cuda_device_capability()
+ return unittest.skipUnless(
+ float(current_compute_capability) == float(expected_compute_capability),
+ "Test not supported for this compute capability.",
+ )
+
+ return decorator
+
+
+# These decorators are for accelerator-specific behaviours that are not GPU-specific
+def require_torch_accelerator(test_case):
+ """Decorator marking a test that requires an accelerator backend and PyTorch."""
+ return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")(
+ test_case
+ )
+
+
+def require_torch_multi_gpu(test_case):
+ """
+ Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without
+ multiple GPUs. To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests
+ -k "multi_gpu"
+ """
+ if not is_torch_available():
+ return unittest.skip("test requires PyTorch")(test_case)
+
+ import torch
+
+ return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
+
+
+def require_torch_multi_accelerator(test_case):
+ """
+ Decorator marking a test that requires a multi-accelerator setup (in PyTorch). These tests are skipped on a machine
+ without multiple hardware accelerators.
+ """
+ if not is_torch_available():
+ return unittest.skip("test requires PyTorch")(test_case)
+
+ import torch
+
+ return unittest.skipUnless(
+ torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1, "test requires multiple hardware accelerators"
+ )(test_case)
+
+
+def require_torch_accelerator_with_fp16(test_case):
+ """Decorator marking a test that requires an accelerator with support for the FP16 data type."""
+ return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
+ test_case
+ )
+
+
+def require_torch_accelerator_with_fp64(test_case):
+ """Decorator marking a test that requires an accelerator with support for the FP64 data type."""
+ return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")(
+ test_case
+ )
+
+
+def require_big_gpu_with_torch_cuda(test_case):
+ """
+ Decorator marking a test that requires a bigger GPU (24GB) for execution. Some example pipelines: Flux, SD3, Cog,
+ etc.
+ """
+ if not is_torch_available():
+ return unittest.skip("test requires PyTorch")(test_case)
+
+ import torch
+
+ if not torch.cuda.is_available():
+ return unittest.skip("test requires PyTorch CUDA")(test_case)
+
+ device_properties = torch.cuda.get_device_properties(0)
+ total_memory = device_properties.total_memory / (1024**3)
+ return unittest.skipUnless(
+ total_memory >= BIG_GPU_MEMORY, f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory"
+ )(test_case)
+
+
+def require_big_accelerator(test_case):
+ """
+ Decorator marking a test that requires a bigger hardware accelerator (24GB) for execution. Some example pipelines:
+ Flux, SD3, Cog, etc.
+ """
+ import pytest
+
+ test_case = pytest.mark.big_accelerator(test_case)
+
+ if not is_torch_available():
+ return unittest.skip("test requires PyTorch")(test_case)
+
+ import torch
+
+ if not (torch.cuda.is_available() or torch.xpu.is_available()):
+ return unittest.skip("test requires PyTorch CUDA")(test_case)
+
+ if torch.xpu.is_available():
+ device_properties = torch.xpu.get_device_properties(0)
+ else:
+ device_properties = torch.cuda.get_device_properties(0)
+
+ total_memory = device_properties.total_memory / (1024**3)
+ return unittest.skipUnless(
+ total_memory >= BIG_GPU_MEMORY,
+ f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
+ )(test_case)
+
+
+def require_torch_accelerator_with_training(test_case):
+ """Decorator marking a test that requires an accelerator with support for training."""
+ return unittest.skipUnless(
+ is_torch_available() and backend_supports_training(torch_device),
+ "test requires accelerator with training support",
+ )(test_case)
+
+
+def skip_mps(test_case):
+ """Decorator marking a test to skip if torch_device is 'mps'"""
+ return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
+
+
+def require_flax(test_case):
+ """
+ Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
+ """
+ return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
+
+
+def require_compel(test_case):
+ """
+ Decorator marking a test that requires compel: https://github.com/damian0815/compel. These tests are skipped when
+ the library is not installed.
+ """
+ return unittest.skipUnless(is_compel_available(), "test requires compel")(test_case)
+
+
+def require_onnxruntime(test_case):
+ """
+ Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed.
+ """
+ return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case)
+
+
+def require_note_seq(test_case):
+ """
+ Decorator marking a test that requires note_seq. These tests are skipped when note_seq isn't installed.
+ """
+ return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case)
+
+
+def require_accelerator(test_case):
+ """
+ Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no
+ hardware accelerator available.
+ """
+ return unittest.skipUnless(torch_device != "cpu", "test requires a hardware accelerator")(test_case)
+
+
+def require_torchsde(test_case):
+ """
+ Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed.
+ """
+ return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case)
+
+
+def require_peft_backend(test_case):
+ """
+ Decorator marking a test that requires PEFT backend, this would require some specific versions of PEFT and
+ transformers.
+ """
+ return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case)
+
+
+def require_timm(test_case):
+ """
+ Decorator marking a test that requires timm. These tests are skipped when timm isn't installed.
+ """
+ return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case)
+
+
+def require_bitsandbytes(test_case):
+ """
+ Decorator marking a test that requires bitsandbytes. These tests are skipped when bitsandbytes isn't installed.
+ """
+ return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case)
+
+
+def require_quanto(test_case):
+ """
+ Decorator marking a test that requires quanto. These tests are skipped when quanto isn't installed.
+ """
+ return unittest.skipUnless(is_optimum_quanto_available(), "test requires quanto")(test_case)
+
+
+def require_accelerate(test_case):
+ """
+ Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
+ """
+ return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
+
+
+def require_peft_version_greater(peft_version):
+ """
+ Decorator marking a test that requires PEFT backend with a specific version, this would require some specific
+ versions of PEFT and transformers.
+ """
+
+ def decorator(test_case):
+ correct_peft_version = is_peft_available() and version.parse(
+ version.parse(importlib.metadata.version("peft")).base_version
+ ) > version.parse(peft_version)
+ return unittest.skipUnless(
+ correct_peft_version, f"test requires PEFT backend with the version greater than {peft_version}"
+ )(test_case)
+
+ return decorator
+
+
+def require_transformers_version_greater(transformers_version):
+ """
+ Decorator marking a test that requires transformers with a specific version, this would require some specific
+ versions of PEFT and transformers.
+ """
+
+ def decorator(test_case):
+ correct_transformers_version = is_transformers_available() and version.parse(
+ version.parse(importlib.metadata.version("transformers")).base_version
+ ) > version.parse(transformers_version)
+ return unittest.skipUnless(
+ correct_transformers_version,
+ f"test requires transformers with the version greater than {transformers_version}",
+ )(test_case)
+
+ return decorator
+
+
+def require_accelerate_version_greater(accelerate_version):
+ def decorator(test_case):
+ correct_accelerate_version = is_accelerate_available() and version.parse(
+ version.parse(importlib.metadata.version("accelerate")).base_version
+ ) > version.parse(accelerate_version)
+ return unittest.skipUnless(
+ correct_accelerate_version, f"Test requires accelerate with the version greater than {accelerate_version}."
+ )(test_case)
+
+ return decorator
+
+
+def require_bitsandbytes_version_greater(bnb_version):
+ def decorator(test_case):
+ correct_bnb_version = is_bitsandbytes_available() and version.parse(
+ version.parse(importlib.metadata.version("bitsandbytes")).base_version
+ ) > version.parse(bnb_version)
+ return unittest.skipUnless(
+ correct_bnb_version, f"Test requires bitsandbytes with the version greater than {bnb_version}."
+ )(test_case)
+
+ return decorator
+
+
+def require_hf_hub_version_greater(hf_hub_version):
+ def decorator(test_case):
+ correct_hf_hub_version = version.parse(
+ version.parse(importlib.metadata.version("huggingface_hub")).base_version
+ ) > version.parse(hf_hub_version)
+ return unittest.skipUnless(
+ correct_hf_hub_version, f"Test requires huggingface_hub with the version greater than {hf_hub_version}."
+ )(test_case)
+
+ return decorator
+
+
+def require_gguf_version_greater_or_equal(gguf_version):
+ def decorator(test_case):
+ correct_gguf_version = is_gguf_available() and version.parse(
+ version.parse(importlib.metadata.version("gguf")).base_version
+ ) >= version.parse(gguf_version)
+ return unittest.skipUnless(
+ correct_gguf_version, f"Test requires gguf with the version greater than {gguf_version}."
+ )(test_case)
+
+ return decorator
+
+
+def require_torchao_version_greater_or_equal(torchao_version):
+ def decorator(test_case):
+ correct_torchao_version = is_torchao_available() and version.parse(
+ version.parse(importlib.metadata.version("torchao")).base_version
+ ) >= version.parse(torchao_version)
+ return unittest.skipUnless(
+ correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}."
+ )(test_case)
+
+ return decorator
+
+
+def require_kernels_version_greater_or_equal(kernels_version):
+ def decorator(test_case):
+ correct_kernels_version = is_kernels_available() and version.parse(
+ version.parse(importlib.metadata.version("kernels")).base_version
+ ) >= version.parse(kernels_version)
+ return unittest.skipUnless(
+ correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}."
+ )(test_case)
+
+ return decorator
+
+
+def deprecate_after_peft_backend(test_case):
+ """
+ Decorator marking a test that will be skipped after PEFT backend
+ """
+ return unittest.skipUnless(not USE_PEFT_BACKEND, "test skipped in favor of PEFT backend")(test_case)
+
+
+def get_python_version():
+ sys_info = sys.version_info
+ major, minor = sys_info.major, sys_info.minor
+ return major, minor
+
+
+def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray:
+ if isinstance(arry, str):
+ if local_path is not None:
+ # local_path can be passed to correct images of tests
+ return Path(local_path, arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]).as_posix()
+ elif arry.startswith("http://") or arry.startswith("https://"):
+ response = requests.get(arry, timeout=DIFFUSERS_REQUEST_TIMEOUT)
+ response.raise_for_status()
+ arry = np.load(BytesIO(response.content))
+ elif os.path.isfile(arry):
+ arry = np.load(arry)
+ else:
+ raise ValueError(
+ f"Incorrect path or url, URLs must start with `http://` or `https://`, and {arry} is not a valid path"
+ )
+ elif isinstance(arry, np.ndarray):
+ pass
+ else:
+ raise ValueError(
+ "Incorrect format used for numpy ndarray. Should be an url linking to an image, a local path, or a"
+ " ndarray."
+ )
+
+ return arry
+
+
+def load_pt(url: str, map_location: Optional[str] = None, weights_only: Optional[bool] = True):
+ response = requests.get(url, timeout=DIFFUSERS_REQUEST_TIMEOUT)
+ response.raise_for_status()
+ arry = torch.load(BytesIO(response.content), map_location=map_location, weights_only=weights_only)
+ return arry
+
+
+def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
+ """
+ Loads `image` to a PIL Image.
+
+ Args:
+ image (`str` or `PIL.Image.Image`):
+ The image to convert to the PIL Image format.
+ Returns:
+ `PIL.Image.Image`:
+ A PIL Image.
+ """
+ if isinstance(image, str):
+ if image.startswith("http://") or image.startswith("https://"):
+ image = PIL.Image.open(requests.get(image, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
+ elif os.path.isfile(image):
+ image = PIL.Image.open(image)
+ else:
+ raise ValueError(
+ f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
+ )
+ elif isinstance(image, PIL.Image.Image):
+ image = image
+ else:
+ raise ValueError(
+ "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
+ )
+ image = PIL.ImageOps.exif_transpose(image)
+ image = image.convert("RGB")
+ return image
+
+
+def preprocess_image(image: PIL.Image, batch_size: int):
+ w, h = image.size
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+def export_to_gif(image: List[PIL.Image.Image], output_gif_path: str = None) -> str:
+ if output_gif_path is None:
+ output_gif_path = tempfile.NamedTemporaryFile(suffix=".gif").name
+
+ image[0].save(
+ output_gif_path,
+ save_all=True,
+ append_images=image[1:],
+ optimize=False,
+ duration=100,
+ loop=0,
+ )
+ return output_gif_path
+
+
+@contextmanager
+def buffered_writer(raw_f):
+ f = io.BufferedWriter(raw_f)
+ yield f
+ f.flush()
+
+
+def export_to_ply(mesh, output_ply_path: str = None):
+ """
+ Write a PLY file for a mesh.
+ """
+ if output_ply_path is None:
+ output_ply_path = tempfile.NamedTemporaryFile(suffix=".ply").name
+
+ coords = mesh.verts.detach().cpu().numpy()
+ faces = mesh.faces.cpu().numpy()
+ rgb = np.stack([mesh.vertex_channels[x].detach().cpu().numpy() for x in "RGB"], axis=1)
+
+ with buffered_writer(open(output_ply_path, "wb")) as f:
+ f.write(b"ply\n")
+ f.write(b"format binary_little_endian 1.0\n")
+ f.write(bytes(f"element vertex {len(coords)}\n", "ascii"))
+ f.write(b"property float x\n")
+ f.write(b"property float y\n")
+ f.write(b"property float z\n")
+ if rgb is not None:
+ f.write(b"property uchar red\n")
+ f.write(b"property uchar green\n")
+ f.write(b"property uchar blue\n")
+ if faces is not None:
+ f.write(bytes(f"element face {len(faces)}\n", "ascii"))
+ f.write(b"property list uchar int vertex_index\n")
+ f.write(b"end_header\n")
+
+ if rgb is not None:
+ rgb = (rgb * 255.499).round().astype(int)
+ vertices = [
+ (*coord, *rgb)
+ for coord, rgb in zip(
+ coords.tolist(),
+ rgb.tolist(),
+ )
+ ]
+ format = struct.Struct("<3f3B")
+ for item in vertices:
+ f.write(format.pack(*item))
+ else:
+ format = struct.Struct("<3f")
+ for vertex in coords.tolist():
+ f.write(format.pack(*vertex))
+
+ if faces is not None:
+ format = struct.Struct(" str:
+ if is_opencv_available():
+ import cv2
+ else:
+ raise ImportError(BACKENDS_MAPPING["opencv"][1].format("export_to_video"))
+ if output_video_path is None:
+ output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
+
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
+ h, w, c = video_frames[0].shape
+ video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=8, frameSize=(w, h))
+ for i in range(len(video_frames)):
+ img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
+ video_writer.write(img)
+ return output_video_path
+
+
+def load_hf_numpy(path) -> np.ndarray:
+ base_url = "https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main"
+
+ if not path.startswith("http://") and not path.startswith("https://"):
+ path = os.path.join(base_url, urllib.parse.quote(path))
+
+ return load_numpy(path)
+
+
+# --- pytest conf functions --- #
+
+# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once
+pytest_opt_registered = {}
+
+
+def pytest_addoption_shared(parser):
+ """
+ This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there.
+
+ It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest`
+ option.
+
+ """
+ option = "--make-reports"
+ if option not in pytest_opt_registered:
+ parser.addoption(
+ option,
+ action="store",
+ default=False,
+ help="generate report files. The value of this option is used as a prefix to report names",
+ )
+ pytest_opt_registered[option] = 1
+
+
+def pytest_terminal_summary_main(tr, id):
+ """
+ Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current
+ directory. The report files are prefixed with the test suite name.
+
+ This function emulates --duration and -rA pytest arguments.
+
+ This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined
+ there.
+
+ Args:
+ - tr: `terminalreporter` passed from `conftest.py`
+ - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is
+ needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other.
+
+ NB: this functions taps into a private _pytest API and while unlikely, it could break should
+ pytest do internal changes - also it calls default internal methods of terminalreporter which
+ can be hijacked by various `pytest-` plugins and interfere.
+
+ """
+ from _pytest.config import create_terminal_writer
+
+ if not len(id):
+ id = "tests"
+
+ config = tr.config
+ orig_writer = config.get_terminal_writer()
+ orig_tbstyle = config.option.tbstyle
+ orig_reportchars = tr.reportchars
+
+ dir = "reports"
+ Path(dir).mkdir(parents=True, exist_ok=True)
+ report_files = {
+ k: f"{dir}/{id}_{k}.txt"
+ for k in [
+ "durations",
+ "errors",
+ "failures_long",
+ "failures_short",
+ "failures_line",
+ "passes",
+ "stats",
+ "summary_short",
+ "warnings",
+ ]
+ }
+
+ # custom durations report
+ # note: there is no need to call pytest --durations=XX to get this separate report
+ # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66
+ dlist = []
+ for replist in tr.stats.values():
+ for rep in replist:
+ if hasattr(rep, "duration"):
+ dlist.append(rep)
+ if dlist:
+ dlist.sort(key=lambda x: x.duration, reverse=True)
+ with open(report_files["durations"], "w") as f:
+ durations_min = 0.05 # sec
+ f.write("slowest durations\n")
+ for i, rep in enumerate(dlist):
+ if rep.duration < durations_min:
+ f.write(f"{len(dlist) - i} durations < {durations_min} secs were omitted")
+ break
+ f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
+
+ def summary_failures_short(tr):
+ # expecting that the reports were --tb=long (default) so we chop them off here to the last frame
+ reports = tr.getreports("failed")
+ if not reports:
+ return
+ tr.write_sep("=", "FAILURES SHORT STACK")
+ for rep in reports:
+ msg = tr._getfailureheadline(rep)
+ tr.write_sep("_", msg, red=True, bold=True)
+ # chop off the optional leading extra frames, leaving only the last one
+ longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S)
+ tr._tw.line(longrepr)
+ # note: not printing out any rep.sections to keep the report short
+
+ # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each
+ # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814
+ # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g.
+ # pytest-instafail does that)
+
+ # report failures with line/short/long styles
+ config.option.tbstyle = "auto" # full tb
+ with open(report_files["failures_long"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_failures()
+
+ # config.option.tbstyle = "short" # short tb
+ with open(report_files["failures_short"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ summary_failures_short(tr)
+
+ config.option.tbstyle = "line" # one line per error
+ with open(report_files["failures_line"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_failures()
+
+ with open(report_files["errors"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_errors()
+
+ with open(report_files["warnings"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_warnings() # normal warnings
+ tr.summary_warnings() # final warnings
+
+ tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary())
+ with open(report_files["passes"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_passes()
+
+ with open(report_files["summary_short"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.short_test_summary()
+
+ with open(report_files["stats"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_stats()
+
+ # restore:
+ tr._tw = orig_writer
+ tr.reportchars = orig_reportchars
+ config.option.tbstyle = orig_tbstyle
+
+
+# Adapted from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers..testing_utils.py#L1905
+def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
+ """
+ To decorate flaky tests (methods or entire classes). They will be retried on failures.
+
+ Args:
+ max_attempts (`int`, *optional*, defaults to 5):
+ The maximum number of attempts to retry the flaky test.
+ wait_before_retry (`float`, *optional*):
+ If provided, will wait that number of seconds before retrying the test.
+ description (`str`, *optional*):
+ A string to describe the situation (what / where / why is flaky, link to GH issue/PR comments, errors,
+ etc.)
+ """
+
+ def decorator(obj):
+ # If decorating a class, wrap each test method on it
+ if inspect.isclass(obj):
+ for attr_name, attr_value in list(obj.__dict__.items()):
+ if callable(attr_value) and attr_name.startswith("test"):
+ # recursively decorate the method
+ setattr(obj, attr_name, decorator(attr_value))
+ return obj
+
+ # Otherwise we're decorating a single test function / method
+ @functools.wraps(obj)
+ def wrapper(*args, **kwargs):
+ retry_count = 1
+ while retry_count < max_attempts:
+ try:
+ return obj(*args, **kwargs)
+ except Exception as err:
+ msg = (
+ f"[FLAKY] {description or obj.__name__!r} "
+ f"failed on attempt {retry_count}/{max_attempts}: {err}"
+ )
+ print(msg, file=sys.stderr)
+ if wait_before_retry is not None:
+ time.sleep(wait_before_retry)
+ retry_count += 1
+
+ return obj(*args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+# Taken from: https://github.com/huggingface/transformers/blob/3658488ff77ff8d45101293e749263acf437f4d5/src/transformers..testing_utils.py#L1787
+def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
+ """
+ To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
+
+ Args:
+ test_case (`unittest.TestCase`):
+ The test that will run `target_func`.
+ target_func (`Callable`):
+ The function implementing the actual testing logic.
+ inputs (`dict`, *optional*, defaults to `None`):
+ The inputs that will be passed to `target_func` through an (input) queue.
+ timeout (`int`, *optional*, defaults to `None`):
+ The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env.
+ variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`.
+ """
+ if timeout is None:
+ timeout = int(os.environ.get("PYTEST_TIMEOUT", 600))
+
+ start_methohd = "spawn"
+ ctx = multiprocessing.get_context(start_methohd)
+
+ input_queue = ctx.Queue(1)
+ output_queue = ctx.JoinableQueue(1)
+
+ # We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle.
+ input_queue.put(inputs, timeout=timeout)
+
+ process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
+ process.start()
+ # Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents
+ # the test to exit properly.
+ try:
+ results = output_queue.get(timeout=timeout)
+ output_queue.task_done()
+ except Exception as e:
+ process.terminate()
+ test_case.fail(e)
+ process.join(timeout=timeout)
+
+ if results["error"] is not None:
+ test_case.fail(f"{results['error']}")
+
+
+class CaptureLogger:
+ """
+ Args:
+ Context manager to capture `logging` streams
+ logger: 'logging` logger object
+ Returns:
+ The captured output is available via `self.out`
+ Example:
+ ```python
+ >>> from diffusers import logging
+ >>> from diffusers..testing_utils import CaptureLogger
+
+ >>> msg = "Testing 1, 2, 3"
+ >>> logging.set_verbosity_info()
+ >>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py")
+ >>> with CaptureLogger(logger) as cl:
+ ... logger.info(msg)
+ >>> assert cl.out, msg + "\n"
+ ```
+ """
+
+ def __init__(self, logger):
+ self.logger = logger
+ self.io = StringIO()
+ self.sh = logging.StreamHandler(self.io)
+ self.out = ""
+
+ def __enter__(self):
+ self.logger.addHandler(self.sh)
+ return self
+
+ def __exit__(self, *exc):
+ self.logger.removeHandler(self.sh)
+ self.out = self.io.getvalue()
+
+ def __repr__(self):
+ return f"captured: {self.out}\n"
+
+
+def enable_full_determinism():
+ """
+ Helper function for reproducible behavior during distributed training. See
+ - https://pytorch.org/docs/stable/notes/randomness.html for pytorch
+ """
+ # Enable PyTorch deterministic mode. This potentially requires either the environment
+ # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
+ # depending on the CUDA version, so we set them both here
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+ torch.use_deterministic_algorithms(True)
+
+ # Enable CUDNN deterministic mode
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ torch.backends.cuda.matmul.allow_tf32 = False
+
+
+def disable_full_determinism():
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ""
+ torch.use_deterministic_algorithms(False)
+
+
+# Utils for custom and alternative accelerator devices
+def _is_torch_fp16_available(device):
+ if not is_torch_available():
+ return False
+
+ import torch
+
+ device = torch.device(device)
+
+ try:
+ x = torch.zeros((2, 2), dtype=torch.float16).to(device)
+ _ = torch.mul(x, x)
+ return True
+
+ except Exception as e:
+ if device.type == "cuda":
+ raise ValueError(
+ f"You have passed a device of type 'cuda' which should work with 'fp16', but 'cuda' does not seem to be correctly installed on your machine: {e}"
+ )
+
+ return False
+
+
+def _is_torch_fp64_available(device):
+ if not is_torch_available():
+ return False
+
+ import torch
+
+ device = torch.device(device)
+
+ try:
+ x = torch.zeros((2, 2), dtype=torch.float64).to(device)
+ _ = torch.mul(x, x)
+ return True
+
+ except Exception as e:
+ if device.type == "cuda":
+ raise ValueError(
+ f"You have passed a device of type 'cuda' which should work with 'fp64', but 'cuda' does not seem to be correctly installed on your machine: {e}"
+ )
+
+ return False
+
+
+# Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch
+if is_torch_available():
+ # Behaviour flags
+ BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}
+
+ # Function definitions
+ BACKEND_EMPTY_CACHE = {
+ "cuda": torch.cuda.empty_cache,
+ "xpu": torch.xpu.empty_cache,
+ "cpu": None,
+ "mps": torch.mps.empty_cache,
+ "default": None,
+ }
+ BACKEND_DEVICE_COUNT = {
+ "cuda": torch.cuda.device_count,
+ "xpu": torch.xpu.device_count,
+ "cpu": lambda: 0,
+ "mps": lambda: 0,
+ "default": 0,
+ }
+ BACKEND_MANUAL_SEED = {
+ "cuda": torch.cuda.manual_seed,
+ "xpu": torch.xpu.manual_seed,
+ "cpu": torch.manual_seed,
+ "mps": torch.mps.manual_seed,
+ "default": torch.manual_seed,
+ }
+ BACKEND_RESET_PEAK_MEMORY_STATS = {
+ "cuda": torch.cuda.reset_peak_memory_stats,
+ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
+ "cpu": None,
+ "mps": None,
+ "default": None,
+ }
+ BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
+ "cuda": torch.cuda.reset_max_memory_allocated,
+ "xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
+ "cpu": None,
+ "mps": None,
+ "default": None,
+ }
+ BACKEND_MAX_MEMORY_ALLOCATED = {
+ "cuda": torch.cuda.max_memory_allocated,
+ "xpu": getattr(torch.xpu, "max_memory_allocated", None),
+ "cpu": 0,
+ "mps": 0,
+ "default": 0,
+ }
+ BACKEND_SYNCHRONIZE = {
+ "cuda": torch.cuda.synchronize,
+ "xpu": getattr(torch.xpu, "synchronize", None),
+ "cpu": None,
+ "mps": None,
+ "default": None,
+ }
+
+
+# This dispatches a defined function according to the accelerator from the function definitions.
+def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs):
+ if device not in dispatch_table:
+ return dispatch_table["default"](*args, **kwargs)
+
+ fn = dispatch_table[device]
+
+ # Some device agnostic functions return values. Need to guard against 'None' instead at
+ # user level
+ if not callable(fn):
+ return fn
+
+ return fn(*args, **kwargs)
+
+
+# These are callables which automatically dispatch the function specific to the accelerator
+def backend_manual_seed(device: str, seed: int):
+ return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
+
+
+def backend_synchronize(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE)
+
+
+def backend_empty_cache(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
+
+
+def backend_device_count(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
+
+
+def backend_reset_peak_memory_stats(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)
+
+
+def backend_reset_max_memory_allocated(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
+
+
+def backend_max_memory_allocated(device: str):
+ return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
+
+
+# These are callables which return boolean behaviour flags and can be used to specify some
+# device agnostic alternative where the feature is unsupported.
+def backend_supports_training(device: str):
+ if not is_torch_available():
+ return False
+
+ if device not in BACKEND_SUPPORTS_TRAINING:
+ device = "default"
+
+ return BACKEND_SUPPORTS_TRAINING[device]
+
+
+# Guard for when Torch is not available
+if is_torch_available():
+ # Update device function dict mapping
+ def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str):
+ try:
+ # Try to import the function directly
+ spec_fn = getattr(device_spec_module, attribute_name)
+ device_fn_dict[torch_device] = spec_fn
+ except AttributeError as e:
+ # If the function doesn't exist, and there is no default, throw an error
+ if "default" not in device_fn_dict:
+ raise AttributeError(
+ f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found."
+ ) from e
+
+ if "DIFFUSERS_TEST_DEVICE_SPEC" in os.environ:
+ device_spec_path = os.environ["DIFFUSERS_TEST_DEVICE_SPEC"]
+ if not Path(device_spec_path).is_file():
+ raise ValueError(f"Specified path to device specification file is not found. Received {device_spec_path}")
+
+ try:
+ import_name = device_spec_path[: device_spec_path.index(".py")]
+ except ValueError as e:
+ raise ValueError(f"Provided device spec file is not a Python file! Received {device_spec_path}") from e
+
+ device_spec_module = importlib.import_module(import_name)
+
+ try:
+ device_name = device_spec_module.DEVICE_NAME
+ except AttributeError:
+ raise AttributeError("Device spec file did not contain `DEVICE_NAME`")
+
+ if "DIFFUSERS_TEST_DEVICE" in os.environ and torch_device != device_name:
+ msg = f"Mismatch between environment variable `DIFFUSERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n"
+ msg += "Either unset `DIFFUSERS_TEST_DEVICE` or ensure it matches device spec name."
+ raise ValueError(msg)
+
+ torch_device = device_name
+
+ # Add one entry here for each `BACKEND_*` dictionary.
+ update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
+ update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
+ update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
+ update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING")
+ update_mapping_from_spec(BACKEND_RESET_PEAK_MEMORY_STATS, "RESET_PEAK_MEMORY_STATS_FN")
+ update_mapping_from_spec(BACKEND_RESET_MAX_MEMORY_ALLOCATED, "RESET_MAX_MEMORY_ALLOCATED_FN")
+ update_mapping_from_spec(BACKEND_MAX_MEMORY_ALLOCATED, "MAX_MEMORY_ALLOCATED_FN")
+
+
+# Modified from https://github.com/huggingface/transformers/blob/cdfb018d0300fef3b07d9220f3efe9c2a9974662/src/transformers..testing_utils.py#L3090
+
+# Type definition of key used in `Expectations` class.
+DeviceProperties = Tuple[Union[str, None], Union[int, None]]
+
+
+@functools.lru_cache
+def get_device_properties() -> DeviceProperties:
+ """
+ Get environment device properties.
+ """
+ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
+ import torch
+
+ major, _ = torch.cuda.get_device_capability()
+ if IS_ROCM_SYSTEM:
+ return ("rocm", major)
+ else:
+ return ("cuda", major)
+ elif IS_XPU_SYSTEM:
+ import torch
+
+ # To get more info of the architecture meaning and bit allocation, refer to https://github.com/intel/llvm/blob/sycl/sycl/include/sycl/ext/oneapi/experimental/device_architecture.def
+ arch = torch.xpu.get_device_capability()["architecture"]
+ gen_mask = 0x000000FF00000000
+ gen = (arch & gen_mask) >> 32
+ return ("xpu", gen)
+ else:
+ return (torch_device, None)
+
+
+if TYPE_CHECKING:
+ DevicePropertiesUserDict = UserDict[DeviceProperties, Any]
+else:
+ DevicePropertiesUserDict = UserDict
+
+if is_torch_available():
+ from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
+ from diffusers.hooks.group_offloading import (
+ _GROUP_ID_LAZY_LEAF,
+ _compute_group_hash,
+ _find_parent_module_in_module_dict,
+ _gather_buffers_with_no_group_offloading_parent,
+ _gather_parameters_with_no_group_offloading_parent,
+ )
+
+ def _get_expected_safetensors_files(
+ module: torch.nn.Module,
+ offload_to_disk_path: str,
+ offload_type: str,
+ num_blocks_per_group: Optional[int] = None,
+ ) -> Set[str]:
+ expected_files = set()
+
+ def get_hashed_filename(group_id: str) -> str:
+ short_hash = _compute_group_hash(group_id)
+ return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors")
+
+ if offload_type == "block_level":
+ if num_blocks_per_group is None:
+ raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.")
+
+ # Handle groups of ModuleList and Sequential blocks
+ unmatched_modules = []
+ for name, submodule in module.named_children():
+ if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
+ unmatched_modules.append(module)
+ continue
+
+ for i in range(0, len(submodule), num_blocks_per_group):
+ current_modules = submodule[i : i + num_blocks_per_group]
+ if not current_modules:
+ continue
+ group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
+ expected_files.add(get_hashed_filename(group_id))
+
+ # Handle the group for unmatched top-level modules and parameters
+ for module in unmatched_modules:
+ expected_files.add(get_hashed_filename(f"{module.__class__.__name__}_unmatched_group"))
+
+ elif offload_type == "leaf_level":
+ # Handle leaf-level module groups
+ for name, submodule in module.named_modules():
+ if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
+ # These groups will always have parameters, so a file is expected
+ expected_files.add(get_hashed_filename(name))
+
+ # Handle groups for non-leaf parameters/buffers
+ modules_with_group_offloading = {
+ name for name, sm in module.named_modules() if isinstance(sm, _GO_LC_SUPPORTED_PYTORCH_LAYERS)
+ }
+ parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
+ buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
+
+ all_orphans = parameters + buffers
+ if all_orphans:
+ parent_to_tensors = {}
+ module_dict = dict(module.named_modules())
+ for tensor_name, _ in all_orphans:
+ parent_name = _find_parent_module_in_module_dict(tensor_name, module_dict)
+ if parent_name not in parent_to_tensors:
+ parent_to_tensors[parent_name] = []
+ parent_to_tensors[parent_name].append(tensor_name)
+
+ for parent_name in parent_to_tensors:
+ # A file is expected for each parent that gathers orphaned tensors
+ expected_files.add(get_hashed_filename(parent_name))
+ expected_files.add(get_hashed_filename(_GROUP_ID_LAZY_LEAF))
+
+ else:
+ raise ValueError(f"Unsupported offload_type: {offload_type}")
+
+ return expected_files
+
+ def _check_safetensors_serialization(
+ module: torch.nn.Module,
+ offload_to_disk_path: str,
+ offload_type: str,
+ num_blocks_per_group: Optional[int] = None,
+ ) -> bool:
+ if not os.path.isdir(offload_to_disk_path):
+ return False, None, None
+
+ expected_files = _get_expected_safetensors_files(
+ module, offload_to_disk_path, offload_type, num_blocks_per_group
+ )
+ actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors")))
+ missing_files = expected_files - actual_files
+ extra_files = actual_files - expected_files
+
+ is_correct = not missing_files and not extra_files
+ return is_correct, extra_files, missing_files
+
+
+class Expectations(DevicePropertiesUserDict):
+ def get_expectation(self) -> Any:
+ """
+ Find best matching expectation based on environment device properties.
+ """
+ return self.find_expectation(get_device_properties())
+
+ @staticmethod
+ def is_default(key: DeviceProperties) -> bool:
+ return all(p is None for p in key)
+
+ @staticmethod
+ def score(key: DeviceProperties, other: DeviceProperties) -> int:
+ """
+ Returns score indicating how similar two instances of the `Properties` tuple are. Points are calculated using
+ bits, but documented as int. Rules are as follows:
+ * Matching `type` gives 8 points.
+ * Semi-matching `type`, for example cuda and rocm, gives 4 points.
+ * Matching `major` (compute capability major version) gives 2 points.
+ * Default expectation (if present) gives 1 points.
+ """
+ (device_type, major) = key
+ (other_device_type, other_major) = other
+
+ score = 0b0
+ if device_type == other_device_type:
+ score |= 0b1000
+ elif device_type in ["cuda", "rocm"] and other_device_type in ["cuda", "rocm"]:
+ score |= 0b100
+
+ if major == other_major and other_major is not None:
+ score |= 0b10
+
+ if Expectations.is_default(other):
+ score |= 0b1
+
+ return int(score)
+
+ def find_expectation(self, key: DeviceProperties = (None, None)) -> Any:
+ """
+ Find best matching expectation based on provided device properties.
+ """
+ (result_key, result) = max(self.data.items(), key=lambda x: Expectations.score(key, x[0]))
+
+ if Expectations.score(key, result_key) == 0:
+ raise ValueError(f"No matching expectation found for {key}")
+
+ return result
+
+ def __repr__(self):
+ return f"{self.data}"