diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 98baaa874a..89137ee815 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -182,6 +182,8 @@ title: Overview - local: api/schedulers/ddim title: DDIM + - local: api/schedulers/ddim_inverse + title: DDIMInverse - local: api/schedulers/ddpm title: DDPM - local: api/schedulers/deis diff --git a/docs/source/en/api/pipelines/stable_diffusion/pix2pix_zero.mdx b/docs/source/en/api/pipelines/stable_diffusion/pix2pix_zero.mdx index a0ae3a6083..47f080c98f 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/pix2pix_zero.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/pix2pix_zero.mdx @@ -138,14 +138,15 @@ caption = pipeline.generate_caption(raw_image) Then we employ the generated caption and the input image to get the inverted noise: ```py -inv_latents, inv_image = pipeline.invert(caption, image=raw_image) +generator = torch.manual_seed(0) +inv_latents = pipeline.invert(caption, image=raw_image, generator=generator).latents ``` Now, generate the image with edit directions: ```py # See the "Generating source and target embeddings" section below to -# automate the generation of these captions with a pre-trained model like Flan-T5. +# automate the generation of these captions with a pre-trained model like Flan-T5 as explained below. source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"] target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"] diff --git a/docs/source/en/api/schedulers/ddim_inverse.mdx b/docs/source/en/api/schedulers/ddim_inverse.mdx new file mode 100644 index 0000000000..5096a3cee2 --- /dev/null +++ b/docs/source/en/api/schedulers/ddim_inverse.mdx @@ -0,0 +1,21 @@ + + +# Inverse Denoising Diffusion Implicit Models (DDIMInverse) + +## Overview + +This scheduler is the inverted scheduler of [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) (DDIM) by Jiaming Song, Chenlin Meng and Stefano Ermon. +The implementation is mostly based on the DDIM inversion definition of [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/pdf/2211.09794.pdf) + +## DDIMInverseScheduler +[[autodoc]] DDIMInverseScheduler diff --git a/docs/source/en/api/schedulers/overview.mdx b/docs/source/en/api/schedulers/overview.mdx index cfb319b342..44f3a64b25 100644 --- a/docs/source/en/api/schedulers/overview.mdx +++ b/docs/source/en/api/schedulers/overview.mdx @@ -46,6 +46,7 @@ The following table summarizes all officially supported schedulers, their corres | Scheduler | Paper | |---|---| | [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | +| [ddim_inverse](./ddim_inverse) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | | [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | | [deis](./deis) | [**DEISMultistepScheduler**](https://arxiv.org/abs/2204.13902) | | [singlestep_dpm_solver](./singlestep_dpm_solver) | [**Singlestep DPM-Solver**](https://arxiv.org/abs/2206.00927) | diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0465f358b6..9d23c284db 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -67,6 +67,7 @@ else: ScoreSdeVePipeline, ) from .schedulers import ( + DDIMInverseScheduler, DDIMScheduler, DDPMScheduler, DEISMultistepScheduler, diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index cc208d34a3..0484ca8891 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -929,7 +929,7 @@ class DiffusionPipeline(ConfigMixin): if set(components.keys()) != expected_modules: raise ValueError( f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected" - f" {expected_modules} to be defined, but {components} are defined." + f" {expected_modules} to be defined, but {components.keys()} are defined." ) return components diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 338f69446e..81807fd5c2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -13,16 +13,34 @@ # limitations under the License. import inspect +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union +import numpy as np import PIL import torch -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +import torch.nn.functional as F +from transformers import ( + BlipForConditionalGeneration, + BlipProcessor, + CLIPFeatureExtractor, + CLIPTextModel, + CLIPTokenizer, +) from ...models import AutoencoderKL, UNet2DConditionModel from ...models.cross_attention import CrossAttention from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler -from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring +from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler +from ...utils import ( + PIL_INTERPOLATION, + BaseOutput, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -30,6 +48,24 @@ from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +@dataclass +class Pix2PixInversionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + latents (`torch.FloatTensor`) + inverted latents tensor + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + latents: torch.FloatTensor + images: Union[List[PIL.Image.Image], np.ndarray] + + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -46,9 +82,7 @@ EXAMPLE_DOC_STRING = """ >>> model_ckpt = "CompVis/stable-diffusion-v1-4" - >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained( - ... model_ckpt, conditions_input_image=False, torch_dtype=torch.float16 - ... ) + >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16) >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) >>> pipeline.to("cuda") @@ -68,10 +102,94 @@ EXAMPLE_DOC_STRING = """ ... num_inference_steps=50, ... cross_attention_guidance_amount=0.15, ... ).images + >>> images[0].save("edited_image_dog.png") ``` """ +EXAMPLE_INVERT_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from transformers import BlipForConditionalGeneration, BlipProcessor + >>> from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline + + >>> import requests + >>> from PIL import Image + + >>> captioner_id = "Salesforce/blip-image-captioning-base" + >>> processor = BlipProcessor.from_pretrained(captioner_id) + >>> model = BlipForConditionalGeneration.from_pretrained( + ... captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True + ... ) + + >>> sd_model_ckpt = "CompVis/stable-diffusion-v1-4" + >>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained( + ... sd_model_ckpt, + ... caption_generator=model, + ... caption_processor=processor, + ... torch_dtype=torch.float16, + ... safety_checker=None, + ... ) + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.enable_model_cpu_offload() + + >>> img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" + + >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512)) + >>> # generate caption + >>> caption = pipeline.generate_caption(raw_image) + + >>> # "a photography of a cat with flowers and dai dai daie - daie - daie kasaii" + >>> inv_latents = pipeline.invert(caption, image=raw_image).latents + >>> # we need to generate source and target embeds + + >>> source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"] + + >>> target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"] + + >>> source_embeds = pipeline.get_embeds(source_prompts) + >>> target_embeds = pipeline.get_embeds(target_prompts) + >>> # the latents can then be used to edit a real image + + >>> image = pipeline( + ... caption, + ... source_embeds=source_embeds, + ... target_embeds=target_embeds, + ... num_inference_steps=50, + ... cross_attention_guidance_amount=0.15, + ... generator=generator, + ... latents=inv_latents, + ... negative_prompt=caption, + ... ).images[0] + >>> image.save("edited_image.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + def prepare_unet(unet: UNet2DConditionModel): """Modifies the UNet (`unet`) to perform Pix2Pix Zero optimizations.""" @@ -179,13 +297,17 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. - conditions_input_image (bool): - Whether to condition the pipeline with an input image to compute an inverted noise latent. requires_safety_checker (bool): Whether the pipeline requires a safety checker. We recommend setting it to True if you're using the pipeline publicly. """ - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = [ + "safety_checker", + "feature_extractor", + "caption_generator", + "caption_processor", + "inverse_scheduler", + ] def __init__( self, @@ -194,9 +316,11 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler], - safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, - conditions_input_image: bool = False, + safety_checker: StableDiffusionSafetyChecker, + inverse_scheduler: DDIMInverseScheduler, + caption_generator: BlipForConditionalGeneration, + caption_processor: BlipProcessor, requires_safety_checker: bool = True, ): super().__init__() @@ -211,16 +335,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) - if conditions_input_image: - raise NotImplementedError - # logger.info("Loading caption generator since `conditions_input_image` is True.") - # checkpoint = "Salesforce/blip-image-captioning-base" - # captioner_processor = AutoProcessor.from_pretrained(checkpoint) - # captioner = BlipForConditionalGeneration.from_pretrained(checkpoint, dtype=unet.dtype) - else: - captioner_processor = None - captioner = None - if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" @@ -232,19 +346,15 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, - _captioner_processor=captioner_processor, - _captioner=captioner, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + caption_processor=caption_processor, + caption_generator=caption_generator, + inverse_scheduler=inverse_scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.conditions_input_image = conditions_input_image - self.register_to_config( - _captioner=captioner, - _captioner_processor=captioner_processor, - requires_safety_checker=requires_safety_checker, - ) + self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): @@ -268,6 +378,30 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.vae, self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): @@ -467,7 +601,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): def check_inputs( self, prompt, - conditions_input_image, image, source_embeds, target_embeds, @@ -484,14 +617,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): if source_embeds is None and target_embeds is None: raise ValueError("`source_embeds` and `target_embeds` cannot be undefined.") - if prompt is None and not conditions_input_image: - raise ValueError(f"`prompt` cannot be None when `conditions_input_image` is {conditions_input_image}") - - elif prompt is not None and conditions_input_image: - raise ValueError( - f"`prompt` should not be provided when `conditions_input_image` is {conditions_input_image}" - ) - if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -504,12 +629,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - if conditions_input_image: - if image is None: - raise ValueError("`image` cannot be None when `conditions_input_image` is True.") - elif isinstance(image, (torch.FloatTensor, PIL.Image.Image)): - raise ValueError("Invalid image provided. Supported formats: torch.FloatTensor, PIL.Image.Image.}") - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) @@ -528,15 +647,25 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): latents = latents * self.scheduler.init_noise_sigma return latents - def generate_caption(self, image, return_image=True): + @torch.no_grad() + def generate_caption(self, images): """Generates caption for a given image.""" - inputs = self._captioner_processor(images=image, return_tensors="pt") - outputs = self._captioner.generate(inputs) - caption = self._captioner_processor.batch_deocde(outputs, skip_special_tokens=True)[0] - if return_image: - return caption, inputs["pixel_values"] - else: - return caption + text = "a photography of" + + prev_device = self.caption_generator.device + + device = self._execution_device + inputs = self.caption_processor(images, text, return_tensors="pt").to( + device=device, dtype=self.caption_generator.dtype + ) + self.caption_generator.to(device) + outputs = self.caption_generator.generate(**inputs, max_new_tokens=128) + + # offload caption generator + self.caption_generator.to(prev_device) + + caption = self.caption_processor.batch_decode(outputs, skip_special_tokens=True)[0] + return caption def construct_direction(self, embs_source: torch.Tensor, embs_target: torch.Tensor): """Constructs the edit direction to steer the image generation process semantically.""" @@ -562,6 +691,66 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): return torch.cat(embeds, dim=0).mean(0)[None] + def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + latents = init_latents + + return latents + + def auto_corr_loss(self, hidden_states, generator=None): + batch_size, channel, height, width = hidden_states.shape + if batch_size > 1: + raise ValueError("Only batch_size 1 is supported for now") + + hidden_states = hidden_states.squeeze(0) + # hidden_states must be shape [C,H,W] now + reg_loss = 0.0 + for i in range(hidden_states.shape[0]): + noise = hidden_states[i][None, None, :, :] + while True: + roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 + + if noise.shape[2] <= 8: + break + noise = F.avg_pool2d(noise, kernel_size=2) + return reg_loss + + def kl_divergence(self, hidden_states): + mean = hidden_states.mean() + var = hidden_states.var() + return var + mean**2 - 1 - torch.log(var + 1e-7) + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -595,8 +784,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`PIL.Image.Image`, *optional*): - `Image`, or tensor representing an image batch which will be used for conditioning. source_embeds (`torch.Tensor`): Source concept embeddings. Generation of the embeddings as per the [original paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction. @@ -670,25 +857,12 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, - self.conditions_input_image, image, source_embeds, target_embeds, callback_steps, prompt_embeds, ) - if self.conditions_input_image and prompt_embeds: - logger.warning( - f"You have set `conditions_input_image` to {self.conditions_input_image} and" - " passed `prompt_embeds`. `prompt_embeds` will be ignored. " - ) - - # 2. Generate a caption for the input image if we are conditioning the - # pipeline based on some input image. - if self.conditions_input_image: - prompt, preprocessed_image = self.generate_caption(image) - height, width = preprocessed_image.shape[-2:] - logger.info(f"Generated prompt for the input image: {prompt}.") # 3. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -723,24 +897,17 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): # 5. Generate the inverted noise from the input image or any other image # generated from the input prompt. - if self.conditions_input_image: - # TODO (sayakpaul): Generate this using DDIM inversion. - # We need to get the inverted noise from the input image and this requires - # us to do a sort of `inverse_step()` in DDIM and then regularize the - # noise to enforce the statistical properties of Gaussian. - pass - else: - num_channels_latents = self.unet.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) latents_init = latents.clone() # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline @@ -852,7 +1019,206 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): if output_type == "pil": edited_image = self.numpy_to_pil(edited_image) + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + if not return_dict: return (edited_image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=edited_image, nsfw_content_detected=has_nsfw_concept) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) + def invert( + self, + prompt: Optional[str] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + num_inference_steps: int = 50, + guidance_scale: float = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + cross_attention_guidance_amount: float = 0.1, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + lambda_auto_corr: float = 20.0, + lambda_kl: float = 20.0, + num_reg_steps: int = 5, + num_auto_corr_rolls: int = 5, + ): + r""" + Function used to generate inverted latents given a prompt and image. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image`, *optional*): + `Image`, or tensor representing an image batch which will be used for conditioning. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + cross_attention_guidance_amount (`float`, defaults to 0.1): + Amount of guidance needed from the reference cross-attention maps. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + lambda_auto_corr (`float`, *optional*, defaults to 20.0): + Lambda parameter to control auto correction + lambda_kl (`float`, *optional*, defaults to 20.0): + Lambda parameter to control Kullback–Leibler divergence output + num_reg_steps (`int`, *optional*, defaults to 5): + Number of regularization loss steps + num_auto_corr_rolls (`int`, *optional*, defaults to 5): + Number of auto correction roll steps + + Examples: + + Returns: + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] or + `tuple`: + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is the inverted + latents tensor and then second is the corresponding decoded image. + """ + # 1. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Preprocess image + image = preprocess(image) + + # 4. Prepare latent variables + latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator) + + # 5. Encode input prompt + num_images_per_prompt = 1 + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + ) + + # 4. Prepare timesteps + self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.inverse_scheduler.timesteps + + # 6. Rejig the UNet so that we can obtain the cross-attenion maps and + # use them for guiding the subsequent image generation. + self.unet = prepare_unet(self.unet) + + # 7. Denoising loop where we obtain the cross-attention maps. + num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order + with self.progress_bar(total=num_inference_steps - 2) as progress_bar: + for i, t in enumerate(timesteps[1:-1]): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs={"timestep": t}, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # regularization of the noise prediction + with torch.enable_grad(): + for _ in range(num_reg_steps): + if lambda_auto_corr > 0: + for _ in range(num_auto_corr_rolls): + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + l_ac = self.auto_corr_loss(var, generator=generator) + l_ac.backward() + + grad = var.grad.detach() / num_auto_corr_rolls + noise_pred = noise_pred - lambda_auto_corr * grad + + if lambda_kl > 0: + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + l_kld = self.kl_divergence(var) + l_kld.backward() + + grad = var.grad.detach() + noise_pred = noise_pred - lambda_kl * grad + + noise_pred = noise_pred.detach() + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + inverted_latents = latents.detach().clone() + + # 8. Post-processing + image = self.decode_latents(latents.detach()) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + # 9. Convert to PIL. + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (inverted_latents, image) + + return Pix2PixInversionPipelineOutput(latents=inverted_latents, images=image) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 95af68be60..f479e06ca9 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: from .scheduling_ddim import DDIMScheduler + from .scheduling_ddim_inverse import DDIMInverseScheduler from .scheduling_ddpm import DDPMScheduler from .scheduling_deis_multistep import DEISMultistepScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py new file mode 100644 index 0000000000..bbe9f3c23a --- /dev/null +++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py @@ -0,0 +1,227 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): + """ + DDIMInverseScheduler is the reverse scheduler of [`DDIMScheduler`]. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + clip_sample (`bool`, default `True`): + option to clip predicted sample between -1 and 1 for numerical stability. + set_alpha_to_one (`bool`, default `True`): + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + """ + + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64)) + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps += self.config.steps_offset + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + variance_noise: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + ) -> Union[DDIMSchedulerOutput, Tuple]: + e_t = model_output + + x = sample + prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps + + a_t = self.alphas_cumprod[timestep - 1] + a_prev = self.alphas_cumprod[prev_timestep - 1] if prev_timestep >= 0 else self.final_alpha_cumprod + + pred_x0 = (x - (1 - a_t) ** 0.5 * e_t) / a_t.sqrt() + + dir_xt = (1.0 - a_prev).sqrt() * e_t + + prev_sample = a_prev.sqrt() * pred_x0 + dir_xt + + if not return_dict: + return (prev_sample, pred_x0) + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0) + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index ea0cb77b5b..67e537a635 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -345,6 +345,21 @@ class ScoreSdeVePipeline(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class DDIMInverseScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DDIMScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index 92a4906a5e..103430aefb 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -19,10 +19,12 @@ import unittest import numpy as np import requests import torch +from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import ( AutoencoderKL, + DDIMInverseScheduler, DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, @@ -30,7 +32,7 @@ from diffusers import ( StableDiffusionPix2PixZeroPipeline, UNet2DConditionModel, ) -from diffusers.utils import slow, torch_device +from diffusers.utils import load_numpy, slow, torch_device from diffusers.utils.testing_utils import require_torch_gpu, skip_mps from ...test_pipelines_common import PipelineTesterMixin @@ -94,6 +96,9 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest. "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, + "inverse_scheduler": None, + "caption_generator": None, + "caption_processor": None, } return components @@ -344,3 +349,83 @@ class StableDiffusionPix2PixZeroPipelineSlowTests(unittest.TestCase): mem_bytes = torch.cuda.max_memory_allocated() # make sure that less than 8.2 GB is allocated assert mem_bytes < 8.2 * 10**9 + + +@slow +@require_torch_gpu +class InversionPipelineSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_stable_diffusion_pix2pix_inversion(self): + img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512)) + + pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16 + ) + pipe.inverse_scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config) + + caption = "a photography of a cat with flowers" + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.manual_seed(0) + output = pipe.invert(caption, image=raw_image, generator=generator, num_inference_steps=10) + inv_latents = output[0] + + image_slice = inv_latents[0, -3:, -3:, -1].flatten() + + assert inv_latents.shape == (1, 4, 64, 64) + expected_slice = np.array([0.8877, 0.0587, 0.7700, -1.6035, -0.5962, 0.4827, -0.6265, 1.0498, -0.8599]) + + assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 1e-3 + + def test_stable_diffusion_pix2pix_full(self): + img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png" + raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512)) + + # numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog.png + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog.npy" + ) + + pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16 + ) + pipe.inverse_scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config) + + caption = "a photography of a cat with flowers" + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.manual_seed(0) + output = pipe.invert(caption, image=raw_image, generator=generator) + inv_latents = output[0] + + source_prompts = 4 * ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"] + target_prompts = 4 * ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"] + + source_embeds = pipe.get_embeds(source_prompts) + target_embeds = pipe.get_embeds(target_prompts) + + image = pipe( + caption, + source_embeds=source_embeds, + target_embeds=target_embeds, + num_inference_steps=50, + cross_attention_guidance_amount=0.15, + generator=generator, + latents=inv_latents, + negative_prompt=caption, + output_type="np", + ).images + + max_diff = np.abs(expected_image - image).max() + assert max_diff < 1e-3