mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add padding_mask_crop to all inpaint pipelines (#6360)
* add padding_mask_crop --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -683,9 +683,11 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
self,
|
||||
prompt,
|
||||
image,
|
||||
mask_image,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
output_type,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
@@ -693,6 +695,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
padding_mask_crop=None,
|
||||
):
|
||||
if height is not None and height % 8 != 0 or width is not None and width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
@@ -736,6 +739,19 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
|
||||
# `prompt` needs more sophisticated handling when there are multiple
|
||||
# conditionings.
|
||||
if isinstance(self.controlnet, MultiControlNetModel):
|
||||
@@ -862,7 +878,6 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
|
||||
def prepare_control_image(
|
||||
self,
|
||||
image,
|
||||
@@ -872,10 +887,14 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
crops_coords,
|
||||
resize_mode,
|
||||
do_classifier_free_guidance=False,
|
||||
guess_mode=False,
|
||||
):
|
||||
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
image = self.control_image_processor.preprocess(
|
||||
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
|
||||
).to(dtype=torch.float32)
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
@@ -1074,6 +1093,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
control_image: PipelineImageInput = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
padding_mask_crop: Optional[int] = None,
|
||||
strength: float = 1.0,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
@@ -1130,6 +1150,12 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The width in pixels of the generated image.
|
||||
padding_mask_crop (`int`, *optional*, defaults to `None`):
|
||||
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
|
||||
`padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
|
||||
contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
|
||||
the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
|
||||
and contain information inreleant for inpainging, such as background.
|
||||
strength (`float`, *optional*, defaults to 1.0):
|
||||
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
||||
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
||||
@@ -1240,9 +1266,11 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
control_image,
|
||||
mask_image,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
output_type,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
@@ -1250,6 +1278,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
padding_mask_crop,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
@@ -1264,6 +1293,14 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
height, width = self.image_processor.get_default_height_width(image, height, width)
|
||||
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
|
||||
resize_mode = "fill"
|
||||
else:
|
||||
crops_coords = None
|
||||
resize_mode = "default"
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
||||
@@ -1315,6 +1352,8 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
crops_coords=crops_coords,
|
||||
resize_mode=resize_mode,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
@@ -1330,6 +1369,8 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
crops_coords=crops_coords,
|
||||
resize_mode=resize_mode,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
@@ -1341,10 +1382,15 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
assert False
|
||||
|
||||
# 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width
|
||||
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
original_image = image
|
||||
init_image = self.image_processor.preprocess(
|
||||
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
|
||||
)
|
||||
init_image = init_image.to(dtype=torch.float32)
|
||||
|
||||
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
mask = self.mask_processor.preprocess(
|
||||
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
|
||||
)
|
||||
|
||||
masked_image = init_image * (mask < 0.5)
|
||||
_, _, height, width = init_image.shape
|
||||
@@ -1534,6 +1580,9 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
|
||||
@@ -557,9 +557,11 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
prompt,
|
||||
prompt_2,
|
||||
image,
|
||||
mask_image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
callback_steps,
|
||||
output_type,
|
||||
negative_prompt=None,
|
||||
negative_prompt_2=None,
|
||||
prompt_embeds=None,
|
||||
@@ -570,6 +572,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=1.0,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
padding_mask_crop=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
@@ -632,6 +635,19 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
@@ -745,10 +761,14 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
crops_coords,
|
||||
resize_mode,
|
||||
do_classifier_free_guidance=False,
|
||||
guess_mode=False,
|
||||
):
|
||||
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
image = self.control_image_processor.preprocess(
|
||||
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
|
||||
).to(dtype=torch.float32)
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
@@ -1066,6 +1086,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
padding_mask_crop: Optional[int] = None,
|
||||
strength: float = 0.9999,
|
||||
num_inference_steps: int = 50,
|
||||
denoising_start: Optional[float] = None,
|
||||
@@ -1121,6 +1142,12 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
padding_mask_crop (`int`, *optional*, defaults to `None`):
|
||||
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
|
||||
`padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
|
||||
contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
|
||||
the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
|
||||
and contain information inreleant for inpainging, such as background.
|
||||
strength (`float`, *optional*, defaults to 0.9999):
|
||||
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
|
||||
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
|
||||
@@ -1290,9 +1317,11 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
prompt,
|
||||
prompt_2,
|
||||
control_image,
|
||||
mask_image,
|
||||
strength,
|
||||
num_inference_steps,
|
||||
callback_steps,
|
||||
output_type,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
@@ -1303,6 +1332,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
padding_mask_crop,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
@@ -1370,7 +1400,18 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
|
||||
# 5. Preprocess mask and image - resizes image and mask w.r.t height and width
|
||||
# 5.1 Prepare init image
|
||||
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
if padding_mask_crop is not None:
|
||||
height, width = self.image_processor.get_default_height_width(image, height, width)
|
||||
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
|
||||
resize_mode = "fill"
|
||||
else:
|
||||
crops_coords = None
|
||||
resize_mode = "default"
|
||||
|
||||
original_image = image
|
||||
init_image = self.image_processor.preprocess(
|
||||
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
|
||||
)
|
||||
init_image = init_image.to(dtype=torch.float32)
|
||||
|
||||
# 5.2 Prepare control images
|
||||
@@ -1383,6 +1424,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
crops_coords=crops_coords,
|
||||
resize_mode=resize_mode,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
@@ -1398,6 +1441,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
crops_coords=crops_coords,
|
||||
resize_mode=resize_mode,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
@@ -1409,7 +1454,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
raise ValueError(f"{controlnet.__class__} is not supported.")
|
||||
|
||||
# 5.3 Prepare mask
|
||||
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
mask = self.mask_processor.preprocess(
|
||||
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
|
||||
)
|
||||
|
||||
masked_image = init_image * (mask < 0.5)
|
||||
_, _, height, width = init_image.shape
|
||||
@@ -1684,6 +1731,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
|
||||
@@ -642,6 +642,7 @@ class StableDiffusionInpaintPipeline(
|
||||
width,
|
||||
strength,
|
||||
callback_steps,
|
||||
output_type,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
@@ -693,11 +694,6 @@ class StableDiffusionInpaintPipeline(
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if padding_mask_crop is not None:
|
||||
if self.unet.config.in_channels != 4:
|
||||
raise ValueError(
|
||||
f"The UNet should have 4 input channels for inpainting mask crop, but has"
|
||||
f" {self.unet.config.in_channels} input channels."
|
||||
)
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
@@ -707,6 +703,8 @@ class StableDiffusionInpaintPipeline(
|
||||
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
@@ -1166,6 +1164,7 @@ class StableDiffusionInpaintPipeline(
|
||||
width,
|
||||
strength,
|
||||
callback_steps,
|
||||
output_type,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
|
||||
@@ -744,15 +744,19 @@ class StableDiffusionXLInpaintPipeline(
|
||||
self,
|
||||
prompt,
|
||||
prompt_2,
|
||||
image,
|
||||
mask_image,
|
||||
height,
|
||||
width,
|
||||
strength,
|
||||
callback_steps,
|
||||
output_type,
|
||||
negative_prompt=None,
|
||||
negative_prompt_2=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
padding_mask_crop=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
@@ -810,6 +814,18 @@ class StableDiffusionXLInpaintPipeline(
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if padding_mask_crop is not None:
|
||||
if not isinstance(image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
|
||||
)
|
||||
if not isinstance(mask_image, PIL.Image.Image):
|
||||
raise ValueError(
|
||||
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
|
||||
f" {type(mask_image)}."
|
||||
)
|
||||
if output_type != "pil":
|
||||
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
@@ -1225,6 +1241,7 @@ class StableDiffusionXLInpaintPipeline(
|
||||
masked_image_latents: torch.FloatTensor = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
padding_mask_crop: Optional[int] = None,
|
||||
strength: float = 0.9999,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
@@ -1287,6 +1304,12 @@ class StableDiffusionXLInpaintPipeline(
|
||||
Anything below 512 pixels won't work well for
|
||||
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
and checkpoints that are not specifically fine-tuned on low resolutions.
|
||||
padding_mask_crop (`int`, *optional*, defaults to `None`):
|
||||
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
|
||||
`padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
|
||||
contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
|
||||
the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
|
||||
and contain information inreleant for inpainging, such as background.
|
||||
strength (`float`, *optional*, defaults to 0.9999):
|
||||
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
|
||||
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
|
||||
@@ -1449,15 +1472,19 @@ class StableDiffusionXLInpaintPipeline(
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
image,
|
||||
mask_image,
|
||||
height,
|
||||
width,
|
||||
strength,
|
||||
callback_steps,
|
||||
output_type,
|
||||
negative_prompt,
|
||||
negative_prompt_2,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
padding_mask_crop,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
@@ -1527,10 +1554,22 @@ class StableDiffusionXLInpaintPipeline(
|
||||
is_strength_max = strength == 1.0
|
||||
|
||||
# 5. Preprocess mask and image
|
||||
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
if padding_mask_crop is not None:
|
||||
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
|
||||
resize_mode = "fill"
|
||||
else:
|
||||
crops_coords = None
|
||||
resize_mode = "default"
|
||||
|
||||
original_image = image
|
||||
init_image = self.image_processor.preprocess(
|
||||
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
|
||||
)
|
||||
init_image = init_image.to(dtype=torch.float32)
|
||||
|
||||
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
mask = self.mask_processor.preprocess(
|
||||
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
|
||||
)
|
||||
|
||||
if masked_image_latents is not None:
|
||||
masked_image = masked_image_latents
|
||||
@@ -1791,6 +1830,9 @@ class StableDiffusionXLInpaintPipeline(
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user