mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add multiple conditions to StableDiffusionControlNetInpaintPipeline (#3125)
* try multi controlnet inpaint * multi controlnet inpaint * multi controlnet inpaint
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
@@ -11,6 +11,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
PIL_INTERPOLATION,
|
||||
@@ -184,7 +185,14 @@ def prepare_mask_image(mask_image):
|
||||
|
||||
|
||||
def prepare_controlnet_conditioning_image(
|
||||
controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype
|
||||
controlnet_conditioning_image,
|
||||
width,
|
||||
height,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
do_classifier_free_guidance,
|
||||
):
|
||||
if not isinstance(controlnet_conditioning_image, torch.Tensor):
|
||||
if isinstance(controlnet_conditioning_image, PIL.Image.Image):
|
||||
@@ -214,6 +222,9 @@ def prepare_controlnet_conditioning_image(
|
||||
|
||||
controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
|
||||
|
||||
return controlnet_conditioning_image
|
||||
|
||||
|
||||
@@ -230,7 +241,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet: ControlNetModel,
|
||||
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
@@ -254,6 +265,9 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
controlnet = MultiControlNetModel(controlnet)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -264,6 +278,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
@@ -522,6 +537,42 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds):
|
||||
image_is_pil = isinstance(image, PIL.Image.Image)
|
||||
image_is_tensor = isinstance(image, torch.Tensor)
|
||||
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
|
||||
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
|
||||
|
||||
if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
|
||||
raise TypeError(
|
||||
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
|
||||
)
|
||||
|
||||
if image_is_pil:
|
||||
image_batch_size = 1
|
||||
elif image_is_tensor:
|
||||
image_batch_size = image.shape[0]
|
||||
elif image_is_pil_list:
|
||||
image_batch_size = len(image)
|
||||
elif image_is_tensor_list:
|
||||
image_batch_size = len(image)
|
||||
else:
|
||||
raise ValueError("controlnet condition image is not valid")
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
prompt_batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
prompt_batch_size = len(prompt)
|
||||
elif prompt_embeds is not None:
|
||||
prompt_batch_size = prompt_embeds.shape[0]
|
||||
else:
|
||||
raise ValueError("prompt or prompt_embeds are not valid")
|
||||
|
||||
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
|
||||
raise ValueError(
|
||||
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
|
||||
)
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
@@ -534,6 +585,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
controlnet_conditioning_scale=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
@@ -572,45 +624,35 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)
|
||||
controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)
|
||||
controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(
|
||||
controlnet_conditioning_image[0], PIL.Image.Image
|
||||
)
|
||||
controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(
|
||||
controlnet_conditioning_image[0], torch.Tensor
|
||||
)
|
||||
# check controlnet condition image
|
||||
if isinstance(self.controlnet, ControlNetModel):
|
||||
self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds)
|
||||
elif isinstance(self.controlnet, MultiControlNetModel):
|
||||
if not isinstance(controlnet_conditioning_image, list):
|
||||
raise TypeError("For multiple controlnets: `image` must be type `list`")
|
||||
if len(controlnet_conditioning_image) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
"For multiple controlnets: `image` must have the same length as the number of controlnets."
|
||||
)
|
||||
for image_ in controlnet_conditioning_image:
|
||||
self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds)
|
||||
else:
|
||||
assert False
|
||||
|
||||
if (
|
||||
not controlnet_cond_image_is_pil
|
||||
and not controlnet_cond_image_is_tensor
|
||||
and not controlnet_cond_image_is_pil_list
|
||||
and not controlnet_cond_image_is_tensor_list
|
||||
):
|
||||
raise TypeError(
|
||||
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
|
||||
)
|
||||
|
||||
if controlnet_cond_image_is_pil:
|
||||
controlnet_cond_image_batch_size = 1
|
||||
elif controlnet_cond_image_is_tensor:
|
||||
controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]
|
||||
elif controlnet_cond_image_is_pil_list:
|
||||
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
|
||||
elif controlnet_cond_image_is_tensor_list:
|
||||
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
prompt_batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
prompt_batch_size = len(prompt)
|
||||
elif prompt_embeds is not None:
|
||||
prompt_batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:
|
||||
raise ValueError(
|
||||
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"
|
||||
)
|
||||
# Check `controlnet_conditioning_scale`
|
||||
if isinstance(self.controlnet, ControlNetModel):
|
||||
if not isinstance(controlnet_conditioning_scale, float):
|
||||
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
|
||||
elif isinstance(self.controlnet, MultiControlNetModel):
|
||||
if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
|
||||
self.controlnet.nets
|
||||
):
|
||||
raise ValueError(
|
||||
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
|
||||
" the same length as the number of controlnets"
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
|
||||
if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor):
|
||||
raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor")
|
||||
@@ -630,6 +672,8 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
||||
image_channels, image_height, image_width = image.shape
|
||||
elif image.ndim == 4:
|
||||
image_batch_size, image_channels, image_height, image_width = image.shape
|
||||
else:
|
||||
assert False
|
||||
|
||||
if mask_image.ndim == 2:
|
||||
mask_image_batch_size = 1
|
||||
@@ -797,7 +841,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_conditioning_scale: float = 1.0,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -897,6 +941,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
controlnet_conditioning_scale,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
@@ -913,6 +958,9 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
||||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
@@ -929,15 +977,37 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
||||
|
||||
mask_image = prepare_mask_image(mask_image)
|
||||
|
||||
controlnet_conditioning_image = prepare_controlnet_conditioning_image(
|
||||
controlnet_conditioning_image,
|
||||
width,
|
||||
height,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
self.controlnet.dtype,
|
||||
)
|
||||
# condition image(s)
|
||||
if isinstance(self.controlnet, ControlNetModel):
|
||||
controlnet_conditioning_image = prepare_controlnet_conditioning_image(
|
||||
controlnet_conditioning_image=controlnet_conditioning_image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=self.controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
)
|
||||
elif isinstance(self.controlnet, MultiControlNetModel):
|
||||
controlnet_conditioning_images = []
|
||||
|
||||
for image_ in controlnet_conditioning_image:
|
||||
image_ = prepare_controlnet_conditioning_image(
|
||||
controlnet_conditioning_image=image_,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=self.controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
)
|
||||
controlnet_conditioning_images.append(image_)
|
||||
|
||||
controlnet_conditioning_image = controlnet_conditioning_images
|
||||
else:
|
||||
assert False
|
||||
|
||||
masked_image = image * (mask_image < 0.5)
|
||||
|
||||
@@ -979,9 +1049,6 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
||||
do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
|
||||
|
||||
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
@@ -1007,15 +1074,10 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
controlnet_cond=controlnet_conditioning_image,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
down_block_res_samples = [
|
||||
down_block_res_sample * controlnet_conditioning_scale
|
||||
for down_block_res_sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample *= controlnet_conditioning_scale
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
inpainting_latent_model_input,
|
||||
|
||||
Reference in New Issue
Block a user