diff --git a/.github/workflows/pr_dependency_test.yml b/.github/workflows/pr_dependency_test.yml index baa83e20bf..102414076d 100644 --- a/.github/workflows/pr_dependency_test.yml +++ b/.github/workflows/pr_dependency_test.yml @@ -20,7 +20,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: "3.7" + python-version: "3.8" - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/pr_quality.yml b/.github/workflows/pr_quality.yml index 945c7bcc0b..9656cee341 100644 --- a/.github/workflows/pr_quality.yml +++ b/.github/workflows/pr_quality.yml @@ -20,7 +20,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: "3.7" + python-version: "3.8" - name: Install dependencies run: | python -m pip install --upgrade pip @@ -38,7 +38,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: "3.7" + python-version: "3.8" - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index eaca39797e..ff609ee769 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -17,7 +17,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v1 with: - python-version: 3.7 + python-version: 3.8 - name: Install requirements run: | diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index b56d9c094d..cc50a95643 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -216,6 +216,8 @@ title: AudioLDM 2 - local: api/pipelines/auto_pipeline title: AutoPipeline + - local: api/pipelines/blip_diffusion + title: BLIP Diffusion - local: api/pipelines/consistency_models title: Consistency Models - local: api/pipelines/controlnet diff --git a/docs/source/en/api/pipelines/blip_diffusion.md b/docs/source/en/api/pipelines/blip_diffusion.md new file mode 100644 index 0000000000..698e1f05fd --- /dev/null +++ b/docs/source/en/api/pipelines/blip_diffusion.md @@ -0,0 +1,29 @@ +# Blip Diffusion + +Blip Diffusion was proposed in [BLIP-Diffusion: Pre-trained Subject Representation for Controllable Text-to-Image Generation and Editing](https://arxiv.org/abs/2305.14720). It enables zero-shot subject-driven generation and control-guided zero-shot generation. + + +The abstract from the paper is: + +*Subject-driven text-to-image generation models create novel renditions of an input subject based on text prompts. Existing models suffer from lengthy fine-tuning and difficulties preserving the subject fidelity. To overcome these limitations, we introduce BLIP-Diffusion, a new subject-driven image generation model that supports multimodal control which consumes inputs of subject images and text prompts. Unlike other subject-driven generation models, BLIP-Diffusion introduces a new multimodal encoder which is pre-trained to provide subject representation. We first pre-train the multimodal encoder following BLIP-2 to produce visual representation aligned with the text. Then we design a subject representation learning task which enables a diffusion model to leverage such visual representation and generates new subject renditions. Compared with previous methods such as DreamBooth, our model enables zero-shot subject-driven generation, and efficient fine-tuning for customized subject with up to 20x speedup. We also demonstrate that BLIP-Diffusion can be flexibly combined with existing techniques such as ControlNet and prompt-to-prompt to enable novel subject-driven generation and editing applications.* + +The original codebase can be found at [salesforce/LAVIS](https://github.com/salesforce/LAVIS/tree/main/projects/blip-diffusion). You can find the official BLIP Diffusion checkpoints under the [hf.co/SalesForce](https://hf.co/SalesForce) organization. + +`BlipDiffusionPipeline` and `BlipDiffusionControlNetPipeline` were contributed by [`ayushtues`](https://github.com/ayushtues/). + + + +Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines. + + + + +## BlipDiffusionPipeline +[[autodoc]] BlipDiffusionPipeline + - all + - __call__ + +## BlipDiffusionControlNetPipeline +[[autodoc]] BlipDiffusionControlNetPipeline + - all + - __call__ diff --git a/docs/source/en/installation.md b/docs/source/en/installation.md index 50df14be3f..1a0951bf7b 100644 --- a/docs/source/en/installation.md +++ b/docs/source/en/installation.md @@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License. Install 🤗 Diffusers for whichever deep learning library you're working with. -🤗 Diffusers is tested on Python 3.7+, PyTorch 1.7.0+ and Flax. Follow the installation instructions below for the deep learning library you are using: +🤗 Diffusers is tested on Python 3.8+, PyTorch 1.7.0+ and Flax. Follow the installation instructions below for the deep learning library you are using: - [PyTorch](https://pytorch.org/get-started/locally/) installation instructions. - [Flax](https://flax.readthedocs.io/en/latest/) installation instructions. @@ -106,7 +106,7 @@ pip install -e ".[flax]" These commands will link the folder you cloned the repository to and your Python library paths. Python will now look inside the folder you cloned to in addition to the normal library paths. -For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.7/site-packages/`, Python will also search the `~/diffusers/` folder you cloned to. +For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.8/site-packages/`, Python will also search the `~/diffusers/` folder you cloned to. diff --git a/docs/source/ko/installation.md b/docs/source/ko/installation.md index a10f9f8d1b..4a9146a226 100644 --- a/docs/source/ko/installation.md +++ b/docs/source/ko/installation.md @@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License. 사용하시는 라이브러리에 맞는 🤗 Diffusers를 설치하세요. -🤗 Diffusers는 Python 3.7+, PyTorch 1.7.0+ 및 flax에서 테스트되었습니다. 사용중인 딥러닝 라이브러리에 대한 아래의 설치 안내를 따르세요. +🤗 Diffusers는 Python 3.8+, PyTorch 1.7.0+ 및 flax에서 테스트되었습니다. 사용중인 딥러닝 라이브러리에 대한 아래의 설치 안내를 따르세요. - [PyTorch 설치 안내](https://pytorch.org/get-started/locally/) - [Flax 설치 안내](https://flax.readthedocs.io/en/latest/) @@ -105,7 +105,7 @@ pip install -e ".[flax]" 이러한 명령어들은 저장소를 복제한 폴더와 Python 라이브러리 경로를 연결합니다. Python은 이제 일반 라이브러리 경로에 더하여 복제한 폴더 내부를 살펴봅니다. -예를들어 Python 패키지가 `~/anaconda3/envs/main/lib/python3.7/site-packages/`에 설치되어 있는 경우 Python은 복제한 폴더인 `~/diffusers/`도 검색합니다. +예를들어 Python 패키지가 `~/anaconda3/envs/main/lib/python3.8/site-packages/`에 설치되어 있는 경우 Python은 복제한 폴더인 `~/diffusers/`도 검색합니다. diff --git a/docs/source/zh/installation.md b/docs/source/zh/installation.md index 8cd3ad97cc..5777f1d286 100644 --- a/docs/source/zh/installation.md +++ b/docs/source/zh/installation.md @@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License. 在你正在使用的任意深度学习框架中安装 🤗 Diffusers 。 -🤗 Diffusers已在Python 3.7+、PyTorch 1.7.0+和Flax上进行了测试。按照下面的安装说明,针对你正在使用的深度学习框架进行安装: +🤗 Diffusers已在Python 3.8+、PyTorch 1.7.0+和Flax上进行了测试。按照下面的安装说明,针对你正在使用的深度学习框架进行安装: - [PyTorch](https://pytorch.org/get-started/locally/) installation instructions. - [Flax](https://flax.readthedocs.io/en/latest/) installation instructions. @@ -107,7 +107,7 @@ pip install -e ".[flax]" 这些命令将连接到你克隆的版本库和你的 Python 库路径。 现在,不只是在通常的库路径,Python 还会在你克隆的文件夹内寻找包。 -例如,如果你的 Python 包通常安装在 `~/anaconda3/envs/main/lib/python3.7/Site-packages/`,Python 也会搜索你克隆到的文件夹。`~/diffusers/`。 +例如,如果你的 Python 包通常安装在 `~/anaconda3/envs/main/lib/python3.8/Site-packages/`,Python 也会搜索你克隆到的文件夹。`~/diffusers/`。 diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 75087284cb..d04c616c57 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -908,6 +908,9 @@ def main(): if args.snr_gamma is not None: snr = jnp.array(compute_snr(timesteps)) snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr + if noise_scheduler.config.prediction_type == "v_prediction": + # velocity objective prediction requires SNR weights to be floored to a min value of 1. + snr_loss_weights = snr_loss_weights + 1 loss = loss * snr_loss_weights loss = loss.mean() diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index e262fb82da..6f815c0f85 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -224,6 +224,30 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st raise ValueError(f"{model_class} is not supported.") +def compute_snr(timesteps, noise_scheduler): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR + snr = (alpha / sigma) ** 2 + return snr + + def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( @@ -524,6 +548,13 @@ def parse_args(input_args=None): " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information." ), ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) parser.add_argument( "--pre_compute_text_embeddings", action="store_true", @@ -1261,17 +1292,34 @@ def main(args): # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) - - # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + # Compute instance loss + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(timesteps, noise_scheduler) + base_weight = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + if args.with_prior_preservation: # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients: diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py index 2548c3a286..b89de5e001 100644 --- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py +++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py @@ -875,6 +875,9 @@ def main(): mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + if noise_scheduler.config.prediction_type == "v_prediction": + # velocity objective prediction requires SNR weights to be floored to a min value of 1. + mse_loss_weights = mse_loss_weights + 1 # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index b00884bfb7..0d14e6ccd5 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -955,6 +955,9 @@ def main(): mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + if noise_scheduler.config.prediction_type == "v_prediction": + # velocity objective prediction requires SNR weights to be floored to a min value of 1. + mse_loss_weights = mse_loss_weights + 1 # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index b9830a83ae..5845bda0e5 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -786,6 +786,9 @@ def main(): mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + if noise_scheduler.config.prediction_type == "v_prediction": + # velocity objective prediction requires SNR weights to be floored to a min value of 1. + mse_loss_weights = mse_loss_weights + 1 # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 45ae1cc9ef..7a8c2c353e 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -1075,6 +1075,9 @@ def main(args): mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + if noise_scheduler.config.prediction_type == "v_prediction": + # velocity objective prediction requires SNR weights to be floored to a min value of 1. + mse_loss_weights = mse_loss_weights + 1 # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/scripts/convert_blipdiffusion_to_diffusers.py b/scripts/convert_blipdiffusion_to_diffusers.py new file mode 100644 index 0000000000..03cf67e547 --- /dev/null +++ b/scripts/convert_blipdiffusion_to_diffusers.py @@ -0,0 +1,343 @@ +""" +This script requires you to build `LAVIS` from source, since the pip version doesn't have BLIP Diffusion. Follow instructions here: https://github.com/salesforce/LAVIS/tree/main. +""" + +import argparse +import os +import tempfile + +import torch +from lavis.models import load_model_and_preprocess +from transformers import CLIPTokenizer +from transformers.models.blip_2.configuration_blip_2 import Blip2Config + +from diffusers import ( + AutoencoderKL, + PNDMScheduler, + UNet2DConditionModel, +) +from diffusers.pipelines import BlipDiffusionPipeline +from diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor +from diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel +from diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel + + +BLIP2_CONFIG = { + "vision_config": { + "hidden_size": 1024, + "num_hidden_layers": 23, + "num_attention_heads": 16, + "image_size": 224, + "patch_size": 14, + "intermediate_size": 4096, + "hidden_act": "quick_gelu", + }, + "qformer_config": { + "cross_attention_frequency": 1, + "encoder_hidden_size": 1024, + "vocab_size": 30523, + }, + "num_query_tokens": 16, +} +blip2config = Blip2Config(**BLIP2_CONFIG) + + +def qformer_model_from_original_config(): + qformer = Blip2QFormerModel(blip2config) + return qformer + + +def embeddings_from_original_checkpoint(model, diffuser_embeddings_prefix, original_embeddings_prefix): + embeddings = {} + embeddings.update( + { + f"{diffuser_embeddings_prefix}.word_embeddings.weight": model[ + f"{original_embeddings_prefix}.word_embeddings.weight" + ] + } + ) + embeddings.update( + { + f"{diffuser_embeddings_prefix}.position_embeddings.weight": model[ + f"{original_embeddings_prefix}.position_embeddings.weight" + ] + } + ) + embeddings.update( + {f"{diffuser_embeddings_prefix}.LayerNorm.weight": model[f"{original_embeddings_prefix}.LayerNorm.weight"]} + ) + embeddings.update( + {f"{diffuser_embeddings_prefix}.LayerNorm.bias": model[f"{original_embeddings_prefix}.LayerNorm.bias"]} + ) + return embeddings + + +def proj_layer_from_original_checkpoint(model, diffuser_proj_prefix, original_proj_prefix): + proj_layer = {} + proj_layer.update({f"{diffuser_proj_prefix}.dense1.weight": model[f"{original_proj_prefix}.dense1.weight"]}) + proj_layer.update({f"{diffuser_proj_prefix}.dense1.bias": model[f"{original_proj_prefix}.dense1.bias"]}) + proj_layer.update({f"{diffuser_proj_prefix}.dense2.weight": model[f"{original_proj_prefix}.dense2.weight"]}) + proj_layer.update({f"{diffuser_proj_prefix}.dense2.bias": model[f"{original_proj_prefix}.dense2.bias"]}) + proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.weight": model[f"{original_proj_prefix}.LayerNorm.weight"]}) + proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.bias": model[f"{original_proj_prefix}.LayerNorm.bias"]}) + return proj_layer + + +def attention_from_original_checkpoint(model, diffuser_attention_prefix, original_attention_prefix): + attention = {} + attention.update( + { + f"{diffuser_attention_prefix}.attention.query.weight": model[ + f"{original_attention_prefix}.self.query.weight" + ] + } + ) + attention.update( + {f"{diffuser_attention_prefix}.attention.query.bias": model[f"{original_attention_prefix}.self.query.bias"]} + ) + attention.update( + {f"{diffuser_attention_prefix}.attention.key.weight": model[f"{original_attention_prefix}.self.key.weight"]} + ) + attention.update( + {f"{diffuser_attention_prefix}.attention.key.bias": model[f"{original_attention_prefix}.self.key.bias"]} + ) + attention.update( + { + f"{diffuser_attention_prefix}.attention.value.weight": model[ + f"{original_attention_prefix}.self.value.weight" + ] + } + ) + attention.update( + {f"{diffuser_attention_prefix}.attention.value.bias": model[f"{original_attention_prefix}.self.value.bias"]} + ) + attention.update( + {f"{diffuser_attention_prefix}.output.dense.weight": model[f"{original_attention_prefix}.output.dense.weight"]} + ) + attention.update( + {f"{diffuser_attention_prefix}.output.dense.bias": model[f"{original_attention_prefix}.output.dense.bias"]} + ) + attention.update( + { + f"{diffuser_attention_prefix}.output.LayerNorm.weight": model[ + f"{original_attention_prefix}.output.LayerNorm.weight" + ] + } + ) + attention.update( + { + f"{diffuser_attention_prefix}.output.LayerNorm.bias": model[ + f"{original_attention_prefix}.output.LayerNorm.bias" + ] + } + ) + return attention + + +def output_layers_from_original_checkpoint(model, diffuser_output_prefix, original_output_prefix): + output_layers = {} + output_layers.update({f"{diffuser_output_prefix}.dense.weight": model[f"{original_output_prefix}.dense.weight"]}) + output_layers.update({f"{diffuser_output_prefix}.dense.bias": model[f"{original_output_prefix}.dense.bias"]}) + output_layers.update( + {f"{diffuser_output_prefix}.LayerNorm.weight": model[f"{original_output_prefix}.LayerNorm.weight"]} + ) + output_layers.update( + {f"{diffuser_output_prefix}.LayerNorm.bias": model[f"{original_output_prefix}.LayerNorm.bias"]} + ) + return output_layers + + +def encoder_from_original_checkpoint(model, diffuser_encoder_prefix, original_encoder_prefix): + encoder = {} + for i in range(blip2config.qformer_config.num_hidden_layers): + encoder.update( + attention_from_original_checkpoint( + model, f"{diffuser_encoder_prefix}.{i}.attention", f"{original_encoder_prefix}.{i}.attention" + ) + ) + encoder.update( + attention_from_original_checkpoint( + model, f"{diffuser_encoder_prefix}.{i}.crossattention", f"{original_encoder_prefix}.{i}.crossattention" + ) + ) + + encoder.update( + { + f"{diffuser_encoder_prefix}.{i}.intermediate.dense.weight": model[ + f"{original_encoder_prefix}.{i}.intermediate.dense.weight" + ] + } + ) + encoder.update( + { + f"{diffuser_encoder_prefix}.{i}.intermediate.dense.bias": model[ + f"{original_encoder_prefix}.{i}.intermediate.dense.bias" + ] + } + ) + encoder.update( + { + f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.weight": model[ + f"{original_encoder_prefix}.{i}.intermediate_query.dense.weight" + ] + } + ) + encoder.update( + { + f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.bias": model[ + f"{original_encoder_prefix}.{i}.intermediate_query.dense.bias" + ] + } + ) + + encoder.update( + output_layers_from_original_checkpoint( + model, f"{diffuser_encoder_prefix}.{i}.output", f"{original_encoder_prefix}.{i}.output" + ) + ) + encoder.update( + output_layers_from_original_checkpoint( + model, f"{diffuser_encoder_prefix}.{i}.output_query", f"{original_encoder_prefix}.{i}.output_query" + ) + ) + return encoder + + +def visual_encoder_layer_from_original_checkpoint(model, diffuser_prefix, original_prefix): + visual_encoder_layer = {} + + visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.weight": model[f"{original_prefix}.ln_1.weight"]}) + visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.bias": model[f"{original_prefix}.ln_1.bias"]}) + visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.weight": model[f"{original_prefix}.ln_2.weight"]}) + visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.bias": model[f"{original_prefix}.ln_2.bias"]}) + visual_encoder_layer.update( + {f"{diffuser_prefix}.self_attn.qkv.weight": model[f"{original_prefix}.attn.in_proj_weight"]} + ) + visual_encoder_layer.update( + {f"{diffuser_prefix}.self_attn.qkv.bias": model[f"{original_prefix}.attn.in_proj_bias"]} + ) + visual_encoder_layer.update( + {f"{diffuser_prefix}.self_attn.projection.weight": model[f"{original_prefix}.attn.out_proj.weight"]} + ) + visual_encoder_layer.update( + {f"{diffuser_prefix}.self_attn.projection.bias": model[f"{original_prefix}.attn.out_proj.bias"]} + ) + visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.weight": model[f"{original_prefix}.mlp.c_fc.weight"]}) + visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.bias": model[f"{original_prefix}.mlp.c_fc.bias"]}) + visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.weight": model[f"{original_prefix}.mlp.c_proj.weight"]}) + visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.bias": model[f"{original_prefix}.mlp.c_proj.bias"]}) + + return visual_encoder_layer + + +def visual_encoder_from_original_checkpoint(model, diffuser_prefix, original_prefix): + visual_encoder = {} + + visual_encoder.update( + { + f"{diffuser_prefix}.embeddings.class_embedding": model[f"{original_prefix}.class_embedding"] + .unsqueeze(0) + .unsqueeze(0) + } + ) + visual_encoder.update( + { + f"{diffuser_prefix}.embeddings.position_embedding": model[ + f"{original_prefix}.positional_embedding" + ].unsqueeze(0) + } + ) + visual_encoder.update( + {f"{diffuser_prefix}.embeddings.patch_embedding.weight": model[f"{original_prefix}.conv1.weight"]} + ) + visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.weight": model[f"{original_prefix}.ln_pre.weight"]}) + visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.bias": model[f"{original_prefix}.ln_pre.bias"]}) + + for i in range(blip2config.vision_config.num_hidden_layers): + visual_encoder.update( + visual_encoder_layer_from_original_checkpoint( + model, f"{diffuser_prefix}.encoder.layers.{i}", f"{original_prefix}.transformer.resblocks.{i}" + ) + ) + + visual_encoder.update({f"{diffuser_prefix}.post_layernorm.weight": model["blip.ln_vision.weight"]}) + visual_encoder.update({f"{diffuser_prefix}.post_layernorm.bias": model["blip.ln_vision.bias"]}) + + return visual_encoder + + +def qformer_original_checkpoint_to_diffusers_checkpoint(model): + qformer_checkpoint = {} + qformer_checkpoint.update(embeddings_from_original_checkpoint(model, "embeddings", "blip.Qformer.bert.embeddings")) + qformer_checkpoint.update({"query_tokens": model["blip.query_tokens"]}) + qformer_checkpoint.update(proj_layer_from_original_checkpoint(model, "proj_layer", "proj_layer")) + qformer_checkpoint.update( + encoder_from_original_checkpoint(model, "encoder.layer", "blip.Qformer.bert.encoder.layer") + ) + qformer_checkpoint.update(visual_encoder_from_original_checkpoint(model, "visual_encoder", "blip.visual_encoder")) + return qformer_checkpoint + + +def get_qformer(model): + print("loading qformer") + + qformer = qformer_model_from_original_config() + qformer_diffusers_checkpoint = qformer_original_checkpoint_to_diffusers_checkpoint(model) + + load_checkpoint_to_model(qformer_diffusers_checkpoint, qformer) + + print("done loading qformer") + return qformer + + +def load_checkpoint_to_model(checkpoint, model): + with tempfile.NamedTemporaryFile(delete=False) as file: + torch.save(checkpoint, file.name) + del checkpoint + model.load_state_dict(torch.load(file.name), strict=False) + + os.remove(file.name) + + +def save_blip_diffusion_model(model, args): + qformer = get_qformer(model) + qformer.eval() + + text_encoder = ContextCLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae") + + unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") + vae.eval() + text_encoder.eval() + scheduler = PNDMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + set_alpha_to_one=False, + skip_prk_steps=True, + ) + tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer") + image_processor = BlipImageProcessor() + blip_diffusion = BlipDiffusionPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + unet=unet, + scheduler=scheduler, + qformer=qformer, + image_processor=image_processor, + ) + blip_diffusion.save_pretrained(args.checkpoint_path) + + +def main(args): + model, _, _ = load_model_and_preprocess("blip_diffusion", "base", device="cpu", is_eval=True) + save_blip_diffusion_model(model.state_dict(), args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") + args = parser.parse_args() + + main(args) diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index bc4f06164c..2ca70963d1 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -35,6 +35,12 @@ if __name__ == "__main__": type=str, help="The YAML config file corresponding to the original architecture.", ) + parser.add_argument( + "--config_files", + default=None, + type=str, + help="The YAML config file corresponding to the architecture.", + ) parser.add_argument( "--num_in_channels", default=None, diff --git a/setup.py b/setup.py index 83ecff1b98..0d67f1b96b 100644 --- a/setup.py +++ b/setup.py @@ -257,7 +257,7 @@ setup( package_dir={"": "src"}, packages=find_packages("src"), include_package_data=True, - python_requires=">=3.7.0", + python_requires=">=3.8.0", install_requires=list(install_requires), extras_require=extras, entry_points={"console_scripts": ["diffusers-cli=diffusers.commands.diffusers_cli:main"]}, @@ -269,7 +269,6 @@ setup( "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Topic :: Scientific/Engineering :: Artificial Intelligence", diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8fdf8df60f..cc82cb09f9 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -197,6 +197,8 @@ else: "AudioLDM2ProjectionModel", "AudioLDM2UNet2DConditionModel", "AudioLDMPipeline", + "BlipDiffusionControlNetPipeline", + "BlipDiffusionPipeline", "CLIPImageProjection", "CycleDiffusionPipeline", "IFImg2ImgPipeline", @@ -458,6 +460,8 @@ if TYPE_CHECKING: AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image, + BlipDiffusionControlNetPipeline, + BlipDiffusionPipeline, CLIPImageProjection, ConsistencyModelPipeline, DanceDiffusionPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ef3566cf61..94e849a687 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,456 +1,460 @@ -from typing import TYPE_CHECKING - -from ..utils import ( - OptionalDependencyNotAvailable, - _LazyModule, - get_objects_from_module, - is_flax_available, - is_k_diffusion_available, - is_librosa_available, - is_note_seq_available, - is_onnx_available, - is_torch_available, - is_transformers_available, -) - - -# These modules contain pipelines from multiple libraries/frameworks -_dummy_objects = {} -_import_structure = {"stable_diffusion": [], "latent_diffusion": [], "controlnet": []} - -try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_pt_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) -else: - _import_structure["auto_pipeline"] = [ - "AutoPipelineForImage2Image", - "AutoPipelineForInpainting", - "AutoPipelineForText2Image", - ] - _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] - _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] - _import_structure["ddim"] = ["DDIMPipeline"] - _import_structure["ddpm"] = ["DDPMPipeline"] - _import_structure["dit"] = ["DiTPipeline"] - _import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"]) - _import_structure["latent_diffusion_uncond"] = ["LDMPipeline"] - _import_structure["pipeline_utils"] = ["AudioPipelineOutput", "DiffusionPipeline", "ImagePipelineOutput"] - _import_structure["pndm"] = ["PNDMPipeline"] - _import_structure["repaint"] = ["RePaintPipeline"] - _import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"] - _import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"] -try: - if not (is_torch_available() and is_librosa_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_torch_and_librosa_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects)) -else: - _import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"] -try: - if not (is_torch_available() and is_transformers_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_torch_and_transformers_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) -else: - _import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"] - _import_structure["audioldm"] = ["AudioLDMPipeline"] - _import_structure["audioldm2"] = [ - "AudioLDM2Pipeline", - "AudioLDM2ProjectionModel", - "AudioLDM2UNet2DConditionModel", - ] - _import_structure["controlnet"].extend( - [ - "StableDiffusionControlNetImg2ImgPipeline", - "StableDiffusionControlNetInpaintPipeline", - "StableDiffusionControlNetPipeline", - "StableDiffusionXLControlNetImg2ImgPipeline", - "StableDiffusionXLControlNetInpaintPipeline", - "StableDiffusionXLControlNetPipeline", - ] - ) - _import_structure["deepfloyd_if"] = [ - "IFImg2ImgPipeline", - "IFImg2ImgSuperResolutionPipeline", - "IFInpaintingPipeline", - "IFInpaintingSuperResolutionPipeline", - "IFPipeline", - "IFSuperResolutionPipeline", - ] - _import_structure["kandinsky"] = [ - "KandinskyCombinedPipeline", - "KandinskyImg2ImgCombinedPipeline", - "KandinskyImg2ImgPipeline", - "KandinskyInpaintCombinedPipeline", - "KandinskyInpaintPipeline", - "KandinskyPipeline", - "KandinskyPriorPipeline", - ] - _import_structure["kandinsky2_2"] = [ - "KandinskyV22CombinedPipeline", - "KandinskyV22ControlnetImg2ImgPipeline", - "KandinskyV22ControlnetPipeline", - "KandinskyV22Img2ImgCombinedPipeline", - "KandinskyV22Img2ImgPipeline", - "KandinskyV22InpaintCombinedPipeline", - "KandinskyV22InpaintPipeline", - "KandinskyV22Pipeline", - "KandinskyV22PriorEmb2EmbPipeline", - "KandinskyV22PriorPipeline", - ] - _import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"]) - _import_structure["musicldm"] = ["MusicLDMPipeline"] - _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] - _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] - _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"] - _import_structure["stable_diffusion"].extend( - [ - "CLIPImageProjection", - "CycleDiffusionPipeline", - "StableDiffusionAttendAndExcitePipeline", - "StableDiffusionDepth2ImgPipeline", - "StableDiffusionDiffEditPipeline", - "StableDiffusionGLIGENPipeline", - "StableDiffusionGLIGENPipeline", - "StableDiffusionGLIGENTextImagePipeline", - "StableDiffusionImageVariationPipeline", - "StableDiffusionImg2ImgPipeline", - "StableDiffusionInpaintPipeline", - "StableDiffusionInpaintPipelineLegacy", - "StableDiffusionInstructPix2PixPipeline", - "StableDiffusionLatentUpscalePipeline", - "StableDiffusionLDM3DPipeline", - "StableDiffusionModelEditingPipeline", - "StableDiffusionPanoramaPipeline", - "StableDiffusionParadigmsPipeline", - "StableDiffusionPipeline", - "StableDiffusionPix2PixZeroPipeline", - "StableDiffusionSAGPipeline", - "StableDiffusionUpscalePipeline", - "StableUnCLIPImg2ImgPipeline", - "StableUnCLIPPipeline", - ] - ) - _import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"] - _import_structure["stable_diffusion_xl"] = [ - "StableDiffusionXLImg2ImgPipeline", - "StableDiffusionXLInpaintPipeline", - "StableDiffusionXLInstructPix2PixPipeline", - "StableDiffusionXLPipeline", - ] - _import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"] - _import_structure["text_to_video_synthesis"] = [ - "TextToVideoSDPipeline", - "TextToVideoZeroPipeline", - "VideoToVideoSDPipeline", - ] - _import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"] - _import_structure["unidiffuser"] = [ - "ImageTextPipelineOutput", - "UniDiffuserModel", - "UniDiffuserPipeline", - "UniDiffuserTextDecoder", - ] - _import_structure["versatile_diffusion"] = [ - "VersatileDiffusionDualGuidedPipeline", - "VersatileDiffusionImageVariationPipeline", - "VersatileDiffusionPipeline", - "VersatileDiffusionTextToImagePipeline", - ] - _import_structure["vq_diffusion"] = ["VQDiffusionPipeline"] - _import_structure["wuerstchen"] = [ - "WuerstchenCombinedPipeline", - "WuerstchenDecoderPipeline", - "WuerstchenPriorPipeline", - ] -try: - if not is_onnx_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_onnx_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_onnx_objects)) -else: - _import_structure["onnx_utils"] = ["OnnxRuntimeModel"] -try: - if not (is_torch_available() and is_transformers_available() and is_onnx_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects)) -else: - _import_structure["stable_diffusion"].extend( - [ - "OnnxStableDiffusionImg2ImgPipeline", - "OnnxStableDiffusionInpaintPipeline", - "OnnxStableDiffusionInpaintPipelineLegacy", - "OnnxStableDiffusionPipeline", - "OnnxStableDiffusionUpscalePipeline", - "StableDiffusionOnnxPipeline", - ] - ) -try: - if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects)) -else: - _import_structure["stable_diffusion"].extend(["StableDiffusionKDiffusionPipeline"]) -try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_flax_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_flax_objects)) -else: - _import_structure["pipeline_flax_utils"] = ["FlaxDiffusionPipeline"] -try: - if not (is_flax_available() and is_transformers_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_flax_and_transformers_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) -else: - _import_structure["controlnet"].extend(["FlaxStableDiffusionControlNetPipeline"]) - _import_structure["stable_diffusion"].extend( - [ - "FlaxStableDiffusionImg2ImgPipeline", - "FlaxStableDiffusionInpaintPipeline", - "FlaxStableDiffusionPipeline", - ] - ) -try: - if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) -else: - _import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"] - -if TYPE_CHECKING: - try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_pt_objects import * # noqa F403 - - else: - from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image - from .consistency_models import ConsistencyModelPipeline - from .dance_diffusion import DanceDiffusionPipeline - from .ddim import DDIMPipeline - from .ddpm import DDPMPipeline - from .dit import DiTPipeline - from .latent_diffusion import LDMSuperResolutionPipeline - from .latent_diffusion_uncond import LDMPipeline - from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput - from .pndm import PNDMPipeline - from .repaint import RePaintPipeline - from .score_sde_ve import ScoreSdeVePipeline - from .stochastic_karras_ve import KarrasVePipeline - - try: - if not (is_torch_available() and is_librosa_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_librosa_objects import * - else: - from .audio_diffusion import AudioDiffusionPipeline, Mel - - try: - if not (is_torch_available() and is_transformers_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_transformers_objects import * - else: - from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline - from .audioldm import AudioLDMPipeline - from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel - from .controlnet import ( - StableDiffusionControlNetImg2ImgPipeline, - StableDiffusionControlNetInpaintPipeline, - StableDiffusionControlNetPipeline, - StableDiffusionXLControlNetImg2ImgPipeline, - StableDiffusionXLControlNetInpaintPipeline, - StableDiffusionXLControlNetPipeline, - ) - from .deepfloyd_if import ( - IFImg2ImgPipeline, - IFImg2ImgSuperResolutionPipeline, - IFInpaintingPipeline, - IFInpaintingSuperResolutionPipeline, - IFPipeline, - IFSuperResolutionPipeline, - ) - from .kandinsky import ( - KandinskyCombinedPipeline, - KandinskyImg2ImgCombinedPipeline, - KandinskyImg2ImgPipeline, - KandinskyInpaintCombinedPipeline, - KandinskyInpaintPipeline, - KandinskyPipeline, - KandinskyPriorPipeline, - ) - from .kandinsky2_2 import ( - KandinskyV22CombinedPipeline, - KandinskyV22ControlnetImg2ImgPipeline, - KandinskyV22ControlnetPipeline, - KandinskyV22Img2ImgCombinedPipeline, - KandinskyV22Img2ImgPipeline, - KandinskyV22InpaintCombinedPipeline, - KandinskyV22InpaintPipeline, - KandinskyV22Pipeline, - KandinskyV22PriorEmb2EmbPipeline, - KandinskyV22PriorPipeline, - ) - from .latent_diffusion import LDMTextToImagePipeline - from .musicldm import MusicLDMPipeline - from .paint_by_example import PaintByExamplePipeline - from .semantic_stable_diffusion import SemanticStableDiffusionPipeline - from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline - from .stable_diffusion import ( - CLIPImageProjection, - CycleDiffusionPipeline, - StableDiffusionAttendAndExcitePipeline, - StableDiffusionDepth2ImgPipeline, - StableDiffusionDiffEditPipeline, - StableDiffusionGLIGENPipeline, - StableDiffusionGLIGENTextImagePipeline, - StableDiffusionImageVariationPipeline, - StableDiffusionImg2ImgPipeline, - StableDiffusionInpaintPipeline, - StableDiffusionInpaintPipelineLegacy, - StableDiffusionInstructPix2PixPipeline, - StableDiffusionLatentUpscalePipeline, - StableDiffusionLDM3DPipeline, - StableDiffusionModelEditingPipeline, - StableDiffusionPanoramaPipeline, - StableDiffusionParadigmsPipeline, - StableDiffusionPipeline, - StableDiffusionPix2PixZeroPipeline, - StableDiffusionSAGPipeline, - StableDiffusionUpscalePipeline, - StableUnCLIPImg2ImgPipeline, - StableUnCLIPPipeline, - ) - from .stable_diffusion_safe import StableDiffusionPipelineSafe - from .stable_diffusion_xl import ( - StableDiffusionXLImg2ImgPipeline, - StableDiffusionXLInpaintPipeline, - StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLPipeline, - ) - from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline - from .text_to_video_synthesis import ( - TextToVideoSDPipeline, - TextToVideoZeroPipeline, - VideoToVideoSDPipeline, - ) - from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline - from .unidiffuser import ( - ImageTextPipelineOutput, - UniDiffuserModel, - UniDiffuserPipeline, - UniDiffuserTextDecoder, - ) - from .versatile_diffusion import ( - VersatileDiffusionDualGuidedPipeline, - VersatileDiffusionImageVariationPipeline, - VersatileDiffusionPipeline, - VersatileDiffusionTextToImagePipeline, - ) - from .vq_diffusion import VQDiffusionPipeline - from .wuerstchen import ( - WuerstchenCombinedPipeline, - WuerstchenDecoderPipeline, - WuerstchenPriorPipeline, - ) - - try: - if not is_onnx_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_onnx_objects import * # noqa F403 - - else: - from .onnx_utils import OnnxRuntimeModel - - try: - if not (is_torch_available() and is_transformers_available() and is_onnx_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_transformers_and_onnx_objects import * - else: - from .stable_diffusion import ( - OnnxStableDiffusionImg2ImgPipeline, - OnnxStableDiffusionInpaintPipeline, - OnnxStableDiffusionInpaintPipelineLegacy, - OnnxStableDiffusionPipeline, - OnnxStableDiffusionUpscalePipeline, - StableDiffusionOnnxPipeline, - ) - - try: - if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * - else: - from .stable_diffusion import StableDiffusionKDiffusionPipeline - - try: - if not is_flax_available(): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_flax_objects import * # noqa F403 - else: - from .pipeline_flax_utils import FlaxDiffusionPipeline - - try: - if not (is_flax_available() and is_transformers_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_flax_and_transformers_objects import * - else: - from .controlnet import FlaxStableDiffusionControlNetPipeline - from .stable_diffusion import ( - FlaxStableDiffusionImg2ImgPipeline, - FlaxStableDiffusionInpaintPipeline, - FlaxStableDiffusionPipeline, - ) - - try: - if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 - - else: - from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline - -else: - import sys - - sys.modules[__name__] = _LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - ) - for name, value in _dummy_objects.items(): - setattr(sys.modules[__name__], name, value) +from typing import TYPE_CHECKING + +from ..utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_k_diffusion_available, + is_librosa_available, + is_note_seq_available, + is_onnx_available, + is_torch_available, + is_transformers_available, +) + + +# These modules contain pipelines from multiple libraries/frameworks +_dummy_objects = {} +_import_structure = {"stable_diffusion": [], "latent_diffusion": [], "controlnet": []} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_pt_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) +else: + _import_structure["auto_pipeline"] = [ + "AutoPipelineForImage2Image", + "AutoPipelineForInpainting", + "AutoPipelineForText2Image", + ] + _import_structure["consistency_models"] = ["ConsistencyModelPipeline"] + _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] + _import_structure["ddim"] = ["DDIMPipeline"] + _import_structure["ddpm"] = ["DDPMPipeline"] + _import_structure["dit"] = ["DiTPipeline"] + _import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"]) + _import_structure["latent_diffusion_uncond"] = ["LDMPipeline"] + _import_structure["pipeline_utils"] = ["AudioPipelineOutput", "DiffusionPipeline", "ImagePipelineOutput"] + _import_structure["pndm"] = ["PNDMPipeline"] + _import_structure["repaint"] = ["RePaintPipeline"] + _import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"] + _import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"] +try: + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_torch_and_librosa_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects)) +else: + _import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"] +try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"] + _import_structure["audioldm"] = ["AudioLDMPipeline"] + _import_structure["audioldm2"] = [ + "AudioLDM2Pipeline", + "AudioLDM2ProjectionModel", + "AudioLDM2UNet2DConditionModel", + ] + _import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"] + _import_structure["controlnet"].extend( + [ + "BlipDiffusionControlNetPipeline", + "StableDiffusionControlNetImg2ImgPipeline", + "StableDiffusionControlNetInpaintPipeline", + "StableDiffusionControlNetPipeline", + "StableDiffusionXLControlNetImg2ImgPipeline", + "StableDiffusionXLControlNetInpaintPipeline", + "StableDiffusionXLControlNetPipeline", + ] + ) + _import_structure["deepfloyd_if"] = [ + "IFImg2ImgPipeline", + "IFImg2ImgSuperResolutionPipeline", + "IFInpaintingPipeline", + "IFInpaintingSuperResolutionPipeline", + "IFPipeline", + "IFSuperResolutionPipeline", + ] + _import_structure["kandinsky"] = [ + "KandinskyCombinedPipeline", + "KandinskyImg2ImgCombinedPipeline", + "KandinskyImg2ImgPipeline", + "KandinskyInpaintCombinedPipeline", + "KandinskyInpaintPipeline", + "KandinskyPipeline", + "KandinskyPriorPipeline", + ] + _import_structure["kandinsky2_2"] = [ + "KandinskyV22CombinedPipeline", + "KandinskyV22ControlnetImg2ImgPipeline", + "KandinskyV22ControlnetPipeline", + "KandinskyV22Img2ImgCombinedPipeline", + "KandinskyV22Img2ImgPipeline", + "KandinskyV22InpaintCombinedPipeline", + "KandinskyV22InpaintPipeline", + "KandinskyV22Pipeline", + "KandinskyV22PriorEmb2EmbPipeline", + "KandinskyV22PriorPipeline", + ] + _import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"]) + _import_structure["musicldm"] = ["MusicLDMPipeline"] + _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] + _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] + _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"] + _import_structure["stable_diffusion"].extend( + [ + "CLIPImageProjection", + "CycleDiffusionPipeline", + "StableDiffusionAttendAndExcitePipeline", + "StableDiffusionDepth2ImgPipeline", + "StableDiffusionDiffEditPipeline", + "StableDiffusionGLIGENPipeline", + "StableDiffusionGLIGENPipeline", + "StableDiffusionGLIGENTextImagePipeline", + "StableDiffusionImageVariationPipeline", + "StableDiffusionImg2ImgPipeline", + "StableDiffusionInpaintPipeline", + "StableDiffusionInpaintPipelineLegacy", + "StableDiffusionInstructPix2PixPipeline", + "StableDiffusionLatentUpscalePipeline", + "StableDiffusionLDM3DPipeline", + "StableDiffusionModelEditingPipeline", + "StableDiffusionPanoramaPipeline", + "StableDiffusionParadigmsPipeline", + "StableDiffusionPipeline", + "StableDiffusionPix2PixZeroPipeline", + "StableDiffusionSAGPipeline", + "StableDiffusionUpscalePipeline", + "StableUnCLIPImg2ImgPipeline", + "StableUnCLIPPipeline", + ] + ) + _import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"] + _import_structure["stable_diffusion_xl"] = [ + "StableDiffusionXLImg2ImgPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLInstructPix2PixPipeline", + "StableDiffusionXLPipeline", + ] + _import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"] + _import_structure["text_to_video_synthesis"] = [ + "TextToVideoSDPipeline", + "TextToVideoZeroPipeline", + "VideoToVideoSDPipeline", + ] + _import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"] + _import_structure["unidiffuser"] = [ + "ImageTextPipelineOutput", + "UniDiffuserModel", + "UniDiffuserPipeline", + "UniDiffuserTextDecoder", + ] + _import_structure["versatile_diffusion"] = [ + "VersatileDiffusionDualGuidedPipeline", + "VersatileDiffusionImageVariationPipeline", + "VersatileDiffusionPipeline", + "VersatileDiffusionTextToImagePipeline", + ] + _import_structure["vq_diffusion"] = ["VQDiffusionPipeline"] + _import_structure["wuerstchen"] = [ + "WuerstchenCombinedPipeline", + "WuerstchenDecoderPipeline", + "WuerstchenPriorPipeline", + ] +try: + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_onnx_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_onnx_objects)) +else: + _import_structure["onnx_utils"] = ["OnnxRuntimeModel"] +try: + if not (is_torch_available() and is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects)) +else: + _import_structure["stable_diffusion"].extend( + [ + "OnnxStableDiffusionImg2ImgPipeline", + "OnnxStableDiffusionInpaintPipeline", + "OnnxStableDiffusionInpaintPipelineLegacy", + "OnnxStableDiffusionPipeline", + "OnnxStableDiffusionUpscalePipeline", + "StableDiffusionOnnxPipeline", + ] + ) +try: + if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects)) +else: + _import_structure["stable_diffusion"].extend(["StableDiffusionKDiffusionPipeline"]) +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_flax_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_flax_objects)) +else: + _import_structure["pipeline_flax_utils"] = ["FlaxDiffusionPipeline"] +try: + if not (is_flax_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_flax_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) +else: + _import_structure["controlnet"].extend(["FlaxStableDiffusionControlNetPipeline"]) + _import_structure["stable_diffusion"].extend( + [ + "FlaxStableDiffusionImg2ImgPipeline", + "FlaxStableDiffusionInpaintPipeline", + "FlaxStableDiffusionPipeline", + ] + ) +try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) +else: + _import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"] + +if TYPE_CHECKING: + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_pt_objects import * # noqa F403 + + else: + from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image + from .consistency_models import ConsistencyModelPipeline + from .dance_diffusion import DanceDiffusionPipeline + from .ddim import DDIMPipeline + from .ddpm import DDPMPipeline + from .dit import DiTPipeline + from .latent_diffusion import LDMSuperResolutionPipeline + from .latent_diffusion_uncond import LDMPipeline + from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput + from .pndm import PNDMPipeline + from .repaint import RePaintPipeline + from .score_sde_ve import ScoreSdeVePipeline + from .stochastic_karras_ve import KarrasVePipeline + + try: + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_librosa_objects import * + else: + from .audio_diffusion import AudioDiffusionPipeline, Mel + + try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_objects import * + else: + from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline + from .audioldm import AudioLDMPipeline + from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel + from .blip_diffusion import BlipDiffusionPipeline + from .controlnet import ( + BlipDiffusionControlNetPipeline, + StableDiffusionControlNetImg2ImgPipeline, + StableDiffusionControlNetInpaintPipeline, + StableDiffusionControlNetPipeline, + StableDiffusionXLControlNetImg2ImgPipeline, + StableDiffusionXLControlNetInpaintPipeline, + StableDiffusionXLControlNetPipeline, + ) + from .deepfloyd_if import ( + IFImg2ImgPipeline, + IFImg2ImgSuperResolutionPipeline, + IFInpaintingPipeline, + IFInpaintingSuperResolutionPipeline, + IFPipeline, + IFSuperResolutionPipeline, + ) + from .kandinsky import ( + KandinskyCombinedPipeline, + KandinskyImg2ImgCombinedPipeline, + KandinskyImg2ImgPipeline, + KandinskyInpaintCombinedPipeline, + KandinskyInpaintPipeline, + KandinskyPipeline, + KandinskyPriorPipeline, + ) + from .kandinsky2_2 import ( + KandinskyV22CombinedPipeline, + KandinskyV22ControlnetImg2ImgPipeline, + KandinskyV22ControlnetPipeline, + KandinskyV22Img2ImgCombinedPipeline, + KandinskyV22Img2ImgPipeline, + KandinskyV22InpaintCombinedPipeline, + KandinskyV22InpaintPipeline, + KandinskyV22Pipeline, + KandinskyV22PriorEmb2EmbPipeline, + KandinskyV22PriorPipeline, + ) + from .latent_diffusion import LDMTextToImagePipeline + from .musicldm import MusicLDMPipeline + from .paint_by_example import PaintByExamplePipeline + from .semantic_stable_diffusion import SemanticStableDiffusionPipeline + from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline + from .stable_diffusion import ( + CLIPImageProjection, + CycleDiffusionPipeline, + StableDiffusionAttendAndExcitePipeline, + StableDiffusionDepth2ImgPipeline, + StableDiffusionDiffEditPipeline, + StableDiffusionGLIGENPipeline, + StableDiffusionGLIGENTextImagePipeline, + StableDiffusionImageVariationPipeline, + StableDiffusionImg2ImgPipeline, + StableDiffusionInpaintPipeline, + StableDiffusionInpaintPipelineLegacy, + StableDiffusionInstructPix2PixPipeline, + StableDiffusionLatentUpscalePipeline, + StableDiffusionLDM3DPipeline, + StableDiffusionModelEditingPipeline, + StableDiffusionPanoramaPipeline, + StableDiffusionParadigmsPipeline, + StableDiffusionPipeline, + StableDiffusionPix2PixZeroPipeline, + StableDiffusionSAGPipeline, + StableDiffusionUpscalePipeline, + StableUnCLIPImg2ImgPipeline, + StableUnCLIPPipeline, + ) + from .stable_diffusion_safe import StableDiffusionPipelineSafe + from .stable_diffusion_xl import ( + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLInstructPix2PixPipeline, + StableDiffusionXLPipeline, + ) + from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline + from .text_to_video_synthesis import ( + TextToVideoSDPipeline, + TextToVideoZeroPipeline, + VideoToVideoSDPipeline, + ) + from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline + from .unidiffuser import ( + ImageTextPipelineOutput, + UniDiffuserModel, + UniDiffuserPipeline, + UniDiffuserTextDecoder, + ) + from .versatile_diffusion import ( + VersatileDiffusionDualGuidedPipeline, + VersatileDiffusionImageVariationPipeline, + VersatileDiffusionPipeline, + VersatileDiffusionTextToImagePipeline, + ) + from .vq_diffusion import VQDiffusionPipeline + from .wuerstchen import ( + WuerstchenCombinedPipeline, + WuerstchenDecoderPipeline, + WuerstchenPriorPipeline, + ) + + try: + if not is_onnx_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_onnx_objects import * # noqa F403 + + else: + from .onnx_utils import OnnxRuntimeModel + + try: + if not (is_torch_available() and is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_and_onnx_objects import * + else: + from .stable_diffusion import ( + OnnxStableDiffusionImg2ImgPipeline, + OnnxStableDiffusionInpaintPipeline, + OnnxStableDiffusionInpaintPipelineLegacy, + OnnxStableDiffusionPipeline, + OnnxStableDiffusionUpscalePipeline, + StableDiffusionOnnxPipeline, + ) + + try: + if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * + else: + from .stable_diffusion import StableDiffusionKDiffusionPipeline + + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_flax_objects import * # noqa F403 + else: + from .pipeline_flax_utils import FlaxDiffusionPipeline + + try: + if not (is_flax_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_flax_and_transformers_objects import * + else: + from .controlnet import FlaxStableDiffusionControlNetPipeline + from .stable_diffusion import ( + FlaxStableDiffusionImg2ImgPipeline, + FlaxStableDiffusionInpaintPipeline, + FlaxStableDiffusionPipeline, + ) + + try: + if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 + + else: + from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/blip_diffusion/__init__.py b/src/diffusers/pipelines/blip_diffusion/__init__.py new file mode 100644 index 0000000000..af6c879d5c --- /dev/null +++ b/src/diffusers/pipelines/blip_diffusion/__init__.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL +from PIL import Image + +from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline +else: + from .blip_image_processing import BlipImageProcessor + from .modeling_blip2 import Blip2QFormerModel + from .modeling_ctx_clip import ContextCLIPTextModel + from .pipeline_blip_diffusion import BlipDiffusionPipeline diff --git a/src/diffusers/pipelines/blip_diffusion/blip_image_processing.py b/src/diffusers/pipelines/blip_diffusion/blip_image_processing.py new file mode 100644 index 0000000000..2c2911eb95 --- /dev/null +++ b/src/diffusers/pipelines/blip_diffusion/blip_image_processing.py @@ -0,0 +1,318 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for BLIP.""" + +from typing import Dict, List, Optional, Union + +import numpy as np +import torch +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from transformers.image_transforms import convert_to_rgb, resize, to_channel_dimension_format +from transformers.image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, +) +from transformers.utils import TensorType, is_vision_available, logging + +from diffusers.utils import numpy_to_pil + + +if is_vision_available(): + import PIL + + +logger = logging.get_logger(__name__) + + +# We needed some extra functions on top of the ones in transformers.image_processing_utils.BaseImageProcessor, namely center crop +# Copy-pasted from transformers.models.blip.image_processing_blip.BlipImageProcessor +class BlipImageProcessor(BaseImageProcessor): + r""" + Constructs a BLIP image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be + overridden by the `resample` parameter in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the + `do_rescale` parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be + overridden by the `rescale_factor` parameter in the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. Can be overridden by the `do_normalize` parameter in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be + overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + do_center_crop: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 224, "width": 224} + size = get_size_dict(size, default_to_square=True) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + self.do_center_crop = do_center_crop + + # Copy-pasted from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") + output_size = (size["height"], size["width"]) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + do_center_crop: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + do_convert_rgb: bool = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Controls the size of the image after `resize`. The shortest edge of the image is resized to + `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image + is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest + edge equal to `int(size["shortest_edge"] * (1333 / 800))`. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to normalize the image by if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to normalize the image by if `do_normalize` is set to `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if do_resize and size is None or resample is None: + raise ValueError("Size and resample must be specified if do_resize is True.") + + if do_rescale and rescale_factor is None: + raise ValueError("Rescale factor must be specified if do_rescale is True.") + + if do_normalize and (image_mean is None or image_std is None): + raise ValueError("Image mean and std must be specified if do_normalize is True.") + + # PIL RGBA images are converted to RGB + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + if do_center_crop: + images = [self.center_crop(image, size, input_data_format=input_data_format) for image in images] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) + return encoded_outputs + + # Follows diffusers.VaeImageProcessor.postprocess + def postprocess(self, sample: torch.FloatTensor, output_type: str = "pil"): + if output_type not in ["pt", "np", "pil"]: + raise ValueError( + f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']" + ) + + # Equivalent to diffusers.VaeImageProcessor.denormalize + sample = (sample / 2 + 0.5).clamp(0, 1) + if output_type == "pt": + return sample + + # Equivalent to diffusers.VaeImageProcessor.pt_to_numpy + sample = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "np": + return sample + # Output_type must be 'pil' + sample = numpy_to_pil(sample) + return sample diff --git a/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py b/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py new file mode 100644 index 0000000000..e2862af232 --- /dev/null +++ b/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py @@ -0,0 +1,642 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers import BertTokenizer +from transformers.activations import QuickGELUActivation as QuickGELU +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPooling, + BaseModelOutputWithPoolingAndCrossAttentions, +) +from transformers.models.blip_2.configuration_blip_2 import Blip2Config, Blip2VisionConfig +from transformers.models.blip_2.modeling_blip_2 import ( + Blip2Encoder, + Blip2PreTrainedModel, + Blip2QFormerAttention, + Blip2QFormerIntermediate, + Blip2QFormerOutput, +) +from transformers.pytorch_utils import apply_chunking_to_forward +from transformers.utils import ( + logging, + replace_return_docstrings, +) + + +logger = logging.get_logger(__name__) + + +# There is an implementation of Blip2 in `transformers` : https://github.com/huggingface/transformers/blob/main/src/transformers/models/blip_2/modeling_blip_2.py. +# But it doesn't support getting multimodal embeddings. So, this module can be +# replaced with a future `transformers` version supports that. +class Blip2TextEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + batch_size = embeddings.shape[0] + # repeat the query embeddings for batch size + query_embeds = query_embeds.repeat(batch_size, 1, 1) + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + embeddings = embeddings.to(query_embeds.dtype) + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Blip2 +class Blip2VisionEmbeddings(nn.Module): + def __init__(self, config: Blip2VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype) + return embeddings + + +# The Qformer encoder, which takes the visual embeddings, and the text input, to get multimodal embeddings +class Blip2QFormerEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [Blip2QFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions, query_length) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if layer_module.has_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# The layers making up the Qformer encoder +class Blip2QFormerLayer(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = Blip2QFormerAttention(config) + + self.layer_idx = layer_idx + + if layer_idx % config.cross_attention_frequency == 0: + self.crossattention = Blip2QFormerAttention(config, is_cross_attention=True) + self.has_cross_attention = True + else: + self.has_cross_attention = False + + self.intermediate = Blip2QFormerIntermediate(config) + self.intermediate_query = Blip2QFormerIntermediate(config) + self.output_query = Blip2QFormerOutput(config) + self.output = Blip2QFormerOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + if encoder_hidden_states is None: + raise ValueError("encoder_hidden_states must be given for cross-attention layers") + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:-1] + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +# ProjLayer used to project the multimodal Blip2 embeddings to be used in the text encoder +class ProjLayer(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, drop_p=0.1, eps=1e-12): + super().__init__() + + # Dense1 -> Act -> Dense2 -> Drop -> Res -> Norm + self.dense1 = nn.Linear(in_dim, hidden_dim) + self.act_fn = QuickGELU() + self.dense2 = nn.Linear(hidden_dim, out_dim) + self.dropout = nn.Dropout(drop_p) + + self.LayerNorm = nn.LayerNorm(out_dim, eps=eps) + + def forward(self, x): + x_in = x + + x = self.LayerNorm(x) + x = self.dropout(self.dense2(self.act_fn(self.dense1(x)))) + x_in + + return x + + +# Copy-pasted from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Blip2, BLIP->BLIP_2 +class Blip2VisionModel(Blip2PreTrainedModel): + main_input_name = "pixel_values" + config_class = Blip2VisionConfig + + def __init__(self, config: Blip2VisionConfig): + super().__init__(config) + self.config = config + embed_dim = config.hidden_size + self.embeddings = Blip2VisionEmbeddings(config) + self.pre_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = Blip2Encoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.post_init() + + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Blip2VisionConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layernorm(hidden_states) + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def get_input_embeddings(self): + return self.embeddings + + +# Qformer model, used to get multimodal embeddings from the text and image inputs +class Blip2QFormerModel(Blip2PreTrainedModel): + """ + Querying Transformer (Q-Former), used in BLIP-2. + """ + + def __init__(self, config: Blip2Config): + super().__init__(config) + self.config = config + self.embeddings = Blip2TextEmbeddings(config.qformer_config) + self.visual_encoder = Blip2VisionModel(config.vision_config) + self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) + if not hasattr(config, "tokenizer") or config.tokenizer is None: + self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right") + else: + self.tokenizer = BertTokenizer.from_pretrained(config.tokenizer, truncation_side="right") + self.tokenizer.add_special_tokens({"bos_token": "[DEC]"}) + self.proj_layer = ProjLayer( + in_dim=config.qformer_config.hidden_size, + out_dim=config.qformer_config.hidden_size, + hidden_dim=config.qformer_config.hidden_size * 4, + drop_p=0.1, + eps=1e-12, + ) + + self.encoder = Blip2QFormerEncoder(config.qformer_config) + + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: torch.Tensor, + input_shape: Tuple[int], + device: torch.device, + has_query: bool = False, + ) -> torch.Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`Tuple[int]`): + The shape of the input to the model. + device (`torch.device`): + The device of the input to the model. + + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + text_input=None, + image_input=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of: + shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and + value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are + used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key + value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape + `(batch_size, sequence_length)`. + use_cache (`bool`, `optional`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + + text = self.tokenizer(text_input, return_tensors="pt", padding=True) + text = text.to(self.device) + input_ids = text.input_ids + batch_size = input_ids.shape[0] + query_atts = torch.ones((batch_size, self.query_tokens.size()[1]), dtype=torch.long).to(self.device) + attention_mask = torch.cat([query_atts, text.attention_mask], dim=1) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0 + ) + + query_length = self.query_tokens.shape[1] + + embedding_output = self.embeddings( + input_ids=input_ids, + query_embeds=self.query_tokens, + past_key_values_length=past_key_values_length, + ) + + # embedding_output = self.layernorm(query_embeds) + # embedding_output = self.dropout(embedding_output) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + image_embeds_frozen = self.visual_encoder(image_input).last_hidden_state + # image_embeds_frozen = torch.ones_like(image_embeds_frozen) + encoder_hidden_states = image_embeds_frozen + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if isinstance(encoder_hidden_states, list): + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if isinstance(encoder_attention_mask, list): + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.qformer_config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = sequence_output[:, 0, :] + + if not return_dict: + return self.proj_layer(sequence_output[:, :query_length, :]) + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) diff --git a/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py b/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py new file mode 100644 index 0000000000..53d5718874 --- /dev/null +++ b/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py @@ -0,0 +1,212 @@ +# Copyright 2023 Salesforce.com, inc. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from transformers import CLIPPreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.models.clip.configuration_clip import CLIPTextConfig +from transformers.models.clip.modeling_clip import ( + CLIPEncoder, + _expand_mask, +) + + +# This is a modified version of the CLIPTextModel from transformers.models.clip.modeling_clip +# Which allows for an extra input of "context embeddings", which are the query embeddings used in Qformer +# They pass through the clip model, along with the text embeddings, and interact with them using self attention +class ContextCLIPTextModel(CLIPPreTrainedModel): + config_class = CLIPTextConfig + + _no_split_modules = ["CLIPEncoderLayer"] + + def __init__(self, config: CLIPTextConfig): + super().__init__(config) + self.text_model = ContextCLIPTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + ctx_embeddings: torch.Tensor = None, + ctx_begin_pos: list = None, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + return self.text_model( + ctx_embeddings=ctx_embeddings, + ctx_begin_pos=ctx_begin_pos, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class ContextCLIPTextTransformer(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = ContextCLIPTextEmbeddings(config) + self.encoder = CLIPEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim) + + def forward( + self, + ctx_embeddings: torch.Tensor, + ctx_begin_pos: list, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify either input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + ctx_embeddings=ctx_embeddings, + ctx_begin_pos=ctx_begin_pos, + ) + + bsz, seq_len = input_shape + if ctx_embeddings is not None: + seq_len += ctx_embeddings.size(1) + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( + hidden_states.device + ) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=input_ids.device), + input_ids.to(torch.int).argmax(dim=-1), + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def _build_causal_attention_mask(self, bsz, seq_len, dtype): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) + mask.fill_(torch.tensor(torch.finfo(dtype).min)) + mask.triu_(1) # zero out the lower diagonal + mask = mask.unsqueeze(1) # expand mask + return mask + + +class ContextCLIPTextEmbeddings(nn.Module): + def __init__(self, config: CLIPTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + + def forward( + self, + ctx_embeddings: torch.Tensor, + ctx_begin_pos: list, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + if ctx_embeddings is None: + ctx_len = 0 + else: + ctx_len = ctx_embeddings.shape[1] + + seq_length = (input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]) + ctx_len + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + # for each input embeddings, add the ctx embeddings at the correct position + input_embeds_ctx = [] + bsz = inputs_embeds.shape[0] + + if ctx_embeddings is not None: + for i in range(bsz): + cbp = ctx_begin_pos[i] + + prefix = inputs_embeds[i, :cbp] + # remove the special token embedding + suffix = inputs_embeds[i, cbp:] + + input_embeds_ctx.append(torch.cat([prefix, ctx_embeddings[i], suffix], dim=0)) + + inputs_embeds = torch.stack(input_embeds_ctx, dim=0) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings diff --git a/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py b/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py new file mode 100644 index 0000000000..3ca456c6f4 --- /dev/null +++ b/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py @@ -0,0 +1,339 @@ +# Copyright 2023 Salesforce.com, inc. +# Copyright 2023 The HuggingFace Team. All rights reserved.# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Union + +import PIL +import torch +from transformers import CLIPTokenizer + +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import PNDMScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .blip_image_processing import BlipImageProcessor +from .modeling_blip2 import Blip2QFormerModel +from .modeling_ctx_clip import ContextCLIPTextModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers.pipelines import BlipDiffusionPipeline + >>> from diffusers.utils import load_image + >>> import torch + + >>> blip_diffusion_pipe = BlipDiffusionPipeline.from_pretrained( + ... "Salesforce/blipdiffusion", torch_dtype=torch.float16 + ... ).to("cuda") + + + >>> cond_subject = "dog" + >>> tgt_subject = "dog" + >>> text_prompt_input = "swimming underwater" + + >>> cond_image = load_image( + ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/dog.jpg" + ... ) + >>> guidance_scale = 7.5 + >>> num_inference_steps = 25 + >>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate" + + + >>> output = blip_diffusion_pipe( + ... text_prompt_input, + ... cond_image, + ... cond_subject, + ... tgt_subject, + ... guidance_scale=guidance_scale, + ... num_inference_steps=num_inference_steps, + ... neg_prompt=negative_prompt, + ... height=512, + ... width=512, + ... ).images + >>> output[0].save("image.png") + ``` +""" + + +class BlipDiffusionPipeline(DiffusionPipeline): + """ + Pipeline for Zero-Shot Subject Driven Generation using Blip Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer ([`CLIPTokenizer`]): + Tokenizer for the text encoder + text_encoder ([`ContextCLIPTextModel`]): + Text encoder to encode the text prompt + vae ([`AutoencoderKL`]): + VAE model to map the latents to the image + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + scheduler ([`PNDMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + qformer ([`Blip2QFormerModel`]): + QFormer model to get multi-modal embeddings from the text and image. + image_processor ([`BlipImageProcessor`]): + Image Processor to preprocess and postprocess the image. + ctx_begin_pos (int, `optional`, defaults to 2): + Position of the context token in the text encoder. + """ + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: ContextCLIPTextModel, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + scheduler: PNDMScheduler, + qformer: Blip2QFormerModel, + image_processor: BlipImageProcessor, + ctx_begin_pos: int = 2, + mean: List[float] = None, + std: List[float] = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + unet=unet, + scheduler=scheduler, + qformer=qformer, + image_processor=image_processor, + ) + self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std) + + def get_query_embeddings(self, input_image, src_subject): + return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False) + + # from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it + def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20): + rv = [] + for prompt, tgt_subject in zip(prompts, tgt_subjects): + prompt = f"a {tgt_subject} {prompt.strip()}" + # a trick to amplify the prompt + rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps))) + + return rv + + # Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def encode_prompt(self, query_embeds, prompt): + # embeddings for prompt, with query_embeds as context + max_len = self.text_encoder.text_model.config.max_position_embeddings + max_len -= self.qformer.config.num_query_tokens + + tokenized_prompt = self.tokenizer( + prompt, + padding="max_length", + truncation=True, + max_length=max_len, + return_tensors="pt", + ).to(self.device) + + batch_size = query_embeds.shape[0] + ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size + + text_embeddings = self.text_encoder( + input_ids=tokenized_prompt.input_ids, + ctx_embeddings=query_embeds, + ctx_begin_pos=ctx_begin_pos, + )[0] + + return text_embeddings + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: List[str], + reference_image: PIL.Image.Image, + source_subject_category: List[str], + target_subject_category: List[str], + latents: Optional[torch.FloatTensor] = None, + guidance_scale: float = 7.5, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + neg_prompt: Optional[str] = "", + prompt_strength: float = 1.0, + prompt_reps: int = 20, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`List[str]`): + The prompt or prompts to guide the image generation. + reference_image (`PIL.Image.Image`): + The reference image to condition the generation on. + source_subject_category (`List[str]`): + The source subject category. + target_subject_category (`List[str]`): + The target subject category. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by random sampling. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + height (`int`, *optional*, defaults to 512): + The height of the generated image. + width (`int`, *optional*, defaults to 512): + The width of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + neg_prompt (`str`, *optional*, defaults to ""): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_strength (`float`, *optional*, defaults to 1.0): + The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps + to amplify the prompt. + prompt_reps (`int`, *optional*, defaults to 20): + The number of times the prompt is repeated along with prompt_strength to amplify the prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + + reference_image = self.image_processor.preprocess( + reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt" + )["pixel_values"] + reference_image = reference_image.to(self.device) + + if isinstance(prompt, str): + prompt = [prompt] + if isinstance(source_subject_category, str): + source_subject_category = [source_subject_category] + if isinstance(target_subject_category, str): + target_subject_category = [target_subject_category] + + batch_size = len(prompt) + + prompt = self._build_prompt( + prompts=prompt, + tgt_subjects=target_subject_category, + prompt_strength=prompt_strength, + prompt_reps=prompt_reps, + ) + query_embeds = self.get_query_embeddings(reference_image, source_subject_category) + text_embeddings = self.encode_prompt(query_embeds, prompt) + do_classifier_free_guidance = guidance_scale > 1.0 + if do_classifier_free_guidance: + max_length = self.text_encoder.text_model.config.max_position_embeddings + + uncond_input = self.tokenizer( + [neg_prompt] * batch_size, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder( + input_ids=uncond_input.input_ids.to(self.device), + ctx_embeddings=None, + )[0] + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1) + latents = self.prepare_latents( + batch_size=batch_size, + num_channels=self.unet.config.in_channels, + height=height // scale_down_factor, + width=width // scale_down_factor, + generator=generator, + latents=latents, + dtype=self.unet.dtype, + device=self.device, + ) + # set timesteps + extra_set_kwargs = {} + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + do_classifier_free_guidance = guidance_scale > 1.0 + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + noise_pred = self.unet( + latent_model_input, + timestep=t, + encoder_hidden_states=text_embeddings, + down_block_additional_residuals=None, + mid_block_additional_residual=None, + )["sample"] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step( + noise_pred, + t, + latents, + )["prev_sample"] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/controlnet/__init__.py b/src/diffusers/pipelines/controlnet/__init__.py index 5c551533f3..e14fd438a5 100644 --- a/src/diffusers/pipelines/controlnet/__init__.py +++ b/src/diffusers/pipelines/controlnet/__init__.py @@ -1,77 +1,79 @@ -from typing import TYPE_CHECKING - -from ...utils import ( - OptionalDependencyNotAvailable, - _LazyModule, - get_objects_from_module, - is_flax_available, - is_torch_available, - is_transformers_available, -) - - -_dummy_objects = {} -_import_structure = {} - -try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils import dummy_torch_and_transformers_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) -else: - _import_structure["multicontrolnet"] = ["MultiControlNetModel"] - _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"] - _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"] - _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"] - _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"] - _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"] - _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"] -try: - if not (is_transformers_available() and is_flax_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils import dummy_flax_and_transformers_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) -else: - _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"] - - -if TYPE_CHECKING: - try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() - - except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_objects import * - else: - from .multicontrolnet import MultiControlNetModel - from .pipeline_controlnet import StableDiffusionControlNetPipeline - from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline - from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline - from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline - from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline - from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline - - try: - if not (is_transformers_available() and is_flax_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 - else: - from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline - - -else: - import sys - - sys.modules[__name__] = _LazyModule( - __name__, - globals()["__file__"], - _import_structure, - module_spec=__spec__, - ) - for name, value in _dummy_objects.items(): - setattr(sys.modules[__name__], name, value) +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["multicontrolnet"] = ["MultiControlNetModel"] + _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"] + _import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"] + _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"] + _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"] + _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"] + _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"] + _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"] +try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_flax_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) +else: + _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"] + + +if TYPE_CHECKING: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .multicontrolnet import MultiControlNetModel + from .pipeline_controlnet import StableDiffusionControlNetPipeline + from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline + from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline + from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline + from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline + from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline + from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline + + try: + if not (is_transformers_available() and is_flax_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 + else: + from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py new file mode 100644 index 0000000000..1a7efa3212 --- /dev/null +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py @@ -0,0 +1,405 @@ +# Copyright 2023 Salesforce.com, inc. +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Union + +import PIL +import torch +from transformers import CLIPTokenizer + +from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...schedulers import PNDMScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..blip_diffusion.blip_image_processing import BlipImageProcessor +from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel +from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers.pipelines import BlipDiffusionControlNetPipeline + >>> from diffusers.utils import load_image + >>> from controlnet_aux import CannyDetector + >>> import torch + + >>> blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained( + ... "Salesforce/blipdiffusion-controlnet", torch_dtype=torch.float16 + ... ).to("cuda") + + >>> style_subject = "flower" + >>> tgt_subject = "teapot" + >>> text_prompt = "on a marble table" + + >>> cldm_cond_image = load_image( + ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/kettle.jpg" + ... ).resize(512, 512) + >>> canny = CannyDetector() + >>> cldm_cond_image = canny(cldm_cond_image, 30, 70, output_type="pil") + >>> style_image = load_image( + ... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/flower.jpg" + ... ) + >>> guidance_scale = 7.5 + >>> num_inference_steps = 50 + >>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate" + + + >>> output = blip_diffusion_pipe( + ... text_prompt, + ... style_image, + ... cldm_cond_image, + ... style_subject, + ... tgt_subject, + ... guidance_scale=guidance_scale, + ... num_inference_steps=num_inference_steps, + ... neg_prompt=negative_prompt, + ... height=512, + ... width=512, + ... ).images + >>> output[0].save("image.png") + ``` +""" + + +class BlipDiffusionControlNetPipeline(DiffusionPipeline): + """ + Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + tokenizer ([`CLIPTokenizer`]): + Tokenizer for the text encoder + text_encoder ([`ContextCLIPTextModel`]): + Text encoder to encode the text prompt + vae ([`AutoencoderKL`]): + VAE model to map the latents to the image + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + scheduler ([`PNDMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + qformer ([`Blip2QFormerModel`]): + QFormer model to get multi-modal embeddings from the text and image. + controlnet ([`ControlNetModel`]): + ControlNet model to get the conditioning image embedding. + image_processor ([`BlipImageProcessor`]): + Image Processor to preprocess and postprocess the image. + ctx_begin_pos (int, `optional`, defaults to 2): + Position of the context token in the text encoder. + """ + + def __init__( + self, + tokenizer: CLIPTokenizer, + text_encoder: ContextCLIPTextModel, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + scheduler: PNDMScheduler, + qformer: Blip2QFormerModel, + controlnet: ControlNetModel, + image_processor: BlipImageProcessor, + ctx_begin_pos: int = 2, + mean: List[float] = None, + std: List[float] = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + unet=unet, + scheduler=scheduler, + qformer=qformer, + controlnet=controlnet, + image_processor=image_processor, + ) + self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std) + + def get_query_embeddings(self, input_image, src_subject): + return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False) + + # from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it + def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20): + rv = [] + for prompt, tgt_subject in zip(prompts, tgt_subjects): + prompt = f"a {tgt_subject} {prompt.strip()}" + # a trick to amplify the prompt + rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps))) + + return rv + + # Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels, height, width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def encode_prompt(self, query_embeds, prompt): + # embeddings for prompt, with query_embeds as context + max_len = self.text_encoder.text_model.config.max_position_embeddings + max_len -= self.qformer.config.num_query_tokens + + tokenized_prompt = self.tokenizer( + prompt, + padding="max_length", + truncation=True, + max_length=max_len, + return_tensors="pt", + ).to(self.device) + + batch_size = query_embeds.shape[0] + ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size + + text_embeddings = self.text_encoder( + input_ids=tokenized_prompt.input_ids, + ctx_embeddings=query_embeds, + ctx_begin_pos=ctx_begin_pos, + )[0] + + return text_embeddings + + # Adapted from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + ): + image = self.image_processor.preprocess( + image, + size={"width": width, "height": height}, + do_rescale=True, + do_center_crop=False, + do_normalize=False, + return_tensors="pt", + )["pixel_values"].to(self.device) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + + return image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: List[str], + reference_image: PIL.Image.Image, + condtioning_image: PIL.Image.Image, + source_subject_category: List[str], + target_subject_category: List[str], + latents: Optional[torch.FloatTensor] = None, + guidance_scale: float = 7.5, + height: int = 512, + width: int = 512, + num_inference_steps: int = 50, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + neg_prompt: Optional[str] = "", + prompt_strength: float = 1.0, + prompt_reps: int = 20, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`List[str]`): + The prompt or prompts to guide the image generation. + reference_image (`PIL.Image.Image`): + The reference image to condition the generation on. + condtioning_image (`PIL.Image.Image`): + The conditioning canny edge image to condition the generation on. + source_subject_category (`List[str]`): + The source subject category. + target_subject_category (`List[str]`): + The target subject category. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by random sampling. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + height (`int`, *optional*, defaults to 512): + The height of the generated image. + width (`int`, *optional*, defaults to 512): + The width of the generated image. + seed (`int`, *optional*, defaults to 42): + The seed to use for random generation. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + neg_prompt (`str`, *optional*, defaults to ""): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_strength (`float`, *optional*, defaults to 1.0): + The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps + to amplify the prompt. + prompt_reps (`int`, *optional*, defaults to 20): + The number of times the prompt is repeated along with prompt_strength to amplify the prompt. + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + + reference_image = self.image_processor.preprocess( + reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt" + )["pixel_values"] + reference_image = reference_image.to(self.device) + + if isinstance(prompt, str): + prompt = [prompt] + if isinstance(source_subject_category, str): + source_subject_category = [source_subject_category] + if isinstance(target_subject_category, str): + target_subject_category = [target_subject_category] + + batch_size = len(prompt) + + prompt = self._build_prompt( + prompts=prompt, + tgt_subjects=target_subject_category, + prompt_strength=prompt_strength, + prompt_reps=prompt_reps, + ) + query_embeds = self.get_query_embeddings(reference_image, source_subject_category) + text_embeddings = self.encode_prompt(query_embeds, prompt) + # 3. unconditional embedding + do_classifier_free_guidance = guidance_scale > 1.0 + if do_classifier_free_guidance: + max_length = self.text_encoder.text_model.config.max_position_embeddings + + uncond_input = self.tokenizer( + [neg_prompt] * batch_size, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + uncond_embeddings = self.text_encoder( + input_ids=uncond_input.input_ids.to(self.device), + ctx_embeddings=None, + )[0] + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1) + latents = self.prepare_latents( + batch_size=batch_size, + num_channels=self.unet.config.in_channels, + height=height // scale_down_factor, + width=width // scale_down_factor, + generator=generator, + latents=latents, + dtype=self.unet.dtype, + device=self.device, + ) + # set timesteps + extra_set_kwargs = {} + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + + cond_image = self.prepare_control_image( + image=condtioning_image, + width=width, + height=height, + batch_size=batch_size, + num_images_per_prompt=1, + device=self.device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + # expand the latents if we are doing classifier free guidance + do_classifier_free_guidance = guidance_scale > 1.0 + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + down_block_res_samples, mid_block_res_sample = self.controlnet( + latent_model_input, + t, + encoder_hidden_states=text_embeddings, + controlnet_cond=cond_image, + return_dict=False, + ) + + noise_pred = self.unet( + latent_model_input, + timestep=t, + encoder_hidden_states=text_embeddings, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + )["sample"] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step( + noise_pred, + t, + latents, + )["prev_sample"] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index 2912eaaee8..ff85af14fb 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -609,6 +609,9 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade noise_sampler = BrownianTreeNoiseSampler(latents, min_sigma, max_sigma, noise_sampler_seed) sampler_kwargs["noise_sampler"] = noise_sampler + if "generator" in inspect.signature(self.sampler).parameters: + sampler_kwargs["generator"] = generator + latents = self.sampler(model_fn, latents, sigmas, **sampler_kwargs) if not output_type == "latent": diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index 1edd12904a..c7fbb0c86c 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -787,8 +787,16 @@ class StableDiffusionXLAdapterPipeline( height, width = self._default_height_width(height, width, image) device = self._execution_device - adapter_input = _preprocess_adapter_image(image, height, width).to(device) + if isinstance(self.adapter, MultiAdapter): + adapter_input = [] + for one_image in image: + one_image = _preprocess_adapter_image(one_image, height, width) + one_image = one_image.to(device=device, dtype=self.adapter.dtype) + adapter_input.append(one_image) + else: + adapter_input = _preprocess_adapter_image(image, height, width) + adapter_input = adapter_input.to(device=device, dtype=self.adapter.dtype) original_size = original_size or (height, width) target_size = target_size or (height, width) @@ -865,10 +873,14 @@ class StableDiffusionXLAdapterPipeline( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Prepare added time ids & embeddings & adapter features - adapter_input = adapter_input.type(latents.dtype) - adapter_state = self.adapter(adapter_input) - for k, v in enumerate(adapter_state): - adapter_state[k] = v * adapter_conditioning_scale + if isinstance(self.adapter, MultiAdapter): + adapter_state = self.adapter(adapter_input, adapter_conditioning_scale) + for k, v in enumerate(adapter_state): + adapter_state[k] = v + else: + adapter_state = self.adapter(adapter_input) + for k, v in enumerate(adapter_state): + adapter_state[k] = v * adapter_conditioning_scale if num_images_per_prompt > 1: for k, v in enumerate(adapter_state): adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index b44ff31379..a0137b83fd 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -89,6 +89,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): `linear` or `scaled_linear`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen @@ -113,6 +116,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + use_karras_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, @@ -243,9 +247,15 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) + log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + + self.log_sigmas = torch.from_numpy(log_sigmas).to(device) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(device=device) @@ -269,7 +279,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): self.sigmas_down = torch.cat([sigmas_down[:1], sigmas_down[1:].repeat_interleave(2), sigmas_down[-1:]]) timesteps = torch.from_numpy(timesteps).to(device) - timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype) + sigmas_interpol = sigmas_interpol.cpu() + log_sigmas = self.log_sigmas.cpu() + timesteps_interpol = np.array( + [self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol] + ) + + timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype) interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten() self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) @@ -282,29 +298,44 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): self._step_index = None - def sigma_to_t(self, sigma): + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): # get log sigma - log_sigma = sigma.log() + log_sigma = np.log(sigma) # get distribution - dists = log_sigma - self.log_sigmas[:, None] + dists = log_sigma - log_sigmas[:, np.newaxis] # get sigmas range - low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) high_idx = low_idx + 1 - low = self.log_sigmas[low_idx] - high = self.log_sigmas[high_idx] + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] # interpolate sigmas w = (low - log_sigma) / (low - high) - w = w.clamp(0, 1) + w = np.clip(w, 0, 1) # transform interpolation to time range t = (1 - w) * low_idx + w * high_idx - t = t.view(sigma.shape) + t = t.reshape(sigma.shape) return t + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + @property def state_in_first_order(self): return self.sample is None diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index d8f54a21e2..ddea57e8c1 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -88,6 +88,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): `linear` or `scaled_linear`. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen @@ -112,6 +115,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + use_karras_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", timestep_spacing: str = "linspace", steps_offset: int = 0, @@ -243,9 +247,14 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - self.log_sigmas = torch.from_numpy(np.log(sigmas)).to(device) - + log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + + self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) sigmas = torch.from_numpy(sigmas).to(device=device) @@ -260,7 +269,12 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): timesteps = torch.from_numpy(timesteps).to(device) # interpolate timesteps - timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device, dtype=timesteps.dtype) + sigmas_interpol = sigmas_interpol.cpu() + log_sigmas = self.log_sigmas.cpu() + timesteps_interpol = np.array( + [self._sigma_to_t(sigma_interpol, log_sigmas) for sigma_interpol in sigmas_interpol] + ) + timesteps_interpol = torch.from_numpy(timesteps_interpol).to(device, dtype=timesteps.dtype) interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten() self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) @@ -273,29 +287,6 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): self._step_index = None - def sigma_to_t(self, sigma): - # get log sigma - log_sigma = sigma.log() - - # get distribution - dists = log_sigma - self.log_sigmas[:, None] - - # get sigmas range - low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) - high_idx = low_idx + 1 - - low = self.log_sigmas[low_idx] - high = self.log_sigmas[high_idx] - - # interpolate sigmas - w = (low - log_sigma) / (low - high) - w = w.clamp(0, 1) - - # transform interpolation to time range - t = (1 - w) * low_idx + w * high_idx - t = t.view(sigma.shape) - return t - @property def state_in_first_order(self): return self.sample is None @@ -318,6 +309,44 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): self._step_index = step_index.item() + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + def step( self, model_output: Union[torch.FloatTensor, np.ndarray], diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 0c7b3117fa..8e95dde52c 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -315,6 +315,36 @@ class AutoPipelineForText2Image(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class BlipDiffusionControlNetPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + +class BlipDiffusionPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class CLIPImageProjection(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/blipdiffusion/__init__.py b/tests/pipelines/blipdiffusion/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pipelines/blipdiffusion/test_blipdiffusion.py b/tests/pipelines/blipdiffusion/test_blipdiffusion.py new file mode 100644 index 0000000000..480581928c --- /dev/null +++ b/tests/pipelines/blipdiffusion/test_blipdiffusion.py @@ -0,0 +1,196 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import CLIPTokenizer +from transformers.models.blip_2.configuration_blip_2 import Blip2Config +from transformers.models.clip.configuration_clip import CLIPTextConfig + +from diffusers import AutoencoderKL, BlipDiffusionPipeline, PNDMScheduler, UNet2DConditionModel +from diffusers.utils.testing_utils import enable_full_determinism +from src.diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor +from src.diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel +from src.diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = BlipDiffusionPipeline + params = [ + "prompt", + "reference_image", + "source_subject_category", + "target_subject_category", + ] + batch_params = [ + "prompt", + "reference_image", + "source_subject_category", + "target_subject_category", + ] + required_optional_params = [ + "generator", + "height", + "width", + "latents", + "guidance_scale", + "num_inference_steps", + "neg_prompt", + "guidance_scale", + "prompt_strength", + "prompt_reps", + ] + + def get_dummy_components(self): + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + vocab_size=1000, + hidden_size=16, + intermediate_size=16, + projection_dim=16, + num_hidden_layers=1, + num_attention_heads=1, + max_position_embeddings=77, + ) + text_encoder = ContextCLIPTextModel(text_encoder_config) + + vae = AutoencoderKL( + in_channels=4, + out_channels=4, + down_block_types=("DownEncoderBlock2D",), + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(32,), + layers_per_block=1, + act_fn="silu", + latent_channels=4, + norm_num_groups=16, + sample_size=16, + ) + + blip_vision_config = { + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 1, + "num_attention_heads": 1, + "image_size": 224, + "patch_size": 14, + "hidden_act": "quick_gelu", + } + + blip_qformer_config = { + "vocab_size": 1000, + "hidden_size": 16, + "num_hidden_layers": 1, + "num_attention_heads": 1, + "intermediate_size": 16, + "max_position_embeddings": 512, + "cross_attention_frequency": 1, + "encoder_hidden_size": 16, + } + qformer_config = Blip2Config( + vision_config=blip_vision_config, + qformer_config=blip_qformer_config, + num_query_tokens=16, + tokenizer="hf-internal-testing/tiny-random-bert", + ) + qformer = Blip2QFormerModel(qformer_config) + + unet = UNet2DConditionModel( + block_out_channels=(16, 32), + norm_num_groups=16, + layers_per_block=1, + sample_size=16, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=16, + ) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + scheduler = PNDMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + set_alpha_to_one=False, + skip_prk_steps=True, + ) + + vae.eval() + qformer.eval() + text_encoder.eval() + + image_processor = BlipImageProcessor() + + components = { + "text_encoder": text_encoder, + "vae": vae, + "qformer": qformer, + "unet": unet, + "tokenizer": tokenizer, + "scheduler": scheduler, + "image_processor": image_processor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + np.random.seed(seed) + reference_image = np.random.rand(32, 32, 3) * 255 + reference_image = Image.fromarray(reference_image.astype("uint8")).convert("RGBA") + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "swimming underwater", + "generator": generator, + "reference_image": reference_image, + "source_subject_category": "dog", + "target_subject_category": "dog", + "height": 32, + "width": 32, + "guidance_scale": 7.5, + "num_inference_steps": 2, + "output_type": "np", + } + return inputs + + def test_blipdiffusion(self): + device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + image = pipe(**self.get_dummy_inputs(device))[0] + image_slice = image[0, -3:, -3:, 0] + + assert image.shape == (1, 16, 16, 4) + + expected_slice = np.array([0.7096, 0.5900, 0.6703, 0.4032, 0.7766, 0.3629, 0.5447, 0.4149, 0.8172]) + + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + ), f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}" diff --git a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py new file mode 100644 index 0000000000..f15da0a676 --- /dev/null +++ b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py @@ -0,0 +1,216 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import CLIPTokenizer +from transformers.models.blip_2.configuration_blip_2 import Blip2Config +from transformers.models.clip.configuration_clip import CLIPTextConfig + +from diffusers import ( + AutoencoderKL, + BlipDiffusionControlNetPipeline, + ControlNetModel, + PNDMScheduler, + UNet2DConditionModel, +) +from diffusers.utils.testing_utils import enable_full_determinism +from src.diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor +from src.diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel +from src.diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class BlipDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = BlipDiffusionControlNetPipeline + params = [ + "prompt", + "reference_image", + "source_subject_category", + "target_subject_category", + "condtioning_image", + ] + batch_params = [ + "prompt", + "reference_image", + "source_subject_category", + "target_subject_category", + "condtioning_image", + ] + required_optional_params = [ + "generator", + "height", + "width", + "latents", + "guidance_scale", + "num_inference_steps", + "neg_prompt", + "guidance_scale", + "prompt_strength", + "prompt_reps", + ] + + def get_dummy_components(self): + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + vocab_size=1000, + hidden_size=16, + intermediate_size=16, + projection_dim=16, + num_hidden_layers=1, + num_attention_heads=1, + max_position_embeddings=77, + ) + text_encoder = ContextCLIPTextModel(text_encoder_config) + + vae = AutoencoderKL( + in_channels=4, + out_channels=4, + down_block_types=("DownEncoderBlock2D",), + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(32,), + layers_per_block=1, + act_fn="silu", + latent_channels=4, + norm_num_groups=16, + sample_size=16, + ) + + blip_vision_config = { + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 1, + "num_attention_heads": 1, + "image_size": 224, + "patch_size": 14, + "hidden_act": "quick_gelu", + } + + blip_qformer_config = { + "vocab_size": 1000, + "hidden_size": 16, + "num_hidden_layers": 1, + "num_attention_heads": 1, + "intermediate_size": 16, + "max_position_embeddings": 512, + "cross_attention_frequency": 1, + "encoder_hidden_size": 16, + } + qformer_config = Blip2Config( + vision_config=blip_vision_config, + qformer_config=blip_qformer_config, + num_query_tokens=16, + tokenizer="hf-internal-testing/tiny-random-bert", + ) + qformer = Blip2QFormerModel(qformer_config) + + unet = UNet2DConditionModel( + block_out_channels=(4, 16), + layers_per_block=1, + norm_num_groups=4, + sample_size=16, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=16, + ) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + scheduler = PNDMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + set_alpha_to_one=False, + skip_prk_steps=True, + ) + controlnet = ControlNetModel( + block_out_channels=(4, 16), + layers_per_block=1, + in_channels=4, + norm_num_groups=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=16, + conditioning_embedding_out_channels=(8, 16), + ) + + vae.eval() + qformer.eval() + text_encoder.eval() + + image_processor = BlipImageProcessor() + + components = { + "text_encoder": text_encoder, + "vae": vae, + "qformer": qformer, + "unet": unet, + "tokenizer": tokenizer, + "scheduler": scheduler, + "controlnet": controlnet, + "image_processor": image_processor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + np.random.seed(seed) + reference_image = np.random.rand(32, 32, 3) * 255 + reference_image = Image.fromarray(reference_image.astype("uint8")).convert("RGBA") + cond_image = np.random.rand(32, 32, 3) * 255 + cond_image = Image.fromarray(cond_image.astype("uint8")).convert("RGBA") + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "swimming underwater", + "generator": generator, + "reference_image": reference_image, + "condtioning_image": cond_image, + "source_subject_category": "dog", + "target_subject_category": "dog", + "height": 32, + "width": 32, + "guidance_scale": 7.5, + "num_inference_steps": 2, + "output_type": "np", + } + return inputs + + def test_blipdiffusion_controlnet(self): + device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + image = pipe(**self.get_dummy_inputs(device))[0] + image_slice = image[0, -3:, -3:, 0] + + assert image.shape == (1, 16, 16, 4) + expected_slice = np.array([0.7953, 0.7136, 0.6597, 0.4779, 0.7389, 0.4111, 0.5826, 0.4150, 0.8422]) + + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py index e71f103005..02bb354a97 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py @@ -20,17 +20,20 @@ import numpy as np import torch from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +import diffusers from diffusers import ( AutoencoderKL, EulerDiscreteScheduler, + MultiAdapter, StableDiffusionXLAdapterPipeline, T2IAdapter, UNet2DConditionModel, ) -from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor +from diffusers.utils import logging +from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference enable_full_determinism() @@ -41,7 +44,7 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - def get_dummy_components(self): + def get_dummy_components(self, adapter_type="full_adapter_xl"): torch.manual_seed(0) unet = UNet2DConditionModel( block_out_channels=(32, 64), @@ -97,13 +100,38 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - adapter = T2IAdapter( - in_channels=3, - channels=[32, 64], - num_res_blocks=2, - downscale_factor=4, - adapter_type="full_adapter_xl", - ) + if adapter_type == "full_adapter_xl": + adapter = T2IAdapter( + in_channels=3, + channels=[32, 64], + num_res_blocks=2, + downscale_factor=4, + adapter_type=adapter_type, + ) + elif adapter_type == "multi_adapter": + adapter = MultiAdapter( + [ + T2IAdapter( + in_channels=3, + channels=[32, 64], + num_res_blocks=2, + downscale_factor=4, + adapter_type="full_adapter_xl", + ), + T2IAdapter( + in_channels=3, + channels=[32, 64], + num_res_blocks=2, + downscale_factor=4, + adapter_type="full_adapter_xl", + ), + ] + ) + else: + raise ValueError( + f"Unknown adapter type: {adapter_type}, must be one of 'full_adapter_xl', or 'multi_adapter''" + ) + components = { "adapter": adapter, "unet": unet, @@ -118,8 +146,12 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te } return components - def get_dummy_inputs(self, device, seed=0): - image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) + def get_dummy_inputs(self, device, seed=0, num_images=1): + if num_images == 1: + image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) + else: + image = [floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) for _ in range(num_images)] + if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: @@ -150,3 +182,202 @@ class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.Te [0.5752919, 0.6022097, 0.4728038, 0.49861962, 0.57084894, 0.4644975, 0.5193715, 0.5133664, 0.4729858] ) assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3 + + +class StableDiffusionXLMultiAdapterPipelineFastTests( + StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase +): + def get_dummy_components(self): + return super().get_dummy_components("multi_adapter") + + def get_dummy_inputs(self, device, seed=0): + inputs = super().get_dummy_inputs(device, seed, num_images=2) + inputs["adapter_conditioning_scale"] = [0.5, 0.5] + return inputs + + def test_stable_diffusion_adapter_default_case(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionXLAdapterPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array( + [0.5813032, 0.60995954, 0.47563356, 0.5056669, 0.57199144, 0.4631841, 0.5176794, 0.51252556, 0.47183886] + ) + assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3 + + def test_inference_batch_consistent( + self, batch_sizes=[2, 4, 13], additional_params_copy_to_batched_inputs=["num_inference_steps"] + ): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + + logger = logging.get_logger(pipe.__module__) + logger.setLevel(level=diffusers.logging.FATAL) + + # batchify inputs + for batch_size in batch_sizes: + batched_inputs = {} + for name, value in inputs.items(): + if name in self.batch_params: + # prompt is string + if name == "prompt": + len_prompt = len(value) + # make unequal batch sizes + batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] + + # make last batch super long + batched_inputs[name][-1] = 100 * "very long" + elif name == "image": + batched_images = [] + + for image in value: + batched_images.append(batch_size * [image]) + + batched_inputs[name] = batched_images + else: + batched_inputs[name] = batch_size * [value] + + elif name == "batch_size": + batched_inputs[name] = batch_size + else: + batched_inputs[name] = value + + for arg in additional_params_copy_to_batched_inputs: + batched_inputs[arg] = inputs[arg] + + batched_inputs["output_type"] = "np" + + output = pipe(**batched_inputs) + + assert len(output[0]) == batch_size + + batched_inputs["output_type"] = "np" + + output = pipe(**batched_inputs)[0] + + assert output.shape[0] == batch_size + + logger.setLevel(level=diffusers.logging.WARNING) + + def test_num_images_per_prompt(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + batch_sizes = [1, 2] + num_images_per_prompts = [1, 2] + + for batch_size in batch_sizes: + for num_images_per_prompt in num_images_per_prompts: + inputs = self.get_dummy_inputs(torch_device) + + for key in inputs.keys(): + if key in self.batch_params: + if key == "image": + batched_images = [] + + for image in inputs[key]: + batched_images.append(batch_size * [image]) + + inputs[key] = batched_images + else: + inputs[key] = batch_size * [inputs[key]] + + images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0] + + assert images.shape[0] == batch_size * num_images_per_prompt + + def test_inference_batch_single_identical( + self, + batch_size=3, + test_max_difference=None, + test_mean_pixel_difference=None, + relax_max_difference=False, + expected_max_diff=2e-3, + additional_params_copy_to_batched_inputs=["num_inference_steps"], + ): + if test_max_difference is None: + # TODO(Pedro) - not sure why, but not at all reproducible at the moment it seems + # make sure that batched and non-batched is identical + test_max_difference = torch_device != "mps" + + if test_mean_pixel_difference is None: + # TODO same as above + test_mean_pixel_difference = torch_device != "mps" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + + logger = logging.get_logger(pipe.__module__) + logger.setLevel(level=diffusers.logging.FATAL) + + # batchify inputs + batched_inputs = {} + batch_size = batch_size + for name, value in inputs.items(): + if name in self.batch_params: + # prompt is string + if name == "prompt": + len_prompt = len(value) + # make unequal batch sizes + batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] + + # make last batch super long + batched_inputs[name][-1] = 100 * "very long" + elif name == "image": + batched_images = [] + + for image in value: + batched_images.append(batch_size * [image]) + + batched_inputs[name] = batched_images + else: + batched_inputs[name] = batch_size * [value] + elif name == "batch_size": + batched_inputs[name] = batch_size + elif name == "generator": + batched_inputs[name] = [self.get_generator(i) for i in range(batch_size)] + else: + batched_inputs[name] = value + + for arg in additional_params_copy_to_batched_inputs: + batched_inputs[arg] = inputs[arg] + + output_batch = pipe(**batched_inputs) + assert output_batch[0].shape[0] == batch_size + + inputs["generator"] = self.get_generator(0) + + output = pipe(**inputs) + + logger.setLevel(level=diffusers.logging.WARNING) + if test_max_difference: + if relax_max_difference: + # Taking the median of the largest differences + # is resilient to outliers + diff = np.abs(output_batch[0][0] - output[0][0]) + diff = diff.flatten() + diff.sort() + max_diff = np.median(diff[-5:]) + else: + max_diff = np.abs(output_batch[0][0] - output[0][0]).max() + assert max_diff < expected_max_diff + + if test_mean_pixel_difference: + assert_mean_pixel_difference(output_batch[0][0], output[0][0])