diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index a18468f72c..9f3009cede 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -683,9 +683,11 @@ class StableDiffusionControlNetInpaintPipeline( self, prompt, image, + mask_image, height, width, callback_steps, + output_type, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, @@ -693,6 +695,7 @@ class StableDiffusionControlNetInpaintPipeline( control_guidance_start=0.0, control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, ): if height is not None and height % 8 != 0 or width is not None and width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -736,6 +739,19 @@ class StableDiffusionControlNetInpaintPipeline( f" {negative_prompt_embeds.shape}." ) + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + # `prompt` needs more sophisticated handling when there are multiple # conditionings. if isinstance(self.controlnet, MultiControlNetModel): @@ -862,7 +878,6 @@ class StableDiffusionControlNetInpaintPipeline( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) - # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image def prepare_control_image( self, image, @@ -872,10 +887,14 @@ class StableDiffusionControlNetInpaintPipeline( num_images_per_prompt, device, dtype, + crops_coords, + resize_mode, do_classifier_free_guidance=False, guess_mode=False, ): - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image = self.control_image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: @@ -1074,6 +1093,7 @@ class StableDiffusionControlNetInpaintPipeline( control_image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, + padding_mask_crop: Optional[int] = None, strength: float = 1.0, num_inference_steps: int = 50, guidance_scale: float = 7.5, @@ -1130,6 +1150,12 @@ class StableDiffusionControlNetInpaintPipeline( The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If + `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and + contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on + the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large + and contain information inreleant for inpainging, such as background. strength (`float`, *optional*, defaults to 1.0): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends @@ -1240,9 +1266,11 @@ class StableDiffusionControlNetInpaintPipeline( self.check_inputs( prompt, control_image, + mask_image, height, width, callback_steps, + output_type, negative_prompt, prompt_embeds, negative_prompt_embeds, @@ -1250,6 +1278,7 @@ class StableDiffusionControlNetInpaintPipeline( control_guidance_start, control_guidance_end, callback_on_step_end_tensor_inputs, + padding_mask_crop, ) self._guidance_scale = guidance_scale @@ -1264,6 +1293,14 @@ class StableDiffusionControlNetInpaintPipeline( else: batch_size = prompt_embeds.shape[0] + if padding_mask_crop is not None: + height, width = self.image_processor.get_default_height_width(image, height, width) + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + device = self._execution_device if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): @@ -1315,6 +1352,8 @@ class StableDiffusionControlNetInpaintPipeline( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) @@ -1330,6 +1369,8 @@ class StableDiffusionControlNetInpaintPipeline( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) @@ -1341,10 +1382,15 @@ class StableDiffusionControlNetInpaintPipeline( assert False # 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width - init_image = self.image_processor.preprocess(image, height=height, width=width) + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) init_image = init_image.to(dtype=torch.float32) - mask = self.mask_processor.preprocess(mask_image, height=height, width=width) + mask = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) masked_image = init_image * (mask < 0.5) _, _, height, width = init_image.shape @@ -1534,6 +1580,9 @@ class StableDiffusionControlNetInpaintPipeline( image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + # Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 76b97b48f9..ceda744f4e 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -557,9 +557,11 @@ class StableDiffusionXLControlNetInpaintPipeline( prompt, prompt_2, image, + mask_image, strength, num_inference_steps, callback_steps, + output_type, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, @@ -570,6 +572,7 @@ class StableDiffusionXLControlNetInpaintPipeline( control_guidance_start=0.0, control_guidance_end=1.0, callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") @@ -632,6 +635,19 @@ class StableDiffusionXLControlNetInpaintPipeline( f" {negative_prompt_embeds.shape}." ) + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") + if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." @@ -745,10 +761,14 @@ class StableDiffusionXLControlNetInpaintPipeline( num_images_per_prompt, device, dtype, + crops_coords, + resize_mode, do_classifier_free_guidance=False, guess_mode=False, ): - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image = self.control_image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: @@ -1066,6 +1086,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ] = None, height: Optional[int] = None, width: Optional[int] = None, + padding_mask_crop: Optional[int] = None, strength: float = 0.9999, num_inference_steps: int = 50, denoising_start: Optional[float] = None, @@ -1121,6 +1142,12 @@ class StableDiffusionXLControlNetInpaintPipeline( The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If + `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and + contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on + the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large + and contain information inreleant for inpainging, such as background. strength (`float`, *optional*, defaults to 0.9999): Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the @@ -1290,9 +1317,11 @@ class StableDiffusionXLControlNetInpaintPipeline( prompt, prompt_2, control_image, + mask_image, strength, num_inference_steps, callback_steps, + output_type, negative_prompt, negative_prompt_2, prompt_embeds, @@ -1303,6 +1332,7 @@ class StableDiffusionXLControlNetInpaintPipeline( control_guidance_start, control_guidance_end, callback_on_step_end_tensor_inputs, + padding_mask_crop, ) self._guidance_scale = guidance_scale @@ -1370,7 +1400,18 @@ class StableDiffusionXLControlNetInpaintPipeline( # 5. Preprocess mask and image - resizes image and mask w.r.t height and width # 5.1 Prepare init image - init_image = self.image_processor.preprocess(image, height=height, width=width) + if padding_mask_crop is not None: + height, width = self.image_processor.get_default_height_width(image, height, width) + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) init_image = init_image.to(dtype=torch.float32) # 5.2 Prepare control images @@ -1383,6 +1424,8 @@ class StableDiffusionXLControlNetInpaintPipeline( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) @@ -1398,6 +1441,8 @@ class StableDiffusionXLControlNetInpaintPipeline( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) @@ -1409,7 +1454,9 @@ class StableDiffusionXLControlNetInpaintPipeline( raise ValueError(f"{controlnet.__class__} is not supported.") # 5.3 Prepare mask - mask = self.mask_processor.preprocess(mask_image, height=height, width=width) + mask = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) masked_image = init_image * (mask < 0.5) _, _, height, width = init_image.shape @@ -1684,6 +1731,9 @@ class StableDiffusionXLControlNetInpaintPipeline( image = self.image_processor.postprocess(image, output_type=output_type) + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + # Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 58af756849..6751490abd 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -642,6 +642,7 @@ class StableDiffusionInpaintPipeline( width, strength, callback_steps, + output_type, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, @@ -693,11 +694,6 @@ class StableDiffusionInpaintPipeline( f" {negative_prompt_embeds.shape}." ) if padding_mask_crop is not None: - if self.unet.config.in_channels != 4: - raise ValueError( - f"The UNet should have 4 input channels for inpainting mask crop, but has" - f" {self.unet.config.in_channels} input channels." - ) if not isinstance(image, PIL.Image.Image): raise ValueError( f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." @@ -707,6 +703,8 @@ class StableDiffusionInpaintPipeline( f"The mask image should be a PIL image when inpainting mask crop, but is of type" f" {type(mask_image)}." ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") def prepare_latents( self, @@ -1166,6 +1164,7 @@ class StableDiffusionInpaintPipeline( width, strength, callback_steps, + output_type, negative_prompt, prompt_embeds, negative_prompt_embeds, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 2f02a213b8..f9468adba9 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -744,15 +744,19 @@ class StableDiffusionXLInpaintPipeline( self, prompt, prompt_2, + image, + mask_image, height, width, strength, callback_steps, + output_type, negative_prompt=None, negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") @@ -810,6 +814,18 @@ class StableDiffusionXLInpaintPipeline( f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") def prepare_latents( self, @@ -1225,6 +1241,7 @@ class StableDiffusionXLInpaintPipeline( masked_image_latents: torch.FloatTensor = None, height: Optional[int] = None, width: Optional[int] = None, + padding_mask_crop: Optional[int] = None, strength: float = 0.9999, num_inference_steps: int = 50, timesteps: List[int] = None, @@ -1287,6 +1304,12 @@ class StableDiffusionXLInpaintPipeline( Anything below 512 pixels won't work well for [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and checkpoints that are not specifically fine-tuned on low resolutions. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If + `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and + contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on + the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large + and contain information inreleant for inpainging, such as background. strength (`float`, *optional*, defaults to 0.9999): Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the @@ -1449,15 +1472,19 @@ class StableDiffusionXLInpaintPipeline( self.check_inputs( prompt, prompt_2, + image, + mask_image, height, width, strength, callback_steps, + output_type, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, + padding_mask_crop, ) self._guidance_scale = guidance_scale @@ -1527,10 +1554,22 @@ class StableDiffusionXLInpaintPipeline( is_strength_max = strength == 1.0 # 5. Preprocess mask and image - init_image = self.image_processor.preprocess(image, height=height, width=width) + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) init_image = init_image.to(dtype=torch.float32) - mask = self.mask_processor.preprocess(mask_image, height=height, width=width) + mask = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) if masked_image_latents is not None: masked_image = masked_image_latents @@ -1791,6 +1830,9 @@ class StableDiffusionXLInpaintPipeline( image = self.image_processor.postprocess(image, output_type=output_type) + if padding_mask_crop is not None: + image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] + # Offload all models self.maybe_free_model_hooks()