1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Merge branch 'main' into peftpart-1

This commit is contained in:
Sourab Mangrulkar
2023-09-22 12:36:20 +05:30
35 changed files with 3740 additions and 604 deletions

View File

@@ -20,7 +20,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip

View File

@@ -20,7 +20,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
@@ -38,7 +38,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip

View File

@@ -17,7 +17,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v1
with:
python-version: 3.7
python-version: 3.8
- name: Install requirements
run: |

View File

@@ -216,6 +216,8 @@
title: AudioLDM 2
- local: api/pipelines/auto_pipeline
title: AutoPipeline
- local: api/pipelines/blip_diffusion
title: BLIP Diffusion
- local: api/pipelines/consistency_models
title: Consistency Models
- local: api/pipelines/controlnet

View File

@@ -0,0 +1,29 @@
# Blip Diffusion
Blip Diffusion was proposed in [BLIP-Diffusion: Pre-trained Subject Representation for Controllable Text-to-Image Generation and Editing](https://arxiv.org/abs/2305.14720). It enables zero-shot subject-driven generation and control-guided zero-shot generation.
The abstract from the paper is:
*Subject-driven text-to-image generation models create novel renditions of an input subject based on text prompts. Existing models suffer from lengthy fine-tuning and difficulties preserving the subject fidelity. To overcome these limitations, we introduce BLIP-Diffusion, a new subject-driven image generation model that supports multimodal control which consumes inputs of subject images and text prompts. Unlike other subject-driven generation models, BLIP-Diffusion introduces a new multimodal encoder which is pre-trained to provide subject representation. We first pre-train the multimodal encoder following BLIP-2 to produce visual representation aligned with the text. Then we design a subject representation learning task which enables a diffusion model to leverage such visual representation and generates new subject renditions. Compared with previous methods such as DreamBooth, our model enables zero-shot subject-driven generation, and efficient fine-tuning for customized subject with up to 20x speedup. We also demonstrate that BLIP-Diffusion can be flexibly combined with existing techniques such as ControlNet and prompt-to-prompt to enable novel subject-driven generation and editing applications.*
The original codebase can be found at [salesforce/LAVIS](https://github.com/salesforce/LAVIS/tree/main/projects/blip-diffusion). You can find the official BLIP Diffusion checkpoints under the [hf.co/SalesForce](https://hf.co/SalesForce) organization.
`BlipDiffusionPipeline` and `BlipDiffusionControlNetPipeline` were contributed by [`ayushtues`](https://github.com/ayushtues/).
<Tip>
Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## BlipDiffusionPipeline
[[autodoc]] BlipDiffusionPipeline
- all
- __call__
## BlipDiffusionControlNetPipeline
[[autodoc]] BlipDiffusionControlNetPipeline
- all
- __call__

View File

@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
Install 🤗 Diffusers for whichever deep learning library you're working with.
🤗 Diffusers is tested on Python 3.7+, PyTorch 1.7.0+ and Flax. Follow the installation instructions below for the deep learning library you are using:
🤗 Diffusers is tested on Python 3.8+, PyTorch 1.7.0+ and Flax. Follow the installation instructions below for the deep learning library you are using:
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.
@@ -106,7 +106,7 @@ pip install -e ".[flax]"
These commands will link the folder you cloned the repository to and your Python library paths.
Python will now look inside the folder you cloned to in addition to the normal library paths.
For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.7/site-packages/`, Python will also search the `~/diffusers/` folder you cloned to.
For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.8/site-packages/`, Python will also search the `~/diffusers/` folder you cloned to.
<Tip warning={true}>

View File

@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
사용하시는 라이브러리에 맞는 🤗 Diffusers를 설치하세요.
🤗 Diffusers는 Python 3.7+, PyTorch 1.7.0+ 및 flax에서 테스트되었습니다. 사용중인 딥러닝 라이브러리에 대한 아래의 설치 안내를 따르세요.
🤗 Diffusers는 Python 3.8+, PyTorch 1.7.0+ 및 flax에서 테스트되었습니다. 사용중인 딥러닝 라이브러리에 대한 아래의 설치 안내를 따르세요.
- [PyTorch 설치 안내](https://pytorch.org/get-started/locally/)
- [Flax 설치 안내](https://flax.readthedocs.io/en/latest/)
@@ -105,7 +105,7 @@ pip install -e ".[flax]"
이러한 명령어들은 저장소를 복제한 폴더와 Python 라이브러리 경로를 연결합니다.
Python은 이제 일반 라이브러리 경로에 더하여 복제한 폴더 내부를 살펴봅니다.
예를들어 Python 패키지가 `~/anaconda3/envs/main/lib/python3.7/site-packages/`에 설치되어 있는 경우 Python은 복제한 폴더인 `~/diffusers/`도 검색합니다.
예를들어 Python 패키지가 `~/anaconda3/envs/main/lib/python3.8/site-packages/`에 설치되어 있는 경우 Python은 복제한 폴더인 `~/diffusers/`도 검색합니다.
<Tip warning={true}>

View File

@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
在你正在使用的任意深度学习框架中安装 🤗 Diffusers 。
🤗 Diffusers已在Python 3.7+、PyTorch 1.7.0+和Flax上进行了测试。按照下面的安装说明针对你正在使用的深度学习框架进行安装
🤗 Diffusers已在Python 3.8+、PyTorch 1.7.0+和Flax上进行了测试。按照下面的安装说明针对你正在使用的深度学习框架进行安装
- [PyTorch](https://pytorch.org/get-started/locally/) installation instructions.
- [Flax](https://flax.readthedocs.io/en/latest/) installation instructions.
@@ -107,7 +107,7 @@ pip install -e ".[flax]"
这些命令将连接到你克隆的版本库和你的 Python 库路径。
现在不只是在通常的库路径Python 还会在你克隆的文件夹内寻找包。
例如,如果你的 Python 包通常安装在 `~/anaconda3/envs/main/lib/python3.7/Site-packages/`Python 也会搜索你克隆到的文件夹。`~/diffusers/`
例如,如果你的 Python 包通常安装在 `~/anaconda3/envs/main/lib/python3.8/Site-packages/`Python 也会搜索你克隆到的文件夹。`~/diffusers/`
<Tip warning={true}>

View File

@@ -908,6 +908,9 @@ def main():
if args.snr_gamma is not None:
snr = jnp.array(compute_snr(timesteps))
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
snr_loss_weights = snr_loss_weights + 1
loss = loss * snr_loss_weights
loss = loss.mean()

View File

@@ -224,6 +224,30 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
raise ValueError(f"{model_class} is not supported.")
def compute_snr(timesteps, noise_scheduler):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
# Compute SNR
snr = (alpha / sigma) ** 2
return snr
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
@@ -524,6 +548,13 @@ def parse_args(input_args=None):
" See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information."
),
)
parser.add_argument(
"--snr_gamma",
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
parser.add_argument(
"--pre_compute_text_embeddings",
action="store_true",
@@ -1261,17 +1292,34 @@ def main(args):
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute instance loss
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
# Compute instance loss
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps, noise_scheduler)
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
if args.with_prior_preservation:
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:

View File

@@ -875,6 +875,9 @@ def main():
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.

View File

@@ -955,6 +955,9 @@ def main():
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.

View File

@@ -786,6 +786,9 @@ def main():
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.

View File

@@ -1075,6 +1075,9 @@ def main(args):
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.

View File

@@ -0,0 +1,343 @@
"""
This script requires you to build `LAVIS` from source, since the pip version doesn't have BLIP Diffusion. Follow instructions here: https://github.com/salesforce/LAVIS/tree/main.
"""
import argparse
import os
import tempfile
import torch
from lavis.models import load_model_and_preprocess
from transformers import CLIPTokenizer
from transformers.models.blip_2.configuration_blip_2 import Blip2Config
from diffusers import (
AutoencoderKL,
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.pipelines import BlipDiffusionPipeline
from diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
from diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
from diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
BLIP2_CONFIG = {
"vision_config": {
"hidden_size": 1024,
"num_hidden_layers": 23,
"num_attention_heads": 16,
"image_size": 224,
"patch_size": 14,
"intermediate_size": 4096,
"hidden_act": "quick_gelu",
},
"qformer_config": {
"cross_attention_frequency": 1,
"encoder_hidden_size": 1024,
"vocab_size": 30523,
},
"num_query_tokens": 16,
}
blip2config = Blip2Config(**BLIP2_CONFIG)
def qformer_model_from_original_config():
qformer = Blip2QFormerModel(blip2config)
return qformer
def embeddings_from_original_checkpoint(model, diffuser_embeddings_prefix, original_embeddings_prefix):
embeddings = {}
embeddings.update(
{
f"{diffuser_embeddings_prefix}.word_embeddings.weight": model[
f"{original_embeddings_prefix}.word_embeddings.weight"
]
}
)
embeddings.update(
{
f"{diffuser_embeddings_prefix}.position_embeddings.weight": model[
f"{original_embeddings_prefix}.position_embeddings.weight"
]
}
)
embeddings.update(
{f"{diffuser_embeddings_prefix}.LayerNorm.weight": model[f"{original_embeddings_prefix}.LayerNorm.weight"]}
)
embeddings.update(
{f"{diffuser_embeddings_prefix}.LayerNorm.bias": model[f"{original_embeddings_prefix}.LayerNorm.bias"]}
)
return embeddings
def proj_layer_from_original_checkpoint(model, diffuser_proj_prefix, original_proj_prefix):
proj_layer = {}
proj_layer.update({f"{diffuser_proj_prefix}.dense1.weight": model[f"{original_proj_prefix}.dense1.weight"]})
proj_layer.update({f"{diffuser_proj_prefix}.dense1.bias": model[f"{original_proj_prefix}.dense1.bias"]})
proj_layer.update({f"{diffuser_proj_prefix}.dense2.weight": model[f"{original_proj_prefix}.dense2.weight"]})
proj_layer.update({f"{diffuser_proj_prefix}.dense2.bias": model[f"{original_proj_prefix}.dense2.bias"]})
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.weight": model[f"{original_proj_prefix}.LayerNorm.weight"]})
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.bias": model[f"{original_proj_prefix}.LayerNorm.bias"]})
return proj_layer
def attention_from_original_checkpoint(model, diffuser_attention_prefix, original_attention_prefix):
attention = {}
attention.update(
{
f"{diffuser_attention_prefix}.attention.query.weight": model[
f"{original_attention_prefix}.self.query.weight"
]
}
)
attention.update(
{f"{diffuser_attention_prefix}.attention.query.bias": model[f"{original_attention_prefix}.self.query.bias"]}
)
attention.update(
{f"{diffuser_attention_prefix}.attention.key.weight": model[f"{original_attention_prefix}.self.key.weight"]}
)
attention.update(
{f"{diffuser_attention_prefix}.attention.key.bias": model[f"{original_attention_prefix}.self.key.bias"]}
)
attention.update(
{
f"{diffuser_attention_prefix}.attention.value.weight": model[
f"{original_attention_prefix}.self.value.weight"
]
}
)
attention.update(
{f"{diffuser_attention_prefix}.attention.value.bias": model[f"{original_attention_prefix}.self.value.bias"]}
)
attention.update(
{f"{diffuser_attention_prefix}.output.dense.weight": model[f"{original_attention_prefix}.output.dense.weight"]}
)
attention.update(
{f"{diffuser_attention_prefix}.output.dense.bias": model[f"{original_attention_prefix}.output.dense.bias"]}
)
attention.update(
{
f"{diffuser_attention_prefix}.output.LayerNorm.weight": model[
f"{original_attention_prefix}.output.LayerNorm.weight"
]
}
)
attention.update(
{
f"{diffuser_attention_prefix}.output.LayerNorm.bias": model[
f"{original_attention_prefix}.output.LayerNorm.bias"
]
}
)
return attention
def output_layers_from_original_checkpoint(model, diffuser_output_prefix, original_output_prefix):
output_layers = {}
output_layers.update({f"{diffuser_output_prefix}.dense.weight": model[f"{original_output_prefix}.dense.weight"]})
output_layers.update({f"{diffuser_output_prefix}.dense.bias": model[f"{original_output_prefix}.dense.bias"]})
output_layers.update(
{f"{diffuser_output_prefix}.LayerNorm.weight": model[f"{original_output_prefix}.LayerNorm.weight"]}
)
output_layers.update(
{f"{diffuser_output_prefix}.LayerNorm.bias": model[f"{original_output_prefix}.LayerNorm.bias"]}
)
return output_layers
def encoder_from_original_checkpoint(model, diffuser_encoder_prefix, original_encoder_prefix):
encoder = {}
for i in range(blip2config.qformer_config.num_hidden_layers):
encoder.update(
attention_from_original_checkpoint(
model, f"{diffuser_encoder_prefix}.{i}.attention", f"{original_encoder_prefix}.{i}.attention"
)
)
encoder.update(
attention_from_original_checkpoint(
model, f"{diffuser_encoder_prefix}.{i}.crossattention", f"{original_encoder_prefix}.{i}.crossattention"
)
)
encoder.update(
{
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.weight": model[
f"{original_encoder_prefix}.{i}.intermediate.dense.weight"
]
}
)
encoder.update(
{
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.bias": model[
f"{original_encoder_prefix}.{i}.intermediate.dense.bias"
]
}
)
encoder.update(
{
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.weight": model[
f"{original_encoder_prefix}.{i}.intermediate_query.dense.weight"
]
}
)
encoder.update(
{
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.bias": model[
f"{original_encoder_prefix}.{i}.intermediate_query.dense.bias"
]
}
)
encoder.update(
output_layers_from_original_checkpoint(
model, f"{diffuser_encoder_prefix}.{i}.output", f"{original_encoder_prefix}.{i}.output"
)
)
encoder.update(
output_layers_from_original_checkpoint(
model, f"{diffuser_encoder_prefix}.{i}.output_query", f"{original_encoder_prefix}.{i}.output_query"
)
)
return encoder
def visual_encoder_layer_from_original_checkpoint(model, diffuser_prefix, original_prefix):
visual_encoder_layer = {}
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.weight": model[f"{original_prefix}.ln_1.weight"]})
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.bias": model[f"{original_prefix}.ln_1.bias"]})
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.weight": model[f"{original_prefix}.ln_2.weight"]})
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.bias": model[f"{original_prefix}.ln_2.bias"]})
visual_encoder_layer.update(
{f"{diffuser_prefix}.self_attn.qkv.weight": model[f"{original_prefix}.attn.in_proj_weight"]}
)
visual_encoder_layer.update(
{f"{diffuser_prefix}.self_attn.qkv.bias": model[f"{original_prefix}.attn.in_proj_bias"]}
)
visual_encoder_layer.update(
{f"{diffuser_prefix}.self_attn.projection.weight": model[f"{original_prefix}.attn.out_proj.weight"]}
)
visual_encoder_layer.update(
{f"{diffuser_prefix}.self_attn.projection.bias": model[f"{original_prefix}.attn.out_proj.bias"]}
)
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.weight": model[f"{original_prefix}.mlp.c_fc.weight"]})
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.bias": model[f"{original_prefix}.mlp.c_fc.bias"]})
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.weight": model[f"{original_prefix}.mlp.c_proj.weight"]})
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.bias": model[f"{original_prefix}.mlp.c_proj.bias"]})
return visual_encoder_layer
def visual_encoder_from_original_checkpoint(model, diffuser_prefix, original_prefix):
visual_encoder = {}
visual_encoder.update(
{
f"{diffuser_prefix}.embeddings.class_embedding": model[f"{original_prefix}.class_embedding"]
.unsqueeze(0)
.unsqueeze(0)
}
)
visual_encoder.update(
{
f"{diffuser_prefix}.embeddings.position_embedding": model[
f"{original_prefix}.positional_embedding"
].unsqueeze(0)
}
)
visual_encoder.update(
{f"{diffuser_prefix}.embeddings.patch_embedding.weight": model[f"{original_prefix}.conv1.weight"]}
)
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.weight": model[f"{original_prefix}.ln_pre.weight"]})
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.bias": model[f"{original_prefix}.ln_pre.bias"]})
for i in range(blip2config.vision_config.num_hidden_layers):
visual_encoder.update(
visual_encoder_layer_from_original_checkpoint(
model, f"{diffuser_prefix}.encoder.layers.{i}", f"{original_prefix}.transformer.resblocks.{i}"
)
)
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.weight": model["blip.ln_vision.weight"]})
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.bias": model["blip.ln_vision.bias"]})
return visual_encoder
def qformer_original_checkpoint_to_diffusers_checkpoint(model):
qformer_checkpoint = {}
qformer_checkpoint.update(embeddings_from_original_checkpoint(model, "embeddings", "blip.Qformer.bert.embeddings"))
qformer_checkpoint.update({"query_tokens": model["blip.query_tokens"]})
qformer_checkpoint.update(proj_layer_from_original_checkpoint(model, "proj_layer", "proj_layer"))
qformer_checkpoint.update(
encoder_from_original_checkpoint(model, "encoder.layer", "blip.Qformer.bert.encoder.layer")
)
qformer_checkpoint.update(visual_encoder_from_original_checkpoint(model, "visual_encoder", "blip.visual_encoder"))
return qformer_checkpoint
def get_qformer(model):
print("loading qformer")
qformer = qformer_model_from_original_config()
qformer_diffusers_checkpoint = qformer_original_checkpoint_to_diffusers_checkpoint(model)
load_checkpoint_to_model(qformer_diffusers_checkpoint, qformer)
print("done loading qformer")
return qformer
def load_checkpoint_to_model(checkpoint, model):
with tempfile.NamedTemporaryFile(delete=False) as file:
torch.save(checkpoint, file.name)
del checkpoint
model.load_state_dict(torch.load(file.name), strict=False)
os.remove(file.name)
def save_blip_diffusion_model(model, args):
qformer = get_qformer(model)
qformer.eval()
text_encoder = ContextCLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
vae.eval()
text_encoder.eval()
scheduler = PNDMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
set_alpha_to_one=False,
skip_prk_steps=True,
)
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
image_processor = BlipImageProcessor()
blip_diffusion = BlipDiffusionPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
unet=unet,
scheduler=scheduler,
qformer=qformer,
image_processor=image_processor,
)
blip_diffusion.save_pretrained(args.checkpoint_path)
def main(args):
model, _, _ = load_model_and_preprocess("blip_diffusion", "base", device="cpu", is_eval=True)
save_blip_diffusion_model(model.state_dict(), args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()
main(args)

View File

@@ -35,6 +35,12 @@ if __name__ == "__main__":
type=str,
help="The YAML config file corresponding to the original architecture.",
)
parser.add_argument(
"--config_files",
default=None,
type=str,
help="The YAML config file corresponding to the architecture.",
)
parser.add_argument(
"--num_in_channels",
default=None,

View File

@@ -257,7 +257,7 @@ setup(
package_dir={"": "src"},
packages=find_packages("src"),
include_package_data=True,
python_requires=">=3.7.0",
python_requires=">=3.8.0",
install_requires=list(install_requires),
extras_require=extras,
entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]},
@@ -269,7 +269,6 @@ setup(
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering :: Artificial Intelligence",

View File

@@ -197,6 +197,8 @@ else:
"AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel",
"AudioLDMPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"CLIPImageProjection",
"CycleDiffusionPipeline",
"IFImg2ImgPipeline",
@@ -458,6 +460,8 @@ if TYPE_CHECKING:
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
AutoPipelineForText2Image,
BlipDiffusionControlNetPipeline,
BlipDiffusionPipeline,
CLIPImageProjection,
ConsistencyModelPipeline,
DanceDiffusionPipeline,

View File

@@ -1,456 +1,460 @@
from typing import TYPE_CHECKING
from ..utils import (
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_flax_available,
is_k_diffusion_available,
is_librosa_available,
is_note_seq_available,
is_onnx_available,
is_torch_available,
is_transformers_available,
)
# These modules contain pipelines from multiple libraries/frameworks
_dummy_objects = {}
_import_structure = {"stable_diffusion": [], "latent_diffusion": [], "controlnet": []}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_pt_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
else:
_import_structure["auto_pipeline"] = [
"AutoPipelineForImage2Image",
"AutoPipelineForInpainting",
"AutoPipelineForText2Image",
]
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]
_import_structure["ddim"] = ["DDIMPipeline"]
_import_structure["ddpm"] = ["DDPMPipeline"]
_import_structure["dit"] = ["DiTPipeline"]
_import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"])
_import_structure["latent_diffusion_uncond"] = ["LDMPipeline"]
_import_structure["pipeline_utils"] = ["AudioPipelineOutput", "DiffusionPipeline", "ImagePipelineOutput"]
_import_structure["pndm"] = ["PNDMPipeline"]
_import_structure["repaint"] = ["RePaintPipeline"]
_import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"]
_import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"]
try:
if not (is_torch_available() and is_librosa_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_torch_and_librosa_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects))
else:
_import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"]
try:
if not (is_torch_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
"AudioLDM2Pipeline",
"AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel",
]
_import_structure["controlnet"].extend(
[
"StableDiffusionControlNetImg2ImgPipeline",
"StableDiffusionControlNetInpaintPipeline",
"StableDiffusionControlNetPipeline",
"StableDiffusionXLControlNetImg2ImgPipeline",
"StableDiffusionXLControlNetInpaintPipeline",
"StableDiffusionXLControlNetPipeline",
]
)
_import_structure["deepfloyd_if"] = [
"IFImg2ImgPipeline",
"IFImg2ImgSuperResolutionPipeline",
"IFInpaintingPipeline",
"IFInpaintingSuperResolutionPipeline",
"IFPipeline",
"IFSuperResolutionPipeline",
]
_import_structure["kandinsky"] = [
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
"KandinskyImg2ImgPipeline",
"KandinskyInpaintCombinedPipeline",
"KandinskyInpaintPipeline",
"KandinskyPipeline",
"KandinskyPriorPipeline",
]
_import_structure["kandinsky2_2"] = [
"KandinskyV22CombinedPipeline",
"KandinskyV22ControlnetImg2ImgPipeline",
"KandinskyV22ControlnetPipeline",
"KandinskyV22Img2ImgCombinedPipeline",
"KandinskyV22Img2ImgPipeline",
"KandinskyV22InpaintCombinedPipeline",
"KandinskyV22InpaintPipeline",
"KandinskyV22Pipeline",
"KandinskyV22PriorEmb2EmbPipeline",
"KandinskyV22PriorPipeline",
]
_import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"])
_import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
_import_structure["stable_diffusion"].extend(
[
"CLIPImageProjection",
"CycleDiffusionPipeline",
"StableDiffusionAttendAndExcitePipeline",
"StableDiffusionDepth2ImgPipeline",
"StableDiffusionDiffEditPipeline",
"StableDiffusionGLIGENPipeline",
"StableDiffusionGLIGENPipeline",
"StableDiffusionGLIGENTextImagePipeline",
"StableDiffusionImageVariationPipeline",
"StableDiffusionImg2ImgPipeline",
"StableDiffusionInpaintPipeline",
"StableDiffusionInpaintPipelineLegacy",
"StableDiffusionInstructPix2PixPipeline",
"StableDiffusionLatentUpscalePipeline",
"StableDiffusionLDM3DPipeline",
"StableDiffusionModelEditingPipeline",
"StableDiffusionPanoramaPipeline",
"StableDiffusionParadigmsPipeline",
"StableDiffusionPipeline",
"StableDiffusionPix2PixZeroPipeline",
"StableDiffusionSAGPipeline",
"StableDiffusionUpscalePipeline",
"StableUnCLIPImg2ImgPipeline",
"StableUnCLIPPipeline",
]
)
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
_import_structure["stable_diffusion_xl"] = [
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLInstructPix2PixPipeline",
"StableDiffusionXLPipeline",
]
_import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"]
_import_structure["text_to_video_synthesis"] = [
"TextToVideoSDPipeline",
"TextToVideoZeroPipeline",
"VideoToVideoSDPipeline",
]
_import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"]
_import_structure["unidiffuser"] = [
"ImageTextPipelineOutput",
"UniDiffuserModel",
"UniDiffuserPipeline",
"UniDiffuserTextDecoder",
]
_import_structure["versatile_diffusion"] = [
"VersatileDiffusionDualGuidedPipeline",
"VersatileDiffusionImageVariationPipeline",
"VersatileDiffusionPipeline",
"VersatileDiffusionTextToImagePipeline",
]
_import_structure["vq_diffusion"] = ["VQDiffusionPipeline"]
_import_structure["wuerstchen"] = [
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
]
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_onnx_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_onnx_objects))
else:
_import_structure["onnx_utils"] = ["OnnxRuntimeModel"]
try:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects))
else:
_import_structure["stable_diffusion"].extend(
[
"OnnxStableDiffusionImg2ImgPipeline",
"OnnxStableDiffusionInpaintPipeline",
"OnnxStableDiffusionInpaintPipelineLegacy",
"OnnxStableDiffusionPipeline",
"OnnxStableDiffusionUpscalePipeline",
"StableDiffusionOnnxPipeline",
]
)
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
else:
_import_structure["stable_diffusion"].extend(["StableDiffusionKDiffusionPipeline"])
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_flax_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_flax_objects))
else:
_import_structure["pipeline_flax_utils"] = ["FlaxDiffusionPipeline"]
try:
if not (is_flax_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_flax_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
else:
_import_structure["controlnet"].extend(["FlaxStableDiffusionControlNetPipeline"])
_import_structure["stable_diffusion"].extend(
[
"FlaxStableDiffusionImg2ImgPipeline",
"FlaxStableDiffusionInpaintPipeline",
"FlaxStableDiffusionPipeline",
]
)
try:
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects))
else:
_import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"]
if TYPE_CHECKING:
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image
from .consistency_models import ConsistencyModelPipeline
from .dance_diffusion import DanceDiffusionPipeline
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
from .dit import DiTPipeline
from .latent_diffusion import LDMSuperResolutionPipeline
from .latent_diffusion_uncond import LDMPipeline
from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput
from .pndm import PNDMPipeline
from .repaint import RePaintPipeline
from .score_sde_ve import ScoreSdeVePipeline
from .stochastic_karras_ve import KarrasVePipeline
try:
if not (is_torch_available() and is_librosa_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_librosa_objects import *
else:
from .audio_diffusion import AudioDiffusionPipeline, Mel
try:
if not (is_torch_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_objects import *
else:
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .audioldm import AudioLDMPipeline
from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
from .controlnet import (
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline,
)
from .deepfloyd_if import (
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,
IFInpaintingPipeline,
IFInpaintingSuperResolutionPipeline,
IFPipeline,
IFSuperResolutionPipeline,
)
from .kandinsky import (
KandinskyCombinedPipeline,
KandinskyImg2ImgCombinedPipeline,
KandinskyImg2ImgPipeline,
KandinskyInpaintCombinedPipeline,
KandinskyInpaintPipeline,
KandinskyPipeline,
KandinskyPriorPipeline,
)
from .kandinsky2_2 import (
KandinskyV22CombinedPipeline,
KandinskyV22ControlnetImg2ImgPipeline,
KandinskyV22ControlnetPipeline,
KandinskyV22Img2ImgCombinedPipeline,
KandinskyV22Img2ImgPipeline,
KandinskyV22InpaintCombinedPipeline,
KandinskyV22InpaintPipeline,
KandinskyV22Pipeline,
KandinskyV22PriorEmb2EmbPipeline,
KandinskyV22PriorPipeline,
)
from .latent_diffusion import LDMTextToImagePipeline
from .musicldm import MusicLDMPipeline
from .paint_by_example import PaintByExamplePipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_diffusion import (
CLIPImageProjection,
CycleDiffusionPipeline,
StableDiffusionAttendAndExcitePipeline,
StableDiffusionDepth2ImgPipeline,
StableDiffusionDiffEditPipeline,
StableDiffusionGLIGENPipeline,
StableDiffusionGLIGENTextImagePipeline,
StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionLatentUpscalePipeline,
StableDiffusionLDM3DPipeline,
StableDiffusionModelEditingPipeline,
StableDiffusionPanoramaPipeline,
StableDiffusionParadigmsPipeline,
StableDiffusionPipeline,
StableDiffusionPix2PixZeroPipeline,
StableDiffusionSAGPipeline,
StableDiffusionUpscalePipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
)
from .stable_diffusion_safe import StableDiffusionPipelineSafe
from .stable_diffusion_xl import (
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLPipeline,
)
from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline
from .text_to_video_synthesis import (
TextToVideoSDPipeline,
TextToVideoZeroPipeline,
VideoToVideoSDPipeline,
)
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
from .unidiffuser import (
ImageTextPipelineOutput,
UniDiffuserModel,
UniDiffuserPipeline,
UniDiffuserTextDecoder,
)
from .versatile_diffusion import (
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
)
from .vq_diffusion import VQDiffusionPipeline
from .wuerstchen import (
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
)
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_onnx_objects import * # noqa F403
else:
from .onnx_utils import OnnxRuntimeModel
try:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_and_onnx_objects import *
else:
from .stable_diffusion import (
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
OnnxStableDiffusionInpaintPipelineLegacy,
OnnxStableDiffusionPipeline,
OnnxStableDiffusionUpscalePipeline,
StableDiffusionOnnxPipeline,
)
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
else:
from .stable_diffusion import StableDiffusionKDiffusionPipeline
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_flax_objects import * # noqa F403
else:
from .pipeline_flax_utils import FlaxDiffusionPipeline
try:
if not (is_flax_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_flax_and_transformers_objects import *
else:
from .controlnet import FlaxStableDiffusionControlNetPipeline
from .stable_diffusion import (
FlaxStableDiffusionImg2ImgPipeline,
FlaxStableDiffusionInpaintPipeline,
FlaxStableDiffusionPipeline,
)
try:
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
else:
from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
from typing import TYPE_CHECKING
from ..utils import (
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_flax_available,
is_k_diffusion_available,
is_librosa_available,
is_note_seq_available,
is_onnx_available,
is_torch_available,
is_transformers_available,
)
# These modules contain pipelines from multiple libraries/frameworks
_dummy_objects = {}
_import_structure = {"stable_diffusion": [], "latent_diffusion": [], "controlnet": []}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_pt_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
else:
_import_structure["auto_pipeline"] = [
"AutoPipelineForImage2Image",
"AutoPipelineForInpainting",
"AutoPipelineForText2Image",
]
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]
_import_structure["ddim"] = ["DDIMPipeline"]
_import_structure["ddpm"] = ["DDPMPipeline"]
_import_structure["dit"] = ["DiTPipeline"]
_import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"])
_import_structure["latent_diffusion_uncond"] = ["LDMPipeline"]
_import_structure["pipeline_utils"] = ["AudioPipelineOutput", "DiffusionPipeline", "ImagePipelineOutput"]
_import_structure["pndm"] = ["PNDMPipeline"]
_import_structure["repaint"] = ["RePaintPipeline"]
_import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"]
_import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"]
try:
if not (is_torch_available() and is_librosa_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_torch_and_librosa_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects))
else:
_import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"]
try:
if not (is_torch_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
"AudioLDM2Pipeline",
"AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel",
]
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
"StableDiffusionControlNetImg2ImgPipeline",
"StableDiffusionControlNetInpaintPipeline",
"StableDiffusionControlNetPipeline",
"StableDiffusionXLControlNetImg2ImgPipeline",
"StableDiffusionXLControlNetInpaintPipeline",
"StableDiffusionXLControlNetPipeline",
]
)
_import_structure["deepfloyd_if"] = [
"IFImg2ImgPipeline",
"IFImg2ImgSuperResolutionPipeline",
"IFInpaintingPipeline",
"IFInpaintingSuperResolutionPipeline",
"IFPipeline",
"IFSuperResolutionPipeline",
]
_import_structure["kandinsky"] = [
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
"KandinskyImg2ImgPipeline",
"KandinskyInpaintCombinedPipeline",
"KandinskyInpaintPipeline",
"KandinskyPipeline",
"KandinskyPriorPipeline",
]
_import_structure["kandinsky2_2"] = [
"KandinskyV22CombinedPipeline",
"KandinskyV22ControlnetImg2ImgPipeline",
"KandinskyV22ControlnetPipeline",
"KandinskyV22Img2ImgCombinedPipeline",
"KandinskyV22Img2ImgPipeline",
"KandinskyV22InpaintCombinedPipeline",
"KandinskyV22InpaintPipeline",
"KandinskyV22Pipeline",
"KandinskyV22PriorEmb2EmbPipeline",
"KandinskyV22PriorPipeline",
]
_import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"])
_import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
_import_structure["stable_diffusion"].extend(
[
"CLIPImageProjection",
"CycleDiffusionPipeline",
"StableDiffusionAttendAndExcitePipeline",
"StableDiffusionDepth2ImgPipeline",
"StableDiffusionDiffEditPipeline",
"StableDiffusionGLIGENPipeline",
"StableDiffusionGLIGENPipeline",
"StableDiffusionGLIGENTextImagePipeline",
"StableDiffusionImageVariationPipeline",
"StableDiffusionImg2ImgPipeline",
"StableDiffusionInpaintPipeline",
"StableDiffusionInpaintPipelineLegacy",
"StableDiffusionInstructPix2PixPipeline",
"StableDiffusionLatentUpscalePipeline",
"StableDiffusionLDM3DPipeline",
"StableDiffusionModelEditingPipeline",
"StableDiffusionPanoramaPipeline",
"StableDiffusionParadigmsPipeline",
"StableDiffusionPipeline",
"StableDiffusionPix2PixZeroPipeline",
"StableDiffusionSAGPipeline",
"StableDiffusionUpscalePipeline",
"StableUnCLIPImg2ImgPipeline",
"StableUnCLIPPipeline",
]
)
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
_import_structure["stable_diffusion_xl"] = [
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLInstructPix2PixPipeline",
"StableDiffusionXLPipeline",
]
_import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"]
_import_structure["text_to_video_synthesis"] = [
"TextToVideoSDPipeline",
"TextToVideoZeroPipeline",
"VideoToVideoSDPipeline",
]
_import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"]
_import_structure["unidiffuser"] = [
"ImageTextPipelineOutput",
"UniDiffuserModel",
"UniDiffuserPipeline",
"UniDiffuserTextDecoder",
]
_import_structure["versatile_diffusion"] = [
"VersatileDiffusionDualGuidedPipeline",
"VersatileDiffusionImageVariationPipeline",
"VersatileDiffusionPipeline",
"VersatileDiffusionTextToImagePipeline",
]
_import_structure["vq_diffusion"] = ["VQDiffusionPipeline"]
_import_structure["wuerstchen"] = [
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
]
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_onnx_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_onnx_objects))
else:
_import_structure["onnx_utils"] = ["OnnxRuntimeModel"]
try:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects))
else:
_import_structure["stable_diffusion"].extend(
[
"OnnxStableDiffusionImg2ImgPipeline",
"OnnxStableDiffusionInpaintPipeline",
"OnnxStableDiffusionInpaintPipelineLegacy",
"OnnxStableDiffusionPipeline",
"OnnxStableDiffusionUpscalePipeline",
"StableDiffusionOnnxPipeline",
]
)
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
else:
_import_structure["stable_diffusion"].extend(["StableDiffusionKDiffusionPipeline"])
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_flax_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_flax_objects))
else:
_import_structure["pipeline_flax_utils"] = ["FlaxDiffusionPipeline"]
try:
if not (is_flax_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_flax_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
else:
_import_structure["controlnet"].extend(["FlaxStableDiffusionControlNetPipeline"])
_import_structure["stable_diffusion"].extend(
[
"FlaxStableDiffusionImg2ImgPipeline",
"FlaxStableDiffusionInpaintPipeline",
"FlaxStableDiffusionPipeline",
]
)
try:
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects))
else:
_import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"]
if TYPE_CHECKING:
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image
from .consistency_models import ConsistencyModelPipeline
from .dance_diffusion import DanceDiffusionPipeline
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
from .dit import DiTPipeline
from .latent_diffusion import LDMSuperResolutionPipeline
from .latent_diffusion_uncond import LDMPipeline
from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput
from .pndm import PNDMPipeline
from .repaint import RePaintPipeline
from .score_sde_ve import ScoreSdeVePipeline
from .stochastic_karras_ve import KarrasVePipeline
try:
if not (is_torch_available() and is_librosa_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_librosa_objects import *
else:
from .audio_diffusion import AudioDiffusionPipeline, Mel
try:
if not (is_torch_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_objects import *
else:
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
from .audioldm import AudioLDMPipeline
from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
from .blip_diffusion import BlipDiffusionPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline,
)
from .deepfloyd_if import (
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,
IFInpaintingPipeline,
IFInpaintingSuperResolutionPipeline,
IFPipeline,
IFSuperResolutionPipeline,
)
from .kandinsky import (
KandinskyCombinedPipeline,
KandinskyImg2ImgCombinedPipeline,
KandinskyImg2ImgPipeline,
KandinskyInpaintCombinedPipeline,
KandinskyInpaintPipeline,
KandinskyPipeline,
KandinskyPriorPipeline,
)
from .kandinsky2_2 import (
KandinskyV22CombinedPipeline,
KandinskyV22ControlnetImg2ImgPipeline,
KandinskyV22ControlnetPipeline,
KandinskyV22Img2ImgCombinedPipeline,
KandinskyV22Img2ImgPipeline,
KandinskyV22InpaintCombinedPipeline,
KandinskyV22InpaintPipeline,
KandinskyV22Pipeline,
KandinskyV22PriorEmb2EmbPipeline,
KandinskyV22PriorPipeline,
)
from .latent_diffusion import LDMTextToImagePipeline
from .musicldm import MusicLDMPipeline
from .paint_by_example import PaintByExamplePipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_diffusion import (
CLIPImageProjection,
CycleDiffusionPipeline,
StableDiffusionAttendAndExcitePipeline,
StableDiffusionDepth2ImgPipeline,
StableDiffusionDiffEditPipeline,
StableDiffusionGLIGENPipeline,
StableDiffusionGLIGENTextImagePipeline,
StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionInpaintPipelineLegacy,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionLatentUpscalePipeline,
StableDiffusionLDM3DPipeline,
StableDiffusionModelEditingPipeline,
StableDiffusionPanoramaPipeline,
StableDiffusionParadigmsPipeline,
StableDiffusionPipeline,
StableDiffusionPix2PixZeroPipeline,
StableDiffusionSAGPipeline,
StableDiffusionUpscalePipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
)
from .stable_diffusion_safe import StableDiffusionPipelineSafe
from .stable_diffusion_xl import (
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLPipeline,
)
from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline
from .text_to_video_synthesis import (
TextToVideoSDPipeline,
TextToVideoZeroPipeline,
VideoToVideoSDPipeline,
)
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
from .unidiffuser import (
ImageTextPipelineOutput,
UniDiffuserModel,
UniDiffuserPipeline,
UniDiffuserTextDecoder,
)
from .versatile_diffusion import (
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
)
from .vq_diffusion import VQDiffusionPipeline
from .wuerstchen import (
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
)
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_onnx_objects import * # noqa F403
else:
from .onnx_utils import OnnxRuntimeModel
try:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_and_onnx_objects import *
else:
from .stable_diffusion import (
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
OnnxStableDiffusionInpaintPipelineLegacy,
OnnxStableDiffusionPipeline,
OnnxStableDiffusionUpscalePipeline,
StableDiffusionOnnxPipeline,
)
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
else:
from .stable_diffusion import StableDiffusionKDiffusionPipeline
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_flax_objects import * # noqa F403
else:
from .pipeline_flax_utils import FlaxDiffusionPipeline
try:
if not (is_flax_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_flax_and_transformers_objects import *
else:
from .controlnet import FlaxStableDiffusionControlNetPipeline
from .stable_diffusion import (
FlaxStableDiffusionImg2ImgPipeline,
FlaxStableDiffusionInpaintPipeline,
FlaxStableDiffusionPipeline,
)
try:
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
else:
from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,20 @@
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import PIL
from PIL import Image
from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline
else:
from .blip_image_processing import BlipImageProcessor
from .modeling_blip2 import Blip2QFormerModel
from .modeling_ctx_clip import ContextCLIPTextModel
from .pipeline_blip_diffusion import BlipDiffusionPipeline

View File

@@ -0,0 +1,318 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for BLIP."""
from typing import Dict, List, Optional, Union
import numpy as np
import torch
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from transformers.image_transforms import convert_to_rgb, resize, to_channel_dimension_format
from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
)
from transformers.utils import TensorType, is_vision_available, logging
from diffusers.utils import numpy_to_pil
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
# We needed some extra functions on top of the ones in transformers.image_processing_utils.BaseImageProcessor, namely center crop
# Copy-pasted from transformers.models.blip.image_processing_blip.BlipImageProcessor
class BlipImageProcessor(BaseImageProcessor):
r"""
Constructs a BLIP image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
`do_resize` parameter in the `preprocess` method.
size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
overridden by the `resample` parameter in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
`do_rescale` parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
overridden by the `rescale_factor` parameter in the `preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True,
do_center_crop: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size, default_to_square=True)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_convert_rgb = do_convert_rgb
self.do_center_crop = do_center_crop
# Copy-pasted from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize an image to `(size["height"], size["width"])`.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns:
`np.ndarray`: The resized image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
output_size = (size["height"], size["width"])
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
do_center_crop: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
do_convert_rgb: bool = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Controls the size of the image after `resize`. The shortest edge of the image is resized to
`size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to normalize the image by if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
# PIL RGBA images are converted to RGB
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if is_scaled_image(images[0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize:
images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_rescale:
images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize:
images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
if do_center_crop:
images = [self.center_crop(image, size, input_data_format=input_data_format) for image in images]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
return encoded_outputs
# Follows diffusers.VaeImageProcessor.postprocess
def postprocess(self, sample: torch.FloatTensor, output_type: str = "pil"):
if output_type not in ["pt", "np", "pil"]:
raise ValueError(
f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
)
# Equivalent to diffusers.VaeImageProcessor.denormalize
sample = (sample / 2 + 0.5).clamp(0, 1)
if output_type == "pt":
return sample
# Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "np":
return sample
# Output_type must be 'pil'
sample = numpy_to_pil(sample)
return sample

View File

@@ -0,0 +1,642 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers import BertTokenizer
from transformers.activations import QuickGELUActivation as QuickGELU
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPooling,
BaseModelOutputWithPoolingAndCrossAttentions,
)
from transformers.models.blip_2.configuration_blip_2 import Blip2Config, Blip2VisionConfig
from transformers.models.blip_2.modeling_blip_2 import (
Blip2Encoder,
Blip2PreTrainedModel,
Blip2QFormerAttention,
Blip2QFormerIntermediate,
Blip2QFormerOutput,
)
from transformers.pytorch_utils import apply_chunking_to_forward
from transformers.utils import (
logging,
replace_return_docstrings,
)
logger = logging.get_logger(__name__)
# There is an implementation of Blip2 in `transformers` : https://github.com/huggingface/transformers/blob/main/src/transformers/models/blip_2/modeling_blip_2.py.
# But it doesn't support getting multimodal embeddings. So, this module can be
# replaced with a future `transformers` version supports that.
class Blip2TextEmbeddings(nn.Module):
"""Construct the embeddings from word and position embeddings."""
def __init__(self, config):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.config = config
def forward(
self,
input_ids=None,
position_ids=None,
query_embeds=None,
past_key_values_length=0,
):
if input_ids is not None:
seq_length = input_ids.size()[1]
else:
seq_length = 0
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()
if input_ids is not None:
embeddings = self.word_embeddings(input_ids)
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings = embeddings + position_embeddings
if query_embeds is not None:
batch_size = embeddings.shape[0]
# repeat the query embeddings for batch size
query_embeds = query_embeds.repeat(batch_size, 1, 1)
embeddings = torch.cat((query_embeds, embeddings), dim=1)
else:
embeddings = query_embeds
embeddings = embeddings.to(query_embeds.dtype)
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2
class Blip2VisionEmbeddings(nn.Module):
def __init__(self, config: Blip2VisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
return embeddings
# The Qformer encoder, which takes the visual embeddings, and the text input, to get multimodal embeddings
class Blip2QFormerEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList(
[Blip2QFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
query_length=0,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for i in range(self.config.num_hidden_layers):
layer_module = self.layer[i]
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions, query_length)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
query_length,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if layer_module.has_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
# The layers making up the Qformer encoder
class Blip2QFormerLayer(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = Blip2QFormerAttention(config)
self.layer_idx = layer_idx
if layer_idx % config.cross_attention_frequency == 0:
self.crossattention = Blip2QFormerAttention(config, is_cross_attention=True)
self.has_cross_attention = True
else:
self.has_cross_attention = False
self.intermediate = Blip2QFormerIntermediate(config)
self.intermediate_query = Blip2QFormerIntermediate(config)
self.output_query = Blip2QFormerOutput(config)
self.output = Blip2QFormerOutput(config)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
output_attentions=False,
query_length=0,
):
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=self_attn_past_key_value,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
if query_length > 0:
query_attention_output = attention_output[:, :query_length, :]
if self.has_cross_attention:
if encoder_hidden_states is None:
raise ValueError("encoder_hidden_states must be given for cross-attention layers")
cross_attention_outputs = self.crossattention(
query_attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions=output_attentions,
)
query_attention_output = cross_attention_outputs[0]
# add cross attentions if we output attention weights
outputs = outputs + cross_attention_outputs[1:-1]
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk_query,
self.chunk_size_feed_forward,
self.seq_len_dim,
query_attention_output,
)
if attention_output.shape[1] > query_length:
layer_output_text = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output[:, query_length:, :],
)
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
else:
layer_output = apply_chunking_to_forward(
self.feed_forward_chunk,
self.chunk_size_feed_forward,
self.seq_len_dim,
attention_output,
)
outputs = (layer_output,) + outputs
outputs = outputs + (present_key_value,)
return outputs
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
def feed_forward_chunk_query(self, attention_output):
intermediate_output = self.intermediate_query(attention_output)
layer_output = self.output_query(intermediate_output, attention_output)
return layer_output
# ProjLayer used to project the multimodal Blip2 embeddings to be used in the text encoder
class ProjLayer(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, drop_p=0.1, eps=1e-12):
super().__init__()
# Dense1 -> Act -> Dense2 -> Drop -> Res -> Norm
self.dense1 = nn.Linear(in_dim, hidden_dim)
self.act_fn = QuickGELU()
self.dense2 = nn.Linear(hidden_dim, out_dim)
self.dropout = nn.Dropout(drop_p)
self.LayerNorm = nn.LayerNorm(out_dim, eps=eps)
def forward(self, x):
x_in = x
x = self.LayerNorm(x)
x = self.dropout(self.dense2(self.act_fn(self.dense1(x)))) + x_in
return x
# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Blip2, BLIP->BLIP_2
class Blip2VisionModel(Blip2PreTrainedModel):
main_input_name = "pixel_values"
config_class = Blip2VisionConfig
def __init__(self, config: Blip2VisionConfig):
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
self.embeddings = Blip2VisionEmbeddings(config)
self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = Blip2Encoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.post_init()
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Blip2VisionConfig)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layernorm(hidden_states)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state)
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def get_input_embeddings(self):
return self.embeddings
# Qformer model, used to get multimodal embeddings from the text and image inputs
class Blip2QFormerModel(Blip2PreTrainedModel):
"""
Querying Transformer (Q-Former), used in BLIP-2.
"""
def __init__(self, config: Blip2Config):
super().__init__(config)
self.config = config
self.embeddings = Blip2TextEmbeddings(config.qformer_config)
self.visual_encoder = Blip2VisionModel(config.vision_config)
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
if not hasattr(config, "tokenizer") or config.tokenizer is None:
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right")
else:
self.tokenizer = BertTokenizer.from_pretrained(config.tokenizer, truncation_side="right")
self.tokenizer.add_special_tokens({"bos_token": "[DEC]"})
self.proj_layer = ProjLayer(
in_dim=config.qformer_config.hidden_size,
out_dim=config.qformer_config.hidden_size,
hidden_dim=config.qformer_config.hidden_size * 4,
drop_p=0.1,
eps=1e-12,
)
self.encoder = Blip2QFormerEncoder(config.qformer_config)
self.post_init()
def get_input_embeddings(self):
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def get_extended_attention_mask(
self,
attention_mask: torch.Tensor,
input_shape: Tuple[int],
device: torch.device,
has_query: bool = False,
) -> torch.Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`):
The shape of the input to the model.
device (`torch.device`):
The device of the input to the model.
Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
input_shape, attention_mask.shape
)
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def forward(
self,
text_input=None,
image_input=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
`(batch_size, sequence_length)`.
use_cache (`bool`, `optional`):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
text = self.tokenizer(text_input, return_tensors="pt", padding=True)
text = text.to(self.device)
input_ids = text.input_ids
batch_size = input_ids.shape[0]
query_atts = torch.ones((batch_size, self.query_tokens.size()[1]), dtype=torch.long).to(self.device)
attention_mask = torch.cat([query_atts, text.attention_mask], dim=1)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# past_key_values_length
past_key_values_length = (
past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
)
query_length = self.query_tokens.shape[1]
embedding_output = self.embeddings(
input_ids=input_ids,
query_embeds=self.query_tokens,
past_key_values_length=past_key_values_length,
)
# embedding_output = self.layernorm(query_embeds)
# embedding_output = self.dropout(embedding_output)
input_shape = embedding_output.size()[:-1]
batch_size, seq_length = input_shape
device = embedding_output.device
image_embeds_frozen = self.visual_encoder(image_input).last_hidden_state
# image_embeds_frozen = torch.ones_like(image_embeds_frozen)
encoder_hidden_states = image_embeds_frozen
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, list):
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
else:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if isinstance(encoder_attention_mask, list):
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
elif encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.qformer_config.num_hidden_layers)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
query_length=query_length,
)
sequence_output = encoder_outputs[0]
pooled_output = sequence_output[:, 0, :]
if not return_dict:
return self.proj_layer(sequence_output[:, :query_length, :])
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)

View File

@@ -0,0 +1,212 @@
# Copyright 2023 Salesforce.com, inc.
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers import CLIPPreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.models.clip.configuration_clip import CLIPTextConfig
from transformers.models.clip.modeling_clip import (
CLIPEncoder,
_expand_mask,
)
# This is a modified version of the CLIPTextModel from transformers.models.clip.modeling_clip
# Which allows for an extra input of "context embeddings", which are the query embeddings used in Qformer
# They pass through the clip model, along with the text embeddings, and interact with them using self attention
class ContextCLIPTextModel(CLIPPreTrainedModel):
config_class = CLIPTextConfig
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPTextConfig):
super().__init__(config)
self.text_model = ContextCLIPTextTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
ctx_embeddings: torch.Tensor = None,
ctx_begin_pos: list = None,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
return self.text_model(
ctx_embeddings=ctx_embeddings,
ctx_begin_pos=ctx_begin_pos,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class ContextCLIPTextTransformer(nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = ContextCLIPTextEmbeddings(config)
self.encoder = CLIPEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim)
def forward(
self,
ctx_embeddings: torch.Tensor,
ctx_begin_pos: list,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is None:
raise ValueError("You have to specify either input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
ctx_embeddings=ctx_embeddings,
ctx_begin_pos=ctx_begin_pos,
)
bsz, seq_len = input_shape
if ctx_embeddings is not None:
seq_len += ctx_embeddings.size(1)
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=input_ids.device),
input_ids.to(torch.int).argmax(dim=-1),
]
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
class ContextCLIPTextEmbeddings(nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward(
self,
ctx_embeddings: torch.Tensor,
ctx_begin_pos: list,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
if ctx_embeddings is None:
ctx_len = 0
else:
ctx_len = ctx_embeddings.shape[1]
seq_length = (input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]) + ctx_len
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
# for each input embeddings, add the ctx embeddings at the correct position
input_embeds_ctx = []
bsz = inputs_embeds.shape[0]
if ctx_embeddings is not None:
for i in range(bsz):
cbp = ctx_begin_pos[i]
prefix = inputs_embeds[i, :cbp]
# remove the special token embedding
suffix = inputs_embeds[i, cbp:]
input_embeds_ctx.append(torch.cat([prefix, ctx_embeddings[i], suffix], dim=0))
inputs_embeds = torch.stack(input_embeds_ctx, dim=0)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings

View File

@@ -0,0 +1,339 @@
# Copyright 2023 Salesforce.com, inc.
# Copyright 2023 The HuggingFace Team. All rights reserved.#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import PIL
import torch
from transformers import CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import PNDMScheduler
from ...utils import (
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .blip_image_processing import BlipImageProcessor
from .modeling_blip2 import Blip2QFormerModel
from .modeling_ctx_clip import ContextCLIPTextModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusers.pipelines import BlipDiffusionPipeline
>>> from diffusers.utils import load_image
>>> import torch
>>> blip_diffusion_pipe = BlipDiffusionPipeline.from_pretrained(
... "Salesforce/blipdiffusion", torch_dtype=torch.float16
... ).to("cuda")
>>> cond_subject = "dog"
>>> tgt_subject = "dog"
>>> text_prompt_input = "swimming underwater"
>>> cond_image = load_image(
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/dog.jpg"
... )
>>> guidance_scale = 7.5
>>> num_inference_steps = 25
>>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
>>> output = blip_diffusion_pipe(
... text_prompt_input,
... cond_image,
... cond_subject,
... tgt_subject,
... guidance_scale=guidance_scale,
... num_inference_steps=num_inference_steps,
... neg_prompt=negative_prompt,
... height=512,
... width=512,
... ).images
>>> output[0].save("image.png")
```
"""
class BlipDiffusionPipeline(DiffusionPipeline):
"""
Pipeline for Zero-Shot Subject Driven Generation using Blip Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
tokenizer ([`CLIPTokenizer`]):
Tokenizer for the text encoder
text_encoder ([`ContextCLIPTextModel`]):
Text encoder to encode the text prompt
vae ([`AutoencoderKL`]):
VAE model to map the latents to the image
unet ([`UNet2DConditionModel`]):
Conditional U-Net architecture to denoise the image embedding.
scheduler ([`PNDMScheduler`]):
A scheduler to be used in combination with `unet` to generate image latents.
qformer ([`Blip2QFormerModel`]):
QFormer model to get multi-modal embeddings from the text and image.
image_processor ([`BlipImageProcessor`]):
Image Processor to preprocess and postprocess the image.
ctx_begin_pos (int, `optional`, defaults to 2):
Position of the context token in the text encoder.
"""
def __init__(
self,
tokenizer: CLIPTokenizer,
text_encoder: ContextCLIPTextModel,
vae: AutoencoderKL,
unet: UNet2DConditionModel,
scheduler: PNDMScheduler,
qformer: Blip2QFormerModel,
image_processor: BlipImageProcessor,
ctx_begin_pos: int = 2,
mean: List[float] = None,
std: List[float] = None,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
unet=unet,
scheduler=scheduler,
qformer=qformer,
image_processor=image_processor,
)
self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std)
def get_query_embeddings(self, input_image, src_subject):
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
# from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
rv = []
for prompt, tgt_subject in zip(prompts, tgt_subjects):
prompt = f"a {tgt_subject} {prompt.strip()}"
# a trick to amplify the prompt
rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps)))
return rv
# Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels, height, width)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def encode_prompt(self, query_embeds, prompt):
# embeddings for prompt, with query_embeds as context
max_len = self.text_encoder.text_model.config.max_position_embeddings
max_len -= self.qformer.config.num_query_tokens
tokenized_prompt = self.tokenizer(
prompt,
padding="max_length",
truncation=True,
max_length=max_len,
return_tensors="pt",
).to(self.device)
batch_size = query_embeds.shape[0]
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
text_embeddings = self.text_encoder(
input_ids=tokenized_prompt.input_ids,
ctx_embeddings=query_embeds,
ctx_begin_pos=ctx_begin_pos,
)[0]
return text_embeddings
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: List[str],
reference_image: PIL.Image.Image,
source_subject_category: List[str],
target_subject_category: List[str],
latents: Optional[torch.FloatTensor] = None,
guidance_scale: float = 7.5,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
neg_prompt: Optional[str] = "",
prompt_strength: float = 1.0,
prompt_reps: int = 20,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`List[str]`):
The prompt or prompts to guide the image generation.
reference_image (`PIL.Image.Image`):
The reference image to condition the generation on.
source_subject_category (`List[str]`):
The source subject category.
target_subject_category (`List[str]`):
The target subject category.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by random sampling.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
height (`int`, *optional*, defaults to 512):
The height of the generated image.
width (`int`, *optional*, defaults to 512):
The width of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
neg_prompt (`str`, *optional*, defaults to ""):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
prompt_strength (`float`, *optional*, defaults to 1.0):
The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps
to amplify the prompt.
prompt_reps (`int`, *optional*, defaults to 20):
The number of times the prompt is repeated along with prompt_strength to amplify the prompt.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
reference_image = self.image_processor.preprocess(
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
)["pixel_values"]
reference_image = reference_image.to(self.device)
if isinstance(prompt, str):
prompt = [prompt]
if isinstance(source_subject_category, str):
source_subject_category = [source_subject_category]
if isinstance(target_subject_category, str):
target_subject_category = [target_subject_category]
batch_size = len(prompt)
prompt = self._build_prompt(
prompts=prompt,
tgt_subjects=target_subject_category,
prompt_strength=prompt_strength,
prompt_reps=prompt_reps,
)
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
text_embeddings = self.encode_prompt(query_embeds, prompt)
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
max_length = self.text_encoder.text_model.config.max_position_embeddings
uncond_input = self.tokenizer(
[neg_prompt] * batch_size,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.to(self.device),
ctx_embeddings=None,
)[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1)
latents = self.prepare_latents(
batch_size=batch_size,
num_channels=self.unet.config.in_channels,
height=height // scale_down_factor,
width=width // scale_down_factor,
generator=generator,
latents=latents,
dtype=self.unet.dtype,
device=self.device,
)
# set timesteps
extra_set_kwargs = {}
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
do_classifier_free_guidance = guidance_scale > 1.0
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
noise_pred = self.unet(
latent_model_input,
timestep=t,
encoder_hidden_states=text_embeddings,
down_block_additional_residuals=None,
mid_block_additional_residual=None,
)["sample"]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(
noise_pred,
t,
latents,
)["prev_sample"]
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)

View File

@@ -1,77 +1,79 @@
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_flax_available,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["multicontrolnet"] = ["MultiControlNetModel"]
_import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
_import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
_import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
_import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
_import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
_import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_flax_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
else:
_import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
if TYPE_CHECKING:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .multicontrolnet import MultiControlNetModel
from .pipeline_controlnet import StableDiffusionControlNetPipeline
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_flax_available,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["multicontrolnet"] = ["MultiControlNetModel"]
_import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
_import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
_import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
_import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
_import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
_import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
_import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_flax_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
else:
_import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
if TYPE_CHECKING:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .multicontrolnet import MultiControlNetModel
from .pipeline_controlnet import StableDiffusionControlNetPipeline
from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,405 @@
# Copyright 2023 Salesforce.com, inc.
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import PIL
import torch
from transformers import CLIPTokenizer
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import PNDMScheduler
from ...utils import (
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..blip_diffusion.blip_image_processing import BlipImageProcessor
from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel
from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusers.pipelines import BlipDiffusionControlNetPipeline
>>> from diffusers.utils import load_image
>>> from controlnet_aux import CannyDetector
>>> import torch
>>> blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained(
... "Salesforce/blipdiffusion-controlnet", torch_dtype=torch.float16
... ).to("cuda")
>>> style_subject = "flower"
>>> tgt_subject = "teapot"
>>> text_prompt = "on a marble table"
>>> cldm_cond_image = load_image(
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/kettle.jpg"
... ).resize(512, 512)
>>> canny = CannyDetector()
>>> cldm_cond_image = canny(cldm_cond_image, 30, 70, output_type="pil")
>>> style_image = load_image(
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/flower.jpg"
... )
>>> guidance_scale = 7.5
>>> num_inference_steps = 50
>>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
>>> output = blip_diffusion_pipe(
... text_prompt,
... style_image,
... cldm_cond_image,
... style_subject,
... tgt_subject,
... guidance_scale=guidance_scale,
... num_inference_steps=num_inference_steps,
... neg_prompt=negative_prompt,
... height=512,
... width=512,
... ).images
>>> output[0].save("image.png")
```
"""
class BlipDiffusionControlNetPipeline(DiffusionPipeline):
"""
Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
tokenizer ([`CLIPTokenizer`]):
Tokenizer for the text encoder
text_encoder ([`ContextCLIPTextModel`]):
Text encoder to encode the text prompt
vae ([`AutoencoderKL`]):
VAE model to map the latents to the image
unet ([`UNet2DConditionModel`]):
Conditional U-Net architecture to denoise the image embedding.
scheduler ([`PNDMScheduler`]):
A scheduler to be used in combination with `unet` to generate image latents.
qformer ([`Blip2QFormerModel`]):
QFormer model to get multi-modal embeddings from the text and image.
controlnet ([`ControlNetModel`]):
ControlNet model to get the conditioning image embedding.
image_processor ([`BlipImageProcessor`]):
Image Processor to preprocess and postprocess the image.
ctx_begin_pos (int, `optional`, defaults to 2):
Position of the context token in the text encoder.
"""
def __init__(
self,
tokenizer: CLIPTokenizer,
text_encoder: ContextCLIPTextModel,
vae: AutoencoderKL,
unet: UNet2DConditionModel,
scheduler: PNDMScheduler,
qformer: Blip2QFormerModel,
controlnet: ControlNetModel,
image_processor: BlipImageProcessor,
ctx_begin_pos: int = 2,
mean: List[float] = None,
std: List[float] = None,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
unet=unet,
scheduler=scheduler,
qformer=qformer,
controlnet=controlnet,
image_processor=image_processor,
)
self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std)
def get_query_embeddings(self, input_image, src_subject):
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
# from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
rv = []
for prompt, tgt_subject in zip(prompts, tgt_subjects):
prompt = f"a {tgt_subject} {prompt.strip()}"
# a trick to amplify the prompt
rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps)))
return rv
# Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels, height, width)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def encode_prompt(self, query_embeds, prompt):
# embeddings for prompt, with query_embeds as context
max_len = self.text_encoder.text_model.config.max_position_embeddings
max_len -= self.qformer.config.num_query_tokens
tokenized_prompt = self.tokenizer(
prompt,
padding="max_length",
truncation=True,
max_length=max_len,
return_tensors="pt",
).to(self.device)
batch_size = query_embeds.shape[0]
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
text_embeddings = self.text_encoder(
input_ids=tokenized_prompt.input_ids,
ctx_embeddings=query_embeds,
ctx_begin_pos=ctx_begin_pos,
)[0]
return text_embeddings
# Adapted from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
def prepare_control_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
):
image = self.image_processor.preprocess(
image,
size={"width": width, "height": height},
do_rescale=True,
do_center_crop=False,
do_normalize=False,
return_tensors="pt",
)["pixel_values"].to(self.device)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance:
image = torch.cat([image] * 2)
return image
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: List[str],
reference_image: PIL.Image.Image,
condtioning_image: PIL.Image.Image,
source_subject_category: List[str],
target_subject_category: List[str],
latents: Optional[torch.FloatTensor] = None,
guidance_scale: float = 7.5,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
neg_prompt: Optional[str] = "",
prompt_strength: float = 1.0,
prompt_reps: int = 20,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`List[str]`):
The prompt or prompts to guide the image generation.
reference_image (`PIL.Image.Image`):
The reference image to condition the generation on.
condtioning_image (`PIL.Image.Image`):
The conditioning canny edge image to condition the generation on.
source_subject_category (`List[str]`):
The source subject category.
target_subject_category (`List[str]`):
The target subject category.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by random sampling.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
height (`int`, *optional*, defaults to 512):
The height of the generated image.
width (`int`, *optional*, defaults to 512):
The width of the generated image.
seed (`int`, *optional*, defaults to 42):
The seed to use for random generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
neg_prompt (`str`, *optional*, defaults to ""):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
prompt_strength (`float`, *optional*, defaults to 1.0):
The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps
to amplify the prompt.
prompt_reps (`int`, *optional*, defaults to 20):
The number of times the prompt is repeated along with prompt_strength to amplify the prompt.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
reference_image = self.image_processor.preprocess(
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
)["pixel_values"]
reference_image = reference_image.to(self.device)
if isinstance(prompt, str):
prompt = [prompt]
if isinstance(source_subject_category, str):
source_subject_category = [source_subject_category]
if isinstance(target_subject_category, str):
target_subject_category = [target_subject_category]
batch_size = len(prompt)
prompt = self._build_prompt(
prompts=prompt,
tgt_subjects=target_subject_category,
prompt_strength=prompt_strength,
prompt_reps=prompt_reps,
)
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
text_embeddings = self.encode_prompt(query_embeds, prompt)
# 3. unconditional embedding
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
max_length = self.text_encoder.text_model.config.max_position_embeddings
uncond_input = self.tokenizer(
[neg_prompt] * batch_size,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.to(self.device),
ctx_embeddings=None,
)[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1)
latents = self.prepare_latents(
batch_size=batch_size,
num_channels=self.unet.config.in_channels,
height=height // scale_down_factor,
width=width // scale_down_factor,
generator=generator,
latents=latents,
dtype=self.unet.dtype,
device=self.device,
)
# set timesteps
extra_set_kwargs = {}
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
cond_image = self.prepare_control_image(
image=condtioning_image,
width=width,
height=height,
batch_size=batch_size,
num_images_per_prompt=1,
device=self.device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
do_classifier_free_guidance = guidance_scale > 1.0
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
controlnet_cond=cond_image,
return_dict=False,
)
noise_pred = self.unet(
latent_model_input,
timestep=t,
encoder_hidden_states=text_embeddings,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
)["sample"]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(
noise_pred,
t,
latents,
)["prev_sample"]
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)

View File

@@ -609,6 +609,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed)
sampler_kwargs["noise_sampler"] = noise_sampler
if "generator" in inspect.signature(self.sampler).parameters:
sampler_kwargs["generator"] = generator
latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs)
if not output_type == "latent":

View File

@@ -787,8 +787,16 @@ class StableDiffusionXLAdapterPipeline(
height, width = self._default_height_width(height, width, image)
device = self._execution_device
adapter_input = _preprocess_adapter_image(image, height, width).to(device)
if isinstance(self.adapter, MultiAdapter):
adapter_input = []
for one_image in image:
one_image = _preprocess_adapter_image(one_image, height, width)
one_image = one_image.to(device=device, dtype=self.adapter.dtype)
adapter_input.append(one_image)
else:
adapter_input = _preprocess_adapter_image(image, height, width)
adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype)
original_size = original_size or (height, width)
target_size = target_size or (height, width)
@@ -865,10 +873,14 @@ class StableDiffusionXLAdapterPipeline(
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Prepare added time ids & embeddings & adapter features
adapter_input = adapter_input.type(latents.dtype)
adapter_state = self.adapter(adapter_input)
for k, v in enumerate(adapter_state):
adapter_state[k] = v * adapter_conditioning_scale
if isinstance(self.adapter, MultiAdapter):
adapter_state = self.adapter(adapter_input, adapter_conditioning_scale)
for k, v in enumerate(adapter_state):
adapter_state[k] = v
else:
adapter_state = self.adapter(adapter_input)
for k, v in enumerate(adapter_state):
adapter_state[k] = v * adapter_conditioning_scale
if num_images_per_prompt > 1:
for k, v in enumerate(adapter_state):
adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)

View File

@@ -89,6 +89,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
@@ -113,6 +116,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
use_karras_sigmas: Optional[bool] = False,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
steps_offset: int = 0,
@@ -243,9 +247,15 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
log_sigmas = np.log(sigmas)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
self.log_sigmas = torch.from_numpy(log_sigmas).to(device)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
sigmas = torch.from_numpy(sigmas).to(device=device)
@@ -269,7 +279,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]])
timesteps = torch.from_numpy(timesteps).to(device)
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
sigmas_interpol = sigmas_interpol.cpu()
log_sigmas = self.log_sigmas.cpu()
timesteps_interpol = np.array(
[self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol]
)
timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype)
interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
@@ -282,29 +298,44 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index = None
def sigma_to_t(self, sigma):
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = sigma.log()
log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - self.log_sigmas[:, None]
dists = log_sigma - log_sigmas[:, np.newaxis]
# get sigmas range
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = self.log_sigmas[low_idx]
high = self.log_sigmas[high_idx]
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)
w = np.clip(w, 0, 1)
# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.view(sigma.shape)
t = t.reshape(sigma.shape)
return t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
@property
def state_in_first_order(self):
return self.sample is None

View File

@@ -88,6 +88,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
@@ -112,6 +115,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
use_karras_sigmas: Optional[bool] = False,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
steps_offset: int = 0,
@@ -243,9 +247,14 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device)
log_sigmas = np.log(sigmas)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
sigmas = torch.from_numpy(sigmas).to(device=device)
@@ -260,7 +269,12 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
timesteps = torch.from_numpy(timesteps).to(device)
# interpolate timesteps
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype)
sigmas_interpol = sigmas_interpol.cpu()
log_sigmas = self.log_sigmas.cpu()
timesteps_interpol = np.array(
[self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol]
)
timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype)
interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten()
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])
@@ -273,29 +287,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index = None
def sigma_to_t(self, sigma):
# get log sigma
log_sigma = sigma.log()
# get distribution
dists = log_sigma - self.log_sigmas[:, None]
# get sigmas range
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = self.log_sigmas[low_idx]
high = self.log_sigmas[high_idx]
# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)
# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.view(sigma.shape)
return t
@property
def state_in_first_order(self):
return self.sample is None
@@ -318,6 +309,44 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
self._step_index = step_index.item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
# get sigmas range
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],

View File

@@ -315,6 +315,36 @@ class AutoPipelineForText2Image(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class BlipDiffusionControlNetPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class BlipDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class CLIPImageProjection(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -0,0 +1,196 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import CLIPTokenizer
from transformers.models.blip_2.configuration_blip_2 import Blip2Config
from transformers.models.clip.configuration_clip import CLIPTextConfig
from diffusers import AutoencoderKL, BlipDiffusionPipeline, PNDMScheduler, UNet2DConditionModel
from diffusers.utils.testing_utils import enable_full_determinism
from src.diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
from src.diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
from src.diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = BlipDiffusionPipeline
params = [
"prompt",
"reference_image",
"source_subject_category",
"target_subject_category",
]
batch_params = [
"prompt",
"reference_image",
"source_subject_category",
"target_subject_category",
]
required_optional_params = [
"generator",
"height",
"width",
"latents",
"guidance_scale",
"num_inference_steps",
"neg_prompt",
"guidance_scale",
"prompt_strength",
"prompt_reps",
]
def get_dummy_components(self):
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
vocab_size=1000,
hidden_size=16,
intermediate_size=16,
projection_dim=16,
num_hidden_layers=1,
num_attention_heads=1,
max_position_embeddings=77,
)
text_encoder = ContextCLIPTextModel(text_encoder_config)
vae = AutoencoderKL(
in_channels=4,
out_channels=4,
down_block_types=("DownEncoderBlock2D",),
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(32,),
layers_per_block=1,
act_fn="silu",
latent_channels=4,
norm_num_groups=16,
sample_size=16,
)
blip_vision_config = {
"hidden_size": 16,
"intermediate_size": 16,
"num_hidden_layers": 1,
"num_attention_heads": 1,
"image_size": 224,
"patch_size": 14,
"hidden_act": "quick_gelu",
}
blip_qformer_config = {
"vocab_size": 1000,
"hidden_size": 16,
"num_hidden_layers": 1,
"num_attention_heads": 1,
"intermediate_size": 16,
"max_position_embeddings": 512,
"cross_attention_frequency": 1,
"encoder_hidden_size": 16,
}
qformer_config = Blip2Config(
vision_config=blip_vision_config,
qformer_config=blip_qformer_config,
num_query_tokens=16,
tokenizer="hf-internal-testing/tiny-random-bert",
)
qformer = Blip2QFormerModel(qformer_config)
unet = UNet2DConditionModel(
block_out_channels=(16, 32),
norm_num_groups=16,
layers_per_block=1,
sample_size=16,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=16,
)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
scheduler = PNDMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
set_alpha_to_one=False,
skip_prk_steps=True,
)
vae.eval()
qformer.eval()
text_encoder.eval()
image_processor = BlipImageProcessor()
components = {
"text_encoder": text_encoder,
"vae": vae,
"qformer": qformer,
"unet": unet,
"tokenizer": tokenizer,
"scheduler": scheduler,
"image_processor": image_processor,
}
return components
def get_dummy_inputs(self, device, seed=0):
np.random.seed(seed)
reference_image = np.random.rand(32, 32, 3) * 255
reference_image = Image.fromarray(reference_image.astype("uint8")).convert("RGBA")
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "swimming underwater",
"generator": generator,
"reference_image": reference_image,
"source_subject_category": "dog",
"target_subject_category": "dog",
"height": 32,
"width": 32,
"guidance_scale": 7.5,
"num_inference_steps": 2,
"output_type": "np",
}
return inputs
def test_blipdiffusion(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
image = pipe(**self.get_dummy_inputs(device))[0]
image_slice = image[0, -3:, -3:, 0]
assert image.shape == (1, 16, 16, 4)
expected_slice = np.array([0.7096, 0.5900, 0.6703, 0.4032, 0.7766, 0.3629, 0.5447, 0.4149, 0.8172])
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"

View File

@@ -0,0 +1,216 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import CLIPTokenizer
from transformers.models.blip_2.configuration_blip_2 import Blip2Config
from transformers.models.clip.configuration_clip import CLIPTextConfig
from diffusers import (
AutoencoderKL,
BlipDiffusionControlNetPipeline,
ControlNetModel,
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import enable_full_determinism
from src.diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
from src.diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
from src.diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class BlipDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = BlipDiffusionControlNetPipeline
params = [
"prompt",
"reference_image",
"source_subject_category",
"target_subject_category",
"condtioning_image",
]
batch_params = [
"prompt",
"reference_image",
"source_subject_category",
"target_subject_category",
"condtioning_image",
]
required_optional_params = [
"generator",
"height",
"width",
"latents",
"guidance_scale",
"num_inference_steps",
"neg_prompt",
"guidance_scale",
"prompt_strength",
"prompt_reps",
]
def get_dummy_components(self):
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
vocab_size=1000,
hidden_size=16,
intermediate_size=16,
projection_dim=16,
num_hidden_layers=1,
num_attention_heads=1,
max_position_embeddings=77,
)
text_encoder = ContextCLIPTextModel(text_encoder_config)
vae = AutoencoderKL(
in_channels=4,
out_channels=4,
down_block_types=("DownEncoderBlock2D",),
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(32,),
layers_per_block=1,
act_fn="silu",
latent_channels=4,
norm_num_groups=16,
sample_size=16,
)
blip_vision_config = {
"hidden_size": 16,
"intermediate_size": 16,
"num_hidden_layers": 1,
"num_attention_heads": 1,
"image_size": 224,
"patch_size": 14,
"hidden_act": "quick_gelu",
}
blip_qformer_config = {
"vocab_size": 1000,
"hidden_size": 16,
"num_hidden_layers": 1,
"num_attention_heads": 1,
"intermediate_size": 16,
"max_position_embeddings": 512,
"cross_attention_frequency": 1,
"encoder_hidden_size": 16,
}
qformer_config = Blip2Config(
vision_config=blip_vision_config,
qformer_config=blip_qformer_config,
num_query_tokens=16,
tokenizer="hf-internal-testing/tiny-random-bert",
)
qformer = Blip2QFormerModel(qformer_config)
unet = UNet2DConditionModel(
block_out_channels=(4, 16),
layers_per_block=1,
norm_num_groups=4,
sample_size=16,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=16,
)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
scheduler = PNDMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
set_alpha_to_one=False,
skip_prk_steps=True,
)
controlnet = ControlNetModel(
block_out_channels=(4, 16),
layers_per_block=1,
in_channels=4,
norm_num_groups=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
cross_attention_dim=16,
conditioning_embedding_out_channels=(8, 16),
)
vae.eval()
qformer.eval()
text_encoder.eval()
image_processor = BlipImageProcessor()
components = {
"text_encoder": text_encoder,
"vae": vae,
"qformer": qformer,
"unet": unet,
"tokenizer": tokenizer,
"scheduler": scheduler,
"controlnet": controlnet,
"image_processor": image_processor,
}
return components
def get_dummy_inputs(self, device, seed=0):
np.random.seed(seed)
reference_image = np.random.rand(32, 32, 3) * 255
reference_image = Image.fromarray(reference_image.astype("uint8")).convert("RGBA")
cond_image = np.random.rand(32, 32, 3) * 255
cond_image = Image.fromarray(cond_image.astype("uint8")).convert("RGBA")
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "swimming underwater",
"generator": generator,
"reference_image": reference_image,
"condtioning_image": cond_image,
"source_subject_category": "dog",
"target_subject_category": "dog",
"height": 32,
"width": 32,
"guidance_scale": 7.5,
"num_inference_steps": 2,
"output_type": "np",
}
return inputs
def test_blipdiffusion_controlnet(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
image = pipe(**self.get_dummy_inputs(device))[0]
image_slice = image[0, -3:, -3:, 0]
assert image.shape == (1, 16, 16, 4)
expected_slice = np.array([0.7953, 0.7136, 0.6597, 0.4779, 0.7389, 0.4111, 0.5826, 0.4150, 0.8422])
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"

View File

@@ -20,17 +20,20 @@ import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
import diffusers
from diffusers import (
AutoencoderKL,
EulerDiscreteScheduler,
MultiAdapter,
StableDiffusionXLAdapterPipeline,
T2IAdapter,
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor
from diffusers.utils import logging
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
enable_full_determinism()
@@ -41,7 +44,7 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
def get_dummy_components(self):
def get_dummy_components(self, adapter_type="full_adapter_xl"):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
@@ -97,13 +100,38 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
adapter = T2IAdapter(
in_channels=3,
channels=[32, 64],
num_res_blocks=2,
downscale_factor=4,
adapter_type="full_adapter_xl",
)
if adapter_type == "full_adapter_xl":
adapter = T2IAdapter(
in_channels=3,
channels=[32, 64],
num_res_blocks=2,
downscale_factor=4,
adapter_type=adapter_type,
)
elif adapter_type == "multi_adapter":
adapter = MultiAdapter(
[
T2IAdapter(
in_channels=3,
channels=[32, 64],
num_res_blocks=2,
downscale_factor=4,
adapter_type="full_adapter_xl",
),
T2IAdapter(
in_channels=3,
channels=[32, 64],
num_res_blocks=2,
downscale_factor=4,
adapter_type="full_adapter_xl",
),
]
)
else:
raise ValueError(
f"Unknown adapter type: {adapter_type}, must be one of 'full_adapter_xl', or 'multi_adapter''"
)
components = {
"adapter": adapter,
"unet": unet,
@@ -118,8 +146,12 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te
}
return components
def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
def get_dummy_inputs(self, device, seed=0, num_images=1):
if num_images == 1:
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
else:
image = [floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) for _ in range(num_images)]
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
@@ -150,3 +182,202 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te
[0.5752919, 0.6022097, 0.4728038, 0.49861962, 0.57084894, 0.4644975, 0.5193715, 0.5133664, 0.4729858]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
class StableDiffusionXLMultiAdapterPipelineFastTests(
StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase
):
def get_dummy_components(self):
return super().get_dummy_components("multi_adapter")
def get_dummy_inputs(self, device, seed=0):
inputs = super().get_dummy_inputs(device, seed, num_images=2)
inputs["adapter_conditioning_scale"] = [0.5, 0.5]
return inputs
def test_stable_diffusion_adapter_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLAdapterPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array(
[0.5813032, 0.60995954, 0.47563356, 0.5056669, 0.57199144, 0.4631841, 0.5176794, 0.51252556, 0.47183886]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
def test_inference_batch_consistent(
self, batch_sizes=[2, 4, 13], additional_params_copy_to_batched_inputs=["num_inference_steps"]
):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
logger = logging.get_logger(pipe.__module__)
logger.setLevel(level=diffusers.logging.FATAL)
# batchify inputs
for batch_size in batch_sizes:
batched_inputs = {}
for name, value in inputs.items():
if name in self.batch_params:
# prompt is string
if name == "prompt":
len_prompt = len(value)
# make unequal batch sizes
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
# make last batch super long
batched_inputs[name][-1] = 100 * "very long"
elif name == "image":
batched_images = []
for image in value:
batched_images.append(batch_size * [image])
batched_inputs[name] = batched_images
else:
batched_inputs[name] = batch_size * [value]
elif name == "batch_size":
batched_inputs[name] = batch_size
else:
batched_inputs[name] = value
for arg in additional_params_copy_to_batched_inputs:
batched_inputs[arg] = inputs[arg]
batched_inputs["output_type"] = "np"
output = pipe(**batched_inputs)
assert len(output[0]) == batch_size
batched_inputs["output_type"] = "np"
output = pipe(**batched_inputs)[0]
assert output.shape[0] == batch_size
logger.setLevel(level=diffusers.logging.WARNING)
def test_num_images_per_prompt(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
batch_sizes = [1, 2]
num_images_per_prompts = [1, 2]
for batch_size in batch_sizes:
for num_images_per_prompt in num_images_per_prompts:
inputs = self.get_dummy_inputs(torch_device)
for key in inputs.keys():
if key in self.batch_params:
if key == "image":
batched_images = []
for image in inputs[key]:
batched_images.append(batch_size * [image])
inputs[key] = batched_images
else:
inputs[key] = batch_size * [inputs[key]]
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0]
assert images.shape[0] == batch_size * num_images_per_prompt
def test_inference_batch_single_identical(
self,
batch_size=3,
test_max_difference=None,
test_mean_pixel_difference=None,
relax_max_difference=False,
expected_max_diff=2e-3,
additional_params_copy_to_batched_inputs=["num_inference_steps"],
):
if test_max_difference is None:
# TODO(Pedro) - not sure why, but not at all reproducible at the moment it seems
# make sure that batched and non-batched is identical
test_max_difference = torch_device != "mps"
if test_mean_pixel_difference is None:
# TODO same as above
test_mean_pixel_difference = torch_device != "mps"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
logger = logging.get_logger(pipe.__module__)
logger.setLevel(level=diffusers.logging.FATAL)
# batchify inputs
batched_inputs = {}
batch_size = batch_size
for name, value in inputs.items():
if name in self.batch_params:
# prompt is string
if name == "prompt":
len_prompt = len(value)
# make unequal batch sizes
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
# make last batch super long
batched_inputs[name][-1] = 100 * "very long"
elif name == "image":
batched_images = []
for image in value:
batched_images.append(batch_size * [image])
batched_inputs[name] = batched_images
else:
batched_inputs[name] = batch_size * [value]
elif name == "batch_size":
batched_inputs[name] = batch_size
elif name == "generator":
batched_inputs[name] = [self.get_generator(i) for i in range(batch_size)]
else:
batched_inputs[name] = value
for arg in additional_params_copy_to_batched_inputs:
batched_inputs[arg] = inputs[arg]
output_batch = pipe(**batched_inputs)
assert output_batch[0].shape[0] == batch_size
inputs["generator"] = self.get_generator(0)
output = pipe(**inputs)
logger.setLevel(level=diffusers.logging.WARNING)
if test_max_difference:
if relax_max_difference:
# Taking the median of the largest <n> differences
# is resilient to outliers
diff = np.abs(output_batch[0][0] - output[0][0])
diff = diff.flatten()
diff.sort()
max_diff = np.median(diff[-5:])
else:
max_diff = np.abs(output_batch[0][0] - output[0][0]).max()
assert max_diff < expected_max_diff
if test_mean_pixel_difference:
assert_mean_pixel_difference(output_batch[0][0], output[0][0])