1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

img2img.multiple.controlnets.pipeline (#2833)

* img2img.multiple.controlnets.pipeline

* remove comments

---------

Co-authored-by: mishka <gartsocial@gmail.com>
This commit is contained in:
Michael Gartsbein
2023-03-30 20:00:12 +03:00
committed by GitHub
parent 49609768b4
commit 1d033a95f6

View File

@@ -1,7 +1,7 @@
# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL.Image
@@ -10,6 +10,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
PIL_INTERPOLATION,
@@ -86,7 +87,14 @@ def prepare_image(image):
def prepare_controlnet_conditioning_image(
controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype
controlnet_conditioning_image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance,
):
if not isinstance(controlnet_conditioning_image, torch.Tensor):
if isinstance(controlnet_conditioning_image, PIL.Image.Image):
@@ -116,6 +124,9 @@ def prepare_controlnet_conditioning_image(
controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)
if do_classifier_free_guidance:
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
return controlnet_conditioning_image
@@ -132,7 +143,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
controlnet: ControlNetModel,
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
@@ -156,6 +167,9 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
if isinstance(controlnet, (list, tuple)):
controlnet = MultiControlNetModel(controlnet)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -424,6 +438,42 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image)
image_is_tensor = isinstance(image, torch.Tensor)
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
raise TypeError(
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
)
if image_is_pil:
image_batch_size = 1
elif image_is_tensor:
image_batch_size = image.shape[0]
elif image_is_pil_list:
image_batch_size = len(image)
elif image_is_tensor_list:
image_batch_size = len(image)
else:
raise ValueError("controlnet condition image is not valid")
if prompt is not None and isinstance(prompt, str):
prompt_batch_size = 1
elif prompt is not None and isinstance(prompt, list):
prompt_batch_size = len(prompt)
elif prompt_embeds is not None:
prompt_batch_size = prompt_embeds.shape[0]
else:
raise ValueError("prompt or prompt_embeds are not valid")
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
raise ValueError(
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
)
def check_inputs(
self,
prompt,
@@ -438,6 +488,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
strength=None,
controlnet_guidance_start=None,
controlnet_guidance_end=None,
controlnet_conditioning_scale=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -476,58 +527,51 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
f" {negative_prompt_embeds.shape}."
)
controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)
controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)
controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(
controlnet_conditioning_image[0], PIL.Image.Image
)
controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(
controlnet_conditioning_image[0], torch.Tensor
)
# check controlnet condition image
if (
not controlnet_cond_image_is_pil
and not controlnet_cond_image_is_tensor
and not controlnet_cond_image_is_pil_list
and not controlnet_cond_image_is_tensor_list
):
raise TypeError(
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
)
if isinstance(self.controlnet, ControlNetModel):
self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds)
elif isinstance(self.controlnet, MultiControlNetModel):
if not isinstance(controlnet_conditioning_image, list):
raise TypeError("For multiple controlnets: `image` must be type `list`")
if controlnet_cond_image_is_pil:
controlnet_cond_image_batch_size = 1
elif controlnet_cond_image_is_tensor:
controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]
elif controlnet_cond_image_is_pil_list:
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
elif controlnet_cond_image_is_tensor_list:
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
if len(controlnet_conditioning_image) != len(self.controlnet.nets):
raise ValueError(
"For multiple controlnets: `image` must have the same length as the number of controlnets."
)
if prompt is not None and isinstance(prompt, str):
prompt_batch_size = 1
elif prompt is not None and isinstance(prompt, list):
prompt_batch_size = len(prompt)
elif prompt_embeds is not None:
prompt_batch_size = prompt_embeds.shape[0]
for image_ in controlnet_conditioning_image:
self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds)
else:
assert False
if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:
raise ValueError(
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"
)
# Check `controlnet_conditioning_scale`
if isinstance(self.controlnet, ControlNetModel):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif isinstance(self.controlnet, MultiControlNetModel):
if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
self.controlnet.nets
):
raise ValueError(
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
" the same length as the number of controlnets"
)
else:
assert False
if isinstance(image, torch.Tensor):
if image.ndim != 3 and image.ndim != 4:
raise ValueError("`image` must have 3 or 4 dimensions")
# if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:
# raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")
if image.ndim == 3:
image_batch_size = 1
image_channels, image_height, image_width = image.shape
elif image.ndim == 4:
image_batch_size, image_channels, image_height, image_width = image.shape
else:
assert False
if image_channels != 3:
raise ValueError("`image` must have 3 channels")
@@ -659,7 +703,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: float = 1.0,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
controlnet_guidance_start: float = 0.0,
controlnet_guidance_end: float = 1.0,
):
@@ -759,7 +803,6 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
self.check_inputs(
prompt,
image,
# mask_image,
controlnet_conditioning_image,
height,
width,
@@ -770,6 +813,7 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
strength,
controlnet_guidance_start,
controlnet_guidance_end,
controlnet_conditioning_scale,
)
# 2. Define call parameters
@@ -786,6 +830,9 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)
# 3. Encode input prompt
prompt_embeds = self._encode_prompt(
prompt,
@@ -797,22 +844,41 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
negative_prompt_embeds=negative_prompt_embeds,
)
# 4. Prepare mask, image, and controlnet_conditioning_image
# 4. Prepare image, and controlnet_conditioning_image
image = prepare_image(image)
# mask_image = prepare_mask_image(mask_image)
# condition image(s)
if isinstance(self.controlnet, ControlNetModel):
controlnet_conditioning_image = prepare_controlnet_conditioning_image(
controlnet_conditioning_image=controlnet_conditioning_image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
elif isinstance(self.controlnet, MultiControlNetModel):
controlnet_conditioning_images = []
controlnet_conditioning_image = prepare_controlnet_conditioning_image(
controlnet_conditioning_image,
width,
height,
batch_size * num_images_per_prompt,
num_images_per_prompt,
device,
self.controlnet.dtype,
)
for image_ in controlnet_conditioning_image:
image_ = prepare_controlnet_conditioning_image(
controlnet_conditioning_image=image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
# masked_image = image * (mask_image < 0.5)
controlnet_conditioning_images.append(image_)
controlnet_conditioning_image = controlnet_conditioning_images
else:
assert False
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -830,9 +896,6 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
generator,
)
if do_classifier_free_guidance:
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
@@ -862,15 +925,10 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
t,
encoder_hidden_states=prompt_embeds,
controlnet_cond=controlnet_conditioning_image,
conditioning_scale=controlnet_conditioning_scale,
return_dict=False,
)
down_block_res_samples = [
down_block_res_sample * controlnet_conditioning_scale
for down_block_res_sample in down_block_res_samples
]
mid_block_res_sample *= controlnet_conditioning_scale
# predict the noise residual
noise_pred = self.unet(
latent_model_input,