mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge branch 'main' into peftpart-1
This commit is contained in:
2
.github/workflows/pr_dependency_test.yml
vendored
2
.github/workflows/pr_dependency_test.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.7"
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
4
.github/workflows/pr_quality.yml
vendored
4
.github/workflows/pr_quality.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.7"
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.7"
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: 3.7
|
||||
python-version: 3.8
|
||||
|
||||
- name: Install requirements
|
||||
run: |
|
||||
|
||||
@@ -216,6 +216,8 @@
|
||||
title: AudioLDM 2
|
||||
- local: api/pipelines/auto_pipeline
|
||||
title: AutoPipeline
|
||||
- local: api/pipelines/blip_diffusion
|
||||
title: BLIP Diffusion
|
||||
- local: api/pipelines/consistency_models
|
||||
title: Consistency Models
|
||||
- local: api/pipelines/controlnet
|
||||
|
||||
29
docs/source/en/api/pipelines/blip_diffusion.md
Normal file
29
docs/source/en/api/pipelines/blip_diffusion.md
Normal file
@@ -0,0 +1,29 @@
|
||||
# Blip Diffusion
|
||||
|
||||
Blip Diffusion was proposed in [BLIP-Diffusion: Pre-trained Subject Representation for Controllable Text-to-Image Generation and Editing](https://arxiv.org/abs/2305.14720). It enables zero-shot subject-driven generation and control-guided zero-shot generation.
|
||||
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Subject-driven text-to-image generation models create novel renditions of an input subject based on text prompts. Existing models suffer from lengthy fine-tuning and difficulties preserving the subject fidelity. To overcome these limitations, we introduce BLIP-Diffusion, a new subject-driven image generation model that supports multimodal control which consumes inputs of subject images and text prompts. Unlike other subject-driven generation models, BLIP-Diffusion introduces a new multimodal encoder which is pre-trained to provide subject representation. We first pre-train the multimodal encoder following BLIP-2 to produce visual representation aligned with the text. Then we design a subject representation learning task which enables a diffusion model to leverage such visual representation and generates new subject renditions. Compared with previous methods such as DreamBooth, our model enables zero-shot subject-driven generation, and efficient fine-tuning for customized subject with up to 20x speedup. We also demonstrate that BLIP-Diffusion can be flexibly combined with existing techniques such as ControlNet and prompt-to-prompt to enable novel subject-driven generation and editing applications.*
|
||||
|
||||
The original codebase can be found at [salesforce/LAVIS](https://github.com/salesforce/LAVIS/tree/main/projects/blip-diffusion). You can find the official BLIP Diffusion checkpoints under the [hf.co/SalesForce](https://hf.co/SalesForce) organization.
|
||||
|
||||
`BlipDiffusionPipeline` and `BlipDiffusionControlNetPipeline` were contributed by [`ayushtues`](https://github.com/ayushtues/).
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
## BlipDiffusionPipeline
|
||||
[[autodoc]] BlipDiffusionPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## BlipDiffusionControlNetPipeline
|
||||
[[autodoc]] BlipDiffusionControlNetPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
Install 🤗 Diffusers for whichever deep learning library you're working with.
|
||||
|
||||
🤗 Diffusers is tested on Python 3.7+, PyTorch 1.7.0+ and Flax. Follow the installation instructions below for the deep learning library you are using:
|
||||
🤗 Diffusers is tested on Python 3.8+, PyTorch 1.7.0+ and Flax. Follow the installation instructions below for the deep learning library you are using:
|
||||
|
||||
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
|
||||
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.
|
||||
@@ -106,7 +106,7 @@ pip install -e ".[flax]"
|
||||
|
||||
These commands will link the folder you cloned the repository to and your Python library paths.
|
||||
Python will now look inside the folder you cloned to in addition to the normal library paths.
|
||||
For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.7/site-packages/`, Python will also search the `~/diffusers/` folder you cloned to.
|
||||
For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.8/site-packages/`, Python will also search the `~/diffusers/` folder you cloned to.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
사용하시는 라이브러리에 맞는 🤗 Diffusers를 설치하세요.
|
||||
|
||||
🤗 Diffusers는 Python 3.7+, PyTorch 1.7.0+ 및 flax에서 테스트되었습니다. 사용중인 딥러닝 라이브러리에 대한 아래의 설치 안내를 따르세요.
|
||||
🤗 Diffusers는 Python 3.8+, PyTorch 1.7.0+ 및 flax에서 테스트되었습니다. 사용중인 딥러닝 라이브러리에 대한 아래의 설치 안내를 따르세요.
|
||||
|
||||
- [PyTorch 설치 안내](https://pytorch.org/get-started/locally/)
|
||||
- [Flax 설치 안내](https://flax.readthedocs.io/en/latest/)
|
||||
@@ -105,7 +105,7 @@ pip install -e ".[flax]"
|
||||
|
||||
이러한 명령어들은 저장소를 복제한 폴더와 Python 라이브러리 경로를 연결합니다.
|
||||
Python은 이제 일반 라이브러리 경로에 더하여 복제한 폴더 내부를 살펴봅니다.
|
||||
예를들어 Python 패키지가 `~/anaconda3/envs/main/lib/python3.7/site-packages/`에 설치되어 있는 경우 Python은 복제한 폴더인 `~/diffusers/`도 검색합니다.
|
||||
예를들어 Python 패키지가 `~/anaconda3/envs/main/lib/python3.8/site-packages/`에 설치되어 있는 경우 Python은 복제한 폴더인 `~/diffusers/`도 검색합니다.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
在你正在使用的任意深度学习框架中安装 🤗 Diffusers 。
|
||||
|
||||
🤗 Diffusers已在Python 3.7+、PyTorch 1.7.0+和Flax上进行了测试。按照下面的安装说明,针对你正在使用的深度学习框架进行安装:
|
||||
🤗 Diffusers已在Python 3.8+、PyTorch 1.7.0+和Flax上进行了测试。按照下面的安装说明,针对你正在使用的深度学习框架进行安装:
|
||||
|
||||
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
|
||||
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.
|
||||
@@ -107,7 +107,7 @@ pip install -e ".[flax]"
|
||||
|
||||
这些命令将连接到你克隆的版本库和你的 Python 库路径。
|
||||
现在,不只是在通常的库路径,Python 还会在你克隆的文件夹内寻找包。
|
||||
例如,如果你的 Python 包通常安装在 `~/anaconda3/envs/main/lib/python3.7/Site-packages/`,Python 也会搜索你克隆到的文件夹。`~/diffusers/`。
|
||||
例如,如果你的 Python 包通常安装在 `~/anaconda3/envs/main/lib/python3.8/Site-packages/`,Python 也会搜索你克隆到的文件夹。`~/diffusers/`。
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
|
||||
@@ -908,6 +908,9 @@ def main():
|
||||
if args.snr_gamma is not None:
|
||||
snr = jnp.array(compute_snr(timesteps))
|
||||
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
snr_loss_weights = snr_loss_weights + 1
|
||||
loss = loss * snr_loss_weights
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -224,6 +224,30 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
|
||||
raise ValueError(f"{model_class} is not supported.")
|
||||
|
||||
|
||||
def compute_snr(timesteps, noise_scheduler):
|
||||
"""
|
||||
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
||||
"""
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
||||
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
||||
# Expand the tensors.
|
||||
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
||||
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
||||
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
# Compute SNR
|
||||
snr = (alpha / sigma) ** 2
|
||||
return snr
|
||||
|
||||
|
||||
def parse_args(input_args=None):
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
@@ -524,6 +548,13 @@ def parse_args(input_args=None):
|
||||
" See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--snr_gamma",
|
||||
type=float,
|
||||
default=None,
|
||||
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
||||
"More details here: https://arxiv.org/abs/2303.09556.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pre_compute_text_embeddings",
|
||||
action="store_true",
|
||||
@@ -1261,17 +1292,34 @@ def main(args):
|
||||
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
||||
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||
|
||||
# Compute instance loss
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
||||
|
||||
# Compute instance loss
|
||||
if args.snr_gamma is None:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
else:
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps, noise_scheduler)
|
||||
base_weight = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
if args.with_prior_preservation:
|
||||
# Add the prior loss to the instance loss.
|
||||
loss = loss + args.prior_loss_weight * prior_loss
|
||||
else:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
@@ -875,6 +875,9 @@ def main():
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
mse_loss_weights = mse_loss_weights + 1
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -955,6 +955,9 @@ def main():
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
mse_loss_weights = mse_loss_weights + 1
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -786,6 +786,9 @@ def main():
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
mse_loss_weights = mse_loss_weights + 1
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -1075,6 +1075,9 @@ def main(args):
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
mse_loss_weights = mse_loss_weights + 1
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
343
scripts/convert_blipdiffusion_to_diffusers.py
Normal file
343
scripts/convert_blipdiffusion_to_diffusers.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""
|
||||
This script requires you to build `LAVIS` from source, since the pip version doesn't have BLIP Diffusion. Follow instructions here: https://github.com/salesforce/LAVIS/tree/main.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from lavis.models import load_model_and_preprocess
|
||||
from transformers import CLIPTokenizer
|
||||
from transformers.models.blip_2.configuration_blip_2 import Blip2Config
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
PNDMScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.pipelines import BlipDiffusionPipeline
|
||||
from diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
|
||||
from diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
|
||||
from diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
|
||||
|
||||
|
||||
BLIP2_CONFIG = {
|
||||
"vision_config": {
|
||||
"hidden_size": 1024,
|
||||
"num_hidden_layers": 23,
|
||||
"num_attention_heads": 16,
|
||||
"image_size": 224,
|
||||
"patch_size": 14,
|
||||
"intermediate_size": 4096,
|
||||
"hidden_act": "quick_gelu",
|
||||
},
|
||||
"qformer_config": {
|
||||
"cross_attention_frequency": 1,
|
||||
"encoder_hidden_size": 1024,
|
||||
"vocab_size": 30523,
|
||||
},
|
||||
"num_query_tokens": 16,
|
||||
}
|
||||
blip2config = Blip2Config(**BLIP2_CONFIG)
|
||||
|
||||
|
||||
def qformer_model_from_original_config():
|
||||
qformer = Blip2QFormerModel(blip2config)
|
||||
return qformer
|
||||
|
||||
|
||||
def embeddings_from_original_checkpoint(model, diffuser_embeddings_prefix, original_embeddings_prefix):
|
||||
embeddings = {}
|
||||
embeddings.update(
|
||||
{
|
||||
f"{diffuser_embeddings_prefix}.word_embeddings.weight": model[
|
||||
f"{original_embeddings_prefix}.word_embeddings.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
embeddings.update(
|
||||
{
|
||||
f"{diffuser_embeddings_prefix}.position_embeddings.weight": model[
|
||||
f"{original_embeddings_prefix}.position_embeddings.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
embeddings.update(
|
||||
{f"{diffuser_embeddings_prefix}.LayerNorm.weight": model[f"{original_embeddings_prefix}.LayerNorm.weight"]}
|
||||
)
|
||||
embeddings.update(
|
||||
{f"{diffuser_embeddings_prefix}.LayerNorm.bias": model[f"{original_embeddings_prefix}.LayerNorm.bias"]}
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
def proj_layer_from_original_checkpoint(model, diffuser_proj_prefix, original_proj_prefix):
|
||||
proj_layer = {}
|
||||
proj_layer.update({f"{diffuser_proj_prefix}.dense1.weight": model[f"{original_proj_prefix}.dense1.weight"]})
|
||||
proj_layer.update({f"{diffuser_proj_prefix}.dense1.bias": model[f"{original_proj_prefix}.dense1.bias"]})
|
||||
proj_layer.update({f"{diffuser_proj_prefix}.dense2.weight": model[f"{original_proj_prefix}.dense2.weight"]})
|
||||
proj_layer.update({f"{diffuser_proj_prefix}.dense2.bias": model[f"{original_proj_prefix}.dense2.bias"]})
|
||||
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.weight": model[f"{original_proj_prefix}.LayerNorm.weight"]})
|
||||
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.bias": model[f"{original_proj_prefix}.LayerNorm.bias"]})
|
||||
return proj_layer
|
||||
|
||||
|
||||
def attention_from_original_checkpoint(model, diffuser_attention_prefix, original_attention_prefix):
|
||||
attention = {}
|
||||
attention.update(
|
||||
{
|
||||
f"{diffuser_attention_prefix}.attention.query.weight": model[
|
||||
f"{original_attention_prefix}.self.query.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
attention.update(
|
||||
{f"{diffuser_attention_prefix}.attention.query.bias": model[f"{original_attention_prefix}.self.query.bias"]}
|
||||
)
|
||||
attention.update(
|
||||
{f"{diffuser_attention_prefix}.attention.key.weight": model[f"{original_attention_prefix}.self.key.weight"]}
|
||||
)
|
||||
attention.update(
|
||||
{f"{diffuser_attention_prefix}.attention.key.bias": model[f"{original_attention_prefix}.self.key.bias"]}
|
||||
)
|
||||
attention.update(
|
||||
{
|
||||
f"{diffuser_attention_prefix}.attention.value.weight": model[
|
||||
f"{original_attention_prefix}.self.value.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
attention.update(
|
||||
{f"{diffuser_attention_prefix}.attention.value.bias": model[f"{original_attention_prefix}.self.value.bias"]}
|
||||
)
|
||||
attention.update(
|
||||
{f"{diffuser_attention_prefix}.output.dense.weight": model[f"{original_attention_prefix}.output.dense.weight"]}
|
||||
)
|
||||
attention.update(
|
||||
{f"{diffuser_attention_prefix}.output.dense.bias": model[f"{original_attention_prefix}.output.dense.bias"]}
|
||||
)
|
||||
attention.update(
|
||||
{
|
||||
f"{diffuser_attention_prefix}.output.LayerNorm.weight": model[
|
||||
f"{original_attention_prefix}.output.LayerNorm.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
attention.update(
|
||||
{
|
||||
f"{diffuser_attention_prefix}.output.LayerNorm.bias": model[
|
||||
f"{original_attention_prefix}.output.LayerNorm.bias"
|
||||
]
|
||||
}
|
||||
)
|
||||
return attention
|
||||
|
||||
|
||||
def output_layers_from_original_checkpoint(model, diffuser_output_prefix, original_output_prefix):
|
||||
output_layers = {}
|
||||
output_layers.update({f"{diffuser_output_prefix}.dense.weight": model[f"{original_output_prefix}.dense.weight"]})
|
||||
output_layers.update({f"{diffuser_output_prefix}.dense.bias": model[f"{original_output_prefix}.dense.bias"]})
|
||||
output_layers.update(
|
||||
{f"{diffuser_output_prefix}.LayerNorm.weight": model[f"{original_output_prefix}.LayerNorm.weight"]}
|
||||
)
|
||||
output_layers.update(
|
||||
{f"{diffuser_output_prefix}.LayerNorm.bias": model[f"{original_output_prefix}.LayerNorm.bias"]}
|
||||
)
|
||||
return output_layers
|
||||
|
||||
|
||||
def encoder_from_original_checkpoint(model, diffuser_encoder_prefix, original_encoder_prefix):
|
||||
encoder = {}
|
||||
for i in range(blip2config.qformer_config.num_hidden_layers):
|
||||
encoder.update(
|
||||
attention_from_original_checkpoint(
|
||||
model, f"{diffuser_encoder_prefix}.{i}.attention", f"{original_encoder_prefix}.{i}.attention"
|
||||
)
|
||||
)
|
||||
encoder.update(
|
||||
attention_from_original_checkpoint(
|
||||
model, f"{diffuser_encoder_prefix}.{i}.crossattention", f"{original_encoder_prefix}.{i}.crossattention"
|
||||
)
|
||||
)
|
||||
|
||||
encoder.update(
|
||||
{
|
||||
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.weight": model[
|
||||
f"{original_encoder_prefix}.{i}.intermediate.dense.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
encoder.update(
|
||||
{
|
||||
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.bias": model[
|
||||
f"{original_encoder_prefix}.{i}.intermediate.dense.bias"
|
||||
]
|
||||
}
|
||||
)
|
||||
encoder.update(
|
||||
{
|
||||
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.weight": model[
|
||||
f"{original_encoder_prefix}.{i}.intermediate_query.dense.weight"
|
||||
]
|
||||
}
|
||||
)
|
||||
encoder.update(
|
||||
{
|
||||
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.bias": model[
|
||||
f"{original_encoder_prefix}.{i}.intermediate_query.dense.bias"
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
encoder.update(
|
||||
output_layers_from_original_checkpoint(
|
||||
model, f"{diffuser_encoder_prefix}.{i}.output", f"{original_encoder_prefix}.{i}.output"
|
||||
)
|
||||
)
|
||||
encoder.update(
|
||||
output_layers_from_original_checkpoint(
|
||||
model, f"{diffuser_encoder_prefix}.{i}.output_query", f"{original_encoder_prefix}.{i}.output_query"
|
||||
)
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
||||
def visual_encoder_layer_from_original_checkpoint(model, diffuser_prefix, original_prefix):
|
||||
visual_encoder_layer = {}
|
||||
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.weight": model[f"{original_prefix}.ln_1.weight"]})
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.bias": model[f"{original_prefix}.ln_1.bias"]})
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.weight": model[f"{original_prefix}.ln_2.weight"]})
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.bias": model[f"{original_prefix}.ln_2.bias"]})
|
||||
visual_encoder_layer.update(
|
||||
{f"{diffuser_prefix}.self_attn.qkv.weight": model[f"{original_prefix}.attn.in_proj_weight"]}
|
||||
)
|
||||
visual_encoder_layer.update(
|
||||
{f"{diffuser_prefix}.self_attn.qkv.bias": model[f"{original_prefix}.attn.in_proj_bias"]}
|
||||
)
|
||||
visual_encoder_layer.update(
|
||||
{f"{diffuser_prefix}.self_attn.projection.weight": model[f"{original_prefix}.attn.out_proj.weight"]}
|
||||
)
|
||||
visual_encoder_layer.update(
|
||||
{f"{diffuser_prefix}.self_attn.projection.bias": model[f"{original_prefix}.attn.out_proj.bias"]}
|
||||
)
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.weight": model[f"{original_prefix}.mlp.c_fc.weight"]})
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.bias": model[f"{original_prefix}.mlp.c_fc.bias"]})
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.weight": model[f"{original_prefix}.mlp.c_proj.weight"]})
|
||||
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.bias": model[f"{original_prefix}.mlp.c_proj.bias"]})
|
||||
|
||||
return visual_encoder_layer
|
||||
|
||||
|
||||
def visual_encoder_from_original_checkpoint(model, diffuser_prefix, original_prefix):
|
||||
visual_encoder = {}
|
||||
|
||||
visual_encoder.update(
|
||||
{
|
||||
f"{diffuser_prefix}.embeddings.class_embedding": model[f"{original_prefix}.class_embedding"]
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(0)
|
||||
}
|
||||
)
|
||||
visual_encoder.update(
|
||||
{
|
||||
f"{diffuser_prefix}.embeddings.position_embedding": model[
|
||||
f"{original_prefix}.positional_embedding"
|
||||
].unsqueeze(0)
|
||||
}
|
||||
)
|
||||
visual_encoder.update(
|
||||
{f"{diffuser_prefix}.embeddings.patch_embedding.weight": model[f"{original_prefix}.conv1.weight"]}
|
||||
)
|
||||
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.weight": model[f"{original_prefix}.ln_pre.weight"]})
|
||||
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.bias": model[f"{original_prefix}.ln_pre.bias"]})
|
||||
|
||||
for i in range(blip2config.vision_config.num_hidden_layers):
|
||||
visual_encoder.update(
|
||||
visual_encoder_layer_from_original_checkpoint(
|
||||
model, f"{diffuser_prefix}.encoder.layers.{i}", f"{original_prefix}.transformer.resblocks.{i}"
|
||||
)
|
||||
)
|
||||
|
||||
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.weight": model["blip.ln_vision.weight"]})
|
||||
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.bias": model["blip.ln_vision.bias"]})
|
||||
|
||||
return visual_encoder
|
||||
|
||||
|
||||
def qformer_original_checkpoint_to_diffusers_checkpoint(model):
|
||||
qformer_checkpoint = {}
|
||||
qformer_checkpoint.update(embeddings_from_original_checkpoint(model, "embeddings", "blip.Qformer.bert.embeddings"))
|
||||
qformer_checkpoint.update({"query_tokens": model["blip.query_tokens"]})
|
||||
qformer_checkpoint.update(proj_layer_from_original_checkpoint(model, "proj_layer", "proj_layer"))
|
||||
qformer_checkpoint.update(
|
||||
encoder_from_original_checkpoint(model, "encoder.layer", "blip.Qformer.bert.encoder.layer")
|
||||
)
|
||||
qformer_checkpoint.update(visual_encoder_from_original_checkpoint(model, "visual_encoder", "blip.visual_encoder"))
|
||||
return qformer_checkpoint
|
||||
|
||||
|
||||
def get_qformer(model):
|
||||
print("loading qformer")
|
||||
|
||||
qformer = qformer_model_from_original_config()
|
||||
qformer_diffusers_checkpoint = qformer_original_checkpoint_to_diffusers_checkpoint(model)
|
||||
|
||||
load_checkpoint_to_model(qformer_diffusers_checkpoint, qformer)
|
||||
|
||||
print("done loading qformer")
|
||||
return qformer
|
||||
|
||||
|
||||
def load_checkpoint_to_model(checkpoint, model):
|
||||
with tempfile.NamedTemporaryFile(delete=False) as file:
|
||||
torch.save(checkpoint, file.name)
|
||||
del checkpoint
|
||||
model.load_state_dict(torch.load(file.name), strict=False)
|
||||
|
||||
os.remove(file.name)
|
||||
|
||||
|
||||
def save_blip_diffusion_model(model, args):
|
||||
qformer = get_qformer(model)
|
||||
qformer.eval()
|
||||
|
||||
text_encoder = ContextCLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
|
||||
vae.eval()
|
||||
text_encoder.eval()
|
||||
scheduler = PNDMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
set_alpha_to_one=False,
|
||||
skip_prk_steps=True,
|
||||
)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
|
||||
image_processor = BlipImageProcessor()
|
||||
blip_diffusion = BlipDiffusionPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
qformer=qformer,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
blip_diffusion.save_pretrained(args.checkpoint_path)
|
||||
|
||||
|
||||
def main(args):
|
||||
model, _, _ = load_model_and_preprocess("blip_diffusion", "base", device="cpu", is_eval=True)
|
||||
save_blip_diffusion_model(model.state_dict(), args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
@@ -35,6 +35,12 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
help="The YAML config file corresponding to the original architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_files",
|
||||
default=None,
|
||||
type=str,
|
||||
help="The YAML config file corresponding to the architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_in_channels",
|
||||
default=None,
|
||||
|
||||
3
setup.py
3
setup.py
@@ -257,7 +257,7 @@ setup(
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages("src"),
|
||||
include_package_data=True,
|
||||
python_requires=">=3.7.0",
|
||||
python_requires=">=3.8.0",
|
||||
install_requires=list(install_requires),
|
||||
extras_require=extras,
|
||||
entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]},
|
||||
@@ -269,7 +269,6 @@ setup(
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
|
||||
@@ -197,6 +197,8 @@ else:
|
||||
"AudioLDM2ProjectionModel",
|
||||
"AudioLDM2UNet2DConditionModel",
|
||||
"AudioLDMPipeline",
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
"BlipDiffusionPipeline",
|
||||
"CLIPImageProjection",
|
||||
"CycleDiffusionPipeline",
|
||||
"IFImg2ImgPipeline",
|
||||
@@ -458,6 +460,8 @@ if TYPE_CHECKING:
|
||||
AutoPipelineForImage2Image,
|
||||
AutoPipelineForInpainting,
|
||||
AutoPipelineForText2Image,
|
||||
BlipDiffusionControlNetPipeline,
|
||||
BlipDiffusionPipeline,
|
||||
CLIPImageProjection,
|
||||
ConsistencyModelPipeline,
|
||||
DanceDiffusionPipeline,
|
||||
|
||||
@@ -1,456 +1,460 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_k_diffusion_available,
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
# These modules contain pipelines from multiple libraries/frameworks
|
||||
_dummy_objects = {}
|
||||
_import_structure = {"stable_diffusion": [], "latent_diffusion": [], "controlnet": []}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_pt_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
|
||||
else:
|
||||
_import_structure["auto_pipeline"] = [
|
||||
"AutoPipelineForImage2Image",
|
||||
"AutoPipelineForInpainting",
|
||||
"AutoPipelineForText2Image",
|
||||
]
|
||||
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
|
||||
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]
|
||||
_import_structure["ddim"] = ["DDIMPipeline"]
|
||||
_import_structure["ddpm"] = ["DDPMPipeline"]
|
||||
_import_structure["dit"] = ["DiTPipeline"]
|
||||
_import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"])
|
||||
_import_structure["latent_diffusion_uncond"] = ["LDMPipeline"]
|
||||
_import_structure["pipeline_utils"] = ["AudioPipelineOutput", "DiffusionPipeline", "ImagePipelineOutput"]
|
||||
_import_structure["pndm"] = ["PNDMPipeline"]
|
||||
_import_structure["repaint"] = ["RePaintPipeline"]
|
||||
_import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"]
|
||||
_import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"]
|
||||
try:
|
||||
if not (is_torch_available() and is_librosa_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_torch_and_librosa_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects))
|
||||
else:
|
||||
_import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"]
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"]
|
||||
_import_structure["audioldm"] = ["AudioLDMPipeline"]
|
||||
_import_structure["audioldm2"] = [
|
||||
"AudioLDM2Pipeline",
|
||||
"AudioLDM2ProjectionModel",
|
||||
"AudioLDM2UNet2DConditionModel",
|
||||
]
|
||||
_import_structure["controlnet"].extend(
|
||||
[
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
"StableDiffusionXLControlNetInpaintPipeline",
|
||||
"StableDiffusionXLControlNetPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["deepfloyd_if"] = [
|
||||
"IFImg2ImgPipeline",
|
||||
"IFImg2ImgSuperResolutionPipeline",
|
||||
"IFInpaintingPipeline",
|
||||
"IFInpaintingSuperResolutionPipeline",
|
||||
"IFPipeline",
|
||||
"IFSuperResolutionPipeline",
|
||||
]
|
||||
_import_structure["kandinsky"] = [
|
||||
"KandinskyCombinedPipeline",
|
||||
"KandinskyImg2ImgCombinedPipeline",
|
||||
"KandinskyImg2ImgPipeline",
|
||||
"KandinskyInpaintCombinedPipeline",
|
||||
"KandinskyInpaintPipeline",
|
||||
"KandinskyPipeline",
|
||||
"KandinskyPriorPipeline",
|
||||
]
|
||||
_import_structure["kandinsky2_2"] = [
|
||||
"KandinskyV22CombinedPipeline",
|
||||
"KandinskyV22ControlnetImg2ImgPipeline",
|
||||
"KandinskyV22ControlnetPipeline",
|
||||
"KandinskyV22Img2ImgCombinedPipeline",
|
||||
"KandinskyV22Img2ImgPipeline",
|
||||
"KandinskyV22InpaintCombinedPipeline",
|
||||
"KandinskyV22InpaintPipeline",
|
||||
"KandinskyV22Pipeline",
|
||||
"KandinskyV22PriorEmb2EmbPipeline",
|
||||
"KandinskyV22PriorPipeline",
|
||||
]
|
||||
_import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"])
|
||||
_import_structure["musicldm"] = ["MusicLDMPipeline"]
|
||||
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
|
||||
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
|
||||
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
|
||||
_import_structure["stable_diffusion"].extend(
|
||||
[
|
||||
"CLIPImageProjection",
|
||||
"CycleDiffusionPipeline",
|
||||
"StableDiffusionAttendAndExcitePipeline",
|
||||
"StableDiffusionDepth2ImgPipeline",
|
||||
"StableDiffusionDiffEditPipeline",
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
"StableDiffusionGLIGENTextImagePipeline",
|
||||
"StableDiffusionImageVariationPipeline",
|
||||
"StableDiffusionImg2ImgPipeline",
|
||||
"StableDiffusionInpaintPipeline",
|
||||
"StableDiffusionInpaintPipelineLegacy",
|
||||
"StableDiffusionInstructPix2PixPipeline",
|
||||
"StableDiffusionLatentUpscalePipeline",
|
||||
"StableDiffusionLDM3DPipeline",
|
||||
"StableDiffusionModelEditingPipeline",
|
||||
"StableDiffusionPanoramaPipeline",
|
||||
"StableDiffusionParadigmsPipeline",
|
||||
"StableDiffusionPipeline",
|
||||
"StableDiffusionPix2PixZeroPipeline",
|
||||
"StableDiffusionSAGPipeline",
|
||||
"StableDiffusionUpscalePipeline",
|
||||
"StableUnCLIPImg2ImgPipeline",
|
||||
"StableUnCLIPPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
|
||||
_import_structure["stable_diffusion_xl"] = [
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
"StableDiffusionXLPipeline",
|
||||
]
|
||||
_import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"]
|
||||
_import_structure["text_to_video_synthesis"] = [
|
||||
"TextToVideoSDPipeline",
|
||||
"TextToVideoZeroPipeline",
|
||||
"VideoToVideoSDPipeline",
|
||||
]
|
||||
_import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"]
|
||||
_import_structure["unidiffuser"] = [
|
||||
"ImageTextPipelineOutput",
|
||||
"UniDiffuserModel",
|
||||
"UniDiffuserPipeline",
|
||||
"UniDiffuserTextDecoder",
|
||||
]
|
||||
_import_structure["versatile_diffusion"] = [
|
||||
"VersatileDiffusionDualGuidedPipeline",
|
||||
"VersatileDiffusionImageVariationPipeline",
|
||||
"VersatileDiffusionPipeline",
|
||||
"VersatileDiffusionTextToImagePipeline",
|
||||
]
|
||||
_import_structure["vq_diffusion"] = ["VQDiffusionPipeline"]
|
||||
_import_structure["wuerstchen"] = [
|
||||
"WuerstchenCombinedPipeline",
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
]
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_onnx_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_onnx_objects))
|
||||
else:
|
||||
_import_structure["onnx_utils"] = ["OnnxRuntimeModel"]
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects))
|
||||
else:
|
||||
_import_structure["stable_diffusion"].extend(
|
||||
[
|
||||
"OnnxStableDiffusionImg2ImgPipeline",
|
||||
"OnnxStableDiffusionInpaintPipeline",
|
||||
"OnnxStableDiffusionInpaintPipelineLegacy",
|
||||
"OnnxStableDiffusionPipeline",
|
||||
"OnnxStableDiffusionUpscalePipeline",
|
||||
"StableDiffusionOnnxPipeline",
|
||||
]
|
||||
)
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
|
||||
else:
|
||||
_import_structure["stable_diffusion"].extend(["StableDiffusionKDiffusionPipeline"])
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_flax_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_flax_objects))
|
||||
else:
|
||||
_import_structure["pipeline_flax_utils"] = ["FlaxDiffusionPipeline"]
|
||||
try:
|
||||
if not (is_flax_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_flax_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["controlnet"].extend(["FlaxStableDiffusionControlNetPipeline"])
|
||||
_import_structure["stable_diffusion"].extend(
|
||||
[
|
||||
"FlaxStableDiffusionImg2ImgPipeline",
|
||||
"FlaxStableDiffusionInpaintPipeline",
|
||||
"FlaxStableDiffusionPipeline",
|
||||
]
|
||||
)
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects))
|
||||
else:
|
||||
_import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_pt_objects import * # noqa F403
|
||||
|
||||
else:
|
||||
from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image
|
||||
from .consistency_models import ConsistencyModelPipeline
|
||||
from .dance_diffusion import DanceDiffusionPipeline
|
||||
from .ddim import DDIMPipeline
|
||||
from .ddpm import DDPMPipeline
|
||||
from .dit import DiTPipeline
|
||||
from .latent_diffusion import LDMSuperResolutionPipeline
|
||||
from .latent_diffusion_uncond import LDMPipeline
|
||||
from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput
|
||||
from .pndm import PNDMPipeline
|
||||
from .repaint import RePaintPipeline
|
||||
from .score_sde_ve import ScoreSdeVePipeline
|
||||
from .stochastic_karras_ve import KarrasVePipeline
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_librosa_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_librosa_objects import *
|
||||
else:
|
||||
from .audio_diffusion import AudioDiffusionPipeline, Mel
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
|
||||
from .audioldm import AudioLDMPipeline
|
||||
from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
from .controlnet import (
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
)
|
||||
from .deepfloyd_if import (
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
IFInpaintingPipeline,
|
||||
IFInpaintingSuperResolutionPipeline,
|
||||
IFPipeline,
|
||||
IFSuperResolutionPipeline,
|
||||
)
|
||||
from .kandinsky import (
|
||||
KandinskyCombinedPipeline,
|
||||
KandinskyImg2ImgCombinedPipeline,
|
||||
KandinskyImg2ImgPipeline,
|
||||
KandinskyInpaintCombinedPipeline,
|
||||
KandinskyInpaintPipeline,
|
||||
KandinskyPipeline,
|
||||
KandinskyPriorPipeline,
|
||||
)
|
||||
from .kandinsky2_2 import (
|
||||
KandinskyV22CombinedPipeline,
|
||||
KandinskyV22ControlnetImg2ImgPipeline,
|
||||
KandinskyV22ControlnetPipeline,
|
||||
KandinskyV22Img2ImgCombinedPipeline,
|
||||
KandinskyV22Img2ImgPipeline,
|
||||
KandinskyV22InpaintCombinedPipeline,
|
||||
KandinskyV22InpaintPipeline,
|
||||
KandinskyV22Pipeline,
|
||||
KandinskyV22PriorEmb2EmbPipeline,
|
||||
KandinskyV22PriorPipeline,
|
||||
)
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .musicldm import MusicLDMPipeline
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
from .stable_diffusion import (
|
||||
CLIPImageProjection,
|
||||
CycleDiffusionPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionDiffEditPipeline,
|
||||
StableDiffusionGLIGENPipeline,
|
||||
StableDiffusionGLIGENTextImagePipeline,
|
||||
StableDiffusionImageVariationPipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionInstructPix2PixPipeline,
|
||||
StableDiffusionLatentUpscalePipeline,
|
||||
StableDiffusionLDM3DPipeline,
|
||||
StableDiffusionModelEditingPipeline,
|
||||
StableDiffusionPanoramaPipeline,
|
||||
StableDiffusionParadigmsPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
StableDiffusionSAGPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .stable_diffusion_xl import (
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline
|
||||
from .text_to_video_synthesis import (
|
||||
TextToVideoSDPipeline,
|
||||
TextToVideoZeroPipeline,
|
||||
VideoToVideoSDPipeline,
|
||||
)
|
||||
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
|
||||
from .unidiffuser import (
|
||||
ImageTextPipelineOutput,
|
||||
UniDiffuserModel,
|
||||
UniDiffuserPipeline,
|
||||
UniDiffuserTextDecoder,
|
||||
)
|
||||
from .versatile_diffusion import (
|
||||
VersatileDiffusionDualGuidedPipeline,
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
VersatileDiffusionPipeline,
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
)
|
||||
from .vq_diffusion import VQDiffusionPipeline
|
||||
from .wuerstchen import (
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_onnx_objects import * # noqa F403
|
||||
|
||||
else:
|
||||
from .onnx_utils import OnnxRuntimeModel
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_onnx_objects import *
|
||||
else:
|
||||
from .stable_diffusion import (
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
OnnxStableDiffusionInpaintPipelineLegacy,
|
||||
OnnxStableDiffusionPipeline,
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
StableDiffusionOnnxPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
|
||||
else:
|
||||
from .stable_diffusion import StableDiffusionKDiffusionPipeline
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_flax_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_flax_utils import FlaxDiffusionPipeline
|
||||
|
||||
try:
|
||||
if not (is_flax_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_flax_and_transformers_objects import *
|
||||
else:
|
||||
from .controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
from .stable_diffusion import (
|
||||
FlaxStableDiffusionImg2ImgPipeline,
|
||||
FlaxStableDiffusionInpaintPipeline,
|
||||
FlaxStableDiffusionPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
|
||||
|
||||
else:
|
||||
from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_k_diffusion_available,
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
# These modules contain pipelines from multiple libraries/frameworks
|
||||
_dummy_objects = {}
|
||||
_import_structure = {"stable_diffusion": [], "latent_diffusion": [], "controlnet": []}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_pt_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
|
||||
else:
|
||||
_import_structure["auto_pipeline"] = [
|
||||
"AutoPipelineForImage2Image",
|
||||
"AutoPipelineForInpainting",
|
||||
"AutoPipelineForText2Image",
|
||||
]
|
||||
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
|
||||
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]
|
||||
_import_structure["ddim"] = ["DDIMPipeline"]
|
||||
_import_structure["ddpm"] = ["DDPMPipeline"]
|
||||
_import_structure["dit"] = ["DiTPipeline"]
|
||||
_import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"])
|
||||
_import_structure["latent_diffusion_uncond"] = ["LDMPipeline"]
|
||||
_import_structure["pipeline_utils"] = ["AudioPipelineOutput", "DiffusionPipeline", "ImagePipelineOutput"]
|
||||
_import_structure["pndm"] = ["PNDMPipeline"]
|
||||
_import_structure["repaint"] = ["RePaintPipeline"]
|
||||
_import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"]
|
||||
_import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"]
|
||||
try:
|
||||
if not (is_torch_available() and is_librosa_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_torch_and_librosa_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects))
|
||||
else:
|
||||
_import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"]
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"]
|
||||
_import_structure["audioldm"] = ["AudioLDMPipeline"]
|
||||
_import_structure["audioldm2"] = [
|
||||
"AudioLDM2Pipeline",
|
||||
"AudioLDM2ProjectionModel",
|
||||
"AudioLDM2UNet2DConditionModel",
|
||||
]
|
||||
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
|
||||
_import_structure["controlnet"].extend(
|
||||
[
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
"StableDiffusionXLControlNetInpaintPipeline",
|
||||
"StableDiffusionXLControlNetPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["deepfloyd_if"] = [
|
||||
"IFImg2ImgPipeline",
|
||||
"IFImg2ImgSuperResolutionPipeline",
|
||||
"IFInpaintingPipeline",
|
||||
"IFInpaintingSuperResolutionPipeline",
|
||||
"IFPipeline",
|
||||
"IFSuperResolutionPipeline",
|
||||
]
|
||||
_import_structure["kandinsky"] = [
|
||||
"KandinskyCombinedPipeline",
|
||||
"KandinskyImg2ImgCombinedPipeline",
|
||||
"KandinskyImg2ImgPipeline",
|
||||
"KandinskyInpaintCombinedPipeline",
|
||||
"KandinskyInpaintPipeline",
|
||||
"KandinskyPipeline",
|
||||
"KandinskyPriorPipeline",
|
||||
]
|
||||
_import_structure["kandinsky2_2"] = [
|
||||
"KandinskyV22CombinedPipeline",
|
||||
"KandinskyV22ControlnetImg2ImgPipeline",
|
||||
"KandinskyV22ControlnetPipeline",
|
||||
"KandinskyV22Img2ImgCombinedPipeline",
|
||||
"KandinskyV22Img2ImgPipeline",
|
||||
"KandinskyV22InpaintCombinedPipeline",
|
||||
"KandinskyV22InpaintPipeline",
|
||||
"KandinskyV22Pipeline",
|
||||
"KandinskyV22PriorEmb2EmbPipeline",
|
||||
"KandinskyV22PriorPipeline",
|
||||
]
|
||||
_import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"])
|
||||
_import_structure["musicldm"] = ["MusicLDMPipeline"]
|
||||
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
|
||||
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
|
||||
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
|
||||
_import_structure["stable_diffusion"].extend(
|
||||
[
|
||||
"CLIPImageProjection",
|
||||
"CycleDiffusionPipeline",
|
||||
"StableDiffusionAttendAndExcitePipeline",
|
||||
"StableDiffusionDepth2ImgPipeline",
|
||||
"StableDiffusionDiffEditPipeline",
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
"StableDiffusionGLIGENTextImagePipeline",
|
||||
"StableDiffusionImageVariationPipeline",
|
||||
"StableDiffusionImg2ImgPipeline",
|
||||
"StableDiffusionInpaintPipeline",
|
||||
"StableDiffusionInpaintPipelineLegacy",
|
||||
"StableDiffusionInstructPix2PixPipeline",
|
||||
"StableDiffusionLatentUpscalePipeline",
|
||||
"StableDiffusionLDM3DPipeline",
|
||||
"StableDiffusionModelEditingPipeline",
|
||||
"StableDiffusionPanoramaPipeline",
|
||||
"StableDiffusionParadigmsPipeline",
|
||||
"StableDiffusionPipeline",
|
||||
"StableDiffusionPix2PixZeroPipeline",
|
||||
"StableDiffusionSAGPipeline",
|
||||
"StableDiffusionUpscalePipeline",
|
||||
"StableUnCLIPImg2ImgPipeline",
|
||||
"StableUnCLIPPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
|
||||
_import_structure["stable_diffusion_xl"] = [
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
"StableDiffusionXLPipeline",
|
||||
]
|
||||
_import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"]
|
||||
_import_structure["text_to_video_synthesis"] = [
|
||||
"TextToVideoSDPipeline",
|
||||
"TextToVideoZeroPipeline",
|
||||
"VideoToVideoSDPipeline",
|
||||
]
|
||||
_import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"]
|
||||
_import_structure["unidiffuser"] = [
|
||||
"ImageTextPipelineOutput",
|
||||
"UniDiffuserModel",
|
||||
"UniDiffuserPipeline",
|
||||
"UniDiffuserTextDecoder",
|
||||
]
|
||||
_import_structure["versatile_diffusion"] = [
|
||||
"VersatileDiffusionDualGuidedPipeline",
|
||||
"VersatileDiffusionImageVariationPipeline",
|
||||
"VersatileDiffusionPipeline",
|
||||
"VersatileDiffusionTextToImagePipeline",
|
||||
]
|
||||
_import_structure["vq_diffusion"] = ["VQDiffusionPipeline"]
|
||||
_import_structure["wuerstchen"] = [
|
||||
"WuerstchenCombinedPipeline",
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
]
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_onnx_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_onnx_objects))
|
||||
else:
|
||||
_import_structure["onnx_utils"] = ["OnnxRuntimeModel"]
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects))
|
||||
else:
|
||||
_import_structure["stable_diffusion"].extend(
|
||||
[
|
||||
"OnnxStableDiffusionImg2ImgPipeline",
|
||||
"OnnxStableDiffusionInpaintPipeline",
|
||||
"OnnxStableDiffusionInpaintPipelineLegacy",
|
||||
"OnnxStableDiffusionPipeline",
|
||||
"OnnxStableDiffusionUpscalePipeline",
|
||||
"StableDiffusionOnnxPipeline",
|
||||
]
|
||||
)
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
|
||||
else:
|
||||
_import_structure["stable_diffusion"].extend(["StableDiffusionKDiffusionPipeline"])
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_flax_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_flax_objects))
|
||||
else:
|
||||
_import_structure["pipeline_flax_utils"] = ["FlaxDiffusionPipeline"]
|
||||
try:
|
||||
if not (is_flax_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_flax_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["controlnet"].extend(["FlaxStableDiffusionControlNetPipeline"])
|
||||
_import_structure["stable_diffusion"].extend(
|
||||
[
|
||||
"FlaxStableDiffusionImg2ImgPipeline",
|
||||
"FlaxStableDiffusionInpaintPipeline",
|
||||
"FlaxStableDiffusionPipeline",
|
||||
]
|
||||
)
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects))
|
||||
else:
|
||||
_import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_pt_objects import * # noqa F403
|
||||
|
||||
else:
|
||||
from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image
|
||||
from .consistency_models import ConsistencyModelPipeline
|
||||
from .dance_diffusion import DanceDiffusionPipeline
|
||||
from .ddim import DDIMPipeline
|
||||
from .ddpm import DDPMPipeline
|
||||
from .dit import DiTPipeline
|
||||
from .latent_diffusion import LDMSuperResolutionPipeline
|
||||
from .latent_diffusion_uncond import LDMPipeline
|
||||
from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput
|
||||
from .pndm import PNDMPipeline
|
||||
from .repaint import RePaintPipeline
|
||||
from .score_sde_ve import ScoreSdeVePipeline
|
||||
from .stochastic_karras_ve import KarrasVePipeline
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_librosa_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_librosa_objects import *
|
||||
else:
|
||||
from .audio_diffusion import AudioDiffusionPipeline, Mel
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
|
||||
from .audioldm import AudioLDMPipeline
|
||||
from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
from .blip_diffusion import BlipDiffusionPipeline
|
||||
from .controlnet import (
|
||||
BlipDiffusionControlNetPipeline,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
)
|
||||
from .deepfloyd_if import (
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
IFInpaintingPipeline,
|
||||
IFInpaintingSuperResolutionPipeline,
|
||||
IFPipeline,
|
||||
IFSuperResolutionPipeline,
|
||||
)
|
||||
from .kandinsky import (
|
||||
KandinskyCombinedPipeline,
|
||||
KandinskyImg2ImgCombinedPipeline,
|
||||
KandinskyImg2ImgPipeline,
|
||||
KandinskyInpaintCombinedPipeline,
|
||||
KandinskyInpaintPipeline,
|
||||
KandinskyPipeline,
|
||||
KandinskyPriorPipeline,
|
||||
)
|
||||
from .kandinsky2_2 import (
|
||||
KandinskyV22CombinedPipeline,
|
||||
KandinskyV22ControlnetImg2ImgPipeline,
|
||||
KandinskyV22ControlnetPipeline,
|
||||
KandinskyV22Img2ImgCombinedPipeline,
|
||||
KandinskyV22Img2ImgPipeline,
|
||||
KandinskyV22InpaintCombinedPipeline,
|
||||
KandinskyV22InpaintPipeline,
|
||||
KandinskyV22Pipeline,
|
||||
KandinskyV22PriorEmb2EmbPipeline,
|
||||
KandinskyV22PriorPipeline,
|
||||
)
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .musicldm import MusicLDMPipeline
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
from .stable_diffusion import (
|
||||
CLIPImageProjection,
|
||||
CycleDiffusionPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionDiffEditPipeline,
|
||||
StableDiffusionGLIGENPipeline,
|
||||
StableDiffusionGLIGENTextImagePipeline,
|
||||
StableDiffusionImageVariationPipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionInstructPix2PixPipeline,
|
||||
StableDiffusionLatentUpscalePipeline,
|
||||
StableDiffusionLDM3DPipeline,
|
||||
StableDiffusionModelEditingPipeline,
|
||||
StableDiffusionPanoramaPipeline,
|
||||
StableDiffusionParadigmsPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
StableDiffusionSAGPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .stable_diffusion_xl import (
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline
|
||||
from .text_to_video_synthesis import (
|
||||
TextToVideoSDPipeline,
|
||||
TextToVideoZeroPipeline,
|
||||
VideoToVideoSDPipeline,
|
||||
)
|
||||
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
|
||||
from .unidiffuser import (
|
||||
ImageTextPipelineOutput,
|
||||
UniDiffuserModel,
|
||||
UniDiffuserPipeline,
|
||||
UniDiffuserTextDecoder,
|
||||
)
|
||||
from .versatile_diffusion import (
|
||||
VersatileDiffusionDualGuidedPipeline,
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
VersatileDiffusionPipeline,
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
)
|
||||
from .vq_diffusion import VQDiffusionPipeline
|
||||
from .wuerstchen import (
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_onnx_objects import * # noqa F403
|
||||
|
||||
else:
|
||||
from .onnx_utils import OnnxRuntimeModel
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_onnx_objects import *
|
||||
else:
|
||||
from .stable_diffusion import (
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
OnnxStableDiffusionInpaintPipeline,
|
||||
OnnxStableDiffusionInpaintPipelineLegacy,
|
||||
OnnxStableDiffusionPipeline,
|
||||
OnnxStableDiffusionUpscalePipeline,
|
||||
StableDiffusionOnnxPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
|
||||
else:
|
||||
from .stable_diffusion import StableDiffusionKDiffusionPipeline
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_flax_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_flax_utils import FlaxDiffusionPipeline
|
||||
|
||||
try:
|
||||
if not (is_flax_available() and is_transformers_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_flax_and_transformers_objects import *
|
||||
else:
|
||||
from .controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
from .stable_diffusion import (
|
||||
FlaxStableDiffusionImg2ImgPipeline,
|
||||
FlaxStableDiffusionInpaintPipeline,
|
||||
FlaxStableDiffusionPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
|
||||
|
||||
else:
|
||||
from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
|
||||
20
src/diffusers/pipelines/blip_diffusion/__init__.py
Normal file
20
src/diffusers/pipelines/blip_diffusion/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
from PIL import Image
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline
|
||||
else:
|
||||
from .blip_image_processing import BlipImageProcessor
|
||||
from .modeling_blip2 import Blip2QFormerModel
|
||||
from .modeling_ctx_clip import ContextCLIPTextModel
|
||||
from .pipeline_blip_diffusion import BlipDiffusionPipeline
|
||||
318
src/diffusers/pipelines/blip_diffusion/blip_image_processing.py
Normal file
318
src/diffusers/pipelines/blip_diffusion/blip_image_processing.py
Normal file
@@ -0,0 +1,318 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Image processor class for BLIP."""
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||
from transformers.image_transforms import convert_to_rgb, resize, to_channel_dimension_format
|
||||
from transformers.image_utils import (
|
||||
OPENAI_CLIP_MEAN,
|
||||
OPENAI_CLIP_STD,
|
||||
ChannelDimension,
|
||||
ImageInput,
|
||||
PILImageResampling,
|
||||
infer_channel_dimension_format,
|
||||
is_scaled_image,
|
||||
make_list_of_images,
|
||||
to_numpy_array,
|
||||
valid_images,
|
||||
)
|
||||
from transformers.utils import TensorType, is_vision_available, logging
|
||||
|
||||
from diffusers.utils import numpy_to_pil
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# We needed some extra functions on top of the ones in transformers.image_processing_utils.BaseImageProcessor, namely center crop
|
||||
# Copy-pasted from transformers.models.blip.image_processing_blip.BlipImageProcessor
|
||||
class BlipImageProcessor(BaseImageProcessor):
|
||||
r"""
|
||||
Constructs a BLIP image processor.
|
||||
|
||||
Args:
|
||||
do_resize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
|
||||
`do_resize` parameter in the `preprocess` method.
|
||||
size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
|
||||
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
|
||||
method.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
|
||||
overridden by the `resample` parameter in the `preprocess` method.
|
||||
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||
Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
|
||||
`do_rescale` parameter in the `preprocess` method.
|
||||
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
|
||||
overridden by the `rescale_factor` parameter in the `preprocess` method.
|
||||
do_normalize (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
|
||||
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
|
||||
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
|
||||
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
|
||||
overridden by the `image_mean` parameter in the `preprocess` method.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
|
||||
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
|
||||
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
Can be overridden by the `image_std` parameter in the `preprocess` method.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert the image to RGB.
|
||||
"""
|
||||
|
||||
model_input_names = ["pixel_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
size: Dict[str, int] = None,
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
do_rescale: bool = True,
|
||||
rescale_factor: Union[int, float] = 1 / 255,
|
||||
do_normalize: bool = True,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
do_convert_rgb: bool = True,
|
||||
do_center_crop: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
size = size if size is not None else {"height": 224, "width": 224}
|
||||
size = get_size_dict(size, default_to_square=True)
|
||||
|
||||
self.do_resize = do_resize
|
||||
self.size = size
|
||||
self.resample = resample
|
||||
self.do_rescale = do_rescale
|
||||
self.rescale_factor = rescale_factor
|
||||
self.do_normalize = do_normalize
|
||||
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
|
||||
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
|
||||
self.do_convert_rgb = do_convert_rgb
|
||||
self.do_center_crop = do_center_crop
|
||||
|
||||
# Copy-pasted from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
|
||||
def resize(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
size: Dict[str, int],
|
||||
resample: PILImageResampling = PILImageResampling.BICUBIC,
|
||||
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resize an image to `(size["height"], size["width"])`.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
Image to resize.
|
||||
size (`Dict[str, int]`):
|
||||
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
|
||||
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the output image. If unset, the channel dimension format of the input
|
||||
image is used. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The resized image.
|
||||
"""
|
||||
size = get_size_dict(size)
|
||||
if "height" not in size or "width" not in size:
|
||||
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
|
||||
output_size = (size["height"], size["width"])
|
||||
return resize(
|
||||
image,
|
||||
size=output_size,
|
||||
resample=resample,
|
||||
data_format=data_format,
|
||||
input_data_format=input_data_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
images: ImageInput,
|
||||
do_resize: Optional[bool] = None,
|
||||
size: Optional[Dict[str, int]] = None,
|
||||
resample: PILImageResampling = None,
|
||||
do_rescale: Optional[bool] = None,
|
||||
do_center_crop: Optional[bool] = None,
|
||||
rescale_factor: Optional[float] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
image_mean: Optional[Union[float, List[float]]] = None,
|
||||
image_std: Optional[Union[float, List[float]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
do_convert_rgb: bool = None,
|
||||
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
**kwargs,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Preprocess an image or batch of images.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
|
||||
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
|
||||
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||
Whether to resize the image.
|
||||
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||
Controls the size of the image after `resize`. The shortest edge of the image is resized to
|
||||
`size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
|
||||
is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
|
||||
edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
|
||||
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
|
||||
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||
Whether to rescale the image values between [0 - 1].
|
||||
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
|
||||
Whether to normalize the image.
|
||||
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
|
||||
Image mean to normalize the image by if `do_normalize` is set to `True`.
|
||||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
|
||||
Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
|
||||
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
|
||||
Whether to convert the image to RGB.
|
||||
return_tensors (`str` or `TensorType`, *optional*):
|
||||
The type of tensors to return. Can be one of:
|
||||
- Unset: Return a list of `np.ndarray`.
|
||||
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||
The channel dimension format for the output image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- Unset: Use the channel dimension format of the input image.
|
||||
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||
from the input image. Can be one of:
|
||||
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||
"""
|
||||
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||
resample = resample if resample is not None else self.resample
|
||||
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
|
||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||
image_std = image_std if image_std is not None else self.image_std
|
||||
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
|
||||
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
|
||||
|
||||
size = size if size is not None else self.size
|
||||
size = get_size_dict(size, default_to_square=False)
|
||||
images = make_list_of_images(images)
|
||||
|
||||
if not valid_images(images):
|
||||
raise ValueError(
|
||||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||
)
|
||||
|
||||
if do_resize and size is None or resample is None:
|
||||
raise ValueError("Size and resample must be specified if do_resize is True.")
|
||||
|
||||
if do_rescale and rescale_factor is None:
|
||||
raise ValueError("Rescale factor must be specified if do_rescale is True.")
|
||||
|
||||
if do_normalize and (image_mean is None or image_std is None):
|
||||
raise ValueError("Image mean and std must be specified if do_normalize is True.")
|
||||
|
||||
# PIL RGBA images are converted to RGB
|
||||
if do_convert_rgb:
|
||||
images = [convert_to_rgb(image) for image in images]
|
||||
|
||||
# All transformations expect numpy arrays.
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if is_scaled_image(images[0]) and do_rescale:
|
||||
logger.warning_once(
|
||||
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||
)
|
||||
if input_data_format is None:
|
||||
# We assume that all images have the same channel dimension format.
|
||||
input_data_format = infer_channel_dimension_format(images[0])
|
||||
|
||||
if do_resize:
|
||||
images = [
|
||||
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
|
||||
if do_rescale:
|
||||
images = [
|
||||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
if do_normalize:
|
||||
images = [
|
||||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
|
||||
for image in images
|
||||
]
|
||||
if do_center_crop:
|
||||
images = [self.center_crop(image, size, input_data_format=input_data_format) for image in images]
|
||||
|
||||
images = [
|
||||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
|
||||
]
|
||||
|
||||
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
|
||||
return encoded_outputs
|
||||
|
||||
# Follows diffusers.VaeImageProcessor.postprocess
|
||||
def postprocess(self, sample: torch.FloatTensor, output_type: str = "pil"):
|
||||
if output_type not in ["pt", "np", "pil"]:
|
||||
raise ValueError(
|
||||
f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
|
||||
)
|
||||
|
||||
# Equivalent to diffusers.VaeImageProcessor.denormalize
|
||||
sample = (sample / 2 + 0.5).clamp(0, 1)
|
||||
if output_type == "pt":
|
||||
return sample
|
||||
|
||||
# Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
|
||||
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "np":
|
||||
return sample
|
||||
# Output_type must be 'pil'
|
||||
sample = numpy_to_pil(sample)
|
||||
return sample
|
||||
642
src/diffusers/pipelines/blip_diffusion/modeling_blip2.py
Normal file
642
src/diffusers/pipelines/blip_diffusion/modeling_blip2.py
Normal file
@@ -0,0 +1,642 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from transformers import BertTokenizer
|
||||
from transformers.activations import QuickGELUActivation as QuickGELU
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPooling,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
)
|
||||
from transformers.models.blip_2.configuration_blip_2 import Blip2Config, Blip2VisionConfig
|
||||
from transformers.models.blip_2.modeling_blip_2 import (
|
||||
Blip2Encoder,
|
||||
Blip2PreTrainedModel,
|
||||
Blip2QFormerAttention,
|
||||
Blip2QFormerIntermediate,
|
||||
Blip2QFormerOutput,
|
||||
)
|
||||
from transformers.pytorch_utils import apply_chunking_to_forward
|
||||
from transformers.utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# There is an implementation of Blip2 in `transformers` : https://github.com/huggingface/transformers/blob/main/src/transformers/models/blip_2/modeling_blip_2.py.
|
||||
# But it doesn't support getting multimodal embeddings. So, this module can be
|
||||
# replaced with a future `transformers` version supports that.
|
||||
class Blip2TextEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word and position embeddings."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
||||
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
position_ids=None,
|
||||
query_embeds=None,
|
||||
past_key_values_length=0,
|
||||
):
|
||||
if input_ids is not None:
|
||||
seq_length = input_ids.size()[1]
|
||||
else:
|
||||
seq_length = 0
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
|
||||
|
||||
if input_ids is not None:
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
if self.position_embedding_type == "absolute":
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings = embeddings + position_embeddings
|
||||
|
||||
if query_embeds is not None:
|
||||
batch_size = embeddings.shape[0]
|
||||
# repeat the query embeddings for batch size
|
||||
query_embeds = query_embeds.repeat(batch_size, 1, 1)
|
||||
embeddings = torch.cat((query_embeds, embeddings), dim=1)
|
||||
else:
|
||||
embeddings = query_embeds
|
||||
embeddings = embeddings.to(query_embeds.dtype)
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2
|
||||
class Blip2VisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: Blip2VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches + 1
|
||||
|
||||
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
|
||||
return embeddings
|
||||
|
||||
|
||||
# The Qformer encoder, which takes the visual embeddings, and the text input, to get multimodal embeddings
|
||||
class Blip2QFormerEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList(
|
||||
[Blip2QFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
query_length=0,
|
||||
):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions else None
|
||||
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for i in range(self.config.num_hidden_layers):
|
||||
layer_module = self.layer[i]
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, past_key_value, output_attentions, query_length)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(layer_module),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
query_length,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[-1],)
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if layer_module.has_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_decoder_cache,
|
||||
all_hidden_states,
|
||||
all_self_attentions,
|
||||
all_cross_attentions,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
# The layers making up the Qformer encoder
|
||||
class Blip2QFormerLayer(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = Blip2QFormerAttention(config)
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
if layer_idx % config.cross_attention_frequency == 0:
|
||||
self.crossattention = Blip2QFormerAttention(config, is_cross_attention=True)
|
||||
self.has_cross_attention = True
|
||||
else:
|
||||
self.has_cross_attention = False
|
||||
|
||||
self.intermediate = Blip2QFormerIntermediate(config)
|
||||
self.intermediate_query = Blip2QFormerIntermediate(config)
|
||||
self.output_query = Blip2QFormerOutput(config)
|
||||
self.output = Blip2QFormerOutput(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
query_length=0,
|
||||
):
|
||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||
self_attention_outputs = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
output_attentions=output_attentions,
|
||||
past_key_value=self_attn_past_key_value,
|
||||
)
|
||||
attention_output = self_attention_outputs[0]
|
||||
outputs = self_attention_outputs[1:-1]
|
||||
|
||||
present_key_value = self_attention_outputs[-1]
|
||||
|
||||
if query_length > 0:
|
||||
query_attention_output = attention_output[:, :query_length, :]
|
||||
|
||||
if self.has_cross_attention:
|
||||
if encoder_hidden_states is None:
|
||||
raise ValueError("encoder_hidden_states must be given for cross-attention layers")
|
||||
cross_attention_outputs = self.crossattention(
|
||||
query_attention_output,
|
||||
attention_mask,
|
||||
head_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
query_attention_output = cross_attention_outputs[0]
|
||||
# add cross attentions if we output attention weights
|
||||
outputs = outputs + cross_attention_outputs[1:-1]
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk_query,
|
||||
self.chunk_size_feed_forward,
|
||||
self.seq_len_dim,
|
||||
query_attention_output,
|
||||
)
|
||||
|
||||
if attention_output.shape[1] > query_length:
|
||||
layer_output_text = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk,
|
||||
self.chunk_size_feed_forward,
|
||||
self.seq_len_dim,
|
||||
attention_output[:, query_length:, :],
|
||||
)
|
||||
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
||||
else:
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk,
|
||||
self.chunk_size_feed_forward,
|
||||
self.seq_len_dim,
|
||||
attention_output,
|
||||
)
|
||||
outputs = (layer_output,) + outputs
|
||||
|
||||
outputs = outputs + (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
def feed_forward_chunk_query(self, attention_output):
|
||||
intermediate_output = self.intermediate_query(attention_output)
|
||||
layer_output = self.output_query(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
# ProjLayer used to project the multimodal Blip2 embeddings to be used in the text encoder
|
||||
class ProjLayer(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, hidden_dim, drop_p=0.1, eps=1e-12):
|
||||
super().__init__()
|
||||
|
||||
# Dense1 -> Act -> Dense2 -> Drop -> Res -> Norm
|
||||
self.dense1 = nn.Linear(in_dim, hidden_dim)
|
||||
self.act_fn = QuickGELU()
|
||||
self.dense2 = nn.Linear(hidden_dim, out_dim)
|
||||
self.dropout = nn.Dropout(drop_p)
|
||||
|
||||
self.LayerNorm = nn.LayerNorm(out_dim, eps=eps)
|
||||
|
||||
def forward(self, x):
|
||||
x_in = x
|
||||
|
||||
x = self.LayerNorm(x)
|
||||
x = self.dropout(self.dense2(self.act_fn(self.dense1(x)))) + x_in
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Blip2, BLIP->BLIP_2
|
||||
class Blip2VisionModel(Blip2PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
config_class = Blip2VisionConfig
|
||||
|
||||
def __init__(self, config: Blip2VisionConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
self.embeddings = Blip2VisionEmbeddings(config)
|
||||
self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
self.encoder = Blip2Encoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
self.post_init()
|
||||
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Blip2VisionConfig)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if pixel_values is None:
|
||||
raise ValueError("You have to specify pixel_values")
|
||||
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
hidden_states = self.pre_layernorm(hidden_states)
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
|
||||
# Qformer model, used to get multimodal embeddings from the text and image inputs
|
||||
class Blip2QFormerModel(Blip2PreTrainedModel):
|
||||
"""
|
||||
Querying Transformer (Q-Former), used in BLIP-2.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Blip2Config):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.embeddings = Blip2TextEmbeddings(config.qformer_config)
|
||||
self.visual_encoder = Blip2VisionModel(config.vision_config)
|
||||
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
||||
if not hasattr(config, "tokenizer") or config.tokenizer is None:
|
||||
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
|
||||
else:
|
||||
self.tokenizer = BertTokenizer.from_pretrained(config.tokenizer, truncation_side="right")
|
||||
self.tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
||||
self.proj_layer = ProjLayer(
|
||||
in_dim=config.qformer_config.hidden_size,
|
||||
out_dim=config.qformer_config.hidden_size,
|
||||
hidden_dim=config.qformer_config.hidden_size * 4,
|
||||
drop_p=0.1,
|
||||
eps=1e-12,
|
||||
)
|
||||
|
||||
self.encoder = Blip2QFormerEncoder(config.qformer_config)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""
|
||||
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
||||
class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def get_extended_attention_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_shape: Tuple[int],
|
||||
device: torch.device,
|
||||
has_query: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
||||
|
||||
Arguments:
|
||||
attention_mask (`torch.Tensor`):
|
||||
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
||||
input_shape (`Tuple[int]`):
|
||||
The shape of the input to the model.
|
||||
device (`torch.device`):
|
||||
The device of the input to the model.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
|
||||
"""
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
if attention_mask.dim() == 3:
|
||||
extended_attention_mask = attention_mask[:, None, :, :]
|
||||
elif attention_mask.dim() == 2:
|
||||
# Provided a padding mask of dimensions [batch_size, seq_length]
|
||||
# - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
extended_attention_mask = attention_mask[:, None, None, :]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
||||
input_shape, attention_mask.shape
|
||||
)
|
||||
)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
return extended_attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_input=None,
|
||||
image_input=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||
the model is configured as a decoder.
|
||||
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||||
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
|
||||
shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
|
||||
value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
|
||||
used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
|
||||
value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
|
||||
`(batch_size, sequence_length)`.
|
||||
use_cache (`bool`, `optional`):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||||
`past_key_values`).
|
||||
"""
|
||||
|
||||
text = self.tokenizer(text_input, return_tensors="pt", padding=True)
|
||||
text = text.to(self.device)
|
||||
input_ids = text.input_ids
|
||||
batch_size = input_ids.shape[0]
|
||||
query_atts = torch.ones((batch_size, self.query_tokens.size()[1]), dtype=torch.long).to(self.device)
|
||||
attention_mask = torch.cat([query_atts, text.attention_mask], dim=1)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# past_key_values_length
|
||||
past_key_values_length = (
|
||||
past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
|
||||
)
|
||||
|
||||
query_length = self.query_tokens.shape[1]
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
query_embeds=self.query_tokens,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
# embedding_output = self.layernorm(query_embeds)
|
||||
# embedding_output = self.dropout(embedding_output)
|
||||
|
||||
input_shape = embedding_output.size()[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
device = embedding_output.device
|
||||
|
||||
image_embeds_frozen = self.visual_encoder(image_input).last_hidden_state
|
||||
# image_embeds_frozen = torch.ones_like(image_embeds_frozen)
|
||||
encoder_hidden_states = image_embeds_frozen
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||
if encoder_hidden_states is not None:
|
||||
if isinstance(encoder_hidden_states, list):
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
||||
else:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
|
||||
if isinstance(encoder_attention_mask, list):
|
||||
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
||||
elif encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.qformer_config.num_hidden_layers)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
query_length=query_length,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = sequence_output[:, 0, :]
|
||||
|
||||
if not return_dict:
|
||||
return self.proj_layer(sequence_output[:, :query_length, :])
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
cross_attentions=encoder_outputs.cross_attentions,
|
||||
)
|
||||
212
src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py
Normal file
212
src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py
Normal file
@@ -0,0 +1,212 @@
|
||||
# Copyright 2023 Salesforce.com, inc.
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import CLIPPreTrainedModel
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
||||
from transformers.models.clip.configuration_clip import CLIPTextConfig
|
||||
from transformers.models.clip.modeling_clip import (
|
||||
CLIPEncoder,
|
||||
_expand_mask,
|
||||
)
|
||||
|
||||
|
||||
# This is a modified version of the CLIPTextModel from transformers.models.clip.modeling_clip
|
||||
# Which allows for an extra input of "context embeddings", which are the query embeddings used in Qformer
|
||||
# They pass through the clip model, along with the text embeddings, and interact with them using self attention
|
||||
class ContextCLIPTextModel(CLIPPreTrainedModel):
|
||||
config_class = CLIPTextConfig
|
||||
|
||||
_no_split_modules = ["CLIPEncoderLayer"]
|
||||
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__(config)
|
||||
self.text_model = ContextCLIPTextTransformer(config)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ctx_embeddings: torch.Tensor = None,
|
||||
ctx_begin_pos: list = None,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
return self.text_model(
|
||||
ctx_embeddings=ctx_embeddings,
|
||||
ctx_begin_pos=ctx_begin_pos,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
|
||||
class ContextCLIPTextTransformer(nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
self.embeddings = ContextCLIPTextEmbeddings(config)
|
||||
self.encoder = CLIPEncoder(config)
|
||||
self.final_layer_norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ctx_embeddings: torch.Tensor,
|
||||
ctx_begin_pos: list,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is None:
|
||||
raise ValueError("You have to specify either input_ids")
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
ctx_embeddings=ctx_embeddings,
|
||||
ctx_begin_pos=ctx_begin_pos,
|
||||
)
|
||||
|
||||
bsz, seq_len = input_shape
|
||||
if ctx_embeddings is not None:
|
||||
seq_len += ctx_embeddings.size(1)
|
||||
# CLIP's text model uses causal mask, prepare it here.
|
||||
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
||||
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
|
||||
hidden_states.device
|
||||
)
|
||||
# expand attention_mask
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
causal_attention_mask=causal_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs[0]
|
||||
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
||||
|
||||
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0], device=input_ids.device),
|
||||
input_ids.to(torch.int).argmax(dim=-1),
|
||||
]
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
|
||||
# lazily create causal attention mask, with full attention between the vision tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
|
||||
mask.fill_(torch.tensor(torch.finfo(dtype).min))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
mask = mask.unsqueeze(1) # expand mask
|
||||
return mask
|
||||
|
||||
|
||||
class ContextCLIPTextEmbeddings(nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
super().__init__()
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
||||
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
|
||||
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ctx_embeddings: torch.Tensor,
|
||||
ctx_begin_pos: list,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if ctx_embeddings is None:
|
||||
ctx_len = 0
|
||||
else:
|
||||
ctx_len = ctx_embeddings.shape[1]
|
||||
|
||||
seq_length = (input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]) + ctx_len
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, :seq_length]
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.token_embedding(input_ids)
|
||||
|
||||
# for each input embeddings, add the ctx embeddings at the correct position
|
||||
input_embeds_ctx = []
|
||||
bsz = inputs_embeds.shape[0]
|
||||
|
||||
if ctx_embeddings is not None:
|
||||
for i in range(bsz):
|
||||
cbp = ctx_begin_pos[i]
|
||||
|
||||
prefix = inputs_embeds[i, :cbp]
|
||||
# remove the special token embedding
|
||||
suffix = inputs_embeds[i, cbp:]
|
||||
|
||||
input_embeds_ctx.append(torch.cat([prefix, ctx_embeddings[i], suffix], dim=0))
|
||||
|
||||
inputs_embeds = torch.stack(input_embeds_ctx, dim=0)
|
||||
|
||||
position_embeddings = self.position_embedding(position_ids)
|
||||
embeddings = inputs_embeds + position_embeddings
|
||||
|
||||
return embeddings
|
||||
@@ -0,0 +1,339 @@
|
||||
# Copyright 2023 Salesforce.com, inc.
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import PNDMScheduler
|
||||
from ...utils import (
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from .blip_image_processing import BlipImageProcessor
|
||||
from .modeling_blip2 import Blip2QFormerModel
|
||||
from .modeling_ctx_clip import ContextCLIPTextModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers.pipelines import BlipDiffusionPipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> import torch
|
||||
|
||||
>>> blip_diffusion_pipe = BlipDiffusionPipeline.from_pretrained(
|
||||
... "Salesforce/blipdiffusion", torch_dtype=torch.float16
|
||||
... ).to("cuda")
|
||||
|
||||
|
||||
>>> cond_subject = "dog"
|
||||
>>> tgt_subject = "dog"
|
||||
>>> text_prompt_input = "swimming underwater"
|
||||
|
||||
>>> cond_image = load_image(
|
||||
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/dog.jpg"
|
||||
... )
|
||||
>>> guidance_scale = 7.5
|
||||
>>> num_inference_steps = 25
|
||||
>>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
|
||||
|
||||
|
||||
>>> output = blip_diffusion_pipe(
|
||||
... text_prompt_input,
|
||||
... cond_image,
|
||||
... cond_subject,
|
||||
... tgt_subject,
|
||||
... guidance_scale=guidance_scale,
|
||||
... num_inference_steps=num_inference_steps,
|
||||
... neg_prompt=negative_prompt,
|
||||
... height=512,
|
||||
... width=512,
|
||||
... ).images
|
||||
>>> output[0].save("image.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class BlipDiffusionPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for Zero-Shot Subject Driven Generation using Blip Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
tokenizer ([`CLIPTokenizer`]):
|
||||
Tokenizer for the text encoder
|
||||
text_encoder ([`ContextCLIPTextModel`]):
|
||||
Text encoder to encode the text prompt
|
||||
vae ([`AutoencoderKL`]):
|
||||
VAE model to map the latents to the image
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
Conditional U-Net architecture to denoise the image embedding.
|
||||
scheduler ([`PNDMScheduler`]):
|
||||
A scheduler to be used in combination with `unet` to generate image latents.
|
||||
qformer ([`Blip2QFormerModel`]):
|
||||
QFormer model to get multi-modal embeddings from the text and image.
|
||||
image_processor ([`BlipImageProcessor`]):
|
||||
Image Processor to preprocess and postprocess the image.
|
||||
ctx_begin_pos (int, `optional`, defaults to 2):
|
||||
Position of the context token in the text encoder.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: ContextCLIPTextModel,
|
||||
vae: AutoencoderKL,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: PNDMScheduler,
|
||||
qformer: Blip2QFormerModel,
|
||||
image_processor: BlipImageProcessor,
|
||||
ctx_begin_pos: int = 2,
|
||||
mean: List[float] = None,
|
||||
std: List[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
qformer=qformer,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std)
|
||||
|
||||
def get_query_embeddings(self, input_image, src_subject):
|
||||
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
|
||||
|
||||
# from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it
|
||||
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
|
||||
rv = []
|
||||
for prompt, tgt_subject in zip(prompts, tgt_subjects):
|
||||
prompt = f"a {tgt_subject} {prompt.strip()}"
|
||||
# a trick to amplify the prompt
|
||||
rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps)))
|
||||
|
||||
return rv
|
||||
|
||||
# Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels, height, width)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def encode_prompt(self, query_embeds, prompt):
|
||||
# embeddings for prompt, with query_embeds as context
|
||||
max_len = self.text_encoder.text_model.config.max_position_embeddings
|
||||
max_len -= self.qformer.config.num_query_tokens
|
||||
|
||||
tokenized_prompt = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_len,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
batch_size = query_embeds.shape[0]
|
||||
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
|
||||
|
||||
text_embeddings = self.text_encoder(
|
||||
input_ids=tokenized_prompt.input_ids,
|
||||
ctx_embeddings=query_embeds,
|
||||
ctx_begin_pos=ctx_begin_pos,
|
||||
)[0]
|
||||
|
||||
return text_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: List[str],
|
||||
reference_image: PIL.Image.Image,
|
||||
source_subject_category: List[str],
|
||||
target_subject_category: List[str],
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 50,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
neg_prompt: Optional[str] = "",
|
||||
prompt_strength: float = 1.0,
|
||||
prompt_reps: int = 20,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
reference_image (`PIL.Image.Image`):
|
||||
The reference image to condition the generation on.
|
||||
source_subject_category (`List[str]`):
|
||||
The source subject category.
|
||||
target_subject_category (`List[str]`):
|
||||
The target subject category.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by random sampling.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
neg_prompt (`str`, *optional*, defaults to ""):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
prompt_strength (`float`, *optional*, defaults to 1.0):
|
||||
The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps
|
||||
to amplify the prompt.
|
||||
prompt_reps (`int`, *optional*, defaults to 20):
|
||||
The number of times the prompt is repeated along with prompt_strength to amplify the prompt.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
|
||||
(`np.array`) or `"pt"` (`torch.Tensor`).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
||||
"""
|
||||
|
||||
reference_image = self.image_processor.preprocess(
|
||||
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
|
||||
)["pixel_values"]
|
||||
reference_image = reference_image.to(self.device)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
if isinstance(source_subject_category, str):
|
||||
source_subject_category = [source_subject_category]
|
||||
if isinstance(target_subject_category, str):
|
||||
target_subject_category = [target_subject_category]
|
||||
|
||||
batch_size = len(prompt)
|
||||
|
||||
prompt = self._build_prompt(
|
||||
prompts=prompt,
|
||||
tgt_subjects=target_subject_category,
|
||||
prompt_strength=prompt_strength,
|
||||
prompt_reps=prompt_reps,
|
||||
)
|
||||
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
|
||||
text_embeddings = self.encode_prompt(query_embeds, prompt)
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
if do_classifier_free_guidance:
|
||||
max_length = self.text_encoder.text_model.config.max_position_embeddings
|
||||
|
||||
uncond_input = self.tokenizer(
|
||||
[neg_prompt] * batch_size,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(
|
||||
input_ids=uncond_input.input_ids.to(self.device),
|
||||
ctx_embeddings=None,
|
||||
)[0]
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1)
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
num_channels=self.unet.config.in_channels,
|
||||
height=height // scale_down_factor,
|
||||
width=width // scale_down_factor,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
dtype=self.unet.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
# set timesteps
|
||||
extra_set_kwargs = {}
|
||||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
down_block_additional_residuals=None,
|
||||
mid_block_additional_residual=None,
|
||||
)["sample"]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(
|
||||
noise_pred,
|
||||
t,
|
||||
latents,
|
||||
)["prev_sample"]
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
@@ -1,77 +1,79 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["multicontrolnet"] = ["MultiControlNetModel"]
|
||||
_import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
|
||||
_import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
|
||||
_import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
|
||||
_import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_flax_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .multicontrolnet import MultiControlNetModel
|
||||
from .pipeline_controlnet import StableDiffusionControlNetPipeline
|
||||
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
|
||||
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
|
||||
from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
|
||||
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
|
||||
from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["multicontrolnet"] = ["MultiControlNetModel"]
|
||||
_import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
|
||||
_import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
|
||||
_import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
|
||||
_import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
|
||||
_import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_flax_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .multicontrolnet import MultiControlNetModel
|
||||
from .pipeline_controlnet import StableDiffusionControlNetPipeline
|
||||
from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
|
||||
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
|
||||
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
|
||||
from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
|
||||
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
|
||||
from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
|
||||
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
|
||||
@@ -0,0 +1,405 @@
|
||||
# Copyright 2023 Salesforce.com, inc.
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...schedulers import PNDMScheduler
|
||||
from ...utils import (
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..blip_diffusion.blip_image_processing import BlipImageProcessor
|
||||
from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel
|
||||
from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers.pipelines import BlipDiffusionControlNetPipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> from controlnet_aux import CannyDetector
|
||||
>>> import torch
|
||||
|
||||
>>> blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained(
|
||||
... "Salesforce/blipdiffusion-controlnet", torch_dtype=torch.float16
|
||||
... ).to("cuda")
|
||||
|
||||
>>> style_subject = "flower"
|
||||
>>> tgt_subject = "teapot"
|
||||
>>> text_prompt = "on a marble table"
|
||||
|
||||
>>> cldm_cond_image = load_image(
|
||||
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/kettle.jpg"
|
||||
... ).resize(512, 512)
|
||||
>>> canny = CannyDetector()
|
||||
>>> cldm_cond_image = canny(cldm_cond_image, 30, 70, output_type="pil")
|
||||
>>> style_image = load_image(
|
||||
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/flower.jpg"
|
||||
... )
|
||||
>>> guidance_scale = 7.5
|
||||
>>> num_inference_steps = 50
|
||||
>>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
|
||||
|
||||
|
||||
>>> output = blip_diffusion_pipe(
|
||||
... text_prompt,
|
||||
... style_image,
|
||||
... cldm_cond_image,
|
||||
... style_subject,
|
||||
... tgt_subject,
|
||||
... guidance_scale=guidance_scale,
|
||||
... num_inference_steps=num_inference_steps,
|
||||
... neg_prompt=negative_prompt,
|
||||
... height=512,
|
||||
... width=512,
|
||||
... ).images
|
||||
>>> output[0].save("image.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class BlipDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
tokenizer ([`CLIPTokenizer`]):
|
||||
Tokenizer for the text encoder
|
||||
text_encoder ([`ContextCLIPTextModel`]):
|
||||
Text encoder to encode the text prompt
|
||||
vae ([`AutoencoderKL`]):
|
||||
VAE model to map the latents to the image
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
Conditional U-Net architecture to denoise the image embedding.
|
||||
scheduler ([`PNDMScheduler`]):
|
||||
A scheduler to be used in combination with `unet` to generate image latents.
|
||||
qformer ([`Blip2QFormerModel`]):
|
||||
QFormer model to get multi-modal embeddings from the text and image.
|
||||
controlnet ([`ControlNetModel`]):
|
||||
ControlNet model to get the conditioning image embedding.
|
||||
image_processor ([`BlipImageProcessor`]):
|
||||
Image Processor to preprocess and postprocess the image.
|
||||
ctx_begin_pos (int, `optional`, defaults to 2):
|
||||
Position of the context token in the text encoder.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: ContextCLIPTextModel,
|
||||
vae: AutoencoderKL,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: PNDMScheduler,
|
||||
qformer: Blip2QFormerModel,
|
||||
controlnet: ControlNetModel,
|
||||
image_processor: BlipImageProcessor,
|
||||
ctx_begin_pos: int = 2,
|
||||
mean: List[float] = None,
|
||||
std: List[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
qformer=qformer,
|
||||
controlnet=controlnet,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std)
|
||||
|
||||
def get_query_embeddings(self, input_image, src_subject):
|
||||
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
|
||||
|
||||
# from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it
|
||||
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
|
||||
rv = []
|
||||
for prompt, tgt_subject in zip(prompts, tgt_subjects):
|
||||
prompt = f"a {tgt_subject} {prompt.strip()}"
|
||||
# a trick to amplify the prompt
|
||||
rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps)))
|
||||
|
||||
return rv
|
||||
|
||||
# Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels, height, width)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def encode_prompt(self, query_embeds, prompt):
|
||||
# embeddings for prompt, with query_embeds as context
|
||||
max_len = self.text_encoder.text_model.config.max_position_embeddings
|
||||
max_len -= self.qformer.config.num_query_tokens
|
||||
|
||||
tokenized_prompt = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_len,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
batch_size = query_embeds.shape[0]
|
||||
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
|
||||
|
||||
text_embeddings = self.text_encoder(
|
||||
input_ids=tokenized_prompt.input_ids,
|
||||
ctx_embeddings=query_embeds,
|
||||
ctx_begin_pos=ctx_begin_pos,
|
||||
)[0]
|
||||
|
||||
return text_embeddings
|
||||
|
||||
# Adapted from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
|
||||
def prepare_control_image(
|
||||
self,
|
||||
image,
|
||||
width,
|
||||
height,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
do_classifier_free_guidance=False,
|
||||
):
|
||||
image = self.image_processor.preprocess(
|
||||
image,
|
||||
size={"width": width, "height": height},
|
||||
do_rescale=True,
|
||||
do_center_crop=False,
|
||||
do_normalize=False,
|
||||
return_tensors="pt",
|
||||
)["pixel_values"].to(self.device)
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
# image batch size is the same as prompt batch size
|
||||
repeat_by = num_images_per_prompt
|
||||
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
image = torch.cat([image] * 2)
|
||||
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: List[str],
|
||||
reference_image: PIL.Image.Image,
|
||||
condtioning_image: PIL.Image.Image,
|
||||
source_subject_category: List[str],
|
||||
target_subject_category: List[str],
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
num_inference_steps: int = 50,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
neg_prompt: Optional[str] = "",
|
||||
prompt_strength: float = 1.0,
|
||||
prompt_reps: int = 20,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
reference_image (`PIL.Image.Image`):
|
||||
The reference image to condition the generation on.
|
||||
condtioning_image (`PIL.Image.Image`):
|
||||
The conditioning canny edge image to condition the generation on.
|
||||
source_subject_category (`List[str]`):
|
||||
The source subject category.
|
||||
target_subject_category (`List[str]`):
|
||||
The target subject category.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by random sampling.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width of the generated image.
|
||||
seed (`int`, *optional*, defaults to 42):
|
||||
The seed to use for random generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
neg_prompt (`str`, *optional*, defaults to ""):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
prompt_strength (`float`, *optional*, defaults to 1.0):
|
||||
The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps
|
||||
to amplify the prompt.
|
||||
prompt_reps (`int`, *optional*, defaults to 20):
|
||||
The number of times the prompt is repeated along with prompt_strength to amplify the prompt.
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
||||
"""
|
||||
|
||||
reference_image = self.image_processor.preprocess(
|
||||
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
|
||||
)["pixel_values"]
|
||||
reference_image = reference_image.to(self.device)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
if isinstance(source_subject_category, str):
|
||||
source_subject_category = [source_subject_category]
|
||||
if isinstance(target_subject_category, str):
|
||||
target_subject_category = [target_subject_category]
|
||||
|
||||
batch_size = len(prompt)
|
||||
|
||||
prompt = self._build_prompt(
|
||||
prompts=prompt,
|
||||
tgt_subjects=target_subject_category,
|
||||
prompt_strength=prompt_strength,
|
||||
prompt_reps=prompt_reps,
|
||||
)
|
||||
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
|
||||
text_embeddings = self.encode_prompt(query_embeds, prompt)
|
||||
# 3. unconditional embedding
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
if do_classifier_free_guidance:
|
||||
max_length = self.text_encoder.text_model.config.max_position_embeddings
|
||||
|
||||
uncond_input = self.tokenizer(
|
||||
[neg_prompt] * batch_size,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(
|
||||
input_ids=uncond_input.input_ids.to(self.device),
|
||||
ctx_embeddings=None,
|
||||
)[0]
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1)
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
num_channels=self.unet.config.in_channels,
|
||||
height=height // scale_down_factor,
|
||||
width=width // scale_down_factor,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
dtype=self.unet.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
# set timesteps
|
||||
extra_set_kwargs = {}
|
||||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
||||
|
||||
cond_image = self.prepare_control_image(
|
||||
image=condtioning_image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size,
|
||||
num_images_per_prompt=1,
|
||||
device=self.device,
|
||||
dtype=self.controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
controlnet_cond=cond_image,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
)["sample"]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(
|
||||
noise_pred,
|
||||
t,
|
||||
latents,
|
||||
)["prev_sample"]
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
@@ -609,6 +609,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
|
||||
noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed)
|
||||
sampler_kwargs["noise_sampler"] = noise_sampler
|
||||
|
||||
if "generator" in inspect.signature(self.sampler).parameters:
|
||||
sampler_kwargs["generator"] = generator
|
||||
|
||||
latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs)
|
||||
|
||||
if not output_type == "latent":
|
||||
|
||||
@@ -787,8 +787,16 @@ class StableDiffusionXLAdapterPipeline(
|
||||
height, width = self._default_height_width(height, width, image)
|
||||
device = self._execution_device
|
||||
|
||||
adapter_input = _preprocess_adapter_image(image, height, width).to(device)
|
||||
if isinstance(self.adapter, MultiAdapter):
|
||||
adapter_input = []
|
||||
|
||||
for one_image in image:
|
||||
one_image = _preprocess_adapter_image(one_image, height, width)
|
||||
one_image = one_image.to(device=device, dtype=self.adapter.dtype)
|
||||
adapter_input.append(one_image)
|
||||
else:
|
||||
adapter_input = _preprocess_adapter_image(image, height, width)
|
||||
adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype)
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
@@ -865,10 +873,14 @@ class StableDiffusionXLAdapterPipeline(
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Prepare added time ids & embeddings & adapter features
|
||||
adapter_input = adapter_input.type(latents.dtype)
|
||||
adapter_state = self.adapter(adapter_input)
|
||||
for k, v in enumerate(adapter_state):
|
||||
adapter_state[k] = v * adapter_conditioning_scale
|
||||
if isinstance(self.adapter, MultiAdapter):
|
||||
adapter_state = self.adapter(adapter_input, adapter_conditioning_scale)
|
||||
for k, v in enumerate(adapter_state):
|
||||
adapter_state[k] = v
|
||||
else:
|
||||
adapter_state = self.adapter(adapter_input)
|
||||
for k, v in enumerate(adapter_state):
|
||||
adapter_state[k] = v * adapter_conditioning_scale
|
||||
if num_images_per_prompt > 1:
|
||||
for k, v in enumerate(adapter_state):
|
||||
adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
|
||||
|
||||
@@ -89,6 +89,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
`linear` or `scaled_linear`.
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
||||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
@@ -113,6 +116,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
@@ -243,9 +247,15 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
|
||||
log_sigmas = np.log(sigmas)
|
||||
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
|
||||
self.log_sigmas = torch.from_numpy(log_sigmas).to(device)
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
|
||||
@@ -269,7 +279,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
|
||||
|
||||
timesteps = torch.from_numpy(timesteps).to(device)
|
||||
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
|
||||
sigmas_interpol = sigmas_interpol.cpu()
|
||||
log_sigmas = self.log_sigmas.cpu()
|
||||
timesteps_interpol = np.array(
|
||||
[self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol]
|
||||
)
|
||||
|
||||
timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype)
|
||||
interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
|
||||
|
||||
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
|
||||
@@ -282,29 +298,44 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self._step_index = None
|
||||
|
||||
def sigma_to_t(self, sigma):
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
# get log sigma
|
||||
log_sigma = sigma.log()
|
||||
log_sigma = np.log(sigma)
|
||||
|
||||
# get distribution
|
||||
dists = log_sigma - self.log_sigmas[:, None]
|
||||
dists = log_sigma - log_sigmas[:, np.newaxis]
|
||||
|
||||
# get sigmas range
|
||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
||||
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
|
||||
low = self.log_sigmas[low_idx]
|
||||
high = self.log_sigmas[high_idx]
|
||||
low = log_sigmas[low_idx]
|
||||
high = log_sigmas[high_idx]
|
||||
|
||||
# interpolate sigmas
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = w.clamp(0, 1)
|
||||
w = np.clip(w, 0, 1)
|
||||
|
||||
# transform interpolation to time range
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
t = t.view(sigma.shape)
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
return self.sample is None
|
||||
|
||||
@@ -88,6 +88,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
`linear` or `scaled_linear`.
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
|
||||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
@@ -112,6 +115,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
@@ -243,9 +247,14 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
|
||||
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
|
||||
self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device)
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
|
||||
@@ -260,7 +269,12 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
# interpolate timesteps
|
||||
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
|
||||
sigmas_interpol = sigmas_interpol.cpu()
|
||||
log_sigmas = self.log_sigmas.cpu()
|
||||
timesteps_interpol = np.array(
|
||||
[self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol]
|
||||
)
|
||||
timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype)
|
||||
interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten()
|
||||
|
||||
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
|
||||
@@ -273,29 +287,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self._step_index = None
|
||||
|
||||
def sigma_to_t(self, sigma):
|
||||
# get log sigma
|
||||
log_sigma = sigma.log()
|
||||
|
||||
# get distribution
|
||||
dists = log_sigma - self.log_sigmas[:, None]
|
||||
|
||||
# get sigmas range
|
||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
|
||||
low = self.log_sigmas[low_idx]
|
||||
high = self.log_sigmas[high_idx]
|
||||
|
||||
# interpolate sigmas
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = w.clamp(0, 1)
|
||||
|
||||
# transform interpolation to time range
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
t = t.view(sigma.shape)
|
||||
return t
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
return self.sample is None
|
||||
@@ -318,6 +309,44 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
# get log sigma
|
||||
log_sigma = np.log(sigma)
|
||||
|
||||
# get distribution
|
||||
dists = log_sigma - log_sigmas[:, np.newaxis]
|
||||
|
||||
# get sigmas range
|
||||
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
|
||||
low = log_sigmas[low_idx]
|
||||
high = log_sigmas[high_idx]
|
||||
|
||||
# interpolate sigmas
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = np.clip(w, 0, 1)
|
||||
|
||||
# transform interpolation to time range
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
|
||||
@@ -315,6 +315,36 @@ class AutoPipelineForText2Image(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BlipDiffusionControlNetPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BlipDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class CLIPImageProjection(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
0
tests/pipelines/blipdiffusion/__init__.py
Normal file
0
tests/pipelines/blipdiffusion/__init__.py
Normal file
196
tests/pipelines/blipdiffusion/test_blipdiffusion.py
Normal file
196
tests/pipelines/blipdiffusion/test_blipdiffusion.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPTokenizer
|
||||
from transformers.models.blip_2.configuration_blip_2 import Blip2Config
|
||||
from transformers.models.clip.configuration_clip import CLIPTextConfig
|
||||
|
||||
from diffusers import AutoencoderKL, BlipDiffusionPipeline, PNDMScheduler, UNet2DConditionModel
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
from src.diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
|
||||
from src.diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
|
||||
from src.diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = BlipDiffusionPipeline
|
||||
params = [
|
||||
"prompt",
|
||||
"reference_image",
|
||||
"source_subject_category",
|
||||
"target_subject_category",
|
||||
]
|
||||
batch_params = [
|
||||
"prompt",
|
||||
"reference_image",
|
||||
"source_subject_category",
|
||||
"target_subject_category",
|
||||
]
|
||||
required_optional_params = [
|
||||
"generator",
|
||||
"height",
|
||||
"width",
|
||||
"latents",
|
||||
"guidance_scale",
|
||||
"num_inference_steps",
|
||||
"neg_prompt",
|
||||
"guidance_scale",
|
||||
"prompt_strength",
|
||||
"prompt_reps",
|
||||
]
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
vocab_size=1000,
|
||||
hidden_size=16,
|
||||
intermediate_size=16,
|
||||
projection_dim=16,
|
||||
num_hidden_layers=1,
|
||||
num_attention_heads=1,
|
||||
max_position_embeddings=77,
|
||||
)
|
||||
text_encoder = ContextCLIPTextModel(text_encoder_config)
|
||||
|
||||
vae = AutoencoderKL(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownEncoderBlock2D",),
|
||||
up_block_types=("UpDecoderBlock2D",),
|
||||
block_out_channels=(32,),
|
||||
layers_per_block=1,
|
||||
act_fn="silu",
|
||||
latent_channels=4,
|
||||
norm_num_groups=16,
|
||||
sample_size=16,
|
||||
)
|
||||
|
||||
blip_vision_config = {
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_hidden_layers": 1,
|
||||
"num_attention_heads": 1,
|
||||
"image_size": 224,
|
||||
"patch_size": 14,
|
||||
"hidden_act": "quick_gelu",
|
||||
}
|
||||
|
||||
blip_qformer_config = {
|
||||
"vocab_size": 1000,
|
||||
"hidden_size": 16,
|
||||
"num_hidden_layers": 1,
|
||||
"num_attention_heads": 1,
|
||||
"intermediate_size": 16,
|
||||
"max_position_embeddings": 512,
|
||||
"cross_attention_frequency": 1,
|
||||
"encoder_hidden_size": 16,
|
||||
}
|
||||
qformer_config = Blip2Config(
|
||||
vision_config=blip_vision_config,
|
||||
qformer_config=blip_qformer_config,
|
||||
num_query_tokens=16,
|
||||
tokenizer="hf-internal-testing/tiny-random-bert",
|
||||
)
|
||||
qformer = Blip2QFormerModel(qformer_config)
|
||||
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(16, 32),
|
||||
norm_num_groups=16,
|
||||
layers_per_block=1,
|
||||
sample_size=16,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=16,
|
||||
)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
scheduler = PNDMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
set_alpha_to_one=False,
|
||||
skip_prk_steps=True,
|
||||
)
|
||||
|
||||
vae.eval()
|
||||
qformer.eval()
|
||||
text_encoder.eval()
|
||||
|
||||
image_processor = BlipImageProcessor()
|
||||
|
||||
components = {
|
||||
"text_encoder": text_encoder,
|
||||
"vae": vae,
|
||||
"qformer": qformer,
|
||||
"unet": unet,
|
||||
"tokenizer": tokenizer,
|
||||
"scheduler": scheduler,
|
||||
"image_processor": image_processor,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
np.random.seed(seed)
|
||||
reference_image = np.random.rand(32, 32, 3) * 255
|
||||
reference_image = Image.fromarray(reference_image.astype("uint8")).convert("RGBA")
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "swimming underwater",
|
||||
"generator": generator,
|
||||
"reference_image": reference_image,
|
||||
"source_subject_category": "dog",
|
||||
"target_subject_category": "dog",
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"guidance_scale": 7.5,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_blipdiffusion(self):
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
image = pipe(**self.get_dummy_inputs(device))[0]
|
||||
image_slice = image[0, -3:, -3:, 0]
|
||||
|
||||
assert image.shape == (1, 16, 16, 4)
|
||||
|
||||
expected_slice = np.array([0.7096, 0.5900, 0.6703, 0.4032, 0.7766, 0.3629, 0.5447, 0.4149, 0.8172])
|
||||
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"
|
||||
216
tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
Normal file
216
tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
Normal file
@@ -0,0 +1,216 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPTokenizer
|
||||
from transformers.models.blip_2.configuration_blip_2 import Blip2Config
|
||||
from transformers.models.clip.configuration_clip import CLIPTextConfig
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
BlipDiffusionControlNetPipeline,
|
||||
ControlNetModel,
|
||||
PNDMScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
from src.diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
|
||||
from src.diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
|
||||
from src.diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class BlipDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = BlipDiffusionControlNetPipeline
|
||||
params = [
|
||||
"prompt",
|
||||
"reference_image",
|
||||
"source_subject_category",
|
||||
"target_subject_category",
|
||||
"condtioning_image",
|
||||
]
|
||||
batch_params = [
|
||||
"prompt",
|
||||
"reference_image",
|
||||
"source_subject_category",
|
||||
"target_subject_category",
|
||||
"condtioning_image",
|
||||
]
|
||||
required_optional_params = [
|
||||
"generator",
|
||||
"height",
|
||||
"width",
|
||||
"latents",
|
||||
"guidance_scale",
|
||||
"num_inference_steps",
|
||||
"neg_prompt",
|
||||
"guidance_scale",
|
||||
"prompt_strength",
|
||||
"prompt_reps",
|
||||
]
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
vocab_size=1000,
|
||||
hidden_size=16,
|
||||
intermediate_size=16,
|
||||
projection_dim=16,
|
||||
num_hidden_layers=1,
|
||||
num_attention_heads=1,
|
||||
max_position_embeddings=77,
|
||||
)
|
||||
text_encoder = ContextCLIPTextModel(text_encoder_config)
|
||||
|
||||
vae = AutoencoderKL(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownEncoderBlock2D",),
|
||||
up_block_types=("UpDecoderBlock2D",),
|
||||
block_out_channels=(32,),
|
||||
layers_per_block=1,
|
||||
act_fn="silu",
|
||||
latent_channels=4,
|
||||
norm_num_groups=16,
|
||||
sample_size=16,
|
||||
)
|
||||
|
||||
blip_vision_config = {
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_hidden_layers": 1,
|
||||
"num_attention_heads": 1,
|
||||
"image_size": 224,
|
||||
"patch_size": 14,
|
||||
"hidden_act": "quick_gelu",
|
||||
}
|
||||
|
||||
blip_qformer_config = {
|
||||
"vocab_size": 1000,
|
||||
"hidden_size": 16,
|
||||
"num_hidden_layers": 1,
|
||||
"num_attention_heads": 1,
|
||||
"intermediate_size": 16,
|
||||
"max_position_embeddings": 512,
|
||||
"cross_attention_frequency": 1,
|
||||
"encoder_hidden_size": 16,
|
||||
}
|
||||
qformer_config = Blip2Config(
|
||||
vision_config=blip_vision_config,
|
||||
qformer_config=blip_qformer_config,
|
||||
num_query_tokens=16,
|
||||
tokenizer="hf-internal-testing/tiny-random-bert",
|
||||
)
|
||||
qformer = Blip2QFormerModel(qformer_config)
|
||||
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(4, 16),
|
||||
layers_per_block=1,
|
||||
norm_num_groups=4,
|
||||
sample_size=16,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=16,
|
||||
)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
scheduler = PNDMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
set_alpha_to_one=False,
|
||||
skip_prk_steps=True,
|
||||
)
|
||||
controlnet = ControlNetModel(
|
||||
block_out_channels=(4, 16),
|
||||
layers_per_block=1,
|
||||
in_channels=4,
|
||||
norm_num_groups=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
cross_attention_dim=16,
|
||||
conditioning_embedding_out_channels=(8, 16),
|
||||
)
|
||||
|
||||
vae.eval()
|
||||
qformer.eval()
|
||||
text_encoder.eval()
|
||||
|
||||
image_processor = BlipImageProcessor()
|
||||
|
||||
components = {
|
||||
"text_encoder": text_encoder,
|
||||
"vae": vae,
|
||||
"qformer": qformer,
|
||||
"unet": unet,
|
||||
"tokenizer": tokenizer,
|
||||
"scheduler": scheduler,
|
||||
"controlnet": controlnet,
|
||||
"image_processor": image_processor,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
np.random.seed(seed)
|
||||
reference_image = np.random.rand(32, 32, 3) * 255
|
||||
reference_image = Image.fromarray(reference_image.astype("uint8")).convert("RGBA")
|
||||
cond_image = np.random.rand(32, 32, 3) * 255
|
||||
cond_image = Image.fromarray(cond_image.astype("uint8")).convert("RGBA")
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "swimming underwater",
|
||||
"generator": generator,
|
||||
"reference_image": reference_image,
|
||||
"condtioning_image": cond_image,
|
||||
"source_subject_category": "dog",
|
||||
"target_subject_category": "dog",
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"guidance_scale": 7.5,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_blipdiffusion_controlnet(self):
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
image = pipe(**self.get_dummy_inputs(device))[0]
|
||||
image_slice = image[0, -3:, -3:, 0]
|
||||
|
||||
assert image.shape == (1, 16, 16, 4)
|
||||
expected_slice = np.array([0.7953, 0.7136, 0.6597, 0.4779, 0.7389, 0.4111, 0.5826, 0.4150, 0.8422])
|
||||
|
||||
assert (
|
||||
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
|
||||
@@ -20,17 +20,20 @@ import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
EulerDiscreteScheduler,
|
||||
MultiAdapter,
|
||||
StableDiffusionXLAdapterPipeline,
|
||||
T2IAdapter,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -41,7 +44,7 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
def get_dummy_components(self, adapter_type="full_adapter_xl"):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
@@ -97,13 +100,38 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
|
||||
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
adapter = T2IAdapter(
|
||||
in_channels=3,
|
||||
channels=[32, 64],
|
||||
num_res_blocks=2,
|
||||
downscale_factor=4,
|
||||
adapter_type="full_adapter_xl",
|
||||
)
|
||||
if adapter_type == "full_adapter_xl":
|
||||
adapter = T2IAdapter(
|
||||
in_channels=3,
|
||||
channels=[32, 64],
|
||||
num_res_blocks=2,
|
||||
downscale_factor=4,
|
||||
adapter_type=adapter_type,
|
||||
)
|
||||
elif adapter_type == "multi_adapter":
|
||||
adapter = MultiAdapter(
|
||||
[
|
||||
T2IAdapter(
|
||||
in_channels=3,
|
||||
channels=[32, 64],
|
||||
num_res_blocks=2,
|
||||
downscale_factor=4,
|
||||
adapter_type="full_adapter_xl",
|
||||
),
|
||||
T2IAdapter(
|
||||
in_channels=3,
|
||||
channels=[32, 64],
|
||||
num_res_blocks=2,
|
||||
downscale_factor=4,
|
||||
adapter_type="full_adapter_xl",
|
||||
),
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown adapter type: {adapter_type}, must be one of 'full_adapter_xl', or 'multi_adapter''"
|
||||
)
|
||||
|
||||
components = {
|
||||
"adapter": adapter,
|
||||
"unet": unet,
|
||||
@@ -118,8 +146,12 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
|
||||
def get_dummy_inputs(self, device, seed=0, num_images=1):
|
||||
if num_images == 1:
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
|
||||
else:
|
||||
image = [floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) for _ in range(num_images)]
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
@@ -150,3 +182,202 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
[0.5752919, 0.6022097, 0.4728038, 0.49861962, 0.57084894, 0.4644975, 0.5193715, 0.5133664, 0.4729858]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
|
||||
|
||||
|
||||
class StableDiffusionXLMultiAdapterPipelineFastTests(
|
||||
StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase
|
||||
):
|
||||
def get_dummy_components(self):
|
||||
return super().get_dummy_components("multi_adapter")
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
inputs = super().get_dummy_inputs(device, seed, num_images=2)
|
||||
inputs["adapter_conditioning_scale"] = [0.5, 0.5]
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_adapter_default_case(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLAdapterPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array(
|
||||
[0.5813032, 0.60995954, 0.47563356, 0.5056669, 0.57199144, 0.4631841, 0.5176794, 0.51252556, 0.47183886]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
|
||||
|
||||
def test_inference_batch_consistent(
|
||||
self, batch_sizes=[2, 4, 13], additional_params_copy_to_batched_inputs=["num_inference_steps"]
|
||||
):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# batchify inputs
|
||||
for batch_size in batch_sizes:
|
||||
batched_inputs = {}
|
||||
for name, value in inputs.items():
|
||||
if name in self.batch_params:
|
||||
# prompt is string
|
||||
if name == "prompt":
|
||||
len_prompt = len(value)
|
||||
# make unequal batch sizes
|
||||
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
|
||||
# make last batch super long
|
||||
batched_inputs[name][-1] = 100 * "very long"
|
||||
elif name == "image":
|
||||
batched_images = []
|
||||
|
||||
for image in value:
|
||||
batched_images.append(batch_size * [image])
|
||||
|
||||
batched_inputs[name] = batched_images
|
||||
else:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
|
||||
elif name == "batch_size":
|
||||
batched_inputs[name] = batch_size
|
||||
else:
|
||||
batched_inputs[name] = value
|
||||
|
||||
for arg in additional_params_copy_to_batched_inputs:
|
||||
batched_inputs[arg] = inputs[arg]
|
||||
|
||||
batched_inputs["output_type"] = "np"
|
||||
|
||||
output = pipe(**batched_inputs)
|
||||
|
||||
assert len(output[0]) == batch_size
|
||||
|
||||
batched_inputs["output_type"] = "np"
|
||||
|
||||
output = pipe(**batched_inputs)[0]
|
||||
|
||||
assert output.shape[0] == batch_size
|
||||
|
||||
logger.setLevel(level=diffusers.logging.WARNING)
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
batch_sizes = [1, 2]
|
||||
num_images_per_prompts = [1, 2]
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for num_images_per_prompt in num_images_per_prompts:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
for key in inputs.keys():
|
||||
if key in self.batch_params:
|
||||
if key == "image":
|
||||
batched_images = []
|
||||
|
||||
for image in inputs[key]:
|
||||
batched_images.append(batch_size * [image])
|
||||
|
||||
inputs[key] = batched_images
|
||||
else:
|
||||
inputs[key] = batch_size * [inputs[key]]
|
||||
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0]
|
||||
|
||||
assert images.shape[0] == batch_size * num_images_per_prompt
|
||||
|
||||
def test_inference_batch_single_identical(
|
||||
self,
|
||||
batch_size=3,
|
||||
test_max_difference=None,
|
||||
test_mean_pixel_difference=None,
|
||||
relax_max_difference=False,
|
||||
expected_max_diff=2e-3,
|
||||
additional_params_copy_to_batched_inputs=["num_inference_steps"],
|
||||
):
|
||||
if test_max_difference is None:
|
||||
# TODO(Pedro) - not sure why, but not at all reproducible at the moment it seems
|
||||
# make sure that batched and non-batched is identical
|
||||
test_max_difference = torch_device != "mps"
|
||||
|
||||
if test_mean_pixel_difference is None:
|
||||
# TODO same as above
|
||||
test_mean_pixel_difference = torch_device != "mps"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# batchify inputs
|
||||
batched_inputs = {}
|
||||
batch_size = batch_size
|
||||
for name, value in inputs.items():
|
||||
if name in self.batch_params:
|
||||
# prompt is string
|
||||
if name == "prompt":
|
||||
len_prompt = len(value)
|
||||
# make unequal batch sizes
|
||||
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
|
||||
# make last batch super long
|
||||
batched_inputs[name][-1] = 100 * "very long"
|
||||
elif name == "image":
|
||||
batched_images = []
|
||||
|
||||
for image in value:
|
||||
batched_images.append(batch_size * [image])
|
||||
|
||||
batched_inputs[name] = batched_images
|
||||
else:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
elif name == "batch_size":
|
||||
batched_inputs[name] = batch_size
|
||||
elif name == "generator":
|
||||
batched_inputs[name] = [self.get_generator(i) for i in range(batch_size)]
|
||||
else:
|
||||
batched_inputs[name] = value
|
||||
|
||||
for arg in additional_params_copy_to_batched_inputs:
|
||||
batched_inputs[arg] = inputs[arg]
|
||||
|
||||
output_batch = pipe(**batched_inputs)
|
||||
assert output_batch[0].shape[0] == batch_size
|
||||
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
output = pipe(**inputs)
|
||||
|
||||
logger.setLevel(level=diffusers.logging.WARNING)
|
||||
if test_max_difference:
|
||||
if relax_max_difference:
|
||||
# Taking the median of the largest <n> differences
|
||||
# is resilient to outliers
|
||||
diff = np.abs(output_batch[0][0] - output[0][0])
|
||||
diff = diff.flatten()
|
||||
diff.sort()
|
||||
max_diff = np.median(diff[-5:])
|
||||
else:
|
||||
max_diff = np.abs(output_batch[0][0] - output[0][0]).max()
|
||||
assert max_diff < expected_max_diff
|
||||
|
||||
if test_mean_pixel_difference:
|
||||
assert_mean_pixel_difference(output_batch[0][0], output[0][0])
|
||||
|
||||
Reference in New Issue
Block a user