diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 509ca4f9f8..6c6925bba0 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import numpy as np import PIL.Image import torch from transformers import ( @@ -46,6 +44,12 @@ from ...utils import ( from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import StableDiffusionXLPipelineOutput +from .pipeline_stable_diffusion_xl_utils import ( + StableDiffusionXLMixin, + rescale_noise_cfg, + retrieve_latents, + retrieve_timesteps, +) if is_invisible_watermark_available(): @@ -91,126 +95,10 @@ EXAMPLE_DOC_STRING = """ """ -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.rescale_noise_cfg -def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - r""" - Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on - Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://huggingface.co/papers/2305.08891). - - Args: - noise_cfg (`torch.Tensor`): - The predicted noise tensor for the guided diffusion process. - noise_pred_text (`torch.Tensor`): - The predicted noise tensor for the text-guided diffusion process. - guidance_rescale (`float`, *optional*, defaults to 0.0): - A rescale factor applied to the noise predictions. - - Returns: - noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. - """ - std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) - std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) - # rescale the results from guidance (fixes overexposure) - noise_pred_rescaled = noise_cfg * (std_text / std_cfg) - # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg - return noise_cfg - - -def mask_pil_to_torch(mask, height, width): - # preprocess mask - if isinstance(mask, (PIL.Image.Image, np.ndarray)): - mask = [mask] - - if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): - mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] - mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) - mask = mask.astype(np.float32) / 255.0 - elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): - mask = np.concatenate([m[None, None, :] for m in mask], axis=0) - - mask = torch.from_numpy(mask) - return mask - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" -): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, - **kwargs, -): - r""" - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` - must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, - `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): - Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - class StableDiffusionXLInpaintPipeline( DiffusionPipeline, StableDiffusionMixin, + StableDiffusionXLMixin, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin,