1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add ddim inversion pix2pix (#2397)

* add

* finish

* add tests

* add tests

* up

* up

* pull from main

* uP

* Apply suggestions from code review

* finish

* Update docs/source/en/_toctree.yml

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* finish

* clean docs

* next

* next

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* up

* up

---------

Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Patrick von Platen
2023-02-17 17:27:51 +02:00
committed by GitHub
parent 01a80807de
commit 14b950705a
11 changed files with 808 additions and 88 deletions

View File

@@ -182,6 +182,8 @@
title: Overview
- local: api/schedulers/ddim
title: DDIM
- local: api/schedulers/ddim_inverse
title: DDIMInverse
- local: api/schedulers/ddpm
title: DDPM
- local: api/schedulers/deis

View File

@@ -138,14 +138,15 @@ caption = pipeline.generate_caption(raw_image)
Then we employ the generated caption and the input image to get the inverted noise:
```py
inv_latents, inv_image = pipeline.invert(caption, image=raw_image)
generator = torch.manual_seed(0)
inv_latents = pipeline.invert(caption, image=raw_image, generator=generator).latents
```
Now, generate the image with edit directions:
```py
# See the "Generating source and target embeddings" section below to
# automate the generation of these captions with a pre-trained model like Flan-T5.
# automate the generation of these captions with a pre-trained model like Flan-T5 as explained below.
source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]
target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"]

View File

@@ -0,0 +1,21 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Inverse Denoising Diffusion Implicit Models (DDIMInverse)
## Overview
This scheduler is the inverted scheduler of [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) (DDIM) by Jiaming Song, Chenlin Meng and Stefano Ermon.
The implementation is mostly based on the DDIM inversion definition of [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/pdf/2211.09794.pdf)
## DDIMInverseScheduler
[[autodoc]] DDIMInverseScheduler

View File

@@ -46,6 +46,7 @@ The following table summarizes all officially supported schedulers, their corres
| Scheduler | Paper |
|---|---|
| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) |
| [ddim_inverse](./ddim_inverse) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) |
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) |
| [deis](./deis) | [**DEISMultistepScheduler**](https://arxiv.org/abs/2204.13902) |
| [singlestep_dpm_solver](./singlestep_dpm_solver) | [**Singlestep DPM-Solver**](https://arxiv.org/abs/2206.00927) |

View File

@@ -67,6 +67,7 @@ else:
ScoreSdeVePipeline,
)
from .schedulers import (
DDIMInverseScheduler,
DDIMScheduler,
DDPMScheduler,
DEISMultistepScheduler,

View File

@@ -929,7 +929,7 @@ class DiffusionPipeline(ConfigMixin):
if set(components.keys()) != expected_modules:
raise ValueError(
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
f" {expected_modules} to be defined, but {components} are defined."
f" {expected_modules} to be defined, but {components.keys()} are defined."
)
return components

View File

@@ -13,16 +13,34 @@
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL
import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
import torch.nn.functional as F
from transformers import (
BlipForConditionalGeneration,
BlipProcessor,
CLIPFeatureExtractor,
CLIPTextModel,
CLIPTokenizer,
)
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.cross_attention import CrossAttention
from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler
from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring
from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler
from ...utils import (
PIL_INTERPOLATION,
BaseOutput,
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
@@ -30,6 +48,24 @@ from .safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class Pix2PixInversionPipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
latents (`torch.FloatTensor`)
inverted latents tensor
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
latents: torch.FloatTensor
images: Union[List[PIL.Image.Image], np.ndarray]
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -46,9 +82,7 @@ EXAMPLE_DOC_STRING = """
>>> model_ckpt = "CompVis/stable-diffusion-v1-4"
>>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(
... model_ckpt, conditions_input_image=False, torch_dtype=torch.float16
... )
>>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16)
>>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
>>> pipeline.to("cuda")
@@ -68,10 +102,94 @@ EXAMPLE_DOC_STRING = """
... num_inference_steps=50,
... cross_attention_guidance_amount=0.15,
... ).images
>>> images[0].save("edited_image_dog.png")
```
"""
EXAMPLE_INVERT_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from transformers import BlipForConditionalGeneration, BlipProcessor
>>> from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline
>>> import requests
>>> from PIL import Image
>>> captioner_id = "Salesforce/blip-image-captioning-base"
>>> processor = BlipProcessor.from_pretrained(captioner_id)
>>> model = BlipForConditionalGeneration.from_pretrained(
... captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True
... )
>>> sd_model_ckpt = "CompVis/stable-diffusion-v1-4"
>>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(
... sd_model_ckpt,
... caption_generator=model,
... caption_processor=processor,
... torch_dtype=torch.float16,
... safety_checker=None,
... )
>>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
>>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)
>>> pipeline.enable_model_cpu_offload()
>>> img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
>>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512))
>>> # generate caption
>>> caption = pipeline.generate_caption(raw_image)
>>> # "a photography of a cat with flowers and dai dai daie - daie - daie kasaii"
>>> inv_latents = pipeline.invert(caption, image=raw_image).latents
>>> # we need to generate source and target embeds
>>> source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]
>>> target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"]
>>> source_embeds = pipeline.get_embeds(source_prompts)
>>> target_embeds = pipeline.get_embeds(target_prompts)
>>> # the latents can then be used to edit a real image
>>> image = pipeline(
... caption,
... source_embeds=source_embeds,
... target_embeds=target_embeds,
... num_inference_steps=50,
... cross_attention_guidance_amount=0.15,
... generator=generator,
... latents=inv_latents,
... negative_prompt=caption,
... ).images[0]
>>> image.save("edited_image.png")
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
def prepare_unet(unet: UNet2DConditionModel):
"""Modifies the UNet (`unet`) to perform Pix2Pix Zero optimizations."""
@@ -179,13 +297,17 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
conditions_input_image (bool):
Whether to condition the pipeline with an input image to compute an inverted noise latent.
requires_safety_checker (bool):
Whether the pipeline requires a safety checker. We recommend setting it to True if you're using the
pipeline publicly.
"""
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = [
"safety_checker",
"feature_extractor",
"caption_generator",
"caption_processor",
"inverse_scheduler",
]
def __init__(
self,
@@ -194,9 +316,11 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
conditions_input_image: bool = False,
safety_checker: StableDiffusionSafetyChecker,
inverse_scheduler: DDIMInverseScheduler,
caption_generator: BlipForConditionalGeneration,
caption_processor: BlipProcessor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -211,16 +335,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if conditions_input_image:
raise NotImplementedError
# logger.info("Loading caption generator since `conditions_input_image` is True.")
# checkpoint = "Salesforce/blip-image-captioning-base"
# captioner_processor = AutoProcessor.from_pretrained(checkpoint)
# captioner = BlipForConditionalGeneration.from_pretrained(checkpoint, dtype=unet.dtype)
else:
captioner_processor = None
captioner = None
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
@@ -232,19 +346,15 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
_captioner_processor=captioner_processor,
_captioner=captioner,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
caption_processor=caption_processor,
caption_generator=caption_generator,
inverse_scheduler=inverse_scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.conditions_input_image = conditions_input_image
self.register_to_config(
_captioner=captioner,
_captioner_processor=captioner_processor,
requires_safety_checker=requires_safety_checker,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, gpu_id=0):
@@ -268,6 +378,30 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
if self.safety_checker is not None:
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
device = torch.device(f"cuda:{gpu_id}")
hook = None
for cpu_offloaded_model in [self.vae, self.text_encoder, self.unet, self.vae]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
if self.safety_checker is not None:
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
# We'll offload the last model manually.
self.final_offload_hook = hook
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
@@ -467,7 +601,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
def check_inputs(
self,
prompt,
conditions_input_image,
image,
source_embeds,
target_embeds,
@@ -484,14 +617,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
if source_embeds is None and target_embeds is None:
raise ValueError("`source_embeds` and `target_embeds` cannot be undefined.")
if prompt is None and not conditions_input_image:
raise ValueError(f"`prompt` cannot be None when `conditions_input_image` is {conditions_input_image}")
elif prompt is not None and conditions_input_image:
raise ValueError(
f"`prompt` should not be provided when `conditions_input_image` is {conditions_input_image}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -504,12 +629,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if conditions_input_image:
if image is None:
raise ValueError("`image` cannot be None when `conditions_input_image` is True.")
elif isinstance(image, (torch.FloatTensor, PIL.Image.Image)):
raise ValueError("Invalid image provided. Supported formats: torch.FloatTensor, PIL.Image.Image.}")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
@@ -528,15 +647,25 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma
return latents
def generate_caption(self, image, return_image=True):
@torch.no_grad()
def generate_caption(self, images):
"""Generates caption for a given image."""
inputs = self._captioner_processor(images=image, return_tensors="pt")
outputs = self._captioner.generate(inputs)
caption = self._captioner_processor.batch_deocde(outputs, skip_special_tokens=True)[0]
if return_image:
return caption, inputs["pixel_values"]
else:
return caption
text = "a photography of"
prev_device = self.caption_generator.device
device = self._execution_device
inputs = self.caption_processor(images, text, return_tensors="pt").to(
device=device, dtype=self.caption_generator.dtype
)
self.caption_generator.to(device)
outputs = self.caption_generator.generate(**inputs, max_new_tokens=128)
# offload caption generator
self.caption_generator.to(prev_device)
caption = self.caption_processor.batch_decode(outputs, skip_special_tokens=True)[0]
return caption
def construct_direction(self, embs_source: torch.Tensor, embs_target: torch.Tensor):
"""Constructs the edit direction to steer the image generation process semantically."""
@@ -562,6 +691,66 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
return torch.cat(embeds, dim=0).mean(0)[None]
def prepare_image_latents(self, image, batch_size, dtype, device, generator=None):
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
image = image.to(device=device, dtype=dtype)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
)
else:
init_latents = torch.cat([init_latents], dim=0)
latents = init_latents
return latents
def auto_corr_loss(self, hidden_states, generator=None):
batch_size, channel, height, width = hidden_states.shape
if batch_size > 1:
raise ValueError("Only batch_size 1 is supported for now")
hidden_states = hidden_states.squeeze(0)
# hidden_states must be shape [C,H,W] now
reg_loss = 0.0
for i in range(hidden_states.shape[0]):
noise = hidden_states[i][None, None, :, :]
while True:
roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
if noise.shape[2] <= 8:
break
noise = F.avg_pool2d(noise, kernel_size=2)
return reg_loss
def kl_divergence(self, hidden_states):
mean = hidden_states.mean()
var = hidden_states.var()
return var + mean**2 - 1 - torch.log(var + 1e-7)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -595,8 +784,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`PIL.Image.Image`, *optional*):
`Image`, or tensor representing an image batch which will be used for conditioning.
source_embeds (`torch.Tensor`):
Source concept embeddings. Generation of the embeddings as per the [original
paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction.
@@ -670,25 +857,12 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
self.conditions_input_image,
image,
source_embeds,
target_embeds,
callback_steps,
prompt_embeds,
)
if self.conditions_input_image and prompt_embeds:
logger.warning(
f"You have set `conditions_input_image` to {self.conditions_input_image} and"
" passed `prompt_embeds`. `prompt_embeds` will be ignored. "
)
# 2. Generate a caption for the input image if we are conditioning the
# pipeline based on some input image.
if self.conditions_input_image:
prompt, preprocessed_image = self.generate_caption(image)
height, width = preprocessed_image.shape[-2:]
logger.info(f"Generated prompt for the input image: {prompt}.")
# 3. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -723,24 +897,17 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
# 5. Generate the inverted noise from the input image or any other image
# generated from the input prompt.
if self.conditions_input_image:
# TODO (sayakpaul): Generate this using DDIM inversion.
# We need to get the inverted noise from the input image and this requires
# us to do a sort of `inverse_step()` in DDIM and then regularize the
# noise to enforce the statistical properties of Gaussian.
pass
else:
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
latents_init = latents.clone()
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
@@ -852,7 +1019,206 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
if output_type == "pil":
edited_image = self.numpy_to_pil(edited_image)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict:
return (edited_image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=edited_image, nsfw_content_detected=has_nsfw_concept)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_INVERT_DOC_STRING)
def invert(
self,
prompt: Optional[str] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
num_inference_steps: int = 50,
guidance_scale: float = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
cross_attention_guidance_amount: float = 0.1,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
lambda_auto_corr: float = 20.0,
lambda_kl: float = 20.0,
num_reg_steps: int = 5,
num_auto_corr_rolls: int = 5,
):
r"""
Function used to generate inverted latents given a prompt and image.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`PIL.Image.Image`, *optional*):
`Image`, or tensor representing an image batch which will be used for conditioning.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
cross_attention_guidance_amount (`float`, defaults to 0.1):
Amount of guidance needed from the reference cross-attention maps.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
lambda_auto_corr (`float`, *optional*, defaults to 20.0):
Lambda parameter to control auto correction
lambda_kl (`float`, *optional*, defaults to 20.0):
Lambda parameter to control KullbackLeibler divergence output
num_reg_steps (`int`, *optional*, defaults to 5):
Number of regularization loss steps
num_auto_corr_rolls (`int`, *optional*, defaults to 5):
Number of auto correction roll steps
Examples:
Returns:
[`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] or
`tuple`:
[`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is the inverted
latents tensor and then second is the corresponding decoded image.
"""
# 1. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if cross_attention_kwargs is None:
cross_attention_kwargs = {}
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Preprocess image
image = preprocess(image)
# 4. Prepare latent variables
latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator)
# 5. Encode input prompt
num_images_per_prompt = 1
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
)
# 4. Prepare timesteps
self.inverse_scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.inverse_scheduler.timesteps
# 6. Rejig the UNet so that we can obtain the cross-attenion maps and
# use them for guiding the subsequent image generation.
self.unet = prepare_unet(self.unet)
# 7. Denoising loop where we obtain the cross-attention maps.
num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order
with self.progress_bar(total=num_inference_steps - 2) as progress_bar:
for i, t in enumerate(timesteps[1:-1]):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs={"timestep": t},
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# regularization of the noise prediction
with torch.enable_grad():
for _ in range(num_reg_steps):
if lambda_auto_corr > 0:
for _ in range(num_auto_corr_rolls):
var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True)
l_ac = self.auto_corr_loss(var, generator=generator)
l_ac.backward()
grad = var.grad.detach() / num_auto_corr_rolls
noise_pred = noise_pred - lambda_auto_corr * grad
if lambda_kl > 0:
var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True)
l_kld = self.kl_divergence(var)
l_kld.backward()
grad = var.grad.detach()
noise_pred = noise_pred - lambda_kl * grad
noise_pred = noise_pred.detach()
# compute the previous noisy sample x_t -> x_t-1
latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
inverted_latents = latents.detach().clone()
# 8. Post-processing
image = self.decode_latents(latents.detach())
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
# 9. Convert to PIL.
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (inverted_latents, image)
return Pix2PixInversionPipelineOutput(latents=inverted_latents, images=image)

View File

@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddim_inverse import DDIMInverseScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_deis_multistep import DEISMultistepScheduler
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler

View File

@@ -0,0 +1,227 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import BaseOutput
@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
class DDIMSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
"""
DDIMInverseScheduler is the reverse scheduler of [`DDIMScheduler`].
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
For more details, see the original paper: https://arxiv.org/abs/2010.02502
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
set_alpha_to_one (`bool`, default `True`):
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64))
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
if num_inference_steps > self.config.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().copy().astype(np.int64)
self.timesteps = torch.from_numpy(timesteps).to(device)
self.timesteps += self.config.steps_offset
def step(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
variance_noise: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
) -> Union[DDIMSchedulerOutput, Tuple]:
e_t = model_output
x = sample
prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
a_t = self.alphas_cumprod[timestep - 1]
a_prev = self.alphas_cumprod[prev_timestep - 1] if prev_timestep >= 0 else self.final_alpha_cumprod
pred_x0 = (x - (1 - a_t) ** 0.5 * e_t) / a_t.sqrt()
dir_xt = (1.0 - a_prev).sqrt() * e_t
prev_sample = a_prev.sqrt() * pred_x0 + dir_xt
if not return_dict:
return (prev_sample, pred_x0)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0)
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -345,6 +345,21 @@ class ScoreSdeVePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class DDIMInverseScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DDIMScheduler(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -19,10 +19,12 @@ import unittest
import numpy as np
import requests
import torch
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
DDIMInverseScheduler,
DDIMScheduler,
DDPMScheduler,
EulerAncestralDiscreteScheduler,
@@ -30,7 +32,7 @@ from diffusers import (
StableDiffusionPix2PixZeroPipeline,
UNet2DConditionModel,
)
from diffusers.utils import slow, torch_device
from diffusers.utils import load_numpy, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ...test_pipelines_common import PipelineTesterMixin
@@ -94,6 +96,9 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"inverse_scheduler": None,
"caption_generator": None,
"caption_processor": None,
}
return components
@@ -344,3 +349,83 @@ class StableDiffusionPix2PixZeroPipelineSlowTests(unittest.TestCase):
mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 8.2 GB is allocated
assert mem_bytes < 8.2 * 10**9
@slow
@require_torch_gpu
class InversionPipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_stable_diffusion_pix2pix_inversion(self):
img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512))
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
)
pipe.inverse_scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
caption = "a photography of a cat with flowers"
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
output = pipe.invert(caption, image=raw_image, generator=generator, num_inference_steps=10)
inv_latents = output[0]
image_slice = inv_latents[0, -3:, -3:, -1].flatten()
assert inv_latents.shape == (1, 4, 64, 64)
expected_slice = np.array([0.8877, 0.0587, 0.7700, -1.6035, -0.5962, 0.4827, -0.6265, 1.0498, -0.8599])
assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 1e-3
def test_stable_diffusion_pix2pix_full(self):
img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512))
# numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog.png
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog.npy"
)
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
)
pipe.inverse_scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
caption = "a photography of a cat with flowers"
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
output = pipe.invert(caption, image=raw_image, generator=generator)
inv_latents = output[0]
source_prompts = 4 * ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]
target_prompts = 4 * ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"]
source_embeds = pipe.get_embeds(source_prompts)
target_embeds = pipe.get_embeds(target_prompts)
image = pipe(
caption,
source_embeds=source_embeds,
target_embeds=target_embeds,
num_inference_steps=50,
cross_attention_guidance_amount=0.15,
generator=generator,
latents=inv_latents,
negative_prompt=caption,
output_type="np",
).images
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3