mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Support of ip-adapter to the StableDiffusionControlNetInpaintPipeline (#5887)
* Change pipeline_controlnet_inpaint.py to add ip-adapter support. Changes are similar to those in pipeline_controlnet * Change tests for the StableDiffusionControlNetInpaintPipeline by adding image_encoder: None * Update src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py Co-authored-by: YiYi Xu <yixu310@gmail.com> --------- Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -21,10 +21,10 @@ import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
@@ -241,7 +241,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False
|
||||
|
||||
|
||||
class StableDiffusionControlNetInpaintPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for image inpainting using Stable Diffusion with ControlNet guidance.
|
||||
@@ -251,6 +251,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -288,7 +289,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
@@ -302,6 +303,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -334,6 +336,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
@@ -593,6 +596,20 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
uncond_image_embeds = torch.zeros_like(image_embeds)
|
||||
return image_embeds, uncond_image_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
if self.safety_checker is None:
|
||||
@@ -1053,6 +1070,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -1131,6 +1149,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
@@ -1264,6 +1283,11 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. Prepare image
|
||||
if isinstance(controlnet, ControlNetModel):
|
||||
control_image = self.prepare_control_image(
|
||||
@@ -1299,7 +1323,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
else:
|
||||
assert False
|
||||
|
||||
# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
|
||||
# 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width
|
||||
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
init_image = init_image.to(dtype=torch.float32)
|
||||
|
||||
@@ -1360,7 +1384,10 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7.1 Create tensor stating which controlnets to keep
|
||||
# 7.1 Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
|
||||
# 7.2 Create tensor stating which controlnets to keep
|
||||
controlnet_keep = []
|
||||
for i in range(len(timesteps)):
|
||||
keeps = [
|
||||
@@ -1423,6 +1450,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
cross_attention_kwargs=self.cross_attention_kwargs,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
|
||||
@@ -132,6 +132,7 @@ class ControlNetInpaintPipelineFastTests(
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
"image_encoder": None,
|
||||
}
|
||||
return components
|
||||
|
||||
@@ -248,6 +249,7 @@ class ControlNetSimpleInpaintPipelineFastTests(ControlNetInpaintPipelineFastTest
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
"image_encoder": None,
|
||||
}
|
||||
return components
|
||||
|
||||
@@ -342,6 +344,7 @@ class MultiControlNetInpaintPipelineFastTests(
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
"image_encoder": None,
|
||||
}
|
||||
return components
|
||||
|
||||
|
||||
Reference in New Issue
Block a user