mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'main' into 1d_blocks
This commit is contained in:
5
.github/ISSUE_TEMPLATE/config.yml
vendored
5
.github/ISSUE_TEMPLATE/config.yml
vendored
@@ -1,7 +1,4 @@
|
||||
contact_links:
|
||||
- name: Forum
|
||||
url: https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63
|
||||
about: General usage questions and community discussions
|
||||
- name: Blank issue
|
||||
url: https://github.com/huggingface/diffusers/issues/new
|
||||
about: Please note that the Forum is in most places the right place for discussions
|
||||
about: General usage questions and community discussions
|
||||
|
||||
50
.github/workflows/build_docker_images.yml
vendored
Normal file
50
.github/workflows/build_docker_images.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
name: Build Docker images (nightly)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: "0 0 * * *" # every day at midnight
|
||||
|
||||
concurrency:
|
||||
group: docker-image-builds
|
||||
cancel-in-progress: false
|
||||
|
||||
env:
|
||||
REGISTRY: diffusers
|
||||
|
||||
jobs:
|
||||
build-docker-images:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
image-name:
|
||||
- diffusers-pytorch-cpu
|
||||
- diffusers-pytorch-cuda
|
||||
- diffusers-flax-cpu
|
||||
- diffusers-flax-tpu
|
||||
- diffusers-onnxruntime-cpu
|
||||
- diffusers-onnxruntime-cuda
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ env.REGISTRY }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v3
|
||||
with:
|
||||
no-cache: true
|
||||
context: ./docker/${{ matrix.image-name }}
|
||||
push: true
|
||||
tags: ${{ env.REGISTRY }}/${{ matrix.image-name }}:latest
|
||||
5
.github/workflows/build_pr_documentation.yml
vendored
5
.github/workflows/build_pr_documentation.yml
vendored
@@ -9,8 +9,11 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@use_hf_hub
|
||||
with:
|
||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||
pr_number: ${{ github.event.number }}
|
||||
package: diffusers
|
||||
secrets:
|
||||
token: ${{ secrets.HF_DOC_PUSH }}
|
||||
comment_bot_token: ${{ secrets.HUGGINGFACE_PUSH }}
|
||||
|
||||
5
.github/workflows/delete_doc_comment.yml
vendored
5
.github/workflows/delete_doc_comment.yml
vendored
@@ -7,7 +7,10 @@ on:
|
||||
|
||||
jobs:
|
||||
delete:
|
||||
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main
|
||||
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@use_hf_hub
|
||||
with:
|
||||
pr_number: ${{ github.event.number }}
|
||||
package: diffusers
|
||||
secrets:
|
||||
token: ${{ secrets.HF_DOC_PUSH }}
|
||||
comment_bot_token: ${{ secrets.HUGGINGFACE_PUSH }}
|
||||
|
||||
80
.github/workflows/pr_tests.yml
vendored
80
.github/workflows/pr_tests.yml
vendored
@@ -11,19 +11,45 @@ concurrency:
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
OMP_NUM_THREADS: 4
|
||||
MKL_NUM_THREADS: 4
|
||||
PYTEST_TIMEOUT: 60
|
||||
MPS_TORCH_VERSION: 1.13.0
|
||||
|
||||
jobs:
|
||||
run_tests_cpu:
|
||||
name: CPU tests on Ubuntu
|
||||
runs-on: [ self-hosted, docker-gpu ]
|
||||
run_fast_tests:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- name: Fast PyTorch CPU tests on Ubuntu
|
||||
framework: pytorch
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu
|
||||
- name: Fast Flax CPU tests on Ubuntu
|
||||
framework: flax
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-flax-cpu
|
||||
report: flax_cpu
|
||||
- name: Fast ONNXRuntime CPU tests on Ubuntu
|
||||
framework: onnxruntime
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-onnxruntime-cpu
|
||||
report: onnx_cpu
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
|
||||
runs-on: ${{ matrix.config.runner }}
|
||||
|
||||
container:
|
||||
image: python:3.7
|
||||
image: ${{ matrix.config.image }}
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
@@ -32,8 +58,6 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate
|
||||
|
||||
@@ -41,25 +65,43 @@ jobs:
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run all fast tests on CPU
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
- name: Run fast PyTorch CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch' }}
|
||||
run: |
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_cpu tests/
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run fast Flax TPU tests
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
run: |
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run fast ONNXRuntime CPU tests
|
||||
if: ${{ matrix.config.framework == 'onnxruntime' }}
|
||||
run: |
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_cpu_failures_short.txt
|
||||
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: pr_torch_cpu_test_reports
|
||||
name: pr_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
run_tests_apple_m1:
|
||||
name: MPS tests on Apple M1
|
||||
run_fast_tests_apple_m1:
|
||||
name: Fast PyTorch MPS tests on MacOS
|
||||
runs-on: [ self-hosted, apple-m1 ]
|
||||
|
||||
steps:
|
||||
@@ -91,12 +133,10 @@ jobs:
|
||||
run: |
|
||||
${CONDA_RUN} python utils/print_env.py
|
||||
|
||||
- name: Run all fast tests on MPS
|
||||
- name: Run fast PyTorch tests on M1 (MPS)
|
||||
shell: arch -arch arm64 bash {0}
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/
|
||||
${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
|
||||
91
.github/workflows/push_tests.yml
vendored
91
.github/workflows/push_tests.yml
vendored
@@ -14,12 +14,38 @@ env:
|
||||
RUN_SLOW: yes
|
||||
|
||||
jobs:
|
||||
run_tests_single_gpu:
|
||||
name: Diffusers tests
|
||||
runs-on: [ self-hosted, docker-gpu, single-gpu ]
|
||||
run_slow_tests:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- name: Slow PyTorch CUDA tests on Ubuntu
|
||||
framework: pytorch
|
||||
runner: docker-gpu
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
report: torch_cuda
|
||||
- name: Slow Flax TPU tests on Ubuntu
|
||||
framework: flax
|
||||
runner: docker-tpu
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
report: flax_tpu
|
||||
- name: Slow ONNXRuntime CUDA tests on Ubuntu
|
||||
framework: onnxruntime
|
||||
runner: docker-gpu
|
||||
image: diffusers/diffusers-onnxruntime-cuda
|
||||
report: onnx_cuda
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
|
||||
runs-on: ${{ matrix.config.runner }}
|
||||
|
||||
container:
|
||||
image: nvcr.io/nvidia/pytorch:22.07-py3
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache
|
||||
image: ${{ matrix.config.image }}
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ ${{ matrix.config.runner == 'docker-tpu' && '--privileged' || '--gpus 0'}}
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
@@ -28,14 +54,12 @@ jobs:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
if : ${{ matrix.config.runner == 'docker-gpu' }}
|
||||
run: |
|
||||
nvidia-smi
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip uninstall -y torch torchvision torchtext
|
||||
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate
|
||||
|
||||
@@ -43,29 +67,55 @@ jobs:
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run all (incl. slow) tests on GPU
|
||||
- name: Run slow PyTorch CUDA tests
|
||||
if: ${{ matrix.config.framework == 'pytorch' }}
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_gpu tests/
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run slow Flax TPU tests
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 0 \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run slow ONNXRuntime CUDA tests
|
||||
if: ${{ matrix.config.framework == 'onnxruntime' }}
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_gpu_failures_short.txt
|
||||
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: torch_test_reports
|
||||
name: ${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
run_examples_single_gpu:
|
||||
name: Examples tests
|
||||
runs-on: [ self-hosted, docker-gpu, single-gpu ]
|
||||
run_examples_tests:
|
||||
name: Examples PyTorch CUDA tests on Ubuntu
|
||||
|
||||
runs-on: docker-gpu
|
||||
|
||||
container:
|
||||
image: nvcr.io/nvidia/pytorch:22.07-py3
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
@@ -79,9 +129,6 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip uninstall -y torch torchvision torchtext
|
||||
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
python -m pip install -e .[quality,test,training]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate
|
||||
|
||||
@@ -93,11 +140,11 @@ jobs:
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_gpu examples/
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/examples_torch_gpu_failures_short.txt
|
||||
run: cat reports/examples_torch_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -163,4 +163,6 @@ tags
|
||||
*.lock
|
||||
|
||||
# DS_Store (MacOS)
|
||||
.DS_Store
|
||||
.DS_Store
|
||||
# RL pipelines may produce mp4 outputs
|
||||
*.mp4
|
||||
22
README.md
22
README.md
@@ -27,10 +27,12 @@ More precisely, 🤗 Diffusers offers:
|
||||
|
||||
## Installation
|
||||
|
||||
### For PyTorch
|
||||
|
||||
**With `pip`**
|
||||
|
||||
```bash
|
||||
pip install --upgrade diffusers
|
||||
pip install --upgrade diffusers[torch]
|
||||
```
|
||||
|
||||
**With `conda`**
|
||||
@@ -39,6 +41,14 @@ pip install --upgrade diffusers
|
||||
conda install -c conda-forge diffusers
|
||||
```
|
||||
|
||||
### For Flax
|
||||
|
||||
**With `pip`**
|
||||
|
||||
```bash
|
||||
pip install --upgrade diffusers[flax]
|
||||
```
|
||||
|
||||
**Apple Silicon (M1/M2) support**
|
||||
|
||||
Please, refer to [the documentation](https://huggingface.co/docs/diffusers/optimization/mps).
|
||||
@@ -336,14 +346,14 @@ Textual Inversion is a technique for capturing novel concepts from a small numbe
|
||||
|
||||
- Textual Inversion. Capture novel concepts from a small set of sample images, and associate them with new "words" in the embedding space of the text encoder. Please, refer to [our training examples](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) or [documentation](https://huggingface.co/docs/diffusers/training/text_inversion) to try for yourself.
|
||||
|
||||
- Dreambooth. Another technique to capture new concepts in Stable Diffusion. This method fine-tunes the UNet (and, optionally, also the text encoder) of the pipeline to achieve impressive results. Please, refer to [our training examples](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) and [training report](https://wandb.ai/psuraj/dreambooth/reports/Dreambooth-Training-Analysis--VmlldzoyNzk0NDc3) for additional details and training recommendations.
|
||||
- Dreambooth. Another technique to capture new concepts in Stable Diffusion. This method fine-tunes the UNet (and, optionally, also the text encoder) of the pipeline to achieve impressive results. Please, refer to [our training example](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) and [training report](https://huggingface.co/blog/dreambooth) for additional details and training recommendations.
|
||||
|
||||
- Full Stable Diffusion fine-tuning. If you have a more sizable dataset with a specific look or style, you can fine-tune Stable Diffusion so that it outputs images following those examples. This was the approach taken to create [a Pokémon Stable Diffusion model](https://huggingface.co/justinpinkney/pokemon-stable-diffusion) (by Justing Pinkney / Lambda Labs), [a Japanese specific version of Stable Diffusion](https://huggingface.co/spaces/rinna/japanese-stable-diffusion) (by [Rinna Co.](https://github.com/rinnakk/japanese-stable-diffusion/) and others. You can start at [our text-to-image fine-tuning example](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) and go from there.
|
||||
|
||||
|
||||
## Stable Diffusion Community Pipelines
|
||||
|
||||
The release of Stable Diffusion as an open source model has fostered a lot of interesting ideas and experimentation. Our [Community Examples folder](https://github.com/huggingface/diffusers/tree/main/examples/community) contains many ideas worth exploring, like interpolating to create animated videos, using CLIP Guidance for additional prompt fidelity, term weighting, and much more! Take a look and [contribute your own](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipelines).
|
||||
The release of Stable Diffusion as an open source model has fostered a lot of interesting ideas and experimentation. Our [Community Examples folder](https://github.com/huggingface/diffusers/tree/main/examples/community) contains many ideas worth exploring, like interpolating to create animated videos, using CLIP Guidance for additional prompt fidelity, term weighting, and much more! [Take a look](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview) and [contribute your own](https://huggingface.co/docs/diffusers/using-diffusers/contribute_pipeline).
|
||||
|
||||
## Other Examples
|
||||
|
||||
@@ -354,7 +364,7 @@ There are many ways to try running Diffusers! Here we outline code-focused tools
|
||||
If you want to run the code yourself 💻, you can try out:
|
||||
- [Text-to-Image Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256)
|
||||
```python
|
||||
# !pip install diffusers transformers
|
||||
# !pip install diffusers["torch"] transformers
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
device = "cuda"
|
||||
@@ -373,7 +383,7 @@ image.save("squirrel.png")
|
||||
```
|
||||
- [Unconditional Diffusion with discrete scheduler](https://huggingface.co/google/ddpm-celebahq-256)
|
||||
```python
|
||||
# !pip install diffusers
|
||||
# !pip install diffusers["torch"]
|
||||
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline
|
||||
|
||||
model_id = "google/ddpm-celebahq-256"
|
||||
@@ -418,7 +428,7 @@ If you just want to play around with some web demos, you can try out the followi
|
||||
<p>
|
||||
|
||||
**Schedulers**: Algorithm class for both **inference** and **training**.
|
||||
The class provides functionality to compute previous image according to alpha, beta schedule as well as predict noise for training.
|
||||
The class provides functionality to compute previous image according to alpha, beta schedule as well as predict noise for training. Also known as **Samplers**.
|
||||
*Examples*: [DDPM](https://arxiv.org/abs/2006.11239), [DDIM](https://arxiv.org/abs/2010.02502), [PNDM](https://arxiv.org/abs/2202.09778), [DEIS](https://arxiv.org/abs/2204.13902)
|
||||
|
||||
<p align="center">
|
||||
|
||||
42
docker/diffusers-flax-cpu/Dockerfile
Normal file
42
docker/diffusers-flax-cpu/Dockerfile
Normal file
@@ -0,0 +1,42 @@
|
||||
FROM ubuntu:20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt update && \
|
||||
apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
python3.8-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --upgrade --no-cache-dir \
|
||||
clu \
|
||||
"jax[cpu]>=0.2.16,!=0.3.2" \
|
||||
"flax>=0.4.1" \
|
||||
"jaxlib>=0.1.65" && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
modelcards \
|
||||
numpy \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
44
docker/diffusers-flax-tpu/Dockerfile
Normal file
44
docker/diffusers-flax-tpu/Dockerfile
Normal file
@@ -0,0 +1,44 @@
|
||||
FROM ubuntu:20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt update && \
|
||||
apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
python3.8-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
"jax[tpu]>=0.2.16,!=0.3.2" \
|
||||
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \
|
||||
python3 -m pip install --upgrade --no-cache-dir \
|
||||
clu \
|
||||
"flax>=0.4.1" \
|
||||
"jaxlib>=0.1.65" && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
modelcards \
|
||||
numpy \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
42
docker/diffusers-onnxruntime-cpu/Dockerfile
Normal file
42
docker/diffusers-onnxruntime-cpu/Dockerfile
Normal file
@@ -0,0 +1,42 @@
|
||||
FROM ubuntu:20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt update && \
|
||||
apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
python3.8-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
onnxruntime \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
modelcards \
|
||||
numpy \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
42
docker/diffusers-onnxruntime-cuda/Dockerfile
Normal file
42
docker/diffusers-onnxruntime-cuda/Dockerfile
Normal file
@@ -0,0 +1,42 @@
|
||||
FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt update && \
|
||||
apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
python3.8-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
"onnxruntime-gpu>=1.13.1" \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu117 && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
modelcards \
|
||||
numpy \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
41
docker/diffusers-pytorch-cpu/Dockerfile
Normal file
41
docker/diffusers-pytorch-cpu/Dockerfile
Normal file
@@ -0,0 +1,41 @@
|
||||
FROM ubuntu:20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt update && \
|
||||
apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
python3.8-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
modelcards \
|
||||
numpy \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
41
docker/diffusers-pytorch-cuda/Dockerfile
Normal file
41
docker/diffusers-pytorch-cuda/Dockerfile
Normal file
@@ -0,0 +1,41 @@
|
||||
FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt update && \
|
||||
apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
python3.8-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu117 && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
modelcards \
|
||||
numpy \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -78,6 +78,8 @@
|
||||
- sections:
|
||||
- local: api/pipelines/overview
|
||||
title: "Overview"
|
||||
- local: api/pipelines/cycle_diffusion
|
||||
title: "Cycle Diffusion"
|
||||
- local: api/pipelines/ddim
|
||||
title: "DDIM"
|
||||
- local: api/pipelines/ddpm
|
||||
@@ -96,5 +98,9 @@
|
||||
title: "Stochastic Karras VE"
|
||||
- local: api/pipelines/dance_diffusion
|
||||
title: "Dance Diffusion"
|
||||
- local: api/pipelines/vq_diffusion
|
||||
title: "VQ Diffusion"
|
||||
- local: api/pipelines/repaint
|
||||
title: "RePaint"
|
||||
title: "Pipelines"
|
||||
title: "API"
|
||||
|
||||
@@ -22,12 +22,15 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
|
||||
## UNet2DOutput
|
||||
[[autodoc]] models.unet_2d.UNet2DOutput
|
||||
|
||||
## UNet1DModel
|
||||
[[autodoc]] UNet1DModel
|
||||
|
||||
## UNet2DModel
|
||||
[[autodoc]] UNet2DModel
|
||||
|
||||
## UNet1DOutput
|
||||
[[autodoc]] models.unet_1d.UNet1DOutput
|
||||
|
||||
## UNet1DModel
|
||||
[[autodoc]] UNet1DModel
|
||||
|
||||
## UNet2DConditionOutput
|
||||
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput
|
||||
|
||||
@@ -49,6 +52,12 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
|
||||
## AutoencoderKL
|
||||
[[autodoc]] AutoencoderKL
|
||||
|
||||
## Transformer2DModel
|
||||
[[autodoc]] Transformer2DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
[[autodoc]] models.attention.Transformer2DModelOutput
|
||||
|
||||
## FlaxModelMixin
|
||||
[[autodoc]] FlaxModelMixin
|
||||
|
||||
|
||||
99
docs/source/api/pipelines/cycle_diffusion.mdx
Normal file
99
docs/source/api/pipelines/cycle_diffusion.mdx
Normal file
@@ -0,0 +1,99 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Cycle Diffusion
|
||||
|
||||
## Overview
|
||||
|
||||
Cycle Diffusion is a Text-Guided Image-to-Image Generation model proposed in [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://arxiv.org/abs/2210.05559) by Chen Henry Wu, Fernando De la Torre.
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
*Diffusion models have achieved unprecedented performance in generative modeling. The commonly-adopted formulation of the latent code of diffusion models is a sequence of gradually denoised samples, as opposed to the simpler (e.g., Gaussian) latent space of GANs, VAEs, and normalizing flows. This paper provides an alternative, Gaussian formulation of the latent space of various diffusion models, as well as an invertible DPM-Encoder that maps images into the latent space. While our formulation is purely based on the definition of diffusion models, we demonstrate several intriguing consequences. (1) Empirically, we observe that a common latent space emerges from two diffusion models trained independently on related domains. In light of this finding, we propose CycleDiffusion, which uses DPM-Encoder for unpaired image-to-image translation. Furthermore, applying CycleDiffusion to text-to-image diffusion models, we show that large-scale text-to-image diffusion models can be used as zero-shot image-to-image editors. (2) One can guide pre-trained diffusion models and GANs by controlling the latent codes in a unified, plug-and-play formulation based on energy-based models. Using the CLIP model and a face recognition model as guidance, we demonstrate that diffusion models have better coverage of low-density sub-populations and individuals than GANs.*
|
||||
|
||||
*Tips*:
|
||||
- The Cycle Diffusion pipeline is fully compatible with any [Stable Diffusion](./stable_diffusion) checkpoints
|
||||
- Currently Cycle Diffusion only works with the [`DDIMScheduler`].
|
||||
|
||||
*Example*:
|
||||
|
||||
In the following we should how to best use the [`CycleDiffusionPipeline`]
|
||||
|
||||
```python
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
from diffusers import CycleDiffusionPipeline, DDIMScheduler
|
||||
|
||||
# load the pipeline
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
model_id_or_path = "CompVis/stable-diffusion-v1-4"
|
||||
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
|
||||
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
|
||||
|
||||
# let's download an initial image
|
||||
url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/An%20astronaut%20riding%20a%20horse.png"
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
init_image = init_image.resize((512, 512))
|
||||
init_image.save("horse.png")
|
||||
|
||||
# let's specify a prompt
|
||||
source_prompt = "An astronaut riding a horse"
|
||||
prompt = "An astronaut riding an elephant"
|
||||
|
||||
# call the pipeline
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
source_prompt=source_prompt,
|
||||
init_image=init_image,
|
||||
num_inference_steps=100,
|
||||
eta=0.1,
|
||||
strength=0.8,
|
||||
guidance_scale=2,
|
||||
source_guidance_scale=1,
|
||||
).images[0]
|
||||
|
||||
image.save("horse_to_elephant.png")
|
||||
|
||||
# let's try another example
|
||||
# See more samples at the original repo: https://github.com/ChenWu98/cycle-diffusion
|
||||
url = "https://raw.githubusercontent.com/ChenWu98/cycle-diffusion/main/data/dalle2/A%20black%20colored%20car.png"
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
init_image = init_image.resize((512, 512))
|
||||
init_image.save("black.png")
|
||||
|
||||
source_prompt = "A black colored car"
|
||||
prompt = "A blue colored car"
|
||||
|
||||
# call the pipeline
|
||||
torch.manual_seed(0)
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
source_prompt=source_prompt,
|
||||
init_image=init_image,
|
||||
num_inference_steps=100,
|
||||
eta=0.1,
|
||||
strength=0.85,
|
||||
guidance_scale=3,
|
||||
source_guidance_scale=1,
|
||||
).images[0]
|
||||
|
||||
image.save("black_to_blue.png")
|
||||
```
|
||||
|
||||
## CycleDiffusionPipeline
|
||||
[[autodoc]] CycleDiffusionPipeline
|
||||
- __call__
|
||||
@@ -20,7 +20,8 @@ The abstract of the paper is the following:
|
||||
|
||||
Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample. To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.
|
||||
|
||||
The original codebase of this paper can be found [here](https://github.com/ermongroup/ddim).
|
||||
The original codebase of this paper can be found here: [ermongroup/ddim](https://github.com/ermongroup/ddim).
|
||||
For questions, feel free to contact the author on [tsong.me](https://tsong.me/).
|
||||
|
||||
## Available Pipelines:
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff
|
||||
| Pipeline | Tasks | Colab
|
||||
|---|---|:---:|
|
||||
| [pipeline_latent_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py) | *Text-to-Image Generation* | - |
|
||||
| [pipeline_latent_diffusion_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py) | *Super Resolution* | - |
|
||||
|
||||
## Examples:
|
||||
|
||||
@@ -40,3 +41,7 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff
|
||||
## LDMTextToImagePipeline
|
||||
[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion.LDMTextToImagePipeline
|
||||
- __call__
|
||||
|
||||
## LDMSuperResolutionPipeline
|
||||
[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion_superresolution.LDMSuperResolutionPipeline
|
||||
- __call__
|
||||
|
||||
@@ -28,7 +28,7 @@ or created independently from each other.
|
||||
|
||||
To that end, we strive to offer all open-sourced, state-of-the-art diffusion system under a unified API.
|
||||
More specifically, we strive to provide pipelines that
|
||||
- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LatentDiffusionPipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)),
|
||||
- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LDMTextToImagePipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)),
|
||||
- 2. have a simple user interface to run the model in inference (see the [Pipelines API](#pipelines-api) section),
|
||||
- 3. are easy to understand with code that is self-explanatory and can be read along-side the official paper (see [Pipelines summary](#pipelines-summary)),
|
||||
- 4. can easily be contributed by the community (see the [Contribution](#contribution) section).
|
||||
@@ -41,19 +41,24 @@ If you are looking for *official* training examples, please have a look at [exam
|
||||
The following table summarizes all officially supported pipelines, their corresponding paper, and if
|
||||
available a colab notebook to directly try them out.
|
||||
|
||||
|
||||
| Pipeline | Paper | Tasks | Colab
|
||||
|---|---|:---:|:---:|
|
||||
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
|
||||
| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
|
||||
| [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
|
||||
| [pndm](./pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
|
||||
| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
|
||||
| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
|
||||
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
||||
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
||||
| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
|
||||
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
|
||||
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
|
||||
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
|
||||
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
|
||||
| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
|
||||
| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
|
||||
| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
|
||||
| [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
||||
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
|
||||
|
||||
|
||||
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
|
||||
|
||||
|
||||
77
docs/source/api/pipelines/repaint.mdx
Normal file
77
docs/source/api/pipelines/repaint.mdx
Normal file
@@ -0,0 +1,77 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# RePaint
|
||||
|
||||
## Overview
|
||||
|
||||
[RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2201.09865) (PNDM) by Andreas Lugmayr, Martin Danelljan, Andres Romero, Fisher Yu, Radu Timofte, Luc Van Gool.
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
Free-form inpainting is the task of adding new content to an image in the regions specified by an arbitrary binary mask. Most existing approaches train for a certain distribution of masks, which limits their generalization capabilities to unseen mask types. Furthermore, training with pixel-wise and perceptual losses often leads to simple textural extensions towards the missing areas instead of semantically meaningful generation. In this work, we propose RePaint: A Denoising Diffusion Probabilistic Model (DDPM) based inpainting approach that is applicable to even extreme masks. We employ a pretrained unconditional DDPM as the generative prior. To condition the generation process, we only alter the reverse diffusion iterations by sampling the unmasked regions using the given image information. Since this technique does not modify or condition the original DDPM network itself, the model produces high-quality and diverse output images for any inpainting form. We validate our method for both faces and general-purpose image inpainting using standard and extreme masks.
|
||||
RePaint outperforms state-of-the-art Autoregressive, and GAN approaches for at least five out of six mask distributions.
|
||||
|
||||
The original codebase can be found [here](https://github.com/andreas128/RePaint).
|
||||
|
||||
## Available Pipelines:
|
||||
|
||||
| Pipeline | Tasks | Colab
|
||||
|-------------------------------------------------------------------------------------------------------------------------------|--------------------|:---:|
|
||||
| [pipeline_repaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/repaint/pipeline_repaint.py) | *Image Inpainting* | - |
|
||||
|
||||
## Usage example
|
||||
|
||||
```python
|
||||
from io import BytesIO
|
||||
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
import requests
|
||||
from diffusers import RePaintPipeline, RePaintScheduler
|
||||
|
||||
|
||||
def download_image(url):
|
||||
response = requests.get(url)
|
||||
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
|
||||
img_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/celeba_hq_256.png"
|
||||
mask_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png"
|
||||
|
||||
# Load the original image and the mask as PIL images
|
||||
original_image = download_image(img_url).resize((256, 256))
|
||||
mask_image = download_image(mask_url).resize((256, 256))
|
||||
|
||||
# Load the RePaint scheduler and pipeline based on a pretrained DDPM model
|
||||
scheduler = RePaintScheduler.from_config("google/ddpm-ema-celebahq-256")
|
||||
pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
generator = torch.Generator(device="cuda").manual_seed(0)
|
||||
output = pipe(
|
||||
original_image=original_image,
|
||||
mask_image=mask_image,
|
||||
num_inference_steps=250,
|
||||
eta=0.0,
|
||||
jump_length=10,
|
||||
jump_n_sample=10,
|
||||
generator=generator,
|
||||
)
|
||||
inpainted_image = output.images[0]
|
||||
```
|
||||
|
||||
## RePaintPipeline
|
||||
[[autodoc]] pipelines.repaint.pipeline_repaint.RePaintPipeline
|
||||
- __call__
|
||||
|
||||
@@ -31,6 +31,21 @@ For more details about how Stable Diffusion works and how it differs from the ba
|
||||
|
||||
## Tips
|
||||
|
||||
### How to load and use different schedulers.
|
||||
|
||||
The stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
|
||||
To use a different scheduler, you can pass the `scheduler` argument to `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
|
||||
|
||||
euler_scheduler = EulerDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler)
|
||||
```
|
||||
|
||||
|
||||
### How to conver all use cases with multiple or single pipeline
|
||||
|
||||
If you want to use all possible use cases in a single `DiffusionPipeline` you can either:
|
||||
- Make use of the [Stable Diffusion Mega Pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#stable-diffusion-mega) or
|
||||
- Make use of the `components` functionality to instantiate all components in the most memory-efficient way:
|
||||
|
||||
34
docs/source/api/pipelines/vq_diffusion.mdx
Normal file
34
docs/source/api/pipelines/vq_diffusion.mdx
Normal file
@@ -0,0 +1,34 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# VQDiffusion
|
||||
|
||||
## Overview
|
||||
|
||||
[Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) by Shuyang Gu, Dong Chen, Jianmin Bao, Fang Wen, Bo Zhang, Dongdong Chen, Lu Yuan, Baining Guo
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
We present the vector quantized diffusion (VQ-Diffusion) model for text-to-image generation. This method is based on a vector quantized variational autoencoder (VQ-VAE) whose latent space is modeled by a conditional variant of the recently developed Denoising Diffusion Probabilistic Model (DDPM). We find that this latent-space method is well-suited for text-to-image generation tasks because it not only eliminates the unidirectional bias with existing methods but also allows us to incorporate a mask-and-replace diffusion strategy to avoid the accumulation of errors, which is a serious problem with existing methods. Our experiments show that the VQ-Diffusion produces significantly better text-to-image generation results when compared with conventional autoregressive (AR) models with similar numbers of parameters. Compared with previous GAN-based text-to-image methods, our VQ-Diffusion can handle more complex scenes and improve the synthesized image quality by a large margin. Finally, we show that the image generation computation in our method can be made highly efficient by reparameterization. With traditional AR methods, the text-to-image generation time increases linearly with the output image resolution and hence is quite time consuming even for normal size images. The VQ-Diffusion allows us to achieve a better trade-off between quality and speed. Our experiments indicate that the VQ-Diffusion model with the reparameterization is fifteen times faster than traditional AR methods while achieving a better image quality.
|
||||
|
||||
The original codebase can be found [here](https://github.com/microsoft/VQ-Diffusion).
|
||||
|
||||
## Available Pipelines:
|
||||
|
||||
| Pipeline | Tasks | Colab
|
||||
|---|---|:---:|
|
||||
| [pipeline_vq_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py) | *Text-to-Image Generation* | - |
|
||||
|
||||
|
||||
## VQDiffusionPipeline
|
||||
[[autodoc]] pipelines.vq_diffusion.pipeline_vq_diffusion.VQDiffusionPipeline
|
||||
- __call__
|
||||
@@ -16,7 +16,7 @@ Diffusers contains multiple pre-built schedule functions for the diffusion proce
|
||||
|
||||
## What is a scheduler?
|
||||
|
||||
The schedule functions, denoted *Schedulers* in the library take in the output of a trained model, a sample which the diffusion process is iterating on, and a timestep to return a denoised sample.
|
||||
The schedule functions, denoted *Schedulers* in the library take in the output of a trained model, a sample which the diffusion process is iterating on, and a timestep to return a denoised sample. That's why schedulers may also be called *Samplers* in other diffusion models implementations.
|
||||
|
||||
- Schedulers define the methodology for iteratively adding noise to an image or for updating a sample based on model outputs.
|
||||
- adding noise in different manners represent the algorithmic processes to train a diffusion model by adding noise to images.
|
||||
@@ -70,6 +70,12 @@ Original paper can be found [here](https://arxiv.org/abs/2010.02502).
|
||||
|
||||
[[autodoc]] DDPMScheduler
|
||||
|
||||
#### Multistep DPM-Solver
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [improved version](https://arxiv.org/abs/2211.01095). The original implementation can be found [here](https://github.com/LuChengTHU/dpm-solver).
|
||||
|
||||
[[autodoc]] DPMSolverMultistepScheduler
|
||||
|
||||
#### Variance exploding, stochastic sampling from Karras et. al
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2006.11239).
|
||||
@@ -112,3 +118,34 @@ Score SDE-VP is under construction.
|
||||
</Tip>
|
||||
|
||||
[[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler
|
||||
|
||||
#### Euler scheduler
|
||||
|
||||
Euler scheduler (Algorithm 2) from the paper [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) by Karras et al. (2022). Based on the original [k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51) implementation by Katherine Crowson.
|
||||
Fast scheduler which often times generates good outputs with 20-30 steps.
|
||||
|
||||
[[autodoc]] EulerDiscreteScheduler
|
||||
|
||||
|
||||
#### Euler Ancestral scheduler
|
||||
|
||||
Ancestral sampling with Euler method steps. Based on the original (k-diffusion)[https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72] implementation by Katherine Crowson.
|
||||
Fast scheduler which often times generates good outputs with 20-30 steps.
|
||||
|
||||
[[autodoc]] EulerAncestralDiscreteScheduler
|
||||
|
||||
|
||||
#### VQDiffusionScheduler
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2111.14822)
|
||||
|
||||
[[autodoc]] VQDiffusionScheduler
|
||||
|
||||
#### RePaint scheduler
|
||||
|
||||
DDPM-based inpainting scheduler for unsupervised inpainting with extreme masks.
|
||||
Intended for use with [`RePaintPipeline`].
|
||||
Based on the paper [RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2201.09865)
|
||||
and the original implementation by Andreas Lugmayr et al.: https://github.com/andreas128/RePaint
|
||||
|
||||
[[autodoc]] RePaintScheduler
|
||||
|
||||
BIN
docs/source/imgs/access_request.png
Normal file
BIN
docs/source/imgs/access_request.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 102 KiB |
@@ -34,6 +34,8 @@ available a colab notebook to directly try them out.
|
||||
|
||||
| Pipeline | Paper | Tasks | Colab
|
||||
|---|---|:---:|:---:|
|
||||
| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation |
|
||||
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
|
||||
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
|
||||
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
|
||||
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
|
||||
@@ -45,5 +47,6 @@ available a colab notebook to directly try them out.
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
||||
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
|
||||
|
||||
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
|
||||
|
||||
@@ -12,9 +12,12 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Installation
|
||||
|
||||
Install Diffusers for with PyTorch. Support for other libraries will come in the future
|
||||
Install 🤗 Diffusers for whichever deep learning library you’re working with.
|
||||
|
||||
🤗 Diffusers is tested on Python 3.7+, and PyTorch 1.7.0+.
|
||||
🤗 Diffusers is tested on Python 3.7+, PyTorch 1.7.0+ and flax. Follow the installation instructions below for the deep learning library you are using:
|
||||
|
||||
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
|
||||
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.
|
||||
|
||||
## Install with pip
|
||||
|
||||
@@ -36,12 +39,30 @@ source .env/bin/activate
|
||||
|
||||
Now you're ready to install 🤗 Diffusers with the following command:
|
||||
|
||||
**For PyTorch**
|
||||
|
||||
```bash
|
||||
pip install diffusers
|
||||
pip install diffusers["torch"]
|
||||
```
|
||||
|
||||
**For Flax**
|
||||
|
||||
```bash
|
||||
pip install diffusers["flax"]
|
||||
```
|
||||
|
||||
## Install from source
|
||||
|
||||
Before intsalling `diffusers` from source, make sure you have `torch` and `accelerate` installed.
|
||||
|
||||
For `torch` installation refer to the `torch` [docs](https://pytorch.org/get-started/locally/#start-locally).
|
||||
|
||||
To install `accelerate`
|
||||
|
||||
```bash
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
Install 🤗 Diffusers from source with the following command:
|
||||
|
||||
```bash
|
||||
@@ -67,7 +88,18 @@ Clone the repository and install 🤗 Diffusers with the following commands:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
**For PyTorch**
|
||||
|
||||
```
|
||||
pip install -e ".[torch]"
|
||||
```
|
||||
|
||||
**For Flax**
|
||||
|
||||
```
|
||||
pip install -e ".[flax]"
|
||||
```
|
||||
|
||||
These commands will link the folder you cloned the repository to and your Python library paths.
|
||||
|
||||
@@ -22,6 +22,7 @@ We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for
|
||||
| fp16 | 3.61s | x2.63 |
|
||||
| channels last | 3.30s | x2.88 |
|
||||
| traced UNet | 3.21s | x2.96 |
|
||||
| memory efficient attention | 2.63s | x3.61 |
|
||||
|
||||
<em>
|
||||
obtained on NVIDIA TITAN RTX by generating a single image of size 512x512 from
|
||||
@@ -290,3 +291,41 @@ pipe.unet = TracedUNet()
|
||||
with torch.inference_mode():
|
||||
image = pipe([prompt] * 1, num_inference_steps=50).images[0]
|
||||
```
|
||||
|
||||
|
||||
## Memory Efficient Attention
|
||||
Recent work on optimizing the bandwitdh in the attention block have generated huge speed ups and gains in GPU memory usage. The most recent being Flash Attention (from @tridao, [code](https://github.com/HazyResearch/flash-attention), [paper](https://arxiv.org/pdf/2205.14135.pdf)) .
|
||||
Here are the speedups we obtain on a few Nvidia GPUs when running the inference at 512x512 with a batch size of 1 (one prompt):
|
||||
|
||||
| GPU | Base Attention FP16 | Memory Efficient Attention FP16 |
|
||||
|------------------ |--------------------- |--------------------------------- |
|
||||
| NVIDIA Tesla T4 | 3.5it/s | 5.5it/s |
|
||||
| NVIDIA 3060 RTX | 4.6it/s | 7.8it/s |
|
||||
| NVIDIA A10G | 8.88it/s | 15.6it/s |
|
||||
| NVIDIA RTX A6000 | 11.7it/s | 21.09it/s |
|
||||
| NVIDIA TITAN RTX | 12.51it/s | 18.22it/s |
|
||||
| A100-SXM4-40GB | 18.6it/s | 29.it/s |
|
||||
| A100-SXM-80GB | 18.7it/s | 29.5it/s |
|
||||
|
||||
To leverage it just make sure you have:
|
||||
- PyTorch > 1.12
|
||||
- Cuda available
|
||||
- Installed the [xformers](https://github.com/facebookresearch/xformers) library
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
with torch.inference_mode():
|
||||
sample = pipe("a small cat")
|
||||
|
||||
# optional: You can disable it via
|
||||
# pipe.disable_xformers_memory_efficient_attention()
|
||||
```
|
||||
@@ -19,11 +19,8 @@ specific language governing permissions and limitations under the License.
|
||||
- Mac computer with Apple silicon (M1/M2) hardware.
|
||||
- macOS 12.6 or later (13.0 or later recommended).
|
||||
- arm64 version of Python.
|
||||
- PyTorch 1.13.0 RC (Release Candidate). You can install it with `pip` using:
|
||||
- PyTorch 1.13. You can install it with `pip` or `conda` using the instructions in https://pytorch.org/get-started/locally/.
|
||||
|
||||
```
|
||||
pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/test/cpu
|
||||
```
|
||||
|
||||
## Inference Pipeline
|
||||
|
||||
@@ -63,4 +60,4 @@ pipeline.enable_attention_slicing()
|
||||
## Known Issues
|
||||
|
||||
- As mentioned above, we are investigating a strange [first-time inference issue](https://github.com/huggingface/diffusers/issues/372).
|
||||
- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this is related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039). For now, we recommend to iterate instead of batching.
|
||||
- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this is related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039). This is being resolved, but for now we recommend to iterate instead of batching.
|
||||
|
||||
@@ -23,7 +23,7 @@ The [Dreambooth training script](https://github.com/huggingface/diffusers/tree/m
|
||||
|
||||
<!-- TODO: replace with our blog when it's done -->
|
||||
|
||||
Dreambooth fine-tuning is very sensitive to hyperparameters and easy to overfit. We recommend you take a look at our [in-depth analysis](https://wandb.ai/psuraj/dreambooth/reports/Dreambooth-Training-Analysis--VmlldzoyNzk0NDc3) with recommended settings for different subjects, and go from there.
|
||||
Dreambooth fine-tuning is very sensitive to hyperparameters and easy to overfit. We recommend you take a look at our [in-depth analysis](https://huggingface.co/blog/dreambooth) with recommended settings for different subjects, and go from there.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -148,7 +148,7 @@ accelerate launch train_dreambooth.py \
|
||||
|
||||
### Fine-tune the text encoder in addition to the UNet
|
||||
|
||||
The script also allows to fine-tune the `text_encoder` along with the `unet`. It has been observed experimentally that this gives much better results, especially on faces. Please, refer to [our report](https://wandb.ai/psuraj/dreambooth/reports/Dreambooth-Training-Analysis--VmlldzoyNzk0NDc3) for more details.
|
||||
The script also allows to fine-tune the `text_encoder` along with the `unet`. It has been observed experimentally that this gives much better results, especially on faces. Please, refer to [our blog](https://huggingface.co/blog/dreambooth) for more details.
|
||||
|
||||
To enable this option, pass the `--train_text_encoder` argument to the training script.
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Stable Diffusion text-to-image fine-tuning
|
||||
|
||||
The [`train_text_to_image.py`](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion) script shows how to fine-tune the stable diffusion model on your own dataset.
|
||||
The [`train_text_to_image.py`](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) script shows how to fine-tune the stable diffusion model on your own dataset.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
|
||||
@@ -128,7 +128,7 @@ pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeli
|
||||
pipe()
|
||||
```
|
||||
|
||||
Another way to upload your custom_pipeline, besides sending a PR, is uploading the code that contains it to the Hugging Face Hub, [as exemplified here](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipelines#loading-custom-pipelines-from-the-hub).
|
||||
Another way to upload your custom_pipeline, besides sending a PR, is uploading the code that contains it to the Hugging Face Hub, [as exemplified here](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview#loading-custom-pipelines-from-the-hub).
|
||||
|
||||
**Try it out now - it works!**
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Loading and Saving Custom Pipelines
|
||||
# Loading and Adding Custom Pipelines
|
||||
|
||||
Diffusers allows you to conveniently load any custom pipeline from the Hugging Face Hub as well as any [official community pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community)
|
||||
via the [`DiffusionPipeline`] class.
|
||||
|
||||
@@ -33,7 +33,7 @@ url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/st
|
||||
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
init_image = init_image.resize((768, 512))
|
||||
init_image.thumbnail((768, 768))
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
|
||||
@@ -12,7 +12,374 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Loading
|
||||
|
||||
The core functionality for saving and loading systems in `Diffusers` is the HuggingFace Hub.
|
||||
A core premise of the diffusers library is to make diffusion models **as accessible as possible**.
|
||||
Accessibility is therefore achieved by providing an API to load complete diffusion pipelines as well as individual components with a single line of code.
|
||||
|
||||
In the following we explain in-detail how to easily load:
|
||||
|
||||
- *Complete Diffusion Pipelines* via the [`DiffusionPipeline.from_pretrained`]
|
||||
- *Diffusion Models* via [`ModelMixin.from_pretrained`]
|
||||
- *Schedulers* via [`ConfigMixin.from_config`]
|
||||
|
||||
## Loading pipelines
|
||||
|
||||
The [`DiffusionPipeline`] class is the easiest way to access any diffusion model that is [available on the Hub](https://huggingface.co/models?library=diffusers). Let's look at an example on how to download [CompVis' Latent Diffusion model](https://huggingface.co/CompVis/ldm-text2im-large-256).
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
repo_id = "CompVis/ldm-text2im-large-256"
|
||||
ldm = DiffusionPipeline.from_pretrained(repo_id)
|
||||
```
|
||||
|
||||
Here [`DiffusionPipeline`] automatically detects the correct pipeline (*i.e.* [`LDMTextToImagePipeline`]), downloads and caches all required configuration and weight files (if not already done so), and finally returns a pipeline instance, called `ldm`.
|
||||
The pipeline instance can then be called using [`LDMTextToImagePipeline.__call__`] (i.e., `ldm("image of a astronaut riding a horse")`) for text-to-image generation.
|
||||
|
||||
Instead of using the generic [`DiffusionPipeline`] class for loading, you can also load the appropriate pipeline class directly. The code snippet above yields the same instance as when doing:
|
||||
|
||||
```python
|
||||
from diffusers import LDMTextToImagePipeline
|
||||
|
||||
repo_id = "CompVis/ldm-text2im-large-256"
|
||||
ldm = LDMTextToImagePipeline.from_pretrained(repo_id)
|
||||
```
|
||||
|
||||
Diffusion pipelines like `LDMTextToImagePipeline` often consist of multiple components. These components can be both parameterized models, such as `"unet"`, `"vqvae"` and "bert", tokenizers or schedulers. These components can interact in complex ways with each other when using the pipeline in inference, *e.g.* for [`LDMTextToImagePipeline`] or [`StableDiffusionPipeline`] the inference call is explained [here](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work).
|
||||
The purpose of the [pipeline classes](./api/overview#diffusers-summary) is to wrap the complexity of these diffusion systems and give the user an easy-to-use API while staying flexible for customization, as will be shown later.
|
||||
|
||||
### Loading pipelines that require access request
|
||||
|
||||
Due to the capabilities of diffusion models to generate extremely realistic images, there is a certain danger that such models might be misused for unwanted applications, *e.g.* generating pornography or violent images.
|
||||
In order to minimize the possibility of such unsolicited use cases, some of the most powerful diffusion models require users to acknowledge a license before being able to use the model. If the user does not agree to the license, the pipeline cannot be downloaded.
|
||||
If you try to load [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) the same way as done previously:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
repo_id = "runwayml/stable-diffusion-v1-5"
|
||||
stable_diffusion = DiffusionPipeline.from_pretrained(repo_id)
|
||||
```
|
||||
|
||||
it will only work if you have both *click-accepted* the license on [the model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) and are logged into the Hugging Face Hub. Otherwise you will get an error message
|
||||
such as the following:
|
||||
|
||||
```
|
||||
OSError: runwayml/stable-diffusion-v1-5 is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
|
||||
If this is a private repository, make sure to pass a token having permission to this repo with `use_auth_token` or log in with `huggingface-cli login`
|
||||
```
|
||||
|
||||
Therefore, we need to make sure to *click-accept* the license. You can do this by simply visiting
|
||||
the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) and clicking on "Agree and access repository":
|
||||
|
||||
<p align="center">
|
||||
<br>
|
||||
<img src="https://raw.githubusercontent.com/huggingface/diffusers/main/docs/source/imgs/access_request.png" width="400"/>
|
||||
<br>
|
||||
</p>
|
||||
|
||||
Second, you need to login with your access token:
|
||||
|
||||
```
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
before trying to load the model. Or alternatively, you can pass [your access token](https://huggingface.co/docs/hub/security-tokens#user-access-tokens) directly via the flag `use_auth_token`. In this case you do **not** need
|
||||
to run `huggingface-cli login` before:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
repo_id = "runwayml/stable-diffusion-v1-5"
|
||||
stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, use_auth_token="<your-access-token>")
|
||||
```
|
||||
|
||||
The final option to use pipelines that require access without having to rely on the Hugging Face Hub is to load the pipeline locally as explained in the next section.
|
||||
|
||||
### Loading pipelines locally
|
||||
|
||||
If you prefer to have complete control over the pipeline and its corresponding files or, as said before, if you want to use pipelines that require an access request without having to be connected to the Hugging Face Hub,
|
||||
we recommend loading pipelines locally.
|
||||
|
||||
To load a diffusion pipeline locally, you first need to manually download the whole folder structure on your local disk and then pass a local path to the [`DiffusionPipeline.from_pretrained`]. Let's again look at an example for
|
||||
[CompVis' Latent Diffusion model](https://huggingface.co/CompVis/ldm-text2im-large-256).
|
||||
|
||||
First, you should make use of [`git-lfs`](https://git-lfs.github.com/) to download the whole folder structure that has been uploaded to the [model repository](https://huggingface.co/CompVis/ldm-text2im-large-256/tree/main):
|
||||
|
||||
```
|
||||
git lfs install
|
||||
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
|
||||
```
|
||||
|
||||
The command above will create a local folder called `./stable-diffusion-v1-5` on your disk.
|
||||
Now, all you have to do is to simply pass the local folder path to `from_pretrained`:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
repo_id = "./stable-diffusion-v1-5"
|
||||
stable_diffusion = DiffusionPipeline.from_pretrained(repo_id)
|
||||
```
|
||||
|
||||
If `repo_id` is a local path, as it is the case here, [`DiffusionPipeline.from_pretrained`] will automatically detect it and therefore not try to download any files from the Hub.
|
||||
While we usually recommend to load weights directly from the Hub to be certain to stay up to date with the newest changes, loading pipelines locally should be preferred if one
|
||||
wants to stay anonymous, self-contained applications, etc...
|
||||
|
||||
### Loading customized pipelines
|
||||
|
||||
Advanced users that want to load customized versions of diffusion pipelines can do so by swapping any of the default components, *e.g.* the scheduler, with other scheduler classes.
|
||||
A classical use case of this functionality is to swap the scheduler. [Stable Diffusion v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) uses the [`PNDMScheduler`] by default which is generally not the most performant scheduler. Since the release
|
||||
of stable diffusion, multiple improved schedulers have been published. To use those, the user has to manually load their preferred scheduler and pass it into [`DiffusionPipeline.from_pretrained`].
|
||||
|
||||
*E.g.* to use [`EulerDiscreteScheduler`] or [`DPMSolverMultistepScheduler`] to have a better quality vs. generation speed trade-off for inference, one could load them as follows:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultistepScheduler
|
||||
|
||||
repo_id = "runwayml/stable-diffusion-v1-5"
|
||||
|
||||
scheduler = EulerDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
# or
|
||||
# scheduler = DPMSolverMultistepScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
|
||||
stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, scheduler=scheduler)
|
||||
```
|
||||
|
||||
Three things are worth paying attention to here.
|
||||
- First, the scheduler is loaded with [`ConfigMixin.from_config`] since it only depends on a configuration file and not any parameterized weights
|
||||
- Second, the scheduler is loaded with a function argument, called `subfolder="scheduler"` as the configuration of stable diffusion's scheduling is defined in a [subfolder of the official pipeline repository](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/scheduler)
|
||||
- Third, the scheduler instance can simply be passed with the `scheduler` keyword argument to [`DiffusionPipeline.from_pretrained`]. This works because the [`StableDiffusionPipeline`] defines its scheduler with the `scheduler` attribute. It's not possible to use a different name, such as `sampler=scheduler` since `sampler` is not a defined keyword for [`StableDiffusionPipeline.__init__`]
|
||||
|
||||
Not only the scheduler components can be customized for diffusion pipelines; in theory, all components of a pipeline can be customized. In practice, however, it often only makes sense to switch out a component that has **compatible** alternatives to what the pipeline expects.
|
||||
Many scheduler classes are compatible with each other as can be seen [here](https://github.com/huggingface/diffusers/blob/0dd8c6b4dbab4069de9ed1cafb53cbd495873879/src/diffusers/schedulers/scheduling_ddim.py#L112). This is not always the case for other components, such as the `"unet"`.
|
||||
|
||||
One special case that can also be customized is the `"safety_checker"` of stable diffusion. If you believe the safety checker doesn't serve you any good, you can simply disable it by passing `None`:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultistepScheduler
|
||||
|
||||
stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, safety_checker=None)
|
||||
```
|
||||
|
||||
Another common use case is to reuse the same components in multiple pipelines, *e.g.* the weights and configurations of [`"runwayml/stable-diffusion-v1-5"`](https://huggingface.co/runwayml/stable-diffusion-v1-5) can be used for both [`StableDiffusionPipeline`] and [`StableDiffusionImg2ImgPipeline`] and we might not want to
|
||||
use the exact same weights into RAM twice. In this case, customizing all the input instances would help us
|
||||
to only load the weights into RAM once:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
|
||||
|
||||
model_id = "runwayml/stable-diffusion-v1-5"
|
||||
stable_diffusion_txt2img = StableDiffusionPipeline.from_pretrained(model_id)
|
||||
|
||||
components = stable_diffusion_txt2img.components
|
||||
|
||||
# weights are not reloaded into RAM
|
||||
stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(**components)
|
||||
```
|
||||
|
||||
Note how the above code snippet makes use of [`DiffusionPipeline.components`].
|
||||
|
||||
### How does loading work?
|
||||
|
||||
As a class method, [`DiffusionPipeline.from_pretrained`] is responsible for two things:
|
||||
- Download the latest version of the folder structure required to run the `repo_id` with `diffusers` and cache them. If the latest folder structure is available in the local cache, [`DiffusionPipeline.from_pretrained`] will simply reuse the cache and **not** re-download the files.
|
||||
- Load the cached weights into the _correct_ pipeline class – one of the [officially supported pipeline classes](./api/overview#diffusers-summary) - and return an instance of the class. The _correct_ pipeline class is thereby retrieved from the `model_index.json` file.
|
||||
|
||||
The underlying folder structure of diffusion pipelines correspond 1-to-1 to their corresponding class instances, *e.g.* [`LDMTextToImagePipeline`] for [`CompVis/ldm-text2im-large-256`](https://huggingface.co/CompVis/ldm-text2im-large-256)
|
||||
This can be understood better by looking at an example. Let's print out pipeline class instance `pipeline` we just defined:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
repo_id = "CompVis/ldm-text2im-large-256"
|
||||
ldm = DiffusionPipeline.from_pretrained(repo_id)
|
||||
print(ldm)
|
||||
```
|
||||
|
||||
*Output*:
|
||||
```
|
||||
LDMTextToImagePipeline {
|
||||
"bert": [
|
||||
"latent_diffusion",
|
||||
"LDMBertModel"
|
||||
],
|
||||
"scheduler": [
|
||||
"diffusers",
|
||||
"DDIMScheduler"
|
||||
],
|
||||
"tokenizer": [
|
||||
"transformers",
|
||||
"BertTokenizer"
|
||||
],
|
||||
"unet": [
|
||||
"diffusers",
|
||||
"UNet2DConditionModel"
|
||||
],
|
||||
"vqvae": [
|
||||
"diffusers",
|
||||
"AutoencoderKL"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
First, we see that the official pipeline is the [`LDMTextToImagePipeline`], and second we see that the `LDMTextToImagePipeline` consists of 5 components:
|
||||
- `"bert"` of class `LDMBertModel` as defined [in the pipeline](https://github.com/huggingface/diffusers/blob/cd502b25cf0debac6f98d27a6638ef95208d1ea2/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py#L664)
|
||||
- `"scheduler"` of class [`DDIMScheduler`]
|
||||
- `"tokenizer"` of class `BertTokenizer` as defined [in `transformers`](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer)
|
||||
- `"unet"` of class [`UNet2DConditionModel`]
|
||||
- `"vqvae"` of class [`AutoencoderKL`]
|
||||
|
||||
Let's now compare the pipeline instance to the folder structure of the model repository `CompVis/ldm-text2im-large-256`. Looking at the folder structure of [`CompVis/ldm-text2im-large-256`](https://huggingface.co/CompVis/ldm-text2im-large-256/tree/main) on the Hub, we can see it matches 1-to-1 the printed out instance of `LDMTextToImagePipeline` above:
|
||||
|
||||
```
|
||||
.
|
||||
├── bert
|
||||
│ ├── config.json
|
||||
│ └── pytorch_model.bin
|
||||
├── model_index.json
|
||||
├── scheduler
|
||||
│ └── scheduler_config.json
|
||||
├── tokenizer
|
||||
│ ├── special_tokens_map.json
|
||||
│ ├── tokenizer_config.json
|
||||
│ └── vocab.txt
|
||||
├── unet
|
||||
│ ├── config.json
|
||||
│ └── diffusion_pytorch_model.bin
|
||||
└── vqvae
|
||||
├── config.json
|
||||
└── diffusion_pytorch_model.bin
|
||||
```
|
||||
|
||||
As we can see each attribute of the instance of `LDMTextToImagePipeline` has its configuration and possibly weights defined in a subfolder that is called **exactly** like the class attribute (`"bert"`, `"scheduler"`, `"tokenizer"`, `"unet"`, `"vqvae"`). Importantly, every pipeline expects a `model_index.json` file that tells the `DiffusionPipeline` both:
|
||||
- which pipeline class should be loaded, and
|
||||
- what sub-classes from which library are stored in which subfolders
|
||||
|
||||
In the case of `CompVis/ldm-text2im-large-256` the `model_index.json` is therefore defined as follows:
|
||||
|
||||
```
|
||||
{
|
||||
"_class_name": "LDMTextToImagePipeline",
|
||||
"_diffusers_version": "0.0.4",
|
||||
"bert": [
|
||||
"latent_diffusion",
|
||||
"LDMBertModel"
|
||||
],
|
||||
"scheduler": [
|
||||
"diffusers",
|
||||
"DDIMScheduler"
|
||||
],
|
||||
"tokenizer": [
|
||||
"transformers",
|
||||
"BertTokenizer"
|
||||
],
|
||||
"unet": [
|
||||
"diffusers",
|
||||
"UNet2DConditionModel"
|
||||
],
|
||||
"vqvae": [
|
||||
"diffusers",
|
||||
"AutoencoderKL"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
- `_class_name` tells `DiffusionPipeline` which pipeline class should be loaded.
|
||||
- `_diffusers_version` can be useful to know under which `diffusers` version this model was created.
|
||||
- Every component of the pipeline is then defined under the form:
|
||||
```
|
||||
"name" : [
|
||||
"library",
|
||||
"class"
|
||||
]
|
||||
```
|
||||
- The `"name"` field corresponds both to the name of the subfolder in which the configuration and weights are stored as well as the attribute name of the pipeline class (as can be seen [here](https://huggingface.co/CompVis/ldm-text2im-large-256/tree/main/bert) and [here](https://github.com/huggingface/diffusers/blob/cd502b25cf0debac6f98d27a6638ef95208d1ea2/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py#L42)
|
||||
- The `"library"` field corresponds to the name of the library, *e.g.* `diffusers` or `transformers` from which the `"class"` should be loaded
|
||||
- The `"class"` field corresponds to the name of the class, *e.g.* [`BertTokenizer`](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer) or [`UNet2DConditionModel`]
|
||||
|
||||
|
||||
## Loading models
|
||||
|
||||
Models as defined under [src/diffusers/models](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) can be loaded via the [`ModelMixin.from_pretrained`] function. The API is very similar the [`DiffusionPipeline.from_pretrained`] and works in the same way:
|
||||
- Download the latest version of the model weights and configuration with `diffusers` and cache them. If the latest files are available in the local cache, [`ModelMixin.from_pretrained`] will simply reuse the cache and **not** re-download the files.
|
||||
- Load the cached weights into the _defined_ model class - one of [the existing model classes](./api/models) - and return an instance of the class.
|
||||
|
||||
In constrast to [`DiffusionPipeline.from_pretrained`], models rely on fewer files that usually don't require a folder structure, but just a `diffusion_pytorch_model.bin` and `config.json` file.
|
||||
|
||||
Let's look at an example:
|
||||
|
||||
```python
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
repo_id = "CompVis/ldm-text2im-large-256"
|
||||
model = UNet2DConditionModel.from_pretrained(repo_id, subfolder="unet")
|
||||
```
|
||||
|
||||
Note how we have to define the `subfolder="unet"` argument to tell [`ModelMixin.from_pretrained`] that the model weights are located in a [subfolder of the repository](https://huggingface.co/CompVis/ldm-text2im-large-256/tree/main/unet).
|
||||
|
||||
As explained in [Loading customized pipelines]("./using-diffusers/loading#loading-customized-pipelines"), one can pass a loaded model to a diffusion pipeline, via [`DiffusionPipeline.from_pretrained`]:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
repo_id = "CompVis/ldm-text2im-large-256"
|
||||
ldm = DiffusionPipeline.from_pretrained(repo_id, unet=model)
|
||||
```
|
||||
|
||||
If the model files can be found directly at the root level, which is usually only the case for some very simple diffusion models, such as [`google/ddpm-cifar10-32`](https://huggingface.co/google/ddpm-cifar10-32), we don't
|
||||
need to pass a `subfolder` argument:
|
||||
|
||||
```python
|
||||
from diffusers import UNet2DModel
|
||||
|
||||
repo_id = "google/ddpm-cifar10-32"
|
||||
model = UNet2DModel.from_pretrained(repo_id)
|
||||
```
|
||||
|
||||
## Loading schedulers
|
||||
|
||||
Schedulers cannot be loaded via a `from_pretrained` method, but instead rely on [`ConfigMixin.from_config`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file.
|
||||
Therefore the loading method was given a different name here.
|
||||
|
||||
In constrast to pipelines or models, loading schedulers does not consume any significant amount of memory and the same configuration file can often be used for a variety of different schedulers.
|
||||
For example, all of:
|
||||
|
||||
- [`DDPMScheduler`]
|
||||
- [`DDIMScheduler`]
|
||||
- [`PNDMScheduler`]
|
||||
- [`LMSDiscreteScheduler`]
|
||||
- [`EulerDiscreteScheduler`]
|
||||
- [`EulerAncestralDiscreteScheduler`]
|
||||
- [`DPMSolverMultistepScheduler`]
|
||||
|
||||
are compatible with [`StableDiffusionPipeline`] and therefore the same scheduler configuration file can be loaded in any of those classes:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import (
|
||||
DDPMScheduler,
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
)
|
||||
|
||||
repo_id = "runwayml/stable-diffusion-v1-5"
|
||||
|
||||
ddpm = DDPMScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
ddim = DDIMScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
pndm = PNDMScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
lms = LMSDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
euler_anc = EulerAncestralDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
euler = EulerDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
dpm = DPMSolverMultistepScheduler.from_config(repo_id, subfolder="scheduler")
|
||||
|
||||
# replace `dpm` with any of `ddpm`, `ddim`, `pndm`, `lms`, `euler`, `euler_anc`
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm)
|
||||
```
|
||||
|
||||
## API
|
||||
|
||||
[[autodoc]] modeling_utils.ModelMixin
|
||||
- from_pretrained
|
||||
@@ -29,6 +396,3 @@ The core functionality for saving and loading systems in `Diffusers` is the Hugg
|
||||
[[autodoc]] pipeline_flax_utils.FlaxDiffusionPipeline
|
||||
- from_pretrained
|
||||
- save_pretrained
|
||||
|
||||
|
||||
Under further construction 🚧, open a [PR](https://github.com/huggingface/diffusers/compare) if you want to contribute!
|
||||
|
||||
@@ -38,11 +38,11 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie
|
||||
|
||||
| Task | 🤗 Accelerate | 🤗 Datasets | Colab
|
||||
|---|---|:---:|:---:|
|
||||
| [**Unconditional Image Generation**](./unconditional_training) | ✅ | ✅ | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [**Text-to-Image fine-tuning**](./text2image) | ✅ | ✅ |
|
||||
| [**Textual Inversion**](./text_inversion) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
|
||||
| [**Unconditional Image Generation**](./unconditional_image_generation) | ✅ | ✅ | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [**Text-to-Image fine-tuning**](./text_to_image) | ✅ | ✅ |
|
||||
| [**Textual Inversion**](./textual_inversion) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb)
|
||||
| [**Dreambooth**](./dreambooth) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)
|
||||
|
||||
| [**Reinforcement Learning for Control**](https://github.com/huggingface/diffusers/blob/main/examples/rl/run_diffusers_locomotion.py) | - | - | coming soon.
|
||||
|
||||
## Community
|
||||
|
||||
|
||||
@@ -17,6 +17,9 @@ If a community doesn't work as expected, please open an issue and ping the autho
|
||||
| Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | - | [Shyam Sudhakaran](https://github.com/shyamsn97) |
|
||||
| Composable Stable Diffusion| Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Seed Resizing Stable Diffusion| Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image| [Imagic Stable Diffusion](#imagic-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Multilingual Stable Diffusion| Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | - | [Juan Carlos Piñeros](https://github.com/juancopi81) |
|
||||
| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting| [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) |
|
||||
|
||||
|
||||
|
||||
@@ -176,9 +179,20 @@ images = pipe.inpaint(prompt=prompt, init_image=init_image, mask_image=mask_imag
|
||||
As shown above this one pipeline can run all both "text-to-image", "image-to-image", and "inpainting" in one pipeline.
|
||||
|
||||
### Long Prompt Weighting Stable Diffusion
|
||||
Features of this custom pipeline:
|
||||
- Input a prompt without the 77 token length limit.
|
||||
- Includes tx2img, img2img. and inpainting pipelines.
|
||||
- Emphasize/weigh part of your prompt with parentheses as so: `a baby deer with (big eyes)`
|
||||
- De-emphasize part of your prompt as so: `a [baby] deer with big eyes`
|
||||
- Precisely weigh part of your prompt as so: `a baby deer with (big eyes:1.3)`
|
||||
|
||||
The Pipeline lets you input prompt without 77 token length limit. And you can increase words weighting by using "()" or decrease words weighting by using "[]"
|
||||
The Pipeline also lets you use the main use cases of the stable diffusion pipeline in a single class.
|
||||
Prompt weighting equivalents:
|
||||
- `a baby deer with` == `(a baby deer with:1.0)`
|
||||
- `(big eyes)` == `(big eyes:1.1)`
|
||||
- `((big eyes))` == `(big eyes:1.21)`
|
||||
- `[big eyes]` == `(big eyes:0.91)`
|
||||
|
||||
You can run this custom pipeline as so:
|
||||
|
||||
#### pytorch
|
||||
|
||||
@@ -373,6 +387,49 @@ for i in range(4):
|
||||
for i, img in enumerate(images):
|
||||
img.save(f"./composable_diffusion/image_{i}.png")
|
||||
```
|
||||
|
||||
### Imagic Stable Diffusion
|
||||
Allows you to edit an image using stable diffusion.
|
||||
|
||||
```python
|
||||
import requests
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, DDIMScheduler
|
||||
has_cuda = torch.cuda.is_available()
|
||||
device = torch.device('cpu' if not has_cuda else 'cuda')
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
safety_checker=None,
|
||||
use_auth_token=True,
|
||||
custom_pipeline="imagic_stable_diffusion",
|
||||
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
||||
).to(device)
|
||||
generator = th.Generator("cuda").manual_seed(0)
|
||||
seed = 0
|
||||
prompt = "A photo of Barack Obama smiling with a big grin"
|
||||
url = 'https://www.dropbox.com/s/6tlwzr73jd1r9yk/obama.png?dl=1'
|
||||
response = requests.get(url)
|
||||
init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
init_image = init_image.resize((512, 512))
|
||||
res = pipe.train(
|
||||
prompt,
|
||||
init_image,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=50,
|
||||
generator=generator)
|
||||
res = pipe(alpha=1)
|
||||
image = res.images[0]
|
||||
image.save('./imagic/imagic_image_alpha_1.png')
|
||||
res = pipe(alpha=1.5)
|
||||
image = res.images[0]
|
||||
image.save('./imagic/imagic_image_alpha_1_5.png')
|
||||
res = pipe(alpha=2)
|
||||
image = res.images[0]
|
||||
image.save('./imagic/imagic_image_alpha_2.png')
|
||||
```
|
||||
|
||||
### Seed Resizing
|
||||
Test seed resizing. Originally generate an image in 512 by 512, then generate image with same seed at 512 by 592 using seed resizing. Finally, generate 512 by 592 using original stable diffusion pipeline.
|
||||
|
||||
@@ -456,4 +513,106 @@ res = pipe_compare(
|
||||
|
||||
image = res.images[0]
|
||||
image.save('./seed_resize/seed_resize_{w}_{h}_image_compare.png'.format(w=width, h=height))
|
||||
```
|
||||
```
|
||||
|
||||
### Multilingual Stable Diffusion Pipeline
|
||||
|
||||
The following code can generate an images from texts in different languages using the pre-trained [mBART-50 many-to-one multilingual machine translation model](https://huggingface.co/facebook/mbart-large-50-many-to-one-mmt) and Stable Diffusion.
|
||||
|
||||
```python
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from transformers import (
|
||||
pipeline,
|
||||
MBart50TokenizerFast,
|
||||
MBartForConditionalGeneration,
|
||||
)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
device_dict = {"cuda": 0, "cpu": -1}
|
||||
|
||||
# helper function taken from: https://huggingface.co/blog/stable_diffusion
|
||||
def image_grid(imgs, rows, cols):
|
||||
assert len(imgs) == rows*cols
|
||||
|
||||
w, h = imgs[0].size
|
||||
grid = Image.new('RGB', size=(cols*w, rows*h))
|
||||
grid_w, grid_h = grid.size
|
||||
|
||||
for i, img in enumerate(imgs):
|
||||
grid.paste(img, box=(i%cols*w, i//cols*h))
|
||||
return grid
|
||||
|
||||
# Add language detection pipeline
|
||||
language_detection_model_ckpt = "papluca/xlm-roberta-base-language-detection"
|
||||
language_detection_pipeline = pipeline("text-classification",
|
||||
model=language_detection_model_ckpt,
|
||||
device=device_dict[device])
|
||||
|
||||
# Add model for language translation
|
||||
trans_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")
|
||||
trans_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-one-mmt").to(device)
|
||||
|
||||
diffuser_pipeline = DiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
custom_pipeline="multilingual_stable_diffusion",
|
||||
detection_pipeline=language_detection_pipeline,
|
||||
translation_model=trans_model,
|
||||
translation_tokenizer=trans_tokenizer,
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
diffuser_pipeline.enable_attention_slicing()
|
||||
diffuser_pipeline = diffuser_pipeline.to(device)
|
||||
|
||||
prompt = ["a photograph of an astronaut riding a horse",
|
||||
"Una casa en la playa",
|
||||
"Ein Hund, der Orange isst",
|
||||
"Un restaurant parisien"]
|
||||
|
||||
output = diffuser_pipeline(prompt)
|
||||
|
||||
images = output.images
|
||||
|
||||
grid = image_grid(images, rows=2, cols=2)
|
||||
```
|
||||
|
||||
This example produces the following images:
|
||||

|
||||
|
||||
### Image to Image Inpainting Stable Diffusion
|
||||
|
||||
Similar to the standard stable diffusion inpainting example, except with the addition of an `inner_image` argument.
|
||||
|
||||
`image`, `inner_image`, and `mask` should have the same dimensions. `inner_image` should have an alpha (transparency) channel.
|
||||
|
||||
The aim is to overlay two images, then mask out the boundary between `image` and `inner_image` to allow stable diffusion to make the connection more seamless.
|
||||
For example, this could be used to place a logo on a shirt and make it blend seamlessly.
|
||||
|
||||
```python
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from diffusers import StableDiffusionInpaintPipeline
|
||||
|
||||
image_path = "./path-to-image.png"
|
||||
inner_image_path = "./path-to-inner-image.png"
|
||||
mask_path = "./path-to-mask.png"
|
||||
|
||||
init_image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
|
||||
inner_image = PIL.Image.open(inner_image_path).convert("RGBA").resize((512, 512))
|
||||
mask_image = PIL.Image.open(mask_path).convert("RGB").resize((512, 512))
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
revision="fp16",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "Your prompt here!"
|
||||
image = pipe(prompt=prompt, image=init_image, inner_image=inner_image, mask_image=mask_image).images[0]
|
||||
```
|
||||
|
||||
@@ -5,7 +5,14 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from diffusers import AutoencoderKL, DiffusionPipeline, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
DiffusionPipeline,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
|
||||
@@ -56,7 +63,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
|
||||
clip_model: CLIPModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler],
|
||||
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler],
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -123,7 +130,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
if isinstance(self.scheduler, PNDMScheduler):
|
||||
if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler)):
|
||||
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
# compute predicted original sample from predicted noise also called
|
||||
@@ -176,6 +183,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
clip_guidance_scale: Optional[float] = 100,
|
||||
clip_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_cutouts: Optional[int] = 4,
|
||||
@@ -275,6 +283,20 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
@@ -306,7 +328,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
|
||||
@@ -32,7 +32,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offsensive or harmful.
|
||||
|
||||
476
examples/community/imagic_stable_diffusion.py
Normal file
476
examples/community/imagic_stable_diffusion.py
Normal file
@@ -0,0 +1,476 @@
|
||||
"""
|
||||
modeled after the textual_inversion.py / train_dreambooth.py and the work
|
||||
of justinpinkney here: https://github.com/justinpinkney/stable-diffusion/blob/main/notebooks/imagic.ipynb
|
||||
"""
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import PIL
|
||||
from accelerate import Accelerator
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils import logging
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.0 * image - 1.0
|
||||
|
||||
|
||||
class ImagicStableDiffusionPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for imagic image editing.
|
||||
See paper here: https://arxiv.org/pdf/2210.09276.pdf
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offsensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def train(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
init_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
embedding_learning_rate: float = 0.001,
|
||||
diffusion_model_learning_rate: float = 2e-6,
|
||||
text_embedding_optimization_steps: int = 500,
|
||||
model_fine_tuning_optimization_steps: int = 1000,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=1,
|
||||
mixed_precision="fp16",
|
||||
)
|
||||
|
||||
if "torch_device" in kwargs:
|
||||
device = kwargs.pop("torch_device")
|
||||
warnings.warn(
|
||||
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
|
||||
" Consider using `pipe.to(torch_device)` instead."
|
||||
)
|
||||
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.to(device)
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
# Freeze vae and unet
|
||||
self.vae.requires_grad_(False)
|
||||
self.unet.requires_grad_(False)
|
||||
self.text_encoder.requires_grad_(False)
|
||||
self.unet.eval()
|
||||
self.vae.eval()
|
||||
self.text_encoder.eval()
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers(
|
||||
"imagic",
|
||||
config={
|
||||
"embedding_learning_rate": embedding_learning_rate,
|
||||
"text_embedding_optimization_steps": text_embedding_optimization_steps,
|
||||
},
|
||||
)
|
||||
|
||||
# get text embeddings for prompt
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncaton=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = torch.nn.Parameter(
|
||||
self.text_encoder(text_input.input_ids.to(self.device))[0], requires_grad=True
|
||||
)
|
||||
text_embeddings = text_embeddings.detach()
|
||||
text_embeddings.requires_grad_()
|
||||
text_embeddings_orig = text_embeddings.clone()
|
||||
|
||||
# Initialize the optimizer
|
||||
optimizer = torch.optim.Adam(
|
||||
[text_embeddings], # only optimize the embeddings
|
||||
lr=embedding_learning_rate,
|
||||
)
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = preprocess(init_image)
|
||||
|
||||
latents_dtype = text_embeddings.dtype
|
||||
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
||||
init_latent_image_dist = self.vae.encode(init_image).latent_dist
|
||||
init_image_latents = init_latent_image_dist.sample(generator=generator)
|
||||
init_image_latents = 0.18215 * init_image_latents
|
||||
|
||||
progress_bar = tqdm(range(text_embedding_optimization_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
global_step = 0
|
||||
|
||||
logger.info("First optimizing the text embedding to better reconstruct the init image")
|
||||
for _ in range(text_embedding_optimization_steps):
|
||||
with accelerator.accumulate(text_embeddings):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn(init_image_latents.shape).to(init_image_latents.device)
|
||||
timesteps = torch.randint(1000, (1,), device=init_image_latents.device)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = self.scheduler.add_noise(init_image_latents, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample
|
||||
|
||||
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
|
||||
accelerator.backward(loss)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
logs = {"loss": loss.detach().item()} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
text_embeddings.requires_grad_(False)
|
||||
|
||||
# Now we fine tune the unet to better reconstruct the image
|
||||
self.unet.requires_grad_(True)
|
||||
self.unet.train()
|
||||
optimizer = torch.optim.Adam(
|
||||
self.unet.parameters(), # only optimize unet
|
||||
lr=diffusion_model_learning_rate,
|
||||
)
|
||||
progress_bar = tqdm(range(model_fine_tuning_optimization_steps), disable=not accelerator.is_local_main_process)
|
||||
|
||||
logger.info("Next fine tuning the entire model to better reconstruct the init image")
|
||||
for _ in range(model_fine_tuning_optimization_steps):
|
||||
with accelerator.accumulate(self.unet.parameters()):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn(init_image_latents.shape).to(init_image_latents.device)
|
||||
timesteps = torch.randint(1000, (1,), device=init_image_latents.device)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = self.scheduler.add_noise(init_image_latents, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample
|
||||
|
||||
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
|
||||
accelerator.backward(loss)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
logs = {"loss": loss.detach().item()} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
self.text_embeddings_orig = text_embeddings_orig
|
||||
self.text_embeddings = text_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
alpha: float = 1.2,
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
guidance_scale: float = 7.5,
|
||||
eta: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
if self.text_embeddings is None:
|
||||
raise ValueError("Please run the pipe.train() before trying to generate an image.")
|
||||
if self.text_embeddings_orig is None:
|
||||
raise ValueError("Please run the pipe.train() before trying to generate an image.")
|
||||
|
||||
text_embeddings = alpha * self.text_embeddings_orig + (1 - alpha) * self.text_embeddings
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens = [""]
|
||||
max_length = self.tokenizer.model_max_length
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.view(1, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||
# However this currently doesn't work in `mps`.
|
||||
latents_shape = (1, self.unet.in_channels, height // 8, width // 8)
|
||||
latents_dtype = text_embeddings.dtype
|
||||
if self.device.type == "mps":
|
||||
# randn does not exist on mps
|
||||
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
|
||||
self.device
|
||||
)
|
||||
else:
|
||||
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# Some schedulers like PNDM have timesteps as arrays
|
||||
# It's more optimized to move all timesteps to correct device beforehand
|
||||
timesteps_tensor = self.scheduler.timesteps.to(self.device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
|
||||
self.device
|
||||
)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
463
examples/community/img2img_inpainting.py
Normal file
463
examples/community/img2img_inpainting.py
Normal file
@@ -0,0 +1,463 @@
|
||||
import inspect
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils import deprecate, logging
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def prepare_mask_and_masked_image(image, mask):
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
mask = np.array(mask.convert("L"))
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask = mask[None, None]
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
def check_size(image, height, width):
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
w, h = image.size
|
||||
elif isinstance(image, torch.Tensor):
|
||||
*_, h, w = image.shape
|
||||
|
||||
if h != height or w != width:
|
||||
raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}")
|
||||
|
||||
|
||||
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int] = (0, 0)):
|
||||
inner_image = inner_image.convert("RGBA")
|
||||
image = image.convert("RGB")
|
||||
|
||||
image.paste(inner_image, paste_offset, inner_image)
|
||||
image = image.convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class ImageToImageInpaintingPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-guided image-to-image inpainting using Stable Diffusion. *This is an experimental feature*.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file"
|
||||
)
|
||||
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
inner_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
image (`torch.Tensor` or `PIL.Image.Image`):
|
||||
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
|
||||
be masked out with `mask_image` and repainted according to `prompt`.
|
||||
inner_image (`torch.Tensor` or `PIL.Image.Image`):
|
||||
`Image`, or tensor representing an image batch which will be overlayed onto `image`. Non-transparent
|
||||
regions of `inner_image` must fit inside white pixels in `mask_image`. Expects four channels, with
|
||||
the last channel representing the alpha channel, which will be used to blend `inner_image` with
|
||||
`image`. If not provided, it will be forcibly cast to RGBA.
|
||||
mask_image (`PIL.Image.Image`):
|
||||
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
||||
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
|
||||
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
|
||||
instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
# check if input sizes are correct
|
||||
check_size(image, height, width)
|
||||
check_size(inner_image, height, width)
|
||||
check_size(mask_image, height, width)
|
||||
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
||||
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""]
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||
# However this currently doesn't work in `mps`.
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
|
||||
latents_dtype = text_embeddings.dtype
|
||||
if latents is None:
|
||||
if self.device.type == "mps":
|
||||
# randn does not exist on mps
|
||||
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
|
||||
self.device
|
||||
)
|
||||
else:
|
||||
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
latents = latents.to(self.device)
|
||||
|
||||
# overlay the inner image
|
||||
image = overlay_inner_image(image, inner_image)
|
||||
|
||||
# prepare mask and masked_image
|
||||
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
|
||||
mask = mask.to(device=self.device, dtype=text_embeddings.dtype)
|
||||
masked_image = masked_image.to(device=self.device, dtype=text_embeddings.dtype)
|
||||
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
mask = torch.nn.functional.interpolate(mask, size=(height // 8, width // 8))
|
||||
|
||||
# encode the mask image into latents space so we can concatenate it to the latents
|
||||
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
||||
masked_image_latents = 0.18215 * masked_image_latents
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
mask = mask.repeat(batch_size * num_images_per_prompt, 1, 1, 1)
|
||||
masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 1, 1, 1)
|
||||
|
||||
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
||||
masked_image_latents = (
|
||||
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
|
||||
)
|
||||
|
||||
num_channels_mask = mask.shape[1]
|
||||
num_channels_masked_image = masked_image_latents.shape[1]
|
||||
|
||||
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
|
||||
raise ValueError(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# Some schedulers like PNDM have timesteps as arrays
|
||||
# It's more optimized to move all timesteps to correct device beforehand
|
||||
timesteps_tensor = self.scheduler.timesteps.to(self.device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
|
||||
self.device
|
||||
)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
@@ -65,7 +65,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
@@ -278,7 +278,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""]
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
@@ -307,7 +307,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
|
||||
@@ -12,7 +12,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils import deprecate, logging
|
||||
from diffusers.utils import deprecate, is_accelerate_available, logging
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
@@ -340,13 +340,15 @@ def get_weighted_text_embeddings(
|
||||
# assign weights to the prompts and normalize in the sense of mean
|
||||
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
||||
if (not skip_parsing) and (not skip_weighting):
|
||||
previous_mean = text_embeddings.mean(axis=[-2, -1])
|
||||
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||
text_embeddings *= prompt_weights.unsqueeze(-1)
|
||||
text_embeddings *= (previous_mean / text_embeddings.mean(axis=[-2, -1])).unsqueeze(-1).unsqueeze(-1)
|
||||
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||
if uncond_prompt is not None:
|
||||
previous_mean = uncond_embeddings.mean(axis=[-2, -1])
|
||||
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
||||
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
||||
uncond_embeddings *= (previous_mean / uncond_embeddings.mean(axis=[-2, -1])).unsqueeze(-1).unsqueeze(-1)
|
||||
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
||||
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
if uncond_prompt is not None:
|
||||
return text_embeddings, uncond_embeddings
|
||||
@@ -396,7 +398,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
@@ -431,6 +433,19 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
@@ -451,6 +466,24 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
||||
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
||||
time. Speed up at training time is not guaranteed.
|
||||
|
||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||
is used.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Disable memory efficient attention as implemented in xformers.
|
||||
"""
|
||||
self.unet.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
@@ -478,6 +511,23 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_sequential_cpu_offload(self):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
|
||||
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
|
||||
"""
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = self.device
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
@@ -498,6 +548,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -560,11 +611,15 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
is_cancelled_callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. If the function returns
|
||||
`True`, the inference will be cancelled.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Returns:
|
||||
`None` if cancelled by `is_cancelled_callback`,
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
@@ -757,8 +812,11 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
if i % callback_steps == 0:
|
||||
if callback is not None:
|
||||
callback(i, t, latents)
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
@@ -435,6 +435,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
|
||||
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -496,11 +497,15 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: np.ndarray)`.
|
||||
is_cancelled_callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. If the function returns
|
||||
`True`, the inference will be cancelled.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Returns:
|
||||
`None` if cancelled by `is_cancelled_callback`,
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
@@ -668,8 +673,11 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
if i % callback_steps == 0:
|
||||
if callback is not None:
|
||||
callback(i, t, latents)
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||
@@ -693,7 +701,7 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
clip_input=safety_checker_input[i : i + 1], images=image[i : i + 1]
|
||||
)
|
||||
images.append(image_i)
|
||||
has_nsfw_concept.append(has_nsfw_concept_i)
|
||||
has_nsfw_concept.append(has_nsfw_concept_i[0])
|
||||
image = np.concatenate(images)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
436
examples/community/multilingual_stable_diffusion.py
Normal file
436
examples/community/multilingual_stable_diffusion.py
Normal file
@@ -0,0 +1,436 @@
|
||||
import inspect
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils import deprecate, logging
|
||||
from transformers import (
|
||||
CLIPFeatureExtractor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
MBart50TokenizerFast,
|
||||
MBartForConditionalGeneration,
|
||||
pipeline,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def detect_language(pipe, prompt, batch_size):
|
||||
"""helper function to detect language(s) of prompt"""
|
||||
|
||||
if batch_size == 1:
|
||||
preds = pipe(prompt, top_k=1, truncation=True, max_length=128)
|
||||
return preds[0]["label"]
|
||||
else:
|
||||
detected_languages = []
|
||||
for p in prompt:
|
||||
preds = pipe(p, top_k=1, truncation=True, max_length=128)
|
||||
detected_languages.append(preds[0]["label"])
|
||||
|
||||
return detected_languages
|
||||
|
||||
|
||||
def translate_prompt(prompt, translation_tokenizer, translation_model, device):
|
||||
"""helper function to translate prompt to English"""
|
||||
|
||||
encoded_prompt = translation_tokenizer(prompt, return_tensors="pt").to(device)
|
||||
generated_tokens = translation_model.generate(**encoded_prompt, max_new_tokens=1000)
|
||||
en_trans = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
||||
|
||||
return en_trans[0]
|
||||
|
||||
|
||||
class MultilingualStableDiffusion(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion in different languages.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
detection_pipeline ([`pipeline`]):
|
||||
Transformers pipeline to detect prompt's language.
|
||||
translation_model ([`MBartForConditionalGeneration`]):
|
||||
Model to translate prompt to English, if necessary. Please refer to the
|
||||
[model card](https://huggingface.co/docs/transformers/model_doc/mbart) for details.
|
||||
translation_tokenizer ([`MBart50TokenizerFast`]):
|
||||
Tokenizer of the translation model.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detection_pipeline: pipeline,
|
||||
translation_model: MBartForConditionalGeneration,
|
||||
translation_tokenizer: MBart50TokenizerFast,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file"
|
||||
)
|
||||
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
detection_pipeline=detection_pipeline,
|
||||
translation_model=translation_model,
|
||||
translation_tokenizer=translation_tokenizer,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation. Can be in different languages.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
# detect language and translate if necessary
|
||||
prompt_language = detect_language(self.detection_pipeline, prompt, batch_size)
|
||||
if batch_size == 1 and prompt_language != "en":
|
||||
prompt = translate_prompt(prompt, self.translation_tokenizer, self.translation_model, self.device)
|
||||
|
||||
if isinstance(prompt, list):
|
||||
for index in range(batch_size):
|
||||
if prompt_language[index] != "en":
|
||||
p = translate_prompt(
|
||||
prompt[index], self.translation_tokenizer, self.translation_model, self.device
|
||||
)
|
||||
prompt[index] = p
|
||||
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
||||
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
# detect language and translate it if necessary
|
||||
negative_prompt_language = detect_language(self.detection_pipeline, negative_prompt, batch_size)
|
||||
if negative_prompt_language != "en":
|
||||
negative_prompt = translate_prompt(
|
||||
negative_prompt, self.translation_tokenizer, self.translation_model, self.device
|
||||
)
|
||||
if isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
# detect language and translate it if necessary
|
||||
if isinstance(negative_prompt, list):
|
||||
negative_prompt_languages = detect_language(self.detection_pipeline, negative_prompt, batch_size)
|
||||
for index in range(batch_size):
|
||||
if negative_prompt_languages[index] != "en":
|
||||
p = translate_prompt(
|
||||
negative_prompt[index], self.translation_tokenizer, self.translation_model, self.device
|
||||
)
|
||||
negative_prompt[index] = p
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||
# However this currently doesn't work in `mps`.
|
||||
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
|
||||
latents_dtype = text_embeddings.dtype
|
||||
if latents is None:
|
||||
if self.device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
|
||||
self.device
|
||||
)
|
||||
else:
|
||||
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
latents = latents.to(self.device)
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# Some schedulers like PNDM have timesteps as arrays
|
||||
# It's more optimized to move all timesteps to correct device beforehand
|
||||
timesteps_tensor = self.scheduler.timesteps.to(self.device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
|
||||
self.device
|
||||
)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
@@ -37,7 +37,7 @@ class SeedResizeStableDiffusionPipeline(DiffusionPipeline):
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
|
||||
@@ -148,7 +148,7 @@ class SpeechToImagePipeline(DiffusionPipeline):
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""]
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
@@ -177,7 +177,7 @@ class SpeechToImagePipeline(DiffusionPipeline):
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
|
||||
@@ -42,7 +42,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline):
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionMegaSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
|
||||
@@ -99,7 +99,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
@@ -295,7 +295,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""]
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
@@ -324,7 +324,7 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
|
||||
@@ -185,7 +185,7 @@ accelerate launch train_dreambooth.py \
|
||||
--class_prompt="a photo of dog" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--use_8bit_adam
|
||||
--use_8bit_adam \
|
||||
--gradient_checkpointing \
|
||||
--learning_rate=2e-6 \
|
||||
--lr_scheduler="constant" \
|
||||
@@ -291,4 +291,4 @@ python train_dreambooth_flax.py \
|
||||
--learning_rate=2e-6 \
|
||||
--num_class_images=200 \
|
||||
--max_train_steps=800
|
||||
```
|
||||
```
|
||||
|
||||
@@ -66,6 +66,7 @@ def parse_args(input_args=None):
|
||||
"--instance_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="The prompt with identifier specifying the instance",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -205,14 +206,16 @@ def parse_args(input_args=None):
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
if args.instance_data_dir is None:
|
||||
raise ValueError("You must specify a train data directory.")
|
||||
|
||||
if args.with_prior_preservation:
|
||||
if args.class_data_dir is None:
|
||||
raise ValueError("You must specify a data directory for class images.")
|
||||
if args.class_prompt is None:
|
||||
raise ValueError("You must specify prompt for class images.")
|
||||
else:
|
||||
if args.class_data_dir is not None:
|
||||
logger.warning("You need not use --class_data_dir without --with_prior_preservation.")
|
||||
if args.class_prompt is not None:
|
||||
logger.warning("You need not use --class_prompt without --with_prior_preservation.")
|
||||
|
||||
return args
|
||||
|
||||
@@ -494,7 +497,12 @@ def main(args):
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
|
||||
input_ids = tokenizer.pad(
|
||||
{"input_ids": input_ids},
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
|
||||
@@ -327,22 +327,6 @@ def main():
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
if jax.process_index() == 0:
|
||||
if args.push_to_hub:
|
||||
if args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
repo = Repository(args.output_dir, clone_from=repo_name)
|
||||
|
||||
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
||||
if "step_*" not in gitignore:
|
||||
gitignore.write("step_*\n")
|
||||
if "epoch_*" not in gitignore:
|
||||
gitignore.write("epoch_*\n")
|
||||
elif args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
rng = jax.random.PRNGKey(args.seed)
|
||||
|
||||
if args.with_prior_preservation:
|
||||
@@ -361,7 +345,8 @@ def main():
|
||||
logger.info(f"Number of class images to sample: {num_new_images}.")
|
||||
|
||||
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
||||
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
|
||||
total_sample_batch_size = args.sample_batch_size * jax.local_device_count()
|
||||
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=total_sample_batch_size)
|
||||
|
||||
for example in tqdm(
|
||||
sample_dataloader, desc="Generating class images", disable=not jax.process_index() == 0
|
||||
@@ -451,7 +436,9 @@ def main():
|
||||
weight_dtype = jnp.bfloat16
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
text_encoder = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", dtype=weight_dtype)
|
||||
text_encoder = FlaxCLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype
|
||||
)
|
||||
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype
|
||||
)
|
||||
|
||||
19
examples/rl/README.md
Normal file
19
examples/rl/README.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# Overview
|
||||
|
||||
These examples show how to run (Diffuser)[https://arxiv.org/abs/2205.09991] in Diffusers.
|
||||
There are four scripts,
|
||||
1. `run_diffuser_locomotion.py` to sample actions and run them in the environment,
|
||||
2. and `run_diffuser_gen_trajectories.py` to just sample actions from the pre-trained diffusion model.
|
||||
|
||||
You will need some RL specific requirements to run the examples:
|
||||
|
||||
```
|
||||
pip install -f https://download.pytorch.org/whl/torch_stable.html \
|
||||
free-mujoco-py \
|
||||
einops \
|
||||
gym==0.24.1 \
|
||||
protobuf==3.20.1 \
|
||||
git+https://github.com/rail-berkeley/d4rl.git \
|
||||
mediapy \
|
||||
Pillow==9.0.0
|
||||
```
|
||||
57
examples/rl/run_diffuser_gen_trajectories.py
Normal file
57
examples/rl/run_diffuser_gen_trajectories.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import d4rl # noqa
|
||||
import gym
|
||||
import tqdm
|
||||
from diffusers.experimental import ValueGuidedRLPipeline
|
||||
|
||||
|
||||
config = dict(
|
||||
n_samples=64,
|
||||
horizon=32,
|
||||
num_inference_steps=20,
|
||||
n_guide_steps=0,
|
||||
scale_grad_by_std=True,
|
||||
scale=0.1,
|
||||
eta=0.0,
|
||||
t_grad_cutoff=2,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
env_name = "hopper-medium-v2"
|
||||
env = gym.make(env_name)
|
||||
|
||||
pipeline = ValueGuidedRLPipeline.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32",
|
||||
env=env,
|
||||
)
|
||||
|
||||
env.seed(0)
|
||||
obs = env.reset()
|
||||
total_reward = 0
|
||||
total_score = 0
|
||||
T = 1000
|
||||
rollout = [obs.copy()]
|
||||
try:
|
||||
for t in tqdm.tqdm(range(T)):
|
||||
# Call the policy
|
||||
denorm_actions = pipeline(obs, planning_horizon=32)
|
||||
|
||||
# execute action in environment
|
||||
next_observation, reward, terminal, _ = env.step(denorm_actions)
|
||||
score = env.get_normalized_score(total_reward)
|
||||
# update return
|
||||
total_reward += reward
|
||||
total_score += score
|
||||
print(
|
||||
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
|
||||
f" {total_score}"
|
||||
)
|
||||
# save observations for rendering
|
||||
rollout.append(next_observation.copy())
|
||||
|
||||
obs = next_observation
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
print(f"Total reward: {total_reward}")
|
||||
57
examples/rl/run_diffuser_locomotion.py
Normal file
57
examples/rl/run_diffuser_locomotion.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import d4rl # noqa
|
||||
import gym
|
||||
import tqdm
|
||||
from diffusers.experimental import ValueGuidedRLPipeline
|
||||
|
||||
|
||||
config = dict(
|
||||
n_samples=64,
|
||||
horizon=32,
|
||||
num_inference_steps=20,
|
||||
n_guide_steps=2,
|
||||
scale_grad_by_std=True,
|
||||
scale=0.1,
|
||||
eta=0.0,
|
||||
t_grad_cutoff=2,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
env_name = "hopper-medium-v2"
|
||||
env = gym.make(env_name)
|
||||
|
||||
pipeline = ValueGuidedRLPipeline.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32",
|
||||
env=env,
|
||||
)
|
||||
|
||||
env.seed(0)
|
||||
obs = env.reset()
|
||||
total_reward = 0
|
||||
total_score = 0
|
||||
T = 1000
|
||||
rollout = [obs.copy()]
|
||||
try:
|
||||
for t in tqdm.tqdm(range(T)):
|
||||
# call the policy
|
||||
denorm_actions = pipeline(obs, planning_horizon=32)
|
||||
|
||||
# execute action in environment
|
||||
next_observation, reward, terminal, _ = env.step(denorm_actions)
|
||||
score = env.get_normalized_score(total_reward)
|
||||
# update return
|
||||
total_reward += reward
|
||||
total_score += score
|
||||
print(
|
||||
f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:"
|
||||
f" {total_score}"
|
||||
)
|
||||
# save observations for rendering
|
||||
rollout.append(next_observation.copy())
|
||||
|
||||
obs = next_observation
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
print(f"Total reward: {total_reward}")
|
||||
@@ -379,7 +379,9 @@ def main():
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
text_encoder = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", dtype=weight_dtype)
|
||||
text_encoder = FlaxCLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype
|
||||
)
|
||||
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ accelerate config
|
||||
|
||||
### Cat toy example
|
||||
|
||||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree.
|
||||
You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-5`, so you'll need to visit [its card](https://huggingface.co/runwayml/stable-diffusion-v1-5), read the license and tick the checkbox if you agree.
|
||||
|
||||
You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens).
|
||||
|
||||
@@ -111,4 +111,4 @@ python textual_inversion_flax.py \
|
||||
--learning_rate=5.0e-04 --scale_lr \
|
||||
--output_dir="textual_inversion_cat"
|
||||
```
|
||||
It should be at least 70% faster than the PyTorch script with the same configuration.
|
||||
It should be at least 70% faster than the PyTorch script with the same configuration.
|
||||
|
||||
@@ -391,7 +391,7 @@ def main():
|
||||
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
text_encoder = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_encoder = FlaxCLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
||||
vae, vae_params = FlaxAutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
||||
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -10,10 +11,12 @@ import torch.nn.functional as F
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import load_dataset
|
||||
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel, __version__
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import deprecate
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
from packaging import version
|
||||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
Compose,
|
||||
@@ -27,6 +30,25 @@ from tqdm.auto import tqdm
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
diffusers_version = version.parse(version.parse(__version__).base_version)
|
||||
|
||||
|
||||
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
||||
"""
|
||||
Extract values from a 1-D numpy array for a batch of indices.
|
||||
|
||||
:param arr: the 1-D numpy array.
|
||||
:param timesteps: a tensor of indices into the array to extract.
|
||||
:param broadcast_shape: a larger shape of K dimensions with the batch
|
||||
dimension equal to the length of timesteps.
|
||||
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
||||
"""
|
||||
if not isinstance(arr, torch.Tensor):
|
||||
arr = torch.from_numpy(arr)
|
||||
res = arr[timesteps].float().to(timesteps.device)
|
||||
while len(res.shape) < len(broadcast_shape):
|
||||
res = res[..., None]
|
||||
return res.expand(broadcast_shape)
|
||||
|
||||
|
||||
def parse_args():
|
||||
@@ -171,6 +193,16 @@ def parse_args():
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--predict_epsilon",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
|
||||
)
|
||||
|
||||
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
|
||||
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
@@ -224,7 +256,17 @@ def main(args):
|
||||
"UpBlock2D",
|
||||
),
|
||||
)
|
||||
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
|
||||
accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
|
||||
|
||||
if accepts_predict_epsilon:
|
||||
noise_scheduler = DDPMScheduler(
|
||||
num_train_timesteps=args.ddpm_num_steps,
|
||||
beta_schedule=args.ddpm_beta_schedule,
|
||||
predict_epsilon=args.predict_epsilon,
|
||||
)
|
||||
else:
|
||||
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=args.learning_rate,
|
||||
@@ -257,6 +299,8 @@ def main(args):
|
||||
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
||||
return {"input": images}
|
||||
|
||||
logger.info(f"Dataset size: {len(dataset)}")
|
||||
|
||||
dataset.set_transform(transforms)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
|
||||
@@ -319,8 +363,20 @@ def main(args):
|
||||
|
||||
with accelerator.accumulate(model):
|
||||
# Predict the noise residual
|
||||
noise_pred = model(noisy_images, timesteps).sample
|
||||
loss = F.mse_loss(noise_pred, noise)
|
||||
model_output = model(noisy_images, timesteps).sample
|
||||
|
||||
if args.predict_epsilon:
|
||||
loss = F.mse_loss(model_output, noise) # this could have different weights!
|
||||
else:
|
||||
alpha_t = _extract_into_tensor(
|
||||
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
|
||||
)
|
||||
snr_weights = alpha_t / (1 - alpha_t)
|
||||
loss = snr_weights * F.mse_loss(
|
||||
model_output, clean_images, reduction="none"
|
||||
) # use SNR weighting from distillation paper
|
||||
loss = loss.mean()
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
@@ -353,9 +409,17 @@ def main(args):
|
||||
scheduler=noise_scheduler,
|
||||
)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
deprecate("todo: remove this check", "0.10.0", "when the most used version is >= 0.8.0")
|
||||
if diffusers_version < version.parse("0.8.0"):
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=pipeline.device).manual_seed(0)
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images
|
||||
images = pipeline(
|
||||
generator=generator,
|
||||
batch_size=args.eval_batch_size,
|
||||
output_type="numpy",
|
||||
).images
|
||||
|
||||
# denormalize the images and save to tensorboard
|
||||
images_processed = (images * 255).round().astype("uint8")
|
||||
|
||||
@@ -112,9 +112,9 @@ def assign_to_checkpoint(
|
||||
continue
|
||||
|
||||
# Global renaming happens here
|
||||
new_path = new_path.replace("middle_block.0", "mid.resnets.0")
|
||||
new_path = new_path.replace("middle_block.1", "mid.attentions.0")
|
||||
new_path = new_path.replace("middle_block.2", "mid.resnets.1")
|
||||
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
||||
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
||||
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
||||
|
||||
if additional_replacements is not None:
|
||||
for replacement in additional_replacements:
|
||||
@@ -175,15 +175,16 @@ def convert_ldm_checkpoint(checkpoint, config):
|
||||
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
||||
|
||||
if f"input_blocks.{i}.0.op.weight" in checkpoint:
|
||||
new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = checkpoint[
|
||||
f"input_blocks.{i}.0.op.weight"
|
||||
]
|
||||
new_checkpoint[f"downsample_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[
|
||||
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = checkpoint[
|
||||
f"input_blocks.{i}.0.op.bias"
|
||||
]
|
||||
continue
|
||||
|
||||
paths = renew_resnet_paths(resnets)
|
||||
meta_path = {"old": f"input_blocks.{i}.0", "new": f"downsample_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
||||
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
||||
resnet_op = {"old": "resnets.2.op", "new": "downsamplers.0.op"}
|
||||
assign_to_checkpoint(
|
||||
paths, new_checkpoint, checkpoint, additional_replacements=[meta_path, resnet_op], config=config
|
||||
@@ -193,18 +194,18 @@ def convert_ldm_checkpoint(checkpoint, config):
|
||||
paths = renew_attention_paths(attentions)
|
||||
meta_path = {
|
||||
"old": f"input_blocks.{i}.1",
|
||||
"new": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}",
|
||||
"new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}",
|
||||
}
|
||||
to_split = {
|
||||
f"input_blocks.{i}.1.qkv.bias": {
|
||||
"key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
|
||||
"query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
|
||||
"value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
|
||||
"key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.bias",
|
||||
"query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.bias",
|
||||
"value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.bias",
|
||||
},
|
||||
f"input_blocks.{i}.1.qkv.weight": {
|
||||
"key": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
|
||||
"query": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
|
||||
"value": f"downsample_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
|
||||
"key": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.key.weight",
|
||||
"query": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.query.weight",
|
||||
"value": f"down_blocks.{block_id}.attentions.{layer_in_block_id}.value.weight",
|
||||
},
|
||||
}
|
||||
assign_to_checkpoint(
|
||||
|
||||
100
scripts/convert_models_diffuser_to_diffusers.py
Normal file
100
scripts/convert_models_diffuser_to_diffusers.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import UNet1DModel
|
||||
|
||||
|
||||
os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True)
|
||||
os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True)
|
||||
|
||||
os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True)
|
||||
|
||||
|
||||
def unet(hor):
|
||||
if hor == 128:
|
||||
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
|
||||
block_out_channels = (32, 128, 256)
|
||||
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D")
|
||||
|
||||
elif hor == 32:
|
||||
down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
|
||||
block_out_channels = (32, 64, 128, 256)
|
||||
up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D")
|
||||
model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch")
|
||||
state_dict = model.state_dict()
|
||||
config = dict(
|
||||
down_block_types=down_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
up_block_types=up_block_types,
|
||||
layers_per_block=1,
|
||||
use_timestep_embedding=True,
|
||||
out_block_type="OutConv1DBlock",
|
||||
norm_num_groups=8,
|
||||
downsample_each_block=False,
|
||||
in_channels=14,
|
||||
out_channels=14,
|
||||
extra_in_channels=0,
|
||||
time_embedding_type="positional",
|
||||
flip_sin_to_cos=False,
|
||||
freq_shift=1,
|
||||
sample_size=65536,
|
||||
mid_block_type="MidResTemporalBlock1D",
|
||||
act_fn="mish",
|
||||
)
|
||||
hf_value_function = UNet1DModel(**config)
|
||||
print(f"length of state dict: {len(state_dict.keys())}")
|
||||
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
|
||||
mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys()))
|
||||
for k, v in mapping.items():
|
||||
state_dict[v] = state_dict.pop(k)
|
||||
hf_value_function.load_state_dict(state_dict)
|
||||
|
||||
torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin")
|
||||
with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
|
||||
def value_function():
|
||||
config = dict(
|
||||
in_channels=14,
|
||||
down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
|
||||
up_block_types=(),
|
||||
out_block_type="ValueFunction",
|
||||
mid_block_type="ValueFunctionMidBlock1D",
|
||||
block_out_channels=(32, 64, 128, 256),
|
||||
layers_per_block=1,
|
||||
downsample_each_block=True,
|
||||
sample_size=65536,
|
||||
out_channels=14,
|
||||
extra_in_channels=0,
|
||||
time_embedding_type="positional",
|
||||
use_timestep_embedding=True,
|
||||
flip_sin_to_cos=False,
|
||||
freq_shift=1,
|
||||
norm_num_groups=8,
|
||||
act_fn="mish",
|
||||
)
|
||||
|
||||
model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch")
|
||||
state_dict = model
|
||||
hf_value_function = UNet1DModel(**config)
|
||||
print(f"length of state dict: {len(state_dict.keys())}")
|
||||
print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
|
||||
|
||||
mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys()))
|
||||
for k, v in mapping.items():
|
||||
state_dict[v] = state_dict.pop(k)
|
||||
|
||||
hf_value_function.load_state_dict(state_dict)
|
||||
|
||||
torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin")
|
||||
with open("hub/hopper-medium-v2/value_function/config.json", "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unet(32)
|
||||
# unet(128)
|
||||
value_function()
|
||||
@@ -30,6 +30,9 @@ except ImportError:
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LDMTextToImagePipeline,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
@@ -647,7 +650,7 @@ if __name__ == "__main__":
|
||||
"--scheduler_type",
|
||||
default="pndm",
|
||||
type=str,
|
||||
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']",
|
||||
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancest', 'dpm']",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extract_ema",
|
||||
@@ -686,6 +689,16 @@ if __name__ == "__main__":
|
||||
)
|
||||
elif args.scheduler_type == "lms":
|
||||
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
|
||||
elif args.scheduler_type == "euler":
|
||||
scheduler = EulerDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
|
||||
elif args.scheduler_type == "euler-ancestral":
|
||||
scheduler = EulerAncestralDiscreteScheduler(
|
||||
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
|
||||
)
|
||||
elif args.scheduler_type == "dpm":
|
||||
scheduler = DPMSolverMultistepScheduler(
|
||||
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
|
||||
)
|
||||
elif args.scheduler_type == "ddim":
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=beta_start,
|
||||
|
||||
@@ -81,6 +81,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||
output_path = Path(output_path)
|
||||
|
||||
# TEXT ENCODER
|
||||
num_tokens = pipeline.text_encoder.config.max_position_embeddings
|
||||
text_hidden_size = pipeline.text_encoder.config.hidden_size
|
||||
text_input = pipeline.tokenizer(
|
||||
"A sample prompt",
|
||||
padding="max_length",
|
||||
@@ -103,13 +105,15 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||
del pipeline.text_encoder
|
||||
|
||||
# UNET
|
||||
unet_in_channels = pipeline.unet.config.in_channels
|
||||
unet_sample_size = pipeline.unet.config.sample_size
|
||||
unet_path = output_path / "unet" / "model.onnx"
|
||||
onnx_export(
|
||||
pipeline.unet,
|
||||
model_args=(
|
||||
torch.randn(2, pipeline.unet.in_channels, 64, 64).to(device=device, dtype=dtype),
|
||||
torch.LongTensor([0, 1]).to(device=device),
|
||||
torch.randn(2, 77, 768).to(device=device, dtype=dtype),
|
||||
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype),
|
||||
torch.randn(2).to(device=device, dtype=dtype),
|
||||
torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=unet_path,
|
||||
@@ -142,11 +146,16 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||
|
||||
# VAE ENCODER
|
||||
vae_encoder = pipeline.vae
|
||||
vae_in_channels = vae_encoder.config.in_channels
|
||||
vae_sample_size = vae_encoder.config.sample_size
|
||||
# need to get the raw tensor output (sample) from the encoder
|
||||
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample()
|
||||
onnx_export(
|
||||
vae_encoder,
|
||||
model_args=(torch.randn(1, 3, 512, 512).to(device=device, dtype=dtype), False),
|
||||
model_args=(
|
||||
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(device=device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_encoder" / "model.onnx",
|
||||
ordered_input_names=["sample", "return_dict"],
|
||||
output_names=["latent_sample"],
|
||||
@@ -158,11 +167,16 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||
|
||||
# VAE DECODER
|
||||
vae_decoder = pipeline.vae
|
||||
vae_latent_channels = vae_decoder.config.latent_channels
|
||||
vae_out_channels = vae_decoder.config.out_channels
|
||||
# forward only through the decoder part
|
||||
vae_decoder.forward = vae_encoder.decode
|
||||
onnx_export(
|
||||
vae_decoder,
|
||||
model_args=(torch.randn(1, 4, 64, 64).to(device=device, dtype=dtype), False),
|
||||
model_args=(
|
||||
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_decoder" / "model.onnx",
|
||||
ordered_input_names=["latent_sample", "return_dict"],
|
||||
output_names=["sample"],
|
||||
@@ -174,24 +188,35 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||
del pipeline.vae
|
||||
|
||||
# SAFETY CHECKER
|
||||
safety_checker = pipeline.safety_checker
|
||||
safety_checker.forward = safety_checker.forward_onnx
|
||||
onnx_export(
|
||||
pipeline.safety_checker,
|
||||
model_args=(
|
||||
torch.randn(1, 3, 224, 224).to(device=device, dtype=dtype),
|
||||
torch.randn(1, 512, 512, 3).to(device=device, dtype=dtype),
|
||||
),
|
||||
output_path=output_path / "safety_checker" / "model.onnx",
|
||||
ordered_input_names=["clip_input", "images"],
|
||||
output_names=["out_images", "has_nsfw_concepts"],
|
||||
dynamic_axes={
|
||||
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
del pipeline.safety_checker
|
||||
if pipeline.safety_checker is not None:
|
||||
safety_checker = pipeline.safety_checker
|
||||
clip_num_channels = safety_checker.config.vision_config.num_channels
|
||||
clip_image_size = safety_checker.config.vision_config.image_size
|
||||
safety_checker.forward = safety_checker.forward_onnx
|
||||
onnx_export(
|
||||
pipeline.safety_checker,
|
||||
model_args=(
|
||||
torch.randn(
|
||||
1,
|
||||
clip_num_channels,
|
||||
clip_image_size,
|
||||
clip_image_size,
|
||||
).to(device=device, dtype=dtype),
|
||||
torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(device=device, dtype=dtype),
|
||||
),
|
||||
output_path=output_path / "safety_checker" / "model.onnx",
|
||||
ordered_input_names=["clip_input", "images"],
|
||||
output_names=["out_images", "has_nsfw_concepts"],
|
||||
dynamic_axes={
|
||||
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
del pipeline.safety_checker
|
||||
safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker")
|
||||
else:
|
||||
safety_checker = None
|
||||
|
||||
onnx_pipeline = OnnxStableDiffusionPipeline(
|
||||
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
|
||||
@@ -200,7 +225,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||
tokenizer=pipeline.tokenizer,
|
||||
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
||||
scheduler=pipeline.scheduler,
|
||||
safety_checker=OnnxRuntimeModel.from_pretrained(output_path / "safety_checker"),
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=pipeline.feature_extractor,
|
||||
)
|
||||
|
||||
|
||||
885
scripts/convert_vq_diffusion_to_diffusers.py
Normal file
885
scripts/convert_vq_diffusion_to_diffusers.py
Normal file
@@ -0,0 +1,885 @@
|
||||
"""
|
||||
This script ports models from VQ-diffusion (https://github.com/microsoft/VQ-Diffusion) to diffusers.
|
||||
|
||||
It currently only supports porting the ITHQ dataset.
|
||||
|
||||
ITHQ dataset:
|
||||
```sh
|
||||
# From the root directory of diffusers.
|
||||
|
||||
# Download the VQVAE checkpoint
|
||||
$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_vqvae.pth?sv=2020-10-02&st=2022-05-30T15%3A17%3A18Z&se=2030-05-31T15%3A17%3A00Z&sr=b&sp=r&sig=1jVavHFPpUjDs%2FTO1V3PTezaNbPp2Nx8MxiWI7y6fEY%3D -O ithq_vqvae.pth
|
||||
|
||||
# Download the VQVAE config
|
||||
# NOTE that in VQ-diffusion the documented file is `configs/ithq.yaml` but the target class
|
||||
# `image_synthesis.modeling.codecs.image_codec.ema_vqvae.PatchVQVAE`
|
||||
# loads `OUTPUT/pretrained_model/taming_dvae/config.yaml`
|
||||
$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/OUTPUT/pretrained_model/taming_dvae/config.yaml -O ithq_vqvae.yaml
|
||||
|
||||
# Download the main model checkpoint
|
||||
$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_learnable.pth?sv=2020-10-02&st=2022-05-30T10%3A22%3A06Z&se=2030-05-31T10%3A22%3A00Z&sr=b&sp=r&sig=GOE%2Bza02%2FPnGxYVOOPtwrTR4RA3%2F5NVgMxdW4kjaEZ8%3D -O ithq_learnable.pth
|
||||
|
||||
# Download the main model config
|
||||
$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/configs/ithq.yaml -O ithq.yaml
|
||||
|
||||
# run the convert script
|
||||
$ python ./scripts/convert_vq_diffusion_to_diffusers.py \
|
||||
--checkpoint_path ./ithq_learnable.pth \
|
||||
--original_config_file ./ithq.yaml \
|
||||
--vqvae_checkpoint_path ./ithq_vqvae.pth \
|
||||
--vqvae_original_config_file ./ithq_vqvae.yaml \
|
||||
--dump_path <path to save pre-trained `VQDiffusionPipeline`>
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
|
||||
import yaml
|
||||
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
||||
from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel
|
||||
from diffusers.models.attention import Transformer2DModel
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from yaml.loader import FullLoader
|
||||
|
||||
|
||||
try:
|
||||
from omegaconf import OmegaConf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"OmegaConf is required to convert the VQ Diffusion checkpoints. Please install it with `pip install"
|
||||
" OmegaConf`."
|
||||
)
|
||||
|
||||
# vqvae model
|
||||
|
||||
PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchVQGAN"]
|
||||
|
||||
|
||||
def vqvae_model_from_original_config(original_config):
|
||||
assert original_config.target in PORTED_VQVAES, f"{original_config.target} has not yet been ported to diffusers."
|
||||
|
||||
original_config = original_config.params
|
||||
|
||||
original_encoder_config = original_config.encoder_config.params
|
||||
original_decoder_config = original_config.decoder_config.params
|
||||
|
||||
in_channels = original_encoder_config.in_channels
|
||||
out_channels = original_decoder_config.out_ch
|
||||
|
||||
down_block_types = get_down_block_types(original_encoder_config)
|
||||
up_block_types = get_up_block_types(original_decoder_config)
|
||||
|
||||
assert original_encoder_config.ch == original_decoder_config.ch
|
||||
assert original_encoder_config.ch_mult == original_decoder_config.ch_mult
|
||||
block_out_channels = tuple(
|
||||
[original_encoder_config.ch * a_ch_mult for a_ch_mult in original_encoder_config.ch_mult]
|
||||
)
|
||||
|
||||
assert original_encoder_config.num_res_blocks == original_decoder_config.num_res_blocks
|
||||
layers_per_block = original_encoder_config.num_res_blocks
|
||||
|
||||
assert original_encoder_config.z_channels == original_decoder_config.z_channels
|
||||
latent_channels = original_encoder_config.z_channels
|
||||
|
||||
num_vq_embeddings = original_config.n_embed
|
||||
|
||||
# Hard coded value for ResnetBlock.GoupNorm(num_groups) in VQ-diffusion
|
||||
norm_num_groups = 32
|
||||
|
||||
e_dim = original_config.embed_dim
|
||||
|
||||
model = VQModel(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
down_block_types=down_block_types,
|
||||
up_block_types=up_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
latent_channels=latent_channels,
|
||||
num_vq_embeddings=num_vq_embeddings,
|
||||
norm_num_groups=norm_num_groups,
|
||||
vq_embed_dim=e_dim,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_down_block_types(original_encoder_config):
|
||||
attn_resolutions = coerce_attn_resolutions(original_encoder_config.attn_resolutions)
|
||||
num_resolutions = len(original_encoder_config.ch_mult)
|
||||
resolution = coerce_resolution(original_encoder_config.resolution)
|
||||
|
||||
curr_res = resolution
|
||||
down_block_types = []
|
||||
|
||||
for _ in range(num_resolutions):
|
||||
if curr_res in attn_resolutions:
|
||||
down_block_type = "AttnDownEncoderBlock2D"
|
||||
else:
|
||||
down_block_type = "DownEncoderBlock2D"
|
||||
|
||||
down_block_types.append(down_block_type)
|
||||
|
||||
curr_res = [r // 2 for r in curr_res]
|
||||
|
||||
return down_block_types
|
||||
|
||||
|
||||
def get_up_block_types(original_decoder_config):
|
||||
attn_resolutions = coerce_attn_resolutions(original_decoder_config.attn_resolutions)
|
||||
num_resolutions = len(original_decoder_config.ch_mult)
|
||||
resolution = coerce_resolution(original_decoder_config.resolution)
|
||||
|
||||
curr_res = [r // 2 ** (num_resolutions - 1) for r in resolution]
|
||||
up_block_types = []
|
||||
|
||||
for _ in reversed(range(num_resolutions)):
|
||||
if curr_res in attn_resolutions:
|
||||
up_block_type = "AttnUpDecoderBlock2D"
|
||||
else:
|
||||
up_block_type = "UpDecoderBlock2D"
|
||||
|
||||
up_block_types.append(up_block_type)
|
||||
|
||||
curr_res = [r * 2 for r in curr_res]
|
||||
|
||||
return up_block_types
|
||||
|
||||
|
||||
def coerce_attn_resolutions(attn_resolutions):
|
||||
attn_resolutions = OmegaConf.to_object(attn_resolutions)
|
||||
attn_resolutions_ = []
|
||||
for ar in attn_resolutions:
|
||||
if isinstance(ar, (list, tuple)):
|
||||
attn_resolutions_.append(list(ar))
|
||||
else:
|
||||
attn_resolutions_.append([ar, ar])
|
||||
return attn_resolutions_
|
||||
|
||||
|
||||
def coerce_resolution(resolution):
|
||||
resolution = OmegaConf.to_object(resolution)
|
||||
if isinstance(resolution, int):
|
||||
resolution = [resolution, resolution] # H, W
|
||||
elif isinstance(resolution, (tuple, list)):
|
||||
resolution = list(resolution)
|
||||
else:
|
||||
raise ValueError("Unknown type of resolution:", resolution)
|
||||
return resolution
|
||||
|
||||
|
||||
# done vqvae model
|
||||
|
||||
# vqvae checkpoint
|
||||
|
||||
|
||||
def vqvae_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
diffusers_checkpoint.update(vqvae_encoder_to_diffusers_checkpoint(model, checkpoint))
|
||||
|
||||
# quant_conv
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"quant_conv.weight": checkpoint["quant_conv.weight"],
|
||||
"quant_conv.bias": checkpoint["quant_conv.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# quantize
|
||||
diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding"]})
|
||||
|
||||
# post_quant_conv
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"post_quant_conv.weight": checkpoint["post_quant_conv.weight"],
|
||||
"post_quant_conv.bias": checkpoint["post_quant_conv.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# decoder
|
||||
diffusers_checkpoint.update(vqvae_decoder_to_diffusers_checkpoint(model, checkpoint))
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
def vqvae_encoder_to_diffusers_checkpoint(model, checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
# conv_in
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"],
|
||||
"encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# down_blocks
|
||||
for down_block_idx, down_block in enumerate(model.encoder.down_blocks):
|
||||
diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}"
|
||||
down_block_prefix = f"encoder.down.{down_block_idx}"
|
||||
|
||||
# resnets
|
||||
for resnet_idx, resnet in enumerate(down_block.resnets):
|
||||
diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}"
|
||||
resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_resnet_to_diffusers_checkpoint(
|
||||
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# downsample
|
||||
|
||||
# do not include the downsample when on the last down block
|
||||
# There is no downsample on the last down block
|
||||
if down_block_idx != len(model.encoder.down_blocks) - 1:
|
||||
# There's a single downsample in the original checkpoint but a list of downsamples
|
||||
# in the diffusers model.
|
||||
diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv"
|
||||
downsample_prefix = f"{down_block_prefix}.downsample.conv"
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"],
|
||||
f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# attentions
|
||||
|
||||
if hasattr(down_block, "attentions"):
|
||||
for attention_idx, _ in enumerate(down_block.attentions):
|
||||
diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}"
|
||||
attention_prefix = f"{down_block_prefix}.attn.{attention_idx}"
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_attention_to_diffusers_checkpoint(
|
||||
checkpoint,
|
||||
diffusers_attention_prefix=diffusers_attention_prefix,
|
||||
attention_prefix=attention_prefix,
|
||||
)
|
||||
)
|
||||
|
||||
# mid block
|
||||
|
||||
# mid block attentions
|
||||
|
||||
# There is a single hardcoded attention block in the middle of the VQ-diffusion encoder
|
||||
diffusers_attention_prefix = "encoder.mid_block.attentions.0"
|
||||
attention_prefix = "encoder.mid.attn_1"
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_attention_to_diffusers_checkpoint(
|
||||
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# mid block resnets
|
||||
|
||||
for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets):
|
||||
diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}"
|
||||
|
||||
# the hardcoded prefixes to `block_` are 1 and 2
|
||||
orig_resnet_idx = diffusers_resnet_idx + 1
|
||||
# There are two hardcoded resnets in the middle of the VQ-diffusion encoder
|
||||
resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_resnet_to_diffusers_checkpoint(
|
||||
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
||||
)
|
||||
)
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
# conv_norm_out
|
||||
"encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"],
|
||||
"encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"],
|
||||
# conv_out
|
||||
"encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"],
|
||||
"encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
def vqvae_decoder_to_diffusers_checkpoint(model, checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
# conv in
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"],
|
||||
"decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# up_blocks
|
||||
|
||||
for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks):
|
||||
# up_blocks are stored in reverse order in the VQ-diffusion checkpoint
|
||||
orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx
|
||||
|
||||
diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}"
|
||||
up_block_prefix = f"decoder.up.{orig_up_block_idx}"
|
||||
|
||||
# resnets
|
||||
for resnet_idx, resnet in enumerate(up_block.resnets):
|
||||
diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"
|
||||
resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_resnet_to_diffusers_checkpoint(
|
||||
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# upsample
|
||||
|
||||
# there is no up sample on the last up block
|
||||
if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1:
|
||||
# There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples
|
||||
# in the diffusers model.
|
||||
diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv"
|
||||
downsample_prefix = f"{up_block_prefix}.upsample.conv"
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"],
|
||||
f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# attentions
|
||||
|
||||
if hasattr(up_block, "attentions"):
|
||||
for attention_idx, _ in enumerate(up_block.attentions):
|
||||
diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}"
|
||||
attention_prefix = f"{up_block_prefix}.attn.{attention_idx}"
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_attention_to_diffusers_checkpoint(
|
||||
checkpoint,
|
||||
diffusers_attention_prefix=diffusers_attention_prefix,
|
||||
attention_prefix=attention_prefix,
|
||||
)
|
||||
)
|
||||
|
||||
# mid block
|
||||
|
||||
# mid block attentions
|
||||
|
||||
# There is a single hardcoded attention block in the middle of the VQ-diffusion decoder
|
||||
diffusers_attention_prefix = "decoder.mid_block.attentions.0"
|
||||
attention_prefix = "decoder.mid.attn_1"
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_attention_to_diffusers_checkpoint(
|
||||
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# mid block resnets
|
||||
|
||||
for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets):
|
||||
diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}"
|
||||
|
||||
# the hardcoded prefixes to `block_` are 1 and 2
|
||||
orig_resnet_idx = diffusers_resnet_idx + 1
|
||||
# There are two hardcoded resnets in the middle of the VQ-diffusion decoder
|
||||
resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_resnet_to_diffusers_checkpoint(
|
||||
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
||||
)
|
||||
)
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
# conv_norm_out
|
||||
"decoder.conv_norm_out.weight": checkpoint["decoder.norm_out.weight"],
|
||||
"decoder.conv_norm_out.bias": checkpoint["decoder.norm_out.bias"],
|
||||
# conv_out
|
||||
"decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"],
|
||||
"decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
def vqvae_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
|
||||
rv = {
|
||||
# norm1
|
||||
f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"],
|
||||
f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"],
|
||||
# conv1
|
||||
f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"],
|
||||
f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"],
|
||||
# norm2
|
||||
f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"],
|
||||
f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"],
|
||||
# conv2
|
||||
f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"],
|
||||
f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"],
|
||||
}
|
||||
|
||||
if resnet.conv_shortcut is not None:
|
||||
rv.update(
|
||||
{
|
||||
f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"],
|
||||
f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
def vqvae_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
|
||||
return {
|
||||
# group_norm
|
||||
f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"],
|
||||
f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"],
|
||||
# query
|
||||
f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0],
|
||||
f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"],
|
||||
# key
|
||||
f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0],
|
||||
f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"],
|
||||
# value
|
||||
f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0],
|
||||
f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"],
|
||||
# proj_attn
|
||||
f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][
|
||||
:, :, 0, 0
|
||||
],
|
||||
f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"],
|
||||
}
|
||||
|
||||
|
||||
# done vqvae checkpoint
|
||||
|
||||
# transformer model
|
||||
|
||||
PORTED_DIFFUSIONS = ["image_synthesis.modeling.transformers.diffusion_transformer.DiffusionTransformer"]
|
||||
PORTED_TRANSFORMERS = ["image_synthesis.modeling.transformers.transformer_utils.Text2ImageTransformer"]
|
||||
PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_image_embedding.DalleMaskImageEmbedding"]
|
||||
|
||||
|
||||
def transformer_model_from_original_config(
|
||||
original_diffusion_config, original_transformer_config, original_content_embedding_config
|
||||
):
|
||||
assert (
|
||||
original_diffusion_config.target in PORTED_DIFFUSIONS
|
||||
), f"{original_diffusion_config.target} has not yet been ported to diffusers."
|
||||
assert (
|
||||
original_transformer_config.target in PORTED_TRANSFORMERS
|
||||
), f"{original_transformer_config.target} has not yet been ported to diffusers."
|
||||
assert (
|
||||
original_content_embedding_config.target in PORTED_CONTENT_EMBEDDINGS
|
||||
), f"{original_content_embedding_config.target} has not yet been ported to diffusers."
|
||||
|
||||
original_diffusion_config = original_diffusion_config.params
|
||||
original_transformer_config = original_transformer_config.params
|
||||
original_content_embedding_config = original_content_embedding_config.params
|
||||
|
||||
inner_dim = original_transformer_config["n_embd"]
|
||||
|
||||
n_heads = original_transformer_config["n_head"]
|
||||
|
||||
# VQ-Diffusion gives dimension of the multi-headed attention layers as the
|
||||
# number of attention heads times the sequence length (the dimension) of a
|
||||
# single head. We want to specify our attention blocks with those values
|
||||
# specified separately
|
||||
assert inner_dim % n_heads == 0
|
||||
d_head = inner_dim // n_heads
|
||||
|
||||
depth = original_transformer_config["n_layer"]
|
||||
context_dim = original_transformer_config["condition_dim"]
|
||||
|
||||
num_embed = original_content_embedding_config["num_embed"]
|
||||
# the number of embeddings in the transformer includes the mask embedding.
|
||||
# the content embedding (the vqvae) does not include the mask embedding.
|
||||
num_embed = num_embed + 1
|
||||
|
||||
height = original_transformer_config["content_spatial_size"][0]
|
||||
width = original_transformer_config["content_spatial_size"][1]
|
||||
|
||||
assert width == height, "width has to be equal to height"
|
||||
dropout = original_transformer_config["resid_pdrop"]
|
||||
num_embeds_ada_norm = original_diffusion_config["diffusion_step"]
|
||||
|
||||
model_kwargs = {
|
||||
"attention_bias": True,
|
||||
"cross_attention_dim": context_dim,
|
||||
"attention_head_dim": d_head,
|
||||
"num_layers": depth,
|
||||
"dropout": dropout,
|
||||
"num_attention_heads": n_heads,
|
||||
"num_vector_embeds": num_embed,
|
||||
"num_embeds_ada_norm": num_embeds_ada_norm,
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": width,
|
||||
"activation_fn": "geglu-approximate",
|
||||
}
|
||||
|
||||
model = Transformer2DModel(**model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
# done transformer model
|
||||
|
||||
# transformer checkpoint
|
||||
|
||||
|
||||
def transformer_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
transformer_prefix = "transformer.transformer"
|
||||
|
||||
diffusers_latent_image_embedding_prefix = "latent_image_embedding"
|
||||
latent_image_embedding_prefix = f"{transformer_prefix}.content_emb"
|
||||
|
||||
# DalleMaskImageEmbedding
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_latent_image_embedding_prefix}.emb.weight": checkpoint[
|
||||
f"{latent_image_embedding_prefix}.emb.weight"
|
||||
],
|
||||
f"{diffusers_latent_image_embedding_prefix}.height_emb.weight": checkpoint[
|
||||
f"{latent_image_embedding_prefix}.height_emb.weight"
|
||||
],
|
||||
f"{diffusers_latent_image_embedding_prefix}.width_emb.weight": checkpoint[
|
||||
f"{latent_image_embedding_prefix}.width_emb.weight"
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# transformer blocks
|
||||
for transformer_block_idx, transformer_block in enumerate(model.transformer_blocks):
|
||||
diffusers_transformer_block_prefix = f"transformer_blocks.{transformer_block_idx}"
|
||||
transformer_block_prefix = f"{transformer_prefix}.blocks.{transformer_block_idx}"
|
||||
|
||||
# ada norm block
|
||||
diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm1"
|
||||
ada_norm_prefix = f"{transformer_block_prefix}.ln1"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
transformer_ada_norm_to_diffusers_checkpoint(
|
||||
checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# attention block
|
||||
diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn1"
|
||||
attention_prefix = f"{transformer_block_prefix}.attn1"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
transformer_attention_to_diffusers_checkpoint(
|
||||
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# ada norm block
|
||||
diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm2"
|
||||
ada_norm_prefix = f"{transformer_block_prefix}.ln1_1"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
transformer_ada_norm_to_diffusers_checkpoint(
|
||||
checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# attention block
|
||||
diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn2"
|
||||
attention_prefix = f"{transformer_block_prefix}.attn2"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
transformer_attention_to_diffusers_checkpoint(
|
||||
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# norm block
|
||||
diffusers_norm_block_prefix = f"{diffusers_transformer_block_prefix}.norm3"
|
||||
norm_block_prefix = f"{transformer_block_prefix}.ln2"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_norm_block_prefix}.weight": checkpoint[f"{norm_block_prefix}.weight"],
|
||||
f"{diffusers_norm_block_prefix}.bias": checkpoint[f"{norm_block_prefix}.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# feedforward block
|
||||
diffusers_feedforward_prefix = f"{diffusers_transformer_block_prefix}.ff"
|
||||
feedforward_prefix = f"{transformer_block_prefix}.mlp"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
transformer_feedforward_to_diffusers_checkpoint(
|
||||
checkpoint,
|
||||
diffusers_feedforward_prefix=diffusers_feedforward_prefix,
|
||||
feedforward_prefix=feedforward_prefix,
|
||||
)
|
||||
)
|
||||
|
||||
# to logits
|
||||
|
||||
diffusers_norm_out_prefix = "norm_out"
|
||||
norm_out_prefix = f"{transformer_prefix}.to_logits.0"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_norm_out_prefix}.weight": checkpoint[f"{norm_out_prefix}.weight"],
|
||||
f"{diffusers_norm_out_prefix}.bias": checkpoint[f"{norm_out_prefix}.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
diffusers_out_prefix = "out"
|
||||
out_prefix = f"{transformer_prefix}.to_logits.1"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_out_prefix}.weight": checkpoint[f"{out_prefix}.weight"],
|
||||
f"{diffusers_out_prefix}.bias": checkpoint[f"{out_prefix}.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
def transformer_ada_norm_to_diffusers_checkpoint(checkpoint, *, diffusers_ada_norm_prefix, ada_norm_prefix):
|
||||
return {
|
||||
f"{diffusers_ada_norm_prefix}.emb.weight": checkpoint[f"{ada_norm_prefix}.emb.weight"],
|
||||
f"{diffusers_ada_norm_prefix}.linear.weight": checkpoint[f"{ada_norm_prefix}.linear.weight"],
|
||||
f"{diffusers_ada_norm_prefix}.linear.bias": checkpoint[f"{ada_norm_prefix}.linear.bias"],
|
||||
}
|
||||
|
||||
|
||||
def transformer_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
|
||||
return {
|
||||
# key
|
||||
f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.key.weight"],
|
||||
f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.key.bias"],
|
||||
# query
|
||||
f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.query.weight"],
|
||||
f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.query.bias"],
|
||||
# value
|
||||
f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.value.weight"],
|
||||
f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.value.bias"],
|
||||
# linear out
|
||||
f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj.weight"],
|
||||
f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj.bias"],
|
||||
}
|
||||
|
||||
|
||||
def transformer_feedforward_to_diffusers_checkpoint(checkpoint, *, diffusers_feedforward_prefix, feedforward_prefix):
|
||||
return {
|
||||
f"{diffusers_feedforward_prefix}.net.0.proj.weight": checkpoint[f"{feedforward_prefix}.0.weight"],
|
||||
f"{diffusers_feedforward_prefix}.net.0.proj.bias": checkpoint[f"{feedforward_prefix}.0.bias"],
|
||||
f"{diffusers_feedforward_prefix}.net.2.weight": checkpoint[f"{feedforward_prefix}.2.weight"],
|
||||
f"{diffusers_feedforward_prefix}.net.2.bias": checkpoint[f"{feedforward_prefix}.2.bias"],
|
||||
}
|
||||
|
||||
|
||||
# done transformer checkpoint
|
||||
|
||||
|
||||
def read_config_file(filename):
|
||||
# The yaml file contains annotations that certain values should
|
||||
# loaded as tuples. By default, OmegaConf will panic when reading
|
||||
# these. Instead, we can manually read the yaml with the FullLoader and then
|
||||
# construct the OmegaConf object.
|
||||
with open(filename) as f:
|
||||
original_config = yaml.load(f, FullLoader)
|
||||
|
||||
return OmegaConf.create(original_config)
|
||||
|
||||
|
||||
# We take separate arguments for the vqvae because the ITHQ vqvae config file
|
||||
# is separate from the config file for the rest of the model.
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--vqvae_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the vqvae checkpoint to convert.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vqvae_original_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The YAML config file corresponding to the original architecture for the vqvae.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--original_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The YAML config file corresponding to the original architecture.",
|
||||
)
|
||||
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_load_device",
|
||||
default="cpu",
|
||||
type=str,
|
||||
required=False,
|
||||
help="The device passed to `map_location` when loading checkpoints.",
|
||||
)
|
||||
|
||||
# See link for how ema weights are always selected
|
||||
# https://github.com/microsoft/VQ-Diffusion/blob/3c98e77f721db7c787b76304fa2c96a36c7b00af/inference_VQ_Diffusion.py#L65
|
||||
parser.add_argument(
|
||||
"--no_use_ema",
|
||||
action="store_true",
|
||||
required=False,
|
||||
help=(
|
||||
"Set to not use the ema weights from the original VQ-Diffusion checkpoint. You probably do not want to set"
|
||||
" it as the original VQ-Diffusion always uses the ema weights when loading models."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
use_ema = not args.no_use_ema
|
||||
|
||||
print(f"loading checkpoints to {args.checkpoint_load_device}")
|
||||
|
||||
checkpoint_map_location = torch.device(args.checkpoint_load_device)
|
||||
|
||||
# vqvae_model
|
||||
|
||||
print(f"loading vqvae, config: {args.vqvae_original_config_file}, checkpoint: {args.vqvae_checkpoint_path}")
|
||||
|
||||
vqvae_original_config = read_config_file(args.vqvae_original_config_file).model
|
||||
vqvae_checkpoint = torch.load(args.vqvae_checkpoint_path, map_location=checkpoint_map_location)["model"]
|
||||
|
||||
with init_empty_weights():
|
||||
vqvae_model = vqvae_model_from_original_config(vqvae_original_config)
|
||||
|
||||
vqvae_diffusers_checkpoint = vqvae_original_checkpoint_to_diffusers_checkpoint(vqvae_model, vqvae_checkpoint)
|
||||
|
||||
with tempfile.NamedTemporaryFile() as vqvae_diffusers_checkpoint_file:
|
||||
torch.save(vqvae_diffusers_checkpoint, vqvae_diffusers_checkpoint_file.name)
|
||||
del vqvae_diffusers_checkpoint
|
||||
del vqvae_checkpoint
|
||||
load_checkpoint_and_dispatch(vqvae_model, vqvae_diffusers_checkpoint_file.name, device_map="auto")
|
||||
|
||||
print("done loading vqvae")
|
||||
|
||||
# done vqvae_model
|
||||
|
||||
# transformer_model
|
||||
|
||||
print(
|
||||
f"loading transformer, config: {args.original_config_file}, checkpoint: {args.checkpoint_path}, use ema:"
|
||||
f" {use_ema}"
|
||||
)
|
||||
|
||||
original_config = read_config_file(args.original_config_file).model
|
||||
|
||||
diffusion_config = original_config.params.diffusion_config
|
||||
transformer_config = original_config.params.diffusion_config.params.transformer_config
|
||||
content_embedding_config = original_config.params.diffusion_config.params.content_emb_config
|
||||
|
||||
pre_checkpoint = torch.load(args.checkpoint_path, map_location=checkpoint_map_location)
|
||||
|
||||
if use_ema:
|
||||
if "ema" in pre_checkpoint:
|
||||
checkpoint = {}
|
||||
for k, v in pre_checkpoint["model"].items():
|
||||
checkpoint[k] = v
|
||||
|
||||
for k, v in pre_checkpoint["ema"].items():
|
||||
# The ema weights are only used on the transformer. To mimic their key as if they came
|
||||
# from the state_dict for the top level model, we prefix with an additional "transformer."
|
||||
# See the source linked in the args.use_ema config for more information.
|
||||
checkpoint[f"transformer.{k}"] = v
|
||||
else:
|
||||
print("attempted to load ema weights but no ema weights are specified in the loaded checkpoint.")
|
||||
checkpoint = pre_checkpoint["model"]
|
||||
else:
|
||||
checkpoint = pre_checkpoint["model"]
|
||||
|
||||
del pre_checkpoint
|
||||
|
||||
with init_empty_weights():
|
||||
transformer_model = transformer_model_from_original_config(
|
||||
diffusion_config, transformer_config, content_embedding_config
|
||||
)
|
||||
|
||||
diffusers_transformer_checkpoint = transformer_original_checkpoint_to_diffusers_checkpoint(
|
||||
transformer_model, checkpoint
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file:
|
||||
torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name)
|
||||
del diffusers_transformer_checkpoint
|
||||
del checkpoint
|
||||
load_checkpoint_and_dispatch(transformer_model, diffusers_transformer_checkpoint_file.name, device_map="auto")
|
||||
|
||||
print("done loading transformer")
|
||||
|
||||
# done transformer_model
|
||||
|
||||
# text encoder
|
||||
|
||||
print("loading CLIP text encoder")
|
||||
|
||||
clip_name = "openai/clip-vit-base-patch32"
|
||||
|
||||
# The original VQ-Diffusion specifies the pad value by the int used in the
|
||||
# returned tokens. Each model uses `0` as the pad value. The transformers clip api
|
||||
# specifies the pad value via the token before it has been tokenized. The `!` pad
|
||||
# token is the same as padding with the `0` pad value.
|
||||
pad_token = "!"
|
||||
|
||||
tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto")
|
||||
|
||||
assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0
|
||||
|
||||
text_encoder_model = CLIPTextModel.from_pretrained(
|
||||
clip_name,
|
||||
# `CLIPTextModel` does not support device_map="auto"
|
||||
# device_map="auto"
|
||||
)
|
||||
|
||||
print("done loading CLIP text encoder")
|
||||
|
||||
# done text encoder
|
||||
|
||||
# scheduler
|
||||
|
||||
scheduler_model = VQDiffusionScheduler(
|
||||
# the scheduler has the same number of embeddings as the transformer
|
||||
num_vec_classes=transformer_model.num_vector_embeds
|
||||
)
|
||||
|
||||
# done scheduler
|
||||
|
||||
print(f"saving VQ diffusion model, path: {args.dump_path}")
|
||||
|
||||
pipe = VQDiffusionPipeline(
|
||||
vqvae=vqvae_model,
|
||||
transformer=transformer_model,
|
||||
tokenizer=tokenizer_model,
|
||||
text_encoder=text_encoder_model,
|
||||
scheduler=scheduler_model,
|
||||
)
|
||||
pipe.save_pretrained(args.dump_path)
|
||||
|
||||
print("done writing VQ diffusion model")
|
||||
11
setup.py
11
setup.py
@@ -89,11 +89,10 @@ _deps = [
|
||||
"huggingface-hub>=0.10.0",
|
||||
"importlib_metadata",
|
||||
"isort>=5.5.4",
|
||||
"jax>=0.2.8,!=0.3.2,<=0.3.6",
|
||||
"jaxlib>=0.1.65,<=0.3.6",
|
||||
"jax>=0.2.8,!=0.3.2",
|
||||
"jaxlib>=0.1.65",
|
||||
"modelcards>=0.1.4",
|
||||
"numpy",
|
||||
"onnxruntime",
|
||||
"parameterized",
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
@@ -179,9 +178,7 @@ extras["quality"] = deps_list("black", "isort", "flake8", "hf-doc-builder")
|
||||
extras["docs"] = deps_list("hf-doc-builder")
|
||||
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
|
||||
extras["test"] = deps_list(
|
||||
"accelerate",
|
||||
"datasets",
|
||||
"onnxruntime",
|
||||
"parameterized",
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
@@ -190,7 +187,7 @@ extras["test"] = deps_list(
|
||||
"torchvision",
|
||||
"transformers"
|
||||
)
|
||||
extras["torch"] = deps_list("torch")
|
||||
extras["torch"] = deps_list("torch", "accelerate")
|
||||
|
||||
if os.name == "nt": # windows
|
||||
extras["flax"] = [] # jax is not supported on windows
|
||||
@@ -213,7 +210,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.7.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.8.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="Diffusers",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -9,7 +9,7 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
__version__ = "0.7.0.dev0"
|
||||
__version__ = "0.8.0.dev0"
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .onnx_utils import OnnxRuntimeModel
|
||||
@@ -18,7 +18,7 @@ from .utils import logging
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
|
||||
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
|
||||
from .optimization import (
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
@@ -35,19 +35,24 @@ if is_torch_available():
|
||||
DDPMPipeline,
|
||||
KarrasVePipeline,
|
||||
LDMPipeline,
|
||||
LDMSuperResolutionPipeline,
|
||||
PNDMPipeline,
|
||||
RePaintPipeline,
|
||||
ScoreSdeVePipeline,
|
||||
)
|
||||
from .schedulers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
IPNDMScheduler,
|
||||
KarrasVeScheduler,
|
||||
PNDMScheduler,
|
||||
RePaintScheduler,
|
||||
SchedulerMixin,
|
||||
ScoreSdeVeScheduler,
|
||||
VQDiffusionScheduler,
|
||||
)
|
||||
from .training_utils import EMAModel
|
||||
else:
|
||||
@@ -60,11 +65,13 @@ else:
|
||||
|
||||
if is_torch_available() and is_transformers_available():
|
||||
from .pipelines import (
|
||||
CycleDiffusionPipeline,
|
||||
LDMTextToImagePipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionPipeline,
|
||||
VQDiffusionPipeline,
|
||||
)
|
||||
else:
|
||||
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
@@ -87,6 +94,7 @@ if is_flax_available():
|
||||
from .schedulers import (
|
||||
FlaxDDIMScheduler,
|
||||
FlaxDDPMScheduler,
|
||||
FlaxDPMSolverMultistepScheduler,
|
||||
FlaxKarrasVeScheduler,
|
||||
FlaxLMSDiscreteScheduler,
|
||||
FlaxPNDMScheduler,
|
||||
|
||||
@@ -101,7 +101,7 @@ class ConfigMixin:
|
||||
output_config_file = os.path.join(save_directory, self.config_name)
|
||||
|
||||
self.to_json_file(output_config_file)
|
||||
logger.info(f"ConfigMixinuration saved in {output_config_file}")
|
||||
logger.info(f"Configuration saved in {output_config_file}")
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
|
||||
@@ -334,6 +334,11 @@ class ConfigMixin:
|
||||
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
|
||||
init_dict = {}
|
||||
for key in expected_keys:
|
||||
# if config param is passed to kwarg and is present in config dict
|
||||
# it should overwrite existing config dict key
|
||||
if key in kwargs and key in config_dict:
|
||||
config_dict[key] = kwargs.pop(key)
|
||||
|
||||
if key in kwargs:
|
||||
# overwrite key
|
||||
init_dict[key] = kwargs.pop(key)
|
||||
|
||||
@@ -13,11 +13,10 @@ deps = {
|
||||
"huggingface-hub": "huggingface-hub>=0.10.0",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"isort": "isort>=5.5.4",
|
||||
"jax": "jax>=0.2.8,!=0.3.2,<=0.3.6",
|
||||
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
|
||||
"jax": "jax>=0.2.8,!=0.3.2",
|
||||
"jaxlib": "jaxlib>=0.1.65",
|
||||
"modelcards": "modelcards>=0.1.4",
|
||||
"numpy": "numpy",
|
||||
"onnxruntime": "onnxruntime",
|
||||
"parameterized": "parameterized",
|
||||
"pytest": "pytest",
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
|
||||
5
src/diffusers/experimental/README.md
Normal file
5
src/diffusers/experimental/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# 🧨 Diffusers Experimental
|
||||
|
||||
We are adding experimental code to support novel applications and usages of the Diffusers library.
|
||||
Currently, the following experiments are supported:
|
||||
* Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model.
|
||||
1
src/diffusers/experimental/__init__.py
Normal file
1
src/diffusers/experimental/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .rl import ValueGuidedRLPipeline
|
||||
1
src/diffusers/experimental/rl/__init__.py
Normal file
1
src/diffusers/experimental/rl/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .value_guided_sampling import ValueGuidedRLPipeline
|
||||
129
src/diffusers/experimental/rl/value_guided_sampling.py
Normal file
129
src/diffusers/experimental/rl/value_guided_sampling.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import tqdm
|
||||
|
||||
from ...models.unet_1d import UNet1DModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...utils.dummy_pt_objects import DDPMScheduler
|
||||
|
||||
|
||||
class ValueGuidedRLPipeline(DiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
value_function: UNet1DModel,
|
||||
unet: UNet1DModel,
|
||||
scheduler: DDPMScheduler,
|
||||
env,
|
||||
):
|
||||
super().__init__()
|
||||
self.value_function = value_function
|
||||
self.unet = unet
|
||||
self.scheduler = scheduler
|
||||
self.env = env
|
||||
self.data = env.get_dataset()
|
||||
self.means = dict()
|
||||
for key in self.data.keys():
|
||||
try:
|
||||
self.means[key] = self.data[key].mean()
|
||||
except:
|
||||
pass
|
||||
self.stds = dict()
|
||||
for key in self.data.keys():
|
||||
try:
|
||||
self.stds[key] = self.data[key].std()
|
||||
except:
|
||||
pass
|
||||
self.state_dim = env.observation_space.shape[0]
|
||||
self.action_dim = env.action_space.shape[0]
|
||||
|
||||
def normalize(self, x_in, key):
|
||||
return (x_in - self.means[key]) / self.stds[key]
|
||||
|
||||
def de_normalize(self, x_in, key):
|
||||
return x_in * self.stds[key] + self.means[key]
|
||||
|
||||
def to_torch(self, x_in):
|
||||
if type(x_in) is dict:
|
||||
return {k: self.to_torch(v) for k, v in x_in.items()}
|
||||
elif torch.is_tensor(x_in):
|
||||
return x_in.to(self.unet.device)
|
||||
return torch.tensor(x_in, device=self.unet.device)
|
||||
|
||||
def reset_x0(self, x_in, cond, act_dim):
|
||||
for key, val in cond.items():
|
||||
x_in[:, key, act_dim:] = val.clone()
|
||||
return x_in
|
||||
|
||||
def run_diffusion(self, x, conditions, n_guide_steps, scale):
|
||||
batch_size = x.shape[0]
|
||||
y = None
|
||||
for i in tqdm.tqdm(self.scheduler.timesteps):
|
||||
# create batch of timesteps to pass into model
|
||||
timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
|
||||
for _ in range(n_guide_steps):
|
||||
with torch.enable_grad():
|
||||
x.requires_grad_()
|
||||
y = self.value_function(x.permute(0, 2, 1), timesteps).sample
|
||||
grad = torch.autograd.grad([y.sum()], [x])[0]
|
||||
|
||||
posterior_variance = self.scheduler._get_variance(i)
|
||||
model_std = torch.exp(0.5 * posterior_variance)
|
||||
grad = model_std * grad
|
||||
grad[timesteps < 2] = 0
|
||||
x = x.detach()
|
||||
x = x + scale * grad
|
||||
x = self.reset_x0(x, conditions, self.action_dim)
|
||||
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
|
||||
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
|
||||
|
||||
# apply conditions to the trajectory
|
||||
x = self.reset_x0(x, conditions, self.action_dim)
|
||||
x = self.to_torch(x)
|
||||
return x, y
|
||||
|
||||
def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
|
||||
# normalize the observations and create batch dimension
|
||||
obs = self.normalize(obs, "observations")
|
||||
obs = obs[None].repeat(batch_size, axis=0)
|
||||
|
||||
conditions = {0: self.to_torch(obs)}
|
||||
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
|
||||
|
||||
# generate initial noise and apply our conditions (to make the trajectories start at current state)
|
||||
x1 = torch.randn(shape, device=self.unet.device)
|
||||
x = self.reset_x0(x1, conditions, self.action_dim)
|
||||
x = self.to_torch(x)
|
||||
|
||||
# run the diffusion process
|
||||
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
|
||||
|
||||
# sort output trajectories by value
|
||||
sorted_idx = y.argsort(0, descending=True).squeeze()
|
||||
sorted_values = x[sorted_idx]
|
||||
actions = sorted_values[:, :, : self.action_dim]
|
||||
actions = actions.detach().cpu().numpy()
|
||||
denorm_actions = self.de_normalize(actions, key="actions")
|
||||
|
||||
# select the action with the highest value
|
||||
if y is not None:
|
||||
selected_index = 0
|
||||
else:
|
||||
# if we didn't run value guiding, select a random action
|
||||
selected_index = np.random.randint(0, batch_size)
|
||||
denorm_actions = denorm_actions[selected_index, 0]
|
||||
return denorm_actions
|
||||
@@ -21,18 +21,37 @@ from typing import Callable, List, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import Tensor, device
|
||||
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests import HTTPError
|
||||
|
||||
from . import __version__
|
||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
DIFFUSERS_CACHE,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
WEIGHTS_NAME,
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_torch_version(">=", "1.9.0"):
|
||||
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
||||
else:
|
||||
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from accelerate.utils.versions import is_torch_version
|
||||
|
||||
|
||||
def get_parameter_device(parameter: torch.nn.Module):
|
||||
try:
|
||||
return next(parameter.parameters()).device
|
||||
@@ -268,6 +287,19 @@ class ModelMixin(torch.nn.Module):
|
||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||
Please refer to the mirror site for more information.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
||||
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
||||
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
||||
setting this argument to `True` will raise an error.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -296,6 +328,41 @@ class ModelMixin(torch.nn.Module):
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warn(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if device_map is not None and not is_accelerate_available():
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
||||
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
||||
)
|
||||
|
||||
# Check if we can handle device_map and dispatching the weights
|
||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `device_map=None`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is False and device_map is not None:
|
||||
raise ValueError(
|
||||
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"diffusers": __version__,
|
||||
@@ -378,12 +445,8 @@ class ModelMixin(torch.nn.Module):
|
||||
|
||||
# restore default dtype
|
||||
|
||||
if device_map == "auto":
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
# Instantiate model with empty weights
|
||||
with accelerate.init_empty_weights():
|
||||
model, unused_kwargs = cls.from_config(
|
||||
config_path,
|
||||
@@ -400,7 +463,17 @@ class ModelMixin(torch.nn.Module):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
|
||||
# if device_map is Non,e load the state dict on move the params from meta device to the cpu
|
||||
if device_map is None:
|
||||
param_device = "cpu"
|
||||
state_dict = load_state_dict(model_file)
|
||||
# move the parms from meta device to cpu
|
||||
for param_name, param in state_dict.items():
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
||||
else: # else let accelerate handle loading and dispatching.
|
||||
# Load weights and dispatch according to the device_map
|
||||
# by deafult the device_map is None and the weights are loaded on the CPU
|
||||
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
|
||||
|
||||
loading_info = {
|
||||
"missing_keys": [],
|
||||
|
||||
@@ -16,6 +16,7 @@ from ..utils import is_flax_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .attention import Transformer2DModel
|
||||
from .unet_1d import UNet1DModel
|
||||
from .unet_2d import UNet2DModel
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
@@ -12,12 +12,218 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..models.embeddings import ImagePositionalEmbeddings
|
||||
from ..utils import BaseOutput
|
||||
from ..utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transformer2DModelOutput(BaseOutput):
|
||||
"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
||||
Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
|
||||
for the unnoised latent pixels.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
else:
|
||||
xformers = None
|
||||
|
||||
|
||||
class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
|
||||
embeddings) inputs.
|
||||
|
||||
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
|
||||
transformer action. Finally, reshape to image.
|
||||
|
||||
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
|
||||
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
|
||||
classes of unnoised image.
|
||||
|
||||
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
|
||||
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
Pass if the input is continuous. The number of channels in the input and output.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
|
||||
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
||||
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
||||
`ImagePositionalEmbeddings`.
|
||||
num_vector_embeds (`int`, *optional*):
|
||||
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
||||
Includes the class for the masked latent pixel.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
||||
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
||||
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
||||
up to but not more than steps than `num_embeds_ada_norm`.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||
# Define whether input is continuous or discrete depending on configuration
|
||||
self.is_input_continuous = in_channels is not None
|
||||
self.is_input_vectorized = num_vector_embeds is not None
|
||||
|
||||
if self.is_input_continuous and self.is_input_vectorized:
|
||||
raise ValueError(
|
||||
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
||||
" sure that either `in_channels` or `num_vector_embeds` is None."
|
||||
)
|
||||
elif not self.is_input_continuous and not self.is_input_vectorized:
|
||||
raise ValueError(
|
||||
f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
|
||||
" sure that either `in_channels` or `num_vector_embeds` is not None."
|
||||
)
|
||||
|
||||
# 2. Define input layers
|
||||
if self.is_input_continuous:
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
||||
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
||||
|
||||
self.height = sample_size
|
||||
self.width = sample_size
|
||||
self.num_vector_embeds = num_vector_embeds
|
||||
self.num_latent_pixels = self.height * self.width
|
||||
|
||||
self.latent_image_embedding = ImagePositionalEmbeddings(
|
||||
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
||||
)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
attention_bias=attention_bias,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Define output layers
|
||||
if self.is_input_continuous:
|
||||
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
||||
|
||||
def _set_attention_slice(self, slice_size):
|
||||
for block in self.transformer_blocks:
|
||||
block._set_attention_slice(slice_size)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
||||
"""
|
||||
Args:
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||
hidden_states
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
|
||||
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
|
||||
tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
if self.is_input_continuous:
|
||||
batch, channel, height, weight = hidden_states.shape
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.latent_image_embedding(hidden_states)
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
|
||||
|
||||
# 3. Output
|
||||
if self.is_input_continuous:
|
||||
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
output = hidden_states + residual
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
logits = self.out(hidden_states)
|
||||
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
||||
logits = logits.permute(0, 2, 1)
|
||||
|
||||
# log(p(x_0))
|
||||
output = F.log_softmax(logits.double(), dim=1).float()
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
for block in self.transformer_blocks:
|
||||
block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
@@ -27,19 +233,19 @@ class AttentionBlock(nn.Module):
|
||||
Uses three q, k, v linear layers to compute attention.
|
||||
|
||||
Parameters:
|
||||
channels (:obj:`int`): The number of channels in the input and output.
|
||||
num_head_channels (:obj:`int`, *optional*):
|
||||
channels (`int`): The number of channels in the input and output.
|
||||
num_head_channels (`int`, *optional*):
|
||||
The number of channels in each head. If None, then `num_heads` = 1.
|
||||
num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
|
||||
rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
|
||||
eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
|
||||
rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
|
||||
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
num_head_channels: Optional[int] = None,
|
||||
num_groups: int = 32,
|
||||
norm_num_groups: int = 32,
|
||||
rescale_output_factor: float = 1.0,
|
||||
eps: float = 1e-5,
|
||||
):
|
||||
@@ -48,7 +254,7 @@ class AttentionBlock(nn.Module):
|
||||
|
||||
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
|
||||
self.num_head_size = num_head_channels
|
||||
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
|
||||
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
|
||||
|
||||
# define q,k,v as linear layers
|
||||
self.query = nn.Linear(channels, channels)
|
||||
@@ -104,112 +310,108 @@ class AttentionBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
|
||||
standard transformer action. Finally, reshape to image.
|
||||
|
||||
Parameters:
|
||||
in_channels (:obj:`int`): The number of channels in the input and output.
|
||||
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
|
||||
d_head (:obj:`int`): The number of channels in each head.
|
||||
depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
||||
context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
n_heads: int,
|
||||
d_head: int,
|
||||
depth: int = 1,
|
||||
dropout: float = 0.0,
|
||||
num_groups: int = 32,
|
||||
context_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
self.d_head = d_head
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
||||
for d in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def _set_attention_slice(self, slice_size):
|
||||
for block in self.transformer_blocks:
|
||||
block._set_attention_slice(slice_size)
|
||||
|
||||
def forward(self, hidden_states, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
batch, channel, height, weight = hidden_states.shape
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, context=context)
|
||||
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
return hidden_states + residual
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (:obj:`int`): The number of channels in the input and output.
|
||||
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
|
||||
d_head (:obj:`int`): The number of channels in each head.
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
|
||||
gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
|
||||
checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
d_head: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
context_dim: Optional[int] = None,
|
||||
gated_ff: bool = True,
|
||||
checkpoint: bool = True,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
) # is a self-attention
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
# layer norms
|
||||
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
||||
if self.use_ada_layer_norm:
|
||||
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def _set_attention_slice(self, slice_size):
|
||||
self.attn1._slice_size = slice_size
|
||||
self.attn2._slice_size = slice_size
|
||||
|
||||
def forward(self, hidden_states, context=None):
|
||||
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
|
||||
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
|
||||
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
if not is_xformers_available():
|
||||
print("Here is how to install it")
|
||||
raise ModuleNotFoundError(
|
||||
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
||||
" xformers",
|
||||
name="xformers",
|
||||
)
|
||||
elif not torch.cuda.is_available():
|
||||
raise ValueError(
|
||||
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
||||
" available for GPU "
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Make sure we can run the memory efficient attention
|
||||
_ = xformers.ops.memory_efficient_attention(
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
|
||||
def forward(self, hidden_states, context=None, timestep=None):
|
||||
# 1. Self-Attention
|
||||
norm_hidden_states = (
|
||||
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
||||
)
|
||||
hidden_states = self.attn1(norm_hidden_states) + hidden_states
|
||||
|
||||
# 2. Cross-Attention
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
||||
)
|
||||
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
|
||||
|
||||
# 3. Feed-forward
|
||||
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -218,20 +420,28 @@ class CrossAttention(nn.Module):
|
||||
A cross attention layer.
|
||||
|
||||
Parameters:
|
||||
query_dim (:obj:`int`): The number of channels in the query.
|
||||
context_dim (:obj:`int`, *optional*):
|
||||
query_dim (`int`): The number of channels in the query.
|
||||
cross_attention_dim (`int`, *optional*):
|
||||
The number of channels in the context. If not given, defaults to `query_dim`.
|
||||
heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
||||
dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
||||
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
bias (`bool`, *optional*, defaults to False):
|
||||
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
|
||||
self,
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias=False,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = context_dim if context_dim is not None else query_dim
|
||||
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
@@ -239,10 +449,11 @@ class CrossAttention(nn.Module):
|
||||
# is split across the batch axis to save memory
|
||||
# You can set slice_size with `set_attention_slice`
|
||||
self._slice_size = None
|
||||
self._use_memory_efficient_attention_xformers = False
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
||||
@@ -279,11 +490,15 @@ class CrossAttention(nn.Module):
|
||||
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
|
||||
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
||||
hidden_states = self._attention(query, key, value)
|
||||
if self._use_memory_efficient_attention_xformers:
|
||||
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
|
||||
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
else:
|
||||
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
|
||||
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
||||
hidden_states = self._attention(query, key, value)
|
||||
else:
|
||||
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
@@ -341,29 +556,47 @@ class CrossAttention(nn.Module):
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def _memory_efficient_attention_xformers(self, query, key, value):
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
|
||||
Parameters:
|
||||
dim (:obj:`int`): The number of channels in the input.
|
||||
dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
||||
mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
||||
glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
dim (`int`): The number of channels in the input.
|
||||
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
||||
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
mult: int = 4,
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "geglu",
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
self.net = nn.ModuleList([])
|
||||
|
||||
if activation_fn == "geglu":
|
||||
geglu = GEGLU(dim, inner_dim)
|
||||
elif activation_fn == "geglu-approximate":
|
||||
geglu = ApproximateGELU(dim, inner_dim)
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
# project in
|
||||
self.net.append(GEGLU(dim, inner_dim))
|
||||
self.net.append(geglu)
|
||||
# project dropout
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
# project out
|
||||
@@ -381,8 +614,8 @@ class GEGLU(nn.Module):
|
||||
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
||||
|
||||
Parameters:
|
||||
dim_in (:obj:`int`): The number of channels in the input.
|
||||
dim_out (:obj:`int`): The number of channels in the output.
|
||||
dim_in (`int`): The number of channels in the input.
|
||||
dim_out (`int`): The number of channels in the output.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int):
|
||||
@@ -398,3 +631,38 @@ class GEGLU(nn.Module):
|
||||
def forward(self, hidden_states):
|
||||
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
||||
return hidden_states * self.gelu(gate)
|
||||
|
||||
|
||||
class ApproximateGELU(nn.Module):
|
||||
"""
|
||||
The approximate form of Gaussian Error Linear Unit (GELU)
|
||||
|
||||
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
"""
|
||||
Norm layer modified to incorporate timestep embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, num_embeddings):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
||||
|
||||
def forward(self, x, timestep):
|
||||
emb = self.linear(self.silu(self.emb(timestep)))
|
||||
scale, shift = torch.chunk(emb, 2)
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
@@ -142,7 +142,7 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxSpatialTransformer(nn.Module):
|
||||
class FlaxTransformer2DModel(nn.Module):
|
||||
r"""
|
||||
A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
|
||||
https://arxiv.org/pdf/1506.02025.pdf
|
||||
|
||||
@@ -62,14 +62,21 @@ def get_timestep_embedding(
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
|
||||
def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(channel, time_embed_dim)
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
self.act = None
|
||||
if act_fn == "silu":
|
||||
self.act = nn.SiLU()
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
||||
elif act_fn == "mish":
|
||||
self.act = nn.Mish()
|
||||
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
||||
|
||||
def forward(self, sample):
|
||||
sample = self.linear_1(sample)
|
||||
@@ -126,3 +133,68 @@ class GaussianFourierProjection(nn.Module):
|
||||
else:
|
||||
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
||||
return out
|
||||
|
||||
|
||||
class ImagePositionalEmbeddings(nn.Module):
|
||||
"""
|
||||
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
|
||||
height and width of the latent space.
|
||||
|
||||
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
|
||||
|
||||
For VQ-diffusion:
|
||||
|
||||
Output vector embeddings are used as input for the transformer.
|
||||
|
||||
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
|
||||
|
||||
Args:
|
||||
num_embed (`int`):
|
||||
Number of embeddings for the latent pixels embeddings.
|
||||
height (`int`):
|
||||
Height of the latent image i.e. the number of height embeddings.
|
||||
width (`int`):
|
||||
Width of the latent image i.e. the number of width embeddings.
|
||||
embed_dim (`int`):
|
||||
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embed: int,
|
||||
height: int,
|
||||
width: int,
|
||||
embed_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.num_embed = num_embed
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.emb = nn.Embedding(self.num_embed, embed_dim)
|
||||
self.height_emb = nn.Embedding(self.height, embed_dim)
|
||||
self.width_emb = nn.Embedding(self.width, embed_dim)
|
||||
|
||||
def forward(self, index):
|
||||
emb = self.emb(index)
|
||||
|
||||
height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
|
||||
|
||||
# 1 x H x D -> 1 x H x 1 x D
|
||||
height_emb = height_emb.unsqueeze(2)
|
||||
|
||||
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
|
||||
|
||||
# 1 x W x D -> 1 x 1 x W x D
|
||||
width_emb = width_emb.unsqueeze(1)
|
||||
|
||||
pos_emb = height_emb + width_emb
|
||||
|
||||
# 1 x H x W x D -> 1 x L xD
|
||||
pos_emb = pos_emb.view(1, self.height * self.width, -1)
|
||||
|
||||
emb = emb + pos_emb[:, : emb.shape[1], :]
|
||||
|
||||
return emb
|
||||
|
||||
@@ -17,23 +17,41 @@ import flax.linen as nn
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
# This is like models.embeddings.get_timestep_embedding (PyTorch) but
|
||||
# less general (only handles the case we currently need).
|
||||
def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1):
|
||||
def get_sinusoidal_embeddings(
|
||||
timesteps: jnp.ndarray,
|
||||
embedding_dim: int,
|
||||
freq_shift: float = 1,
|
||||
min_timescale: float = 1,
|
||||
max_timescale: float = 1.0e4,
|
||||
flip_sin_to_cos: bool = False,
|
||||
scale: float = 1.0,
|
||||
) -> jnp.ndarray:
|
||||
"""Returns the positional encoding (same as Tensor2Tensor).
|
||||
Args:
|
||||
timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
embedding_dim: The number of output channels.
|
||||
min_timescale: The smallest time unit (should probably be 0.0).
|
||||
max_timescale: The largest time unit.
|
||||
Returns:
|
||||
a Tensor of timing signals [N, num_channels]
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
|
||||
assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
|
||||
num_timescales = float(embedding_dim // 2)
|
||||
log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
|
||||
inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
|
||||
emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
|
||||
|
||||
:param timesteps: a 1-D tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
||||
embeddings. :return: an [N x dim] tensor of positional embeddings.
|
||||
"""
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - freq_shift)
|
||||
emb = jnp.exp(jnp.arange(half_dim) * -emb)
|
||||
emb = timesteps[:, None] * emb[None, :]
|
||||
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
|
||||
return emb
|
||||
# scale embeddings
|
||||
scaled_time = scale * emb
|
||||
|
||||
if flip_sin_to_cos:
|
||||
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
|
||||
else:
|
||||
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
|
||||
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
|
||||
return signal
|
||||
|
||||
|
||||
class FlaxTimestepEmbedding(nn.Module):
|
||||
@@ -70,4 +88,6 @@ class FlaxTimesteps(nn.Module):
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, timesteps):
|
||||
return get_sinusoidal_embeddings(timesteps, self.dim, freq_shift=self.freq_shift)
|
||||
return get_sinusoidal_embeddings(
|
||||
timesteps, embedding_dim=self.dim, freq_shift=self.freq_shift, flip_sin_to_cos=True
|
||||
)
|
||||
|
||||
@@ -5,6 +5,75 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels: channels in the inputs and outputs.
|
||||
use_conv: a bool determining if a convolution is applied.
|
||||
use_conv_transpose:
|
||||
out_channels:
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
|
||||
self.conv = None
|
||||
if use_conv_transpose:
|
||||
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(x)
|
||||
|
||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Downsample1D(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels: channels in the inputs and outputs.
|
||||
use_conv: a bool determining if a convolution is applied.
|
||||
out_channels:
|
||||
padding:
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
|
||||
if use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Upsample2D(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
@@ -12,7 +81,8 @@ class Upsample2D(nn.Module):
|
||||
Parameters:
|
||||
channels: channels in the inputs and outputs.
|
||||
use_conv: a bool determining if a convolution is applied.
|
||||
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions.
|
||||
use_conv_transpose:
|
||||
out_channels:
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
||||
@@ -80,7 +150,8 @@ class Downsample2D(nn.Module):
|
||||
Parameters:
|
||||
channels: channels in the inputs and outputs.
|
||||
use_conv: a bool determining if a convolution is applied.
|
||||
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions.
|
||||
out_channels:
|
||||
padding:
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
||||
@@ -415,6 +486,69 @@ class Mish(torch.nn.Module):
|
||||
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
|
||||
|
||||
|
||||
# unet_rl.py
|
||||
def rearrange_dims(tensor):
|
||||
if len(tensor.shape) == 2:
|
||||
return tensor[:, :, None]
|
||||
if len(tensor.shape) == 3:
|
||||
return tensor[:, :, None, :]
|
||||
elif len(tensor.shape) == 4:
|
||||
return tensor[:, :, 0, :]
|
||||
else:
|
||||
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
|
||||
|
||||
|
||||
class Conv1dBlock(nn.Module):
|
||||
"""
|
||||
Conv1d --> GroupNorm --> Mish
|
||||
"""
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
super().__init__()
|
||||
|
||||
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.group_norm = nn.GroupNorm(n_groups, out_channels)
|
||||
self.mish = nn.Mish()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1d(x)
|
||||
x = rearrange_dims(x)
|
||||
x = self.group_norm(x)
|
||||
x = rearrange_dims(x)
|
||||
x = self.mish(x)
|
||||
return x
|
||||
|
||||
|
||||
# unet_rl.py
|
||||
class ResidualTemporalBlock1D(nn.Module):
|
||||
def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
|
||||
super().__init__()
|
||||
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
|
||||
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
|
||||
|
||||
self.time_emb_act = nn.Mish()
|
||||
self.time_emb = nn.Linear(embed_dim, out_channels)
|
||||
|
||||
self.residual_conv = (
|
||||
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x, t):
|
||||
"""
|
||||
Args:
|
||||
x : [ batch_size x inp_channels x horizon ]
|
||||
t : [ batch_size x embed_dim ]
|
||||
|
||||
returns:
|
||||
out : [ batch_size x out_channels x horizon ]
|
||||
"""
|
||||
t = self.time_emb_act(t)
|
||||
t = self.time_emb(t)
|
||||
out = self.conv_in(x) + rearrange_dims(t)
|
||||
out = self.conv_out(out)
|
||||
return out + self.residual_conv(x)
|
||||
|
||||
|
||||
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
||||
r"""Upsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
@@ -8,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .unet_1d_blocks import get_down_block, get_mid_block, get_up_block
|
||||
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -30,11 +44,11 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
implements for all the model (such as downloading or saving, etc.)
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`, *optionl*): Default length of sample. Should be adaptable at runtime.
|
||||
sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
|
||||
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
|
||||
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
|
||||
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
|
||||
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to :
|
||||
obj:`False`): Whether to flip sin to cos for fourier time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
||||
@@ -46,6 +60,12 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
up_down_block_layers (`int`, defaults to 2):
|
||||
number of resnet, attention, or other layers in the up and down blocks.
|
||||
mid_block_layers (`int`, defaults to 5): number of resnet, attention, or other layers in the mid block.
|
||||
mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet.
|
||||
out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet.
|
||||
act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks.
|
||||
norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks.
|
||||
downsample_each_block (`int`, *optional*, defaults to False:
|
||||
experimental feature for using a UNet without upsampling.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -57,18 +77,21 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
out_channels: int = 2,
|
||||
extra_in_channels: int = 0,
|
||||
time_embedding_type: str = "fourier",
|
||||
freq_shift: int = 0,
|
||||
flip_sin_to_cos: bool = True,
|
||||
use_timestep_embedding: bool = False,
|
||||
freq_shift: float = 0.0,
|
||||
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
|
||||
mid_block_type: str = "UNetMidBlock1D",
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
|
||||
mid_block_type: Tuple[str] = "UNetMidBlock1D",
|
||||
out_block_type: str = None,
|
||||
block_out_channels: Tuple[int] = (32, 32, 64),
|
||||
up_down_block_layers: int = 2,
|
||||
mid_block_layers: int = 5,
|
||||
act_fn: str = None,
|
||||
norm_num_groups: int = 8,
|
||||
downsample_each_block: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_size = sample_size
|
||||
|
||||
# time
|
||||
@@ -78,12 +101,19 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
)
|
||||
timestep_input_dim = 2 * block_out_channels[0]
|
||||
elif time_embedding_type == "positional":
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
self.time_proj = Timesteps(
|
||||
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
|
||||
)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
if use_timestep_embedding:
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=timestep_input_dim,
|
||||
time_embed_dim=time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
out_dim=block_out_channels[0],
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
@@ -99,41 +129,66 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
if i == 0:
|
||||
input_channel += extra_in_channels
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=up_down_block_layers,
|
||||
temb_channels=block_out_channels[0],
|
||||
add_downsample=not is_final_block or downsample_each_block,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = get_mid_block(
|
||||
mid_block_type=mid_block_type,
|
||||
mid_channels=block_out_channels[-1],
|
||||
mid_block_type,
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=None,
|
||||
num_layers=mid_block_layers,
|
||||
mid_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
embed_dim=block_out_channels[0],
|
||||
add_downsample=downsample_each_block,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
if out_block_type is None:
|
||||
final_upsample_channels = out_channels
|
||||
else:
|
||||
final_upsample_channels = block_out_channels[0]
|
||||
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else out_channels
|
||||
output_channel = (
|
||||
reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
|
||||
)
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=up_down_block_layers,
|
||||
temb_channels=block_out_channels[0],
|
||||
add_upsample=not is_final_block,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# TODO(PVP, Nathan) placeholder for RL application to be merged shortly
|
||||
# Totally fine to add another layer with a if statement - no need for nn.Identity here
|
||||
# out
|
||||
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
||||
self.out_block = get_out_block(
|
||||
out_block_type=out_block_type,
|
||||
num_groups_out=num_groups_out,
|
||||
embed_dim=block_out_channels[0],
|
||||
out_channels=out_channels,
|
||||
act_fn=act_fn,
|
||||
fc_dim=block_out_channels[-1] // 4,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -152,12 +207,20 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
[`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True,
|
||||
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
# 1. time
|
||||
if len(timestep.shape) == 0:
|
||||
timestep = timestep[None]
|
||||
|
||||
timestep_embed = self.time_proj(timestep)[..., None]
|
||||
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
timestep_embed = self.time_proj(timesteps)
|
||||
if self.config.use_timestep_embedding:
|
||||
timestep_embed = self.time_mlp(timestep_embed)
|
||||
else:
|
||||
timestep_embed = timestep_embed[..., None]
|
||||
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
|
||||
|
||||
# 2. down
|
||||
down_block_res_samples = ()
|
||||
@@ -166,13 +229,18 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 3. mid
|
||||
sample = self.mid_block(sample)
|
||||
if self.mid_block:
|
||||
sample = self.mid_block(sample, timestep_embed)
|
||||
|
||||
# 4. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
res_samples = down_block_res_samples[-1:]
|
||||
down_block_res_samples = down_block_res_samples[:-1]
|
||||
sample = upsample_block(sample, res_samples)
|
||||
sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)
|
||||
|
||||
# 5. post-process
|
||||
if self.out_block:
|
||||
sample = self.out_block(sample, timestep_embed)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
@@ -17,6 +17,256 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims
|
||||
|
||||
|
||||
class DownResnetBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
num_layers=1,
|
||||
conv_shortcut=False,
|
||||
temb_channels=32,
|
||||
groups=32,
|
||||
groups_out=None,
|
||||
non_linearity=None,
|
||||
time_embedding_norm="default",
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.add_downsample = add_downsample
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity == "swish":
|
||||
self.nonlinearity = lambda x: F.silu(x)
|
||||
elif non_linearity == "mish":
|
||||
self.nonlinearity = nn.Mish()
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
else:
|
||||
self.nonlinearity = None
|
||||
|
||||
self.downsample = None
|
||||
if add_downsample:
|
||||
self.downsample = Downsample1D(out_channels, use_conv=True, padding=1)
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
output_states = ()
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.nonlinearity is not None:
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
if self.downsample is not None:
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class UpResnetBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
num_layers=1,
|
||||
temb_channels=32,
|
||||
groups=32,
|
||||
groups_out=None,
|
||||
non_linearity=None,
|
||||
time_embedding_norm="default",
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.add_upsample = add_upsample
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity == "swish":
|
||||
self.nonlinearity = lambda x: F.silu(x)
|
||||
elif non_linearity == "mish":
|
||||
self.nonlinearity = nn.Mish()
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
else:
|
||||
self.nonlinearity = None
|
||||
|
||||
self.upsample = None
|
||||
if add_upsample:
|
||||
self.upsample = Upsample1D(out_channels, use_conv_transpose=True)
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None):
|
||||
if res_hidden_states_tuple is not None:
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1)
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.nonlinearity is not None:
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
if self.upsample is not None:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ValueFunctionMidBlock1D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, embed_dim):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim)
|
||||
self.down1 = Downsample1D(out_channels // 2, use_conv=True)
|
||||
self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim)
|
||||
self.down2 = Downsample1D(out_channels // 4, use_conv=True)
|
||||
|
||||
def forward(self, x, temb=None):
|
||||
x = self.res1(x, temb)
|
||||
x = self.down1(x)
|
||||
x = self.res2(x, temb)
|
||||
x = self.down2(x)
|
||||
return x
|
||||
|
||||
|
||||
class MidResTemporalBlock1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
embed_dim,
|
||||
num_layers: int = 1,
|
||||
add_downsample: bool = False,
|
||||
add_upsample: bool = False,
|
||||
non_linearity=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.add_downsample = add_downsample
|
||||
|
||||
# there will always be at least one resnet
|
||||
resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)]
|
||||
|
||||
for _ in range(num_layers):
|
||||
resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim))
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if non_linearity == "swish":
|
||||
self.nonlinearity = lambda x: F.silu(x)
|
||||
elif non_linearity == "mish":
|
||||
self.nonlinearity = nn.Mish()
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
else:
|
||||
self.nonlinearity = None
|
||||
|
||||
self.upsample = None
|
||||
if add_upsample:
|
||||
self.upsample = Downsample1D(out_channels, use_conv=True)
|
||||
|
||||
self.downsample = None
|
||||
if add_downsample:
|
||||
self.downsample = Downsample1D(out_channels, use_conv=True)
|
||||
|
||||
if self.upsample and self.downsample:
|
||||
raise ValueError("Block cannot downsample and upsample")
|
||||
|
||||
def forward(self, hidden_states, temb):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for resnet in self.resnets[1:]:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.upsample:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
if self.downsample:
|
||||
self.downsample = self.downsample(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OutConv1DBlock(nn.Module):
|
||||
def __init__(self, num_groups_out, out_channels, embed_dim, act_fn):
|
||||
super().__init__()
|
||||
self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2)
|
||||
self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim)
|
||||
if act_fn == "silu":
|
||||
self.final_conv1d_act = nn.SiLU()
|
||||
if act_fn == "mish":
|
||||
self.final_conv1d_act = nn.Mish()
|
||||
self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1)
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
hidden_states = self.final_conv1d_1(hidden_states)
|
||||
hidden_states = rearrange_dims(hidden_states)
|
||||
hidden_states = self.final_conv1d_gn(hidden_states)
|
||||
hidden_states = rearrange_dims(hidden_states)
|
||||
hidden_states = self.final_conv1d_act(hidden_states)
|
||||
hidden_states = self.final_conv1d_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OutValueFunctionBlock(nn.Module):
|
||||
def __init__(self, fc_dim, embed_dim):
|
||||
super().__init__()
|
||||
self.final_block = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(fc_dim + embed_dim, fc_dim // 2),
|
||||
nn.Mish(),
|
||||
nn.Linear(fc_dim // 2, 1),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, temb):
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], -1)
|
||||
hidden_states = torch.cat((hidden_states, temb), dim=-1)
|
||||
for layer in self.final_block:
|
||||
hidden_states = layer(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
_kernels = {
|
||||
"linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8],
|
||||
@@ -78,7 +328,7 @@ class KernelUpsample1D(nn.Module):
|
||||
self.pad = kernel_1d.shape[0] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states, temb=None):
|
||||
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
|
||||
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
|
||||
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
|
||||
@@ -178,34 +428,6 @@ class ResConvBlock(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
def get_down_block(down_block_type, num_layers, out_channels, in_channels):
|
||||
if down_block_type == "DownBlock1D":
|
||||
return DownBlock1D(out_channels=out_channels, in_channels=in_channels, num_layers=num_layers)
|
||||
elif down_block_type == "AttnDownBlock1D":
|
||||
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels, num_layers=num_layers)
|
||||
elif down_block_type == "DownBlock1DNoSkip":
|
||||
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels, num_layers=num_layers)
|
||||
raise ValueError(f"{down_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_up_block(up_block_type, num_layers, in_channels, out_channels):
|
||||
if up_block_type == "UpBlock1D":
|
||||
return UpBlock1D(in_channels=in_channels, out_channels=out_channels, num_layers=num_layers)
|
||||
elif up_block_type == "AttnUpBlock1D":
|
||||
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels, num_layers=num_layers)
|
||||
elif up_block_type == "UpBlock1DNoSkip":
|
||||
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels, num_layers=num_layers)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels):
|
||||
if mid_block_type == "UNetMidBlock1D":
|
||||
return UNetMidBlock1D(
|
||||
in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels, num_layers=num_layers
|
||||
)
|
||||
raise ValueError(f"{mid_block_type} does not exist.")
|
||||
|
||||
|
||||
class UNetMidBlock1D(nn.Module):
|
||||
def __init__(self, mid_channels: int, in_channels: int, num_layers: int = 5, out_channels: int = None):
|
||||
super().__init__()
|
||||
@@ -235,7 +457,7 @@ class UNetMidBlock1D(nn.Module):
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states, temb=None):
|
||||
hidden_states = self.down(hidden_states)
|
||||
for attn, resnet in zip(self.attentions, self.resnets):
|
||||
hidden_states = resnet(hidden_states)
|
||||
@@ -369,7 +591,7 @@ class AttnUpBlock1D(nn.Module):
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = KernelUpsample1D(kernel="cubic")
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple):
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
@@ -403,7 +625,7 @@ class UpBlock1D(nn.Module):
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.up = KernelUpsample1D(kernel="cubic")
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple):
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
@@ -435,7 +657,7 @@ class UpBlock1DNoSkip(nn.Module):
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple):
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
@@ -443,3 +665,63 @@ class UpBlock1DNoSkip(nn.Module):
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample):
|
||||
if down_block_type == "DownResnetBlock1D":
|
||||
return DownResnetBlock1D(
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
elif down_block_type == "DownBlock1D":
|
||||
return DownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
elif down_block_type == "AttnDownBlock1D":
|
||||
return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels)
|
||||
elif down_block_type == "DownBlock1DNoSkip":
|
||||
return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels)
|
||||
raise ValueError(f"{down_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample):
|
||||
if up_block_type == "UpResnetBlock1D":
|
||||
return UpResnetBlock1D(
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
)
|
||||
elif up_block_type == "UpBlock1D":
|
||||
return UpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
elif up_block_type == "AttnUpBlock1D":
|
||||
return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels)
|
||||
elif up_block_type == "UpBlock1DNoSkip":
|
||||
return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample):
|
||||
if mid_block_type == "MidResTemporalBlock1D":
|
||||
return MidResTemporalBlock1D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
embed_dim=embed_dim,
|
||||
add_downsample=add_downsample,
|
||||
)
|
||||
elif mid_block_type == "ValueFunctionMidBlock1D":
|
||||
return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim)
|
||||
elif mid_block_type == "UNetMidBlock1D":
|
||||
return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels)
|
||||
raise ValueError(f"{mid_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim):
|
||||
if out_block_type == "OutConv1DBlock":
|
||||
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
|
||||
elif out_block_type == "ValueFunction":
|
||||
return OutValueFunctionBlock(fc_dim, embed_dim)
|
||||
return None
|
||||
|
||||
@@ -51,7 +51,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
|
||||
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to :
|
||||
obj:`False`): Whether to flip sin to cos for fourier time embedding.
|
||||
obj:`True`): Whether to flip sin to cos for fourier time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
||||
obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
|
||||
types.
|
||||
|
||||
@@ -15,7 +15,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .attention import AttentionBlock, SpatialTransformer
|
||||
from .attention import AttentionBlock, Transformer2DModel
|
||||
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
|
||||
|
||||
|
||||
@@ -109,6 +109,19 @@ def get_down_block(
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
)
|
||||
elif down_block_type == "AttnDownEncoderBlock2D":
|
||||
return AttnDownEncoderBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
raise ValueError(f"{down_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_up_block(
|
||||
@@ -200,6 +213,17 @@ def get_up_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
)
|
||||
elif up_block_type == "AttnUpDecoderBlock2D":
|
||||
return AttnUpDecoderBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
|
||||
@@ -249,7 +273,7 @@ class UNetMidBlock2D(nn.Module):
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
num_groups=resnet_groups,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
@@ -325,13 +349,13 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
SpatialTransformer(
|
||||
in_channels,
|
||||
Transformer2DModel(
|
||||
attn_num_head_channels,
|
||||
in_channels // attn_num_head_channels,
|
||||
depth=1,
|
||||
context_dim=cross_attention_dim,
|
||||
num_groups=resnet_groups,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
@@ -367,10 +391,14 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
for attn in self.attentions:
|
||||
attn._set_attention_slice(slice_size)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
for attn in self.attentions:
|
||||
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states).sample
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
@@ -423,7 +451,7 @@ class AttnDownBlock2D(nn.Module):
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
num_groups=resnet_groups,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -434,7 +462,7 @@ class AttnDownBlock2D(nn.Module):
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
Downsample2D(
|
||||
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -502,13 +530,13 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
SpatialTransformer(
|
||||
out_channels,
|
||||
Transformer2DModel(
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
depth=1,
|
||||
context_dim=cross_attention_dim,
|
||||
num_groups=resnet_groups,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
@@ -518,7 +546,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
Downsample2D(
|
||||
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -542,25 +570,32 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
for attn in self.attentions:
|
||||
attn._set_attention_slice(slice_size)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
for attn in self.attentions:
|
||||
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn), hidden_states, encoder_hidden_states
|
||||
)
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
@@ -616,7 +651,7 @@ class DownBlock2D(nn.Module):
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
Downsample2D(
|
||||
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -694,7 +729,7 @@ class DownEncoderBlock2D(nn.Module):
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
Downsample2D(
|
||||
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -755,7 +790,7 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
num_groups=resnet_groups,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -766,7 +801,7 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
Downsample2D(
|
||||
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -851,7 +886,7 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
down=True,
|
||||
kernel="fir",
|
||||
)
|
||||
self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
|
||||
self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
|
||||
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
|
||||
else:
|
||||
self.resnet_down = None
|
||||
@@ -931,7 +966,7 @@ class SkipDownBlock2D(nn.Module):
|
||||
down=True,
|
||||
kernel="fir",
|
||||
)
|
||||
self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
|
||||
self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
|
||||
self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
|
||||
else:
|
||||
self.resnet_down = None
|
||||
@@ -1006,7 +1041,7 @@ class AttnUpBlock2D(nn.Module):
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
num_groups=resnet_groups,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1081,13 +1116,13 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
SpatialTransformer(
|
||||
out_channels,
|
||||
Transformer2DModel(
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
depth=1,
|
||||
context_dim=cross_attention_dim,
|
||||
num_groups=resnet_groups,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
@@ -1117,6 +1152,10 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
for attn in self.attentions:
|
||||
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@@ -1133,19 +1172,22 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn), hidden_states, encoder_hidden_states
|
||||
)
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@@ -1325,7 +1367,7 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
num_groups=resnet_groups,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import flax.linen as nn
|
||||
import jax.numpy as jnp
|
||||
|
||||
from .attention_flax import FlaxSpatialTransformer
|
||||
from .attention_flax import FlaxTransformer2DModel
|
||||
from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
)
|
||||
resnets.append(res_block)
|
||||
|
||||
attn_block = FlaxSpatialTransformer(
|
||||
attn_block = FlaxTransformer2DModel(
|
||||
in_channels=self.out_channels,
|
||||
n_heads=self.attn_num_head_channels,
|
||||
d_head=self.out_channels // self.attn_num_head_channels,
|
||||
@@ -196,7 +196,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
)
|
||||
resnets.append(res_block)
|
||||
|
||||
attn_block = FlaxSpatialTransformer(
|
||||
attn_block = FlaxTransformer2DModel(
|
||||
in_channels=self.out_channels,
|
||||
n_heads=self.attn_num_head_channels,
|
||||
d_head=self.out_channels // self.attn_num_head_channels,
|
||||
@@ -326,7 +326,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
||||
attentions = []
|
||||
|
||||
for _ in range(self.num_layers):
|
||||
attn_block = FlaxSpatialTransformer(
|
||||
attn_block = FlaxTransformer2DModel(
|
||||
in_channels=self.in_channels,
|
||||
n_heads=self.attn_num_head_channels,
|
||||
d_head=self.in_channels // self.attn_num_head_channels,
|
||||
|
||||
@@ -60,7 +60,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
||||
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
||||
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
@@ -225,6 +225,17 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
if hasattr(block, "attentions") and block.attentions is not None:
|
||||
block.set_attention_slice(slice_size)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
for block in self.down_blocks:
|
||||
if hasattr(block, "attentions") and block.attentions is not None:
|
||||
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
||||
|
||||
self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
||||
|
||||
for block in self.up_blocks:
|
||||
if hasattr(block, "attentions") and block.attentions is not None:
|
||||
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@@ -233,14 +233,16 @@ class VectorQuantizer(nn.Module):
|
||||
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
||||
# backwards compatibility we use the buggy version by default, but you can
|
||||
# specify legacy=False to fix it.
|
||||
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
|
||||
def __init__(
|
||||
self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True
|
||||
):
|
||||
super().__init__()
|
||||
self.n_e = n_e
|
||||
self.e_dim = e_dim
|
||||
self.vq_embed_dim = vq_embed_dim
|
||||
self.beta = beta
|
||||
self.legacy = legacy
|
||||
|
||||
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
||||
self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
|
||||
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
||||
|
||||
self.remap = remap
|
||||
@@ -287,7 +289,7 @@ class VectorQuantizer(nn.Module):
|
||||
def forward(self, z):
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = z.permute(0, 2, 3, 1).contiguous()
|
||||
z_flattened = z.view(-1, self.e_dim)
|
||||
z_flattened = z.view(-1, self.vq_embed_dim)
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
|
||||
d = (
|
||||
@@ -409,6 +411,7 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
|
||||
sample_size (`int`, *optional*, defaults to `32`): TODO
|
||||
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
|
||||
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -425,6 +428,7 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
sample_size: int = 32,
|
||||
num_vq_embeddings: int = 256,
|
||||
norm_num_groups: int = 32,
|
||||
vq_embed_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -440,11 +444,11 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
double_z=False,
|
||||
)
|
||||
|
||||
self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
self.quantize = VectorQuantizer(
|
||||
num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False
|
||||
)
|
||||
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
|
||||
|
||||
self.quant_conv = torch.nn.Conv2d(latent_channels, vq_embed_dim, 1)
|
||||
self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
|
||||
self.post_quant_conv = torch.nn.Conv2d(vq_embed_dim, latent_channels, 1)
|
||||
|
||||
# pass init params to Decoder
|
||||
self.decoder = Decoder(
|
||||
|
||||
@@ -24,7 +24,7 @@ import numpy as np
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from .utils import ONNX_WEIGHTS_NAME, is_onnx_available, logging
|
||||
from .utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging
|
||||
|
||||
|
||||
if is_onnx_available():
|
||||
@@ -33,13 +33,28 @@ if is_onnx_available():
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
ORT_TO_NP_TYPE = {
|
||||
"tensor(bool)": np.bool_,
|
||||
"tensor(int8)": np.int8,
|
||||
"tensor(uint8)": np.uint8,
|
||||
"tensor(int16)": np.int16,
|
||||
"tensor(uint16)": np.uint16,
|
||||
"tensor(int32)": np.int32,
|
||||
"tensor(uint32)": np.uint32,
|
||||
"tensor(int64)": np.int64,
|
||||
"tensor(uint64)": np.uint64,
|
||||
"tensor(float16)": np.float16,
|
||||
"tensor(float)": np.float32,
|
||||
"tensor(double)": np.float64,
|
||||
}
|
||||
|
||||
|
||||
class OnnxRuntimeModel:
|
||||
def __init__(self, model=None, **kwargs):
|
||||
logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.")
|
||||
self.model = model
|
||||
self.model_save_dir = kwargs.get("model_save_dir", None)
|
||||
self.latest_model_name = kwargs.get("latest_model_name", "model.onnx")
|
||||
self.latest_model_name = kwargs.get("latest_model_name", ONNX_WEIGHTS_NAME)
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
||||
@@ -84,6 +99,15 @@ class OnnxRuntimeModel:
|
||||
except shutil.SameFileError:
|
||||
pass
|
||||
|
||||
# copy external weights (for models >2GB)
|
||||
src_path = self.model_save_dir.joinpath(ONNX_EXTERNAL_WEIGHTS_NAME)
|
||||
if src_path.exists():
|
||||
dst_path = Path(save_directory).joinpath(ONNX_EXTERNAL_WEIGHTS_NAME)
|
||||
try:
|
||||
shutil.copyfile(src_path, dst_path)
|
||||
except shutil.SameFileError:
|
||||
pass
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
|
||||
@@ -55,6 +55,8 @@ LOADABLE_CLASSES = {
|
||||
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
|
||||
"FlaxPreTrainedModel": ["save_pretrained", "from_pretrained"],
|
||||
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
|
||||
"ProcessorMixin": ["save_pretrained", "from_pretrained"],
|
||||
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
}
|
||||
|
||||
@@ -161,6 +163,10 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
|
||||
for pipeline_component_name in model_index_dict.keys():
|
||||
sub_model = getattr(self, pipeline_component_name)
|
||||
if sub_model is None:
|
||||
# edge case for saving a pipeline with safety_checker=None
|
||||
continue
|
||||
|
||||
model_cls = sub_model.__class__
|
||||
|
||||
save_method_name = None
|
||||
@@ -168,8 +174,8 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
for library_name, library_classes in LOADABLE_CLASSES.items():
|
||||
library = importlib.import_module(library_name)
|
||||
for base_class, save_load_methods in library_classes.items():
|
||||
class_candidate = getattr(library, base_class)
|
||||
if issubclass(model_cls, class_candidate):
|
||||
class_candidate = getattr(library, base_class, None)
|
||||
if class_candidate is not None and issubclass(model_cls, class_candidate):
|
||||
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
|
||||
save_method_name = save_load_methods[0]
|
||||
break
|
||||
@@ -262,18 +268,27 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
>>> from diffusers import FlaxDiffusionPipeline
|
||||
|
||||
>>> # Download pipeline from huggingface.co and cache.
|
||||
>>> pipeline = FlaxDiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
|
||||
>>> # Requires to be logged in to Hugging Face hub,
|
||||
>>> # see more in [the documentation](https://huggingface.co/docs/hub/security-tokens)
|
||||
>>> pipeline, params = FlaxDiffusionPipeline.from_pretrained(
|
||||
... "runwayml/stable-diffusion-v1-5",
|
||||
... revision="bf16",
|
||||
... dtype=jnp.bfloat16,
|
||||
... )
|
||||
|
||||
>>> # Download pipeline that requires an authorization token
|
||||
>>> # For more information on access tokens, please refer to this section
|
||||
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
|
||||
>>> pipeline = FlaxDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
>>> # Download pipeline, but use a different scheduler
|
||||
>>> from diffusers import FlaxDPMSolverMultistepScheduler
|
||||
|
||||
>>> # Download pipeline, but overwrite scheduler
|
||||
>>> from diffusers import LMSDiscreteScheduler
|
||||
>>> model_id = "runwayml/stable-diffusion-v1-5"
|
||||
>>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_config(
|
||||
... model_id,
|
||||
... subfolder="scheduler",
|
||||
... )
|
||||
|
||||
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
|
||||
>>> pipeline = FlaxDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
|
||||
>>> dpm_pipe, dpm_params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
... model_id, revision="bf16", dtype=jnp.bfloat16, scheduler=dpmpp
|
||||
... )
|
||||
>>> dpm_params["scheduler"] = dpmpp_state
|
||||
```
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
@@ -302,10 +317,19 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
allow_patterns = [os.path.join(k, "*") for k in folder_names]
|
||||
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
|
||||
|
||||
# make sure we don't download PyTorch weights
|
||||
ignore_patterns = "*.bin"
|
||||
|
||||
if cls != FlaxDiffusionPipeline:
|
||||
requested_pipeline_class = cls.__name__
|
||||
else:
|
||||
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
|
||||
requested_pipeline_class = (
|
||||
requested_pipeline_class
|
||||
if requested_pipeline_class.startswith("Flax")
|
||||
else "Flax" + requested_pipeline_class
|
||||
)
|
||||
|
||||
user_agent = {"pipeline_class": requested_pipeline_class}
|
||||
user_agent = http_user_agent(user_agent)
|
||||
|
||||
@@ -319,6 +343,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
else:
|
||||
@@ -337,7 +362,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
if config_dict["_class_name"].startswith("Flax")
|
||||
else "Flax" + config_dict["_class_name"]
|
||||
)
|
||||
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
|
||||
pipeline_class = getattr(diffusers_module, class_name)
|
||||
|
||||
# some modules can be passed directly to the init
|
||||
# in this case they are already instantiated in `kwargs`
|
||||
@@ -357,6 +382,11 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
|
||||
# 3. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
if class_name is None:
|
||||
# edge case for when the pipeline was saved with safety_checker=None
|
||||
init_kwargs[name] = None
|
||||
continue
|
||||
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
loaded_sub_model = None
|
||||
sub_model_should_be_defined = True
|
||||
@@ -368,11 +398,11 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||
|
||||
expected_class_obj = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if issubclass(class_obj, class_candidate):
|
||||
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
||||
expected_class_obj = class_candidate
|
||||
|
||||
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
|
||||
@@ -406,12 +436,12 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
class_obj = import_flax_or_no_model(library, class_name)
|
||||
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||
|
||||
if loaded_sub_model is None and sub_model_should_be_defined:
|
||||
load_method_name = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if issubclass(class_obj, class_candidate):
|
||||
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
||||
load_method_name = importable_classes[class_name][1]
|
||||
|
||||
load_method = getattr(class_obj, load_method_name)
|
||||
|
||||
@@ -33,6 +33,7 @@ from tqdm.auto import tqdm
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||
from .hub_utils import http_user_agent
|
||||
from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
@@ -41,6 +42,8 @@ from .utils import (
|
||||
WEIGHTS_NAME,
|
||||
BaseOutput,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
@@ -71,6 +74,8 @@ LOADABLE_CLASSES = {
|
||||
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
|
||||
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
|
||||
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
|
||||
"ProcessorMixin": ["save_pretrained", "from_pretrained"],
|
||||
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
|
||||
},
|
||||
}
|
||||
|
||||
@@ -176,6 +181,10 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
for pipeline_component_name in model_index_dict.keys():
|
||||
sub_model = getattr(self, pipeline_component_name)
|
||||
if sub_model is None:
|
||||
# edge case for saving a pipeline with safety_checker=None
|
||||
continue
|
||||
|
||||
model_cls = sub_model.__class__
|
||||
|
||||
save_method_name = None
|
||||
@@ -183,8 +192,8 @@ class DiffusionPipeline(ConfigMixin):
|
||||
for library_name, library_classes in LOADABLE_CLASSES.items():
|
||||
library = importlib.import_module(library_name)
|
||||
for base_class, save_load_methods in library_classes.items():
|
||||
class_candidate = getattr(library, base_class)
|
||||
if issubclass(model_cls, class_candidate):
|
||||
class_candidate = getattr(library, base_class, None)
|
||||
if class_candidate is not None and issubclass(model_cls, class_candidate):
|
||||
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
|
||||
save_method_name = save_load_methods[0]
|
||||
break
|
||||
@@ -202,13 +211,13 @@ class DiffusionPipeline(ConfigMixin):
|
||||
for name in module_names.keys():
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]:
|
||||
if module.dtype == torch.float16 and str(torch_device) in ["cpu"]:
|
||||
logger.warning(
|
||||
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` or `mps` device. It"
|
||||
" is not recommended to move them to `cpu` or `mps` as running them will fail. Please make"
|
||||
" sure to use a `cuda` device to run the pipeline in inference. due to the lack of support for"
|
||||
" `float16` operations on those devices in PyTorch. Please remove the"
|
||||
" `torch_dtype=torch.float16` argument, or use a `cuda` device to run inference."
|
||||
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
|
||||
" is not recommended to move them to `cpu` as running them will fail. Please make"
|
||||
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
|
||||
" support for`float16` operations on this device in PyTorch. Please, remove the"
|
||||
" `torch_dtype=torch.float16` argument, or use another device for inference."
|
||||
)
|
||||
module.to(torch_device)
|
||||
return self
|
||||
@@ -223,8 +232,6 @@ class DiffusionPipeline(ConfigMixin):
|
||||
for name in module_names.keys():
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
if module.device == torch.device("meta"):
|
||||
return torch.device("cpu")
|
||||
return module.device
|
||||
return torch.device("cpu")
|
||||
|
||||
@@ -296,8 +303,8 @@ class DiffusionPipeline(ConfigMixin):
|
||||
</Tip>
|
||||
|
||||
For more information on how to load and create custom pipelines, please have a look at [Loading and
|
||||
Creating Custom
|
||||
Pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/custom_pipelines)
|
||||
Adding Custom
|
||||
Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview)
|
||||
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
@@ -324,6 +331,19 @@ class DiffusionPipeline(ConfigMixin):
|
||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||
Please refer to the mirror site for more information. specify the folder name here.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device.
|
||||
|
||||
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
||||
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
||||
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
||||
setting this argument to `True` will raise an error.
|
||||
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
|
||||
@@ -376,6 +396,34 @@ class DiffusionPipeline(ConfigMixin):
|
||||
provider = kwargs.pop("provider", None)
|
||||
sess_options = kwargs.pop("sess_options", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warn(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `device_map=None`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is False and device_map is not None:
|
||||
raise ValueError(
|
||||
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||
)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
@@ -395,6 +443,9 @@ class DiffusionPipeline(ConfigMixin):
|
||||
allow_patterns = [os.path.join(k, "*") for k in folder_names]
|
||||
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
|
||||
|
||||
# make sure we don't download flax weights
|
||||
ignore_patterns = "*.msgpack"
|
||||
|
||||
if custom_pipeline is not None:
|
||||
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
|
||||
|
||||
@@ -417,6 +468,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
else:
|
||||
@@ -473,6 +525,11 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
# 3. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
if class_name is None:
|
||||
# edge case for when the pipeline was saved with safety_checker=None
|
||||
init_kwargs[name] = None
|
||||
continue
|
||||
|
||||
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
if class_name.startswith("Flax"):
|
||||
class_name = class_name[4:]
|
||||
@@ -484,15 +541,15 @@ class DiffusionPipeline(ConfigMixin):
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
if name in passed_class_obj:
|
||||
# 1. check that passed_class_obj has correct parent class
|
||||
if not is_pipeline_module:
|
||||
if not is_pipeline_module and passed_class_obj[name] is not None:
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||
|
||||
expected_class_obj = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if issubclass(class_obj, class_candidate):
|
||||
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
||||
expected_class_obj = class_candidate
|
||||
|
||||
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
|
||||
@@ -522,14 +579,15 @@ class DiffusionPipeline(ConfigMixin):
|
||||
else:
|
||||
# else we just import it from the library.
|
||||
library = importlib.import_module(library_name)
|
||||
|
||||
class_obj = getattr(library, class_name)
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||
|
||||
if loaded_sub_model is None and sub_model_should_be_defined:
|
||||
load_method_name = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if issubclass(class_obj, class_candidate):
|
||||
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
||||
load_method_name = importable_classes[class_name][1]
|
||||
|
||||
if load_method_name is None:
|
||||
@@ -559,8 +617,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
|
||||
)
|
||||
|
||||
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
|
||||
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
|
||||
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
||||
if is_diffusers_model or is_transformers_model:
|
||||
loading_kwargs["device_map"] = device_map
|
||||
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
# check if the module is in a subdirectory
|
||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||
|
||||
@@ -16,7 +16,7 @@ or created independently from each other.
|
||||
|
||||
To that end, we strive to offer all open-sourced, state-of-the-art diffusion system under a unified API.
|
||||
More specifically, we strive to provide pipelines that
|
||||
- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LatentDiffusionPipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)),
|
||||
- 1. can load the officially published weights and yield 1-to-1 the same outputs as the original implementation according to the corresponding paper (*e.g.* [LDMTextToImagePipeline](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/latent_diffusion), uses the officially released weights of [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)),
|
||||
- 2. have a simple user interface to run the model in inference (see the [Pipelines API](#pipelines-api) section),
|
||||
- 3. are easy to understand with code that is self-explanatory and can be read along-side the official paper (see [Pipelines summary](#pipelines-summary)),
|
||||
- 4. can easily be contributed by the community (see the [Contribution](#contribution) section).
|
||||
|
||||
@@ -5,8 +5,10 @@ if is_torch_available():
|
||||
from .dance_diffusion import DanceDiffusionPipeline
|
||||
from .ddim import DDIMPipeline
|
||||
from .ddpm import DDPMPipeline
|
||||
from .latent_diffusion import LDMSuperResolutionPipeline
|
||||
from .latent_diffusion_uncond import LDMPipeline
|
||||
from .pndm import PNDMPipeline
|
||||
from .repaint import RePaintPipeline
|
||||
from .score_sde_ve import ScoreSdeVePipeline
|
||||
from .stochastic_karras_ve import KarrasVePipeline
|
||||
else:
|
||||
@@ -15,11 +17,13 @@ else:
|
||||
if is_torch_available() and is_transformers_available():
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .stable_diffusion import (
|
||||
CycleDiffusionPipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from .vq_diffusion import VQDiffusionPipeline
|
||||
|
||||
if is_transformers_available() and is_onnx_available():
|
||||
from .stable_diffusion import (
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
|
||||
@@ -10,15 +10,14 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...utils import deprecate
|
||||
|
||||
|
||||
class DDIMPipeline(DiffusionPipeline):
|
||||
@@ -44,6 +43,7 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
generator: Optional[torch.Generator] = None,
|
||||
eta: float = 0.0,
|
||||
num_inference_steps: int = 50,
|
||||
use_clipped_model_output: Optional[bool] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
@@ -60,6 +60,9 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
use_clipped_model_output (`bool`, *optional*, defaults to `None`):
|
||||
if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed
|
||||
downstream to the scheduler. So use `None` for schedulers which don't support this argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -72,12 +75,27 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
generated images.
|
||||
"""
|
||||
|
||||
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
|
||||
message = (
|
||||
f"The `generator` device is `{generator.device}` and does not match the pipeline "
|
||||
f"device `{self.device}`, so the `generator` will be ignored. "
|
||||
f'Please use `generator=torch.Generator(device="{self.device}")` instead.'
|
||||
)
|
||||
deprecate(
|
||||
"generator.device == 'cpu'",
|
||||
"0.11.0",
|
||||
message,
|
||||
)
|
||||
generator = None
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
generator=generator,
|
||||
)
|
||||
image = image.to(self.device)
|
||||
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
|
||||
if self.device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
image = torch.randn(image_shape, generator=generator)
|
||||
image = image.to(self.device)
|
||||
else:
|
||||
image = torch.randn(image_shape, generator=generator, device=self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
@@ -89,7 +107,9 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
# do x_t -> x_t-1
|
||||
image = self.scheduler.step(model_output, t, image, eta).prev_sample
|
||||
image = self.scheduler.step(
|
||||
model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
|
||||
).prev_sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
@@ -18,7 +17,9 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...utils import deprecate
|
||||
|
||||
|
||||
class DDPMPipeline(DiffusionPipeline):
|
||||
@@ -68,13 +69,38 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_config(<model_id>, predict_epsilon=True)`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
|
||||
if predict_epsilon is not None:
|
||||
new_config = dict(self.scheduler.config)
|
||||
new_config["predict_epsilon"] = predict_epsilon
|
||||
self.scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
|
||||
message = (
|
||||
f"The `generator` device is `{generator.device}` and does not match the pipeline "
|
||||
f"device `{self.device}`, so the `generator` will be ignored. "
|
||||
f'Please use `torch.Generator(device="{self.device}")` instead.'
|
||||
)
|
||||
deprecate(
|
||||
"generator.device == 'cpu'",
|
||||
"0.11.0",
|
||||
message,
|
||||
)
|
||||
generator = None
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
generator=generator,
|
||||
)
|
||||
image = image.to(self.device)
|
||||
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
|
||||
if self.device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
image = torch.randn(image_shape, generator=generator)
|
||||
image = image.to(self.device)
|
||||
else:
|
||||
image = torch.randn(image_shape, generator=generator, device=self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
@@ -84,7 +110,9 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
model_output = self.unet(image, t).sample
|
||||
|
||||
# 2. compute previous image: x_t -> x_t-1
|
||||
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
|
||||
image = self.scheduler.step(
|
||||
model_output, t, image, generator=generator, predict_epsilon=predict_epsilon
|
||||
).prev_sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# flake8: noqa
|
||||
from ...utils import is_transformers_available
|
||||
from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@@ -32,7 +46,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
[BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer).
|
||||
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
"""
|
||||
|
||||
|
||||
@@ -0,0 +1,169 @@
|
||||
import inspect
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
|
||||
import PIL
|
||||
|
||||
from ...models import UNet2DModel, VQModel
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
|
||||
|
||||
def preprocess(image):
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.0 * image - 1.0
|
||||
|
||||
|
||||
class LDMSuperResolutionPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
A pipeline for image super-resolution using Latent
|
||||
|
||||
This class inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Parameters:
|
||||
vqvae ([`VQModel`]):
|
||||
Vector-quantized (VQ) VAE Model to encode and decode images to and from latent representations.
|
||||
unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`],
|
||||
[`EulerAncestralDiscreteScheduler`], [`DPMSolverMultistepScheduler`], or [`PNDMScheduler`].
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vqvae: VQModel,
|
||||
unet: UNet2DModel,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
],
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
init_image: Union[torch.Tensor, PIL.Image.Image],
|
||||
batch_size: Optional[int] = 1,
|
||||
num_inference_steps: Optional[int] = 100,
|
||||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
) -> Union[Tuple, ImagePipelineOutput]:
|
||||
r"""
|
||||
Args:
|
||||
init_image (`torch.Tensor` or `PIL.Image.Image`):
|
||||
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
||||
process.
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of images to generate.
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
batch_size = 1
|
||||
elif isinstance(init_image, torch.Tensor):
|
||||
batch_size = init_image.shape[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`init_image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(init_image)}"
|
||||
)
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = preprocess(init_image)
|
||||
|
||||
height, width = init_image.shape[-2:]
|
||||
|
||||
# in_channels should be 6: 3 for latents, 3 for low resolution image
|
||||
latents_shape = (batch_size, self.unet.in_channels // 2, height, width)
|
||||
latents_dtype = next(self.unet.parameters()).dtype
|
||||
|
||||
if self.device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype)
|
||||
latents = latents.to(self.device)
|
||||
else:
|
||||
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
|
||||
|
||||
init_image = init_image.to(device=self.device, dtype=latents_dtype)
|
||||
|
||||
# set timesteps and move to the correct device
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
|
||||
timesteps_tensor = self.scheduler.timesteps
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature.
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_kwargs["eta"] = eta
|
||||
|
||||
for t in self.progress_bar(timesteps_tensor):
|
||||
# concat latents and low resolution image in the channel dimension.
|
||||
latents_input = torch.cat([latents, init_image], dim=1)
|
||||
latents_input = self.scheduler.scale_model_input(latents_input, t)
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latents_input, t).sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
|
||||
|
||||
# decode the image latents with the VQVAE
|
||||
image = self.vqvae.decode(latents).sample
|
||||
image = torch.clamp(image, -1.0, 1.0)
|
||||
image = image / 2 + 0.5
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
@@ -18,7 +32,7 @@ class LDMPipeline(DiffusionPipeline):
|
||||
Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
|
||||
unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
[`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latens.
|
||||
[`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler):
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
|
||||
1
src/diffusers/pipelines/repaint/__init__.py
Normal file
1
src/diffusers/pipelines/repaint/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .pipeline_repaint import RePaintPipeline
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user