1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

IP-Adapter support for StableDiffusionXLControlNetInpaintPipeline (#6941)

* add ip-adapter support

* support ip image embeds

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Vinh H. Pham
2024-02-20 01:20:24 +07:00
committed by GitHub
parent 15b125bb0e
commit e0f33dfca4

View File

@@ -19,11 +19,17 @@ import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
@@ -195,6 +201,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
feature_extractor: Optional[CLIPImageProcessor] = None,
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
):
super().__init__()
@@ -210,6 +218,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
@@ -497,6 +507,66 @@ class StableDiffusionXLControlNetInpaintPipeline(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
):
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
else:
image_embeds = ip_adapter_image_embeds
return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -566,6 +636,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
negative_prompt_2=None,
prompt_embeds=None,
negative_prompt_embeds=None,
ip_adapter_image=None,
ip_adapter_image_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
controlnet_conditioning_scale=1.0,
@@ -752,6 +824,11 @@ class StableDiffusionXLControlNetInpaintPipeline(
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)
def prepare_control_image(
self,
image,
@@ -1100,6 +1177,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
@@ -1194,6 +1273,10 @@ class StableDiffusionXLControlNetInpaintPipeline(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
@@ -1326,6 +1409,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
controlnet_conditioning_scale,
@@ -1378,6 +1463,12 @@ class StableDiffusionXLControlNetInpaintPipeline(
clip_skip=self.clip_skip,
)
# 3.1 Encode ip_adapter_image
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt
)
# 4. set timesteps
def denoising_value_valid(dnv):
return isinstance(denoising_end, float) and 0 < dnv < 1
@@ -1649,6 +1740,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
if ip_adapter_image is not None:
added_cond_kwargs["image_embeds"] = image_embeds
if num_channels_unet == 9:
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)