mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
IP-Adapter for StableDiffusion3InpaintPipeline (#10581)
* Added support for IP-Adapter * Added joint_attention_kwargs property
This commit is contained in:
@@ -13,19 +13,21 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
BaseImageProcessor,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
PreTrainedModel,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import SD3Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
@@ -162,7 +164,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
|
||||
class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
|
||||
r"""
|
||||
Args:
|
||||
transformer ([`SD3Transformer2DModel`]):
|
||||
@@ -194,10 +196,14 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
tokenizer_3 (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
image_encoder (`PreTrainedModel`, *optional*):
|
||||
Pre-trained Vision Model for IP Adapter.
|
||||
feature_extractor (`BaseImageProcessor`, *optional*):
|
||||
Image processor for IP Adapter.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
|
||||
_optional_components = []
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
|
||||
_optional_components = ["image_encoder", "feature_extractor"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
@@ -211,6 +217,8 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
text_encoder_3: T5EncoderModel,
|
||||
tokenizer_3: T5TokenizerFast,
|
||||
image_encoder: PreTrainedModel = None,
|
||||
feature_extractor: BaseImageProcessor = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -224,6 +232,8 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
tokenizer_3=tokenizer_3,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
image_encoder=image_encoder,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
|
||||
@@ -818,6 +828,10 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
@@ -826,6 +840,84 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
|
||||
def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
|
||||
"""Encodes the given image into a feature representation using a pre-trained image encoder.
|
||||
|
||||
Args:
|
||||
image (`PipelineImageInput`):
|
||||
Input image to be encoded.
|
||||
device: (`torch.device`):
|
||||
Torch device.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The encoded image feature representation.
|
||||
"""
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
||||
|
||||
image = image.to(device=device, dtype=self.dtype)
|
||||
|
||||
return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(
|
||||
self,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""Prepares image embeddings for use in the IP-Adapter.
|
||||
|
||||
Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
|
||||
|
||||
Args:
|
||||
ip_adapter_image (`PipelineImageInput`, *optional*):
|
||||
The input image to extract features from for IP-Adapter.
|
||||
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
|
||||
Precomputed image embeddings.
|
||||
device: (`torch.device`, *optional*):
|
||||
Torch device.
|
||||
num_images_per_prompt (`int`, defaults to 1):
|
||||
Number of images that should be generated per prompt.
|
||||
do_classifier_free_guidance (`bool`, defaults to True):
|
||||
Whether to use classifier free guidance or not.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
if ip_adapter_image_embeds is not None:
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
|
||||
else:
|
||||
single_image_embeds = ip_adapter_image_embeds
|
||||
elif ip_adapter_image is not None:
|
||||
single_image_embeds = self.encode_image(ip_adapter_image, device)
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds = torch.zeros_like(single_image_embeds)
|
||||
else:
|
||||
raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
|
||||
|
||||
image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
|
||||
|
||||
return image_embeds.to(device=device)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self, *args, **kwargs):
|
||||
if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
|
||||
logger.warning(
|
||||
"`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
|
||||
"`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
|
||||
"`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
|
||||
)
|
||||
|
||||
super().enable_sequential_cpu_offload(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -853,8 +945,11 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
@@ -890,9 +985,9 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
|
||||
`Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
|
||||
latents tensor will ge generated by `mask_image`.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
padding_mask_crop (`int`, *optional*, defaults to `None`):
|
||||
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
|
||||
@@ -953,12 +1048,22 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
ip_adapter_image (`PipelineImageInput`, *optional*):
|
||||
Optional image input to work with IP Adapters.
|
||||
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
|
||||
emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
|
||||
`True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
|
||||
a plain tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
@@ -1006,6 +1111,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
@@ -1160,7 +1266,22 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
f"The transformer {self.transformer.__class__} should have 16 input channels or 33 input channels, not {self.transformer.config.in_channels}."
|
||||
)
|
||||
|
||||
# 7. Denoising loop
|
||||
# 7. Prepare image embeddings
|
||||
if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
|
||||
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
device,
|
||||
batch_size * num_images_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
if self.joint_attention_kwargs is None:
|
||||
self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
|
||||
else:
|
||||
self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -1181,6 +1302,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
|
||||
@@ -106,6 +106,8 @@ class StableDiffusion3InpaintPipelineFastTests(PipelineLatentTesterMixin, unitte
|
||||
"tokenizer_3": tokenizer_3,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"image_encoder": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
|
||||
Reference in New Issue
Block a user