mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add ddim inversion pix2pix (#2397)
* add * finish * add tests * add tests * up * up * pull from main * uP * Apply suggestions from code review * finish * Update docs/source/en/_toctree.yml Co-authored-by: Suraj Patil <surajp815@gmail.com> * finish * clean docs * next * next * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * up * up --------- Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
01a80807de
commit
14b950705a
@@ -182,6 +182,8 @@
|
||||
title: Overview
|
||||
- local: api/schedulers/ddim
|
||||
title: DDIM
|
||||
- local: api/schedulers/ddim_inverse
|
||||
title: DDIMInverse
|
||||
- local: api/schedulers/ddpm
|
||||
title: DDPM
|
||||
- local: api/schedulers/deis
|
||||
|
||||
@@ -138,14 +138,15 @@ caption = pipeline.generate_caption(raw_image)
|
||||
Then we employ the generated caption and the input image to get the inverted noise:
|
||||
|
||||
```py
|
||||
inv_latents, inv_image = pipeline.invert(caption, image=raw_image)
|
||||
generator = torch.manual_seed(0)
|
||||
inv_latents = pipeline.invert(caption, image=raw_image, generator=generator).latents
|
||||
```
|
||||
|
||||
Now, generate the image with edit directions:
|
||||
|
||||
```py
|
||||
# See the "Generating source and target embeddings" section below to
|
||||
# automate the generation of these captions with a pre-trained model like Flan-T5.
|
||||
# automate the generation of these captions with a pre-trained model like Flan-T5 as explained below.
|
||||
source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]
|
||||
target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"]
|
||||
|
||||
|
||||
21
docs/source/en/api/schedulers/ddim_inverse.mdx
Normal file
21
docs/source/en/api/schedulers/ddim_inverse.mdx
Normal file
@@ -0,0 +1,21 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Inverse Denoising Diffusion Implicit Models (DDIMInverse)
|
||||
|
||||
## Overview
|
||||
|
||||
This scheduler is the inverted scheduler of [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) (DDIM) by Jiaming Song, Chenlin Meng and Stefano Ermon.
|
||||
The implementation is mostly based on the DDIM inversion definition of [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/pdf/2211.09794.pdf)
|
||||
|
||||
## DDIMInverseScheduler
|
||||
[[autodoc]] DDIMInverseScheduler
|
||||
@@ -46,6 +46,7 @@ The following table summarizes all officially supported schedulers, their corres
|
||||
| Scheduler | Paper |
|
||||
|---|---|
|
||||
| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) |
|
||||
| [ddim_inverse](./ddim_inverse) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) |
|
||||
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) |
|
||||
| [deis](./deis) | [**DEISMultistepScheduler**](https://arxiv.org/abs/2204.13902) |
|
||||
| [singlestep_dpm_solver](./singlestep_dpm_solver) | [**Singlestep DPM-Solver**](https://arxiv.org/abs/2206.00927) |
|
||||
|
||||
@@ -67,6 +67,7 @@ else:
|
||||
ScoreSdeVePipeline,
|
||||
)
|
||||
from .schedulers import (
|
||||
DDIMInverseScheduler,
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
DEISMultistepScheduler,
|
||||
|
||||
@@ -929,7 +929,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
if set(components.keys()) != expected_modules:
|
||||
raise ValueError(
|
||||
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
|
||||
f" {expected_modules} to be defined, but {components} are defined."
|
||||
f" {expected_modules} to be defined, but {components.keys()} are defined."
|
||||
)
|
||||
|
||||
return components
|
||||
|
||||
@@ -13,16 +13,34 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
import torch.nn.functional as F
|
||||
from transformers import (
|
||||
BlipForConditionalGeneration,
|
||||
BlipProcessor,
|
||||
CLIPFeatureExtractor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.cross_attention import CrossAttention
|
||||
from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler
|
||||
from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring
|
||||
from ...schedulers.scheduling_ddim_inverse import DDIMInverseScheduler
|
||||
from ...utils import (
|
||||
PIL_INTERPOLATION,
|
||||
BaseOutput,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -30,6 +48,24 @@ from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class Pix2PixInversionPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Stable Diffusion pipelines.
|
||||
|
||||
Args:
|
||||
latents (`torch.FloatTensor`)
|
||||
inverted latents tensor
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
latents: torch.FloatTensor
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
@@ -46,9 +82,7 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
|
||||
>>> model_ckpt = "CompVis/stable-diffusion-v1-4"
|
||||
>>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
... model_ckpt, conditions_input_image=False, torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16)
|
||||
>>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
|
||||
>>> pipeline.to("cuda")
|
||||
|
||||
@@ -68,10 +102,94 @@ EXAMPLE_DOC_STRING = """
|
||||
... num_inference_steps=50,
|
||||
... cross_attention_guidance_amount=0.15,
|
||||
... ).images
|
||||
|
||||
>>> images[0].save("edited_image_dog.png")
|
||||
```
|
||||
"""
|
||||
|
||||
EXAMPLE_INVERT_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from transformers import BlipForConditionalGeneration, BlipProcessor
|
||||
>>> from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline
|
||||
|
||||
>>> import requests
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> captioner_id = "Salesforce/blip-image-captioning-base"
|
||||
>>> processor = BlipProcessor.from_pretrained(captioner_id)
|
||||
>>> model = BlipForConditionalGeneration.from_pretrained(
|
||||
... captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||
... )
|
||||
|
||||
>>> sd_model_ckpt = "CompVis/stable-diffusion-v1-4"
|
||||
>>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
... sd_model_ckpt,
|
||||
... caption_generator=model,
|
||||
... caption_processor=processor,
|
||||
... torch_dtype=torch.float16,
|
||||
... safety_checker=None,
|
||||
... )
|
||||
|
||||
>>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
|
||||
>>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)
|
||||
>>> pipeline.enable_model_cpu_offload()
|
||||
|
||||
>>> img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
|
||||
|
||||
>>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512))
|
||||
>>> # generate caption
|
||||
>>> caption = pipeline.generate_caption(raw_image)
|
||||
|
||||
>>> # "a photography of a cat with flowers and dai dai daie - daie - daie kasaii"
|
||||
>>> inv_latents = pipeline.invert(caption, image=raw_image).latents
|
||||
>>> # we need to generate source and target embeds
|
||||
|
||||
>>> source_prompts = ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]
|
||||
|
||||
>>> target_prompts = ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"]
|
||||
|
||||
>>> source_embeds = pipeline.get_embeds(source_prompts)
|
||||
>>> target_embeds = pipeline.get_embeds(target_prompts)
|
||||
>>> # the latents can then be used to edit a real image
|
||||
|
||||
>>> image = pipeline(
|
||||
... caption,
|
||||
... source_embeds=source_embeds,
|
||||
... target_embeds=target_embeds,
|
||||
... num_inference_steps=50,
|
||||
... cross_attention_guidance_amount=0.15,
|
||||
... generator=generator,
|
||||
... latents=inv_latents,
|
||||
... negative_prompt=caption,
|
||||
... ).images[0]
|
||||
>>> image.save("edited_image.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
|
||||
def preprocess(image):
|
||||
if isinstance(image, torch.Tensor):
|
||||
return image
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
w, h = image[0].size
|
||||
w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8
|
||||
|
||||
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = 2.0 * image - 1.0
|
||||
image = torch.from_numpy(image)
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
return image
|
||||
|
||||
|
||||
def prepare_unet(unet: UNet2DConditionModel):
|
||||
"""Modifies the UNet (`unet`) to perform Pix2Pix Zero optimizations."""
|
||||
@@ -179,13 +297,17 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
conditions_input_image (bool):
|
||||
Whether to condition the pipeline with an input image to compute an inverted noise latent.
|
||||
requires_safety_checker (bool):
|
||||
Whether the pipeline requires a safety checker. We recommend setting it to True if you're using the
|
||||
pipeline publicly.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_optional_components = [
|
||||
"safety_checker",
|
||||
"feature_extractor",
|
||||
"caption_generator",
|
||||
"caption_processor",
|
||||
"inverse_scheduler",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -194,9 +316,11 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
conditions_input_image: bool = False,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
inverse_scheduler: DDIMInverseScheduler,
|
||||
caption_generator: BlipForConditionalGeneration,
|
||||
caption_processor: BlipProcessor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -211,16 +335,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if conditions_input_image:
|
||||
raise NotImplementedError
|
||||
# logger.info("Loading caption generator since `conditions_input_image` is True.")
|
||||
# checkpoint = "Salesforce/blip-image-captioning-base"
|
||||
# captioner_processor = AutoProcessor.from_pretrained(checkpoint)
|
||||
# captioner = BlipForConditionalGeneration.from_pretrained(checkpoint, dtype=unet.dtype)
|
||||
else:
|
||||
captioner_processor = None
|
||||
captioner = None
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
@@ -232,19 +346,15 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
_captioner_processor=captioner_processor,
|
||||
_captioner=captioner,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
caption_processor=caption_processor,
|
||||
caption_generator=caption_generator,
|
||||
inverse_scheduler=inverse_scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.conditions_input_image = conditions_input_image
|
||||
self.register_to_config(
|
||||
_captioner=captioner,
|
||||
_captioner_processor=captioner_processor,
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
@@ -268,6 +378,30 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
if self.safety_checker is not None:
|
||||
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.vae, self.text_encoder, self.unet, self.vae]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
@@ -467,7 +601,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
conditions_input_image,
|
||||
image,
|
||||
source_embeds,
|
||||
target_embeds,
|
||||
@@ -484,14 +617,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
if source_embeds is None and target_embeds is None:
|
||||
raise ValueError("`source_embeds` and `target_embeds` cannot be undefined.")
|
||||
|
||||
if prompt is None and not conditions_input_image:
|
||||
raise ValueError(f"`prompt` cannot be None when `conditions_input_image` is {conditions_input_image}")
|
||||
|
||||
elif prompt is not None and conditions_input_image:
|
||||
raise ValueError(
|
||||
f"`prompt` should not be provided when `conditions_input_image` is {conditions_input_image}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
@@ -504,12 +629,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if conditions_input_image:
|
||||
if image is None:
|
||||
raise ValueError("`image` cannot be None when `conditions_input_image` is True.")
|
||||
elif isinstance(image, (torch.FloatTensor, PIL.Image.Image)):
|
||||
raise ValueError("Invalid image provided. Supported formats: torch.FloatTensor, PIL.Image.Image.}")
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
@@ -528,15 +647,25 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def generate_caption(self, image, return_image=True):
|
||||
@torch.no_grad()
|
||||
def generate_caption(self, images):
|
||||
"""Generates caption for a given image."""
|
||||
inputs = self._captioner_processor(images=image, return_tensors="pt")
|
||||
outputs = self._captioner.generate(inputs)
|
||||
caption = self._captioner_processor.batch_deocde(outputs, skip_special_tokens=True)[0]
|
||||
if return_image:
|
||||
return caption, inputs["pixel_values"]
|
||||
else:
|
||||
return caption
|
||||
text = "a photography of"
|
||||
|
||||
prev_device = self.caption_generator.device
|
||||
|
||||
device = self._execution_device
|
||||
inputs = self.caption_processor(images, text, return_tensors="pt").to(
|
||||
device=device, dtype=self.caption_generator.dtype
|
||||
)
|
||||
self.caption_generator.to(device)
|
||||
outputs = self.caption_generator.generate(**inputs, max_new_tokens=128)
|
||||
|
||||
# offload caption generator
|
||||
self.caption_generator.to(prev_device)
|
||||
|
||||
caption = self.caption_processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
||||
return caption
|
||||
|
||||
def construct_direction(self, embs_source: torch.Tensor, embs_target: torch.Tensor):
|
||||
"""Constructs the edit direction to steer the image generation process semantically."""
|
||||
@@ -562,6 +691,66 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
|
||||
return torch.cat(embeds, dim=0).mean(0)[None]
|
||||
|
||||
def prepare_image_latents(self, image, batch_size, dtype, device, generator=None):
|
||||
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
latents = init_latents
|
||||
|
||||
return latents
|
||||
|
||||
def auto_corr_loss(self, hidden_states, generator=None):
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
if batch_size > 1:
|
||||
raise ValueError("Only batch_size 1 is supported for now")
|
||||
|
||||
hidden_states = hidden_states.squeeze(0)
|
||||
# hidden_states must be shape [C,H,W] now
|
||||
reg_loss = 0.0
|
||||
for i in range(hidden_states.shape[0]):
|
||||
noise = hidden_states[i][None, None, :, :]
|
||||
while True:
|
||||
roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
|
||||
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
|
||||
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
|
||||
|
||||
if noise.shape[2] <= 8:
|
||||
break
|
||||
noise = F.avg_pool2d(noise, kernel_size=2)
|
||||
return reg_loss
|
||||
|
||||
def kl_divergence(self, hidden_states):
|
||||
mean = hidden_states.mean()
|
||||
var = hidden_states.var()
|
||||
return var + mean**2 - 1 - torch.log(var + 1e-7)
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -595,8 +784,6 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`PIL.Image.Image`, *optional*):
|
||||
`Image`, or tensor representing an image batch which will be used for conditioning.
|
||||
source_embeds (`torch.Tensor`):
|
||||
Source concept embeddings. Generation of the embeddings as per the [original
|
||||
paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction.
|
||||
@@ -670,25 +857,12 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
self.conditions_input_image,
|
||||
image,
|
||||
source_embeds,
|
||||
target_embeds,
|
||||
callback_steps,
|
||||
prompt_embeds,
|
||||
)
|
||||
if self.conditions_input_image and prompt_embeds:
|
||||
logger.warning(
|
||||
f"You have set `conditions_input_image` to {self.conditions_input_image} and"
|
||||
" passed `prompt_embeds`. `prompt_embeds` will be ignored. "
|
||||
)
|
||||
|
||||
# 2. Generate a caption for the input image if we are conditioning the
|
||||
# pipeline based on some input image.
|
||||
if self.conditions_input_image:
|
||||
prompt, preprocessed_image = self.generate_caption(image)
|
||||
height, width = preprocessed_image.shape[-2:]
|
||||
logger.info(f"Generated prompt for the input image: {prompt}.")
|
||||
|
||||
# 3. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
@@ -723,24 +897,17 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
|
||||
# 5. Generate the inverted noise from the input image or any other image
|
||||
# generated from the input prompt.
|
||||
if self.conditions_input_image:
|
||||
# TODO (sayakpaul): Generate this using DDIM inversion.
|
||||
# We need to get the inverted noise from the input image and this requires
|
||||
# us to do a sort of `inverse_step()` in DDIM and then regularize the
|
||||
# noise to enforce the statistical properties of Gaussian.
|
||||
pass
|
||||
else:
|
||||
num_channels_latents = self.unet.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
num_channels_latents = self.unet.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
latents_init = latents.clone()
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
@@ -852,7 +1019,206 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
if output_type == "pil":
|
||||
edited_image = self.numpy_to_pil(edited_image)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (edited_image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=edited_image, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_INVERT_DOC_STRING)
|
||||
def invert(
|
||||
self,
|
||||
prompt: Optional[str] = None,
|
||||
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_guidance_amount: float = 0.1,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
lambda_auto_corr: float = 20.0,
|
||||
lambda_kl: float = 20.0,
|
||||
num_reg_steps: int = 5,
|
||||
num_auto_corr_rolls: int = 5,
|
||||
):
|
||||
r"""
|
||||
Function used to generate inverted latents given a prompt and image.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`PIL.Image.Image`, *optional*):
|
||||
`Image`, or tensor representing an image batch which will be used for conditioning.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
cross_attention_guidance_amount (`float`, defaults to 0.1):
|
||||
Amount of guidance needed from the reference cross-attention maps.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
lambda_auto_corr (`float`, *optional*, defaults to 20.0):
|
||||
Lambda parameter to control auto correction
|
||||
lambda_kl (`float`, *optional*, defaults to 20.0):
|
||||
Lambda parameter to control Kullback–Leibler divergence output
|
||||
num_reg_steps (`int`, *optional*, defaults to 5):
|
||||
Number of regularization loss steps
|
||||
num_auto_corr_rolls (`int`, *optional*, defaults to 5):
|
||||
Number of auto correction roll steps
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] or
|
||||
`tuple`:
|
||||
[`~pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.Pix2PixInversionPipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is the inverted
|
||||
latents tensor and then second is the corresponding decoded image.
|
||||
"""
|
||||
# 1. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
if cross_attention_kwargs is None:
|
||||
cross_attention_kwargs = {}
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Preprocess image
|
||||
image = preprocess(image)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator)
|
||||
|
||||
# 5. Encode input prompt
|
||||
num_images_per_prompt = 1
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.inverse_scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.inverse_scheduler.timesteps
|
||||
|
||||
# 6. Rejig the UNet so that we can obtain the cross-attenion maps and
|
||||
# use them for guiding the subsequent image generation.
|
||||
self.unet = prepare_unet(self.unet)
|
||||
|
||||
# 7. Denoising loop where we obtain the cross-attention maps.
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps - 2) as progress_bar:
|
||||
for i, t in enumerate(timesteps[1:-1]):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs={"timestep": t},
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# regularization of the noise prediction
|
||||
with torch.enable_grad():
|
||||
for _ in range(num_reg_steps):
|
||||
if lambda_auto_corr > 0:
|
||||
for _ in range(num_auto_corr_rolls):
|
||||
var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True)
|
||||
l_ac = self.auto_corr_loss(var, generator=generator)
|
||||
l_ac.backward()
|
||||
|
||||
grad = var.grad.detach() / num_auto_corr_rolls
|
||||
noise_pred = noise_pred - lambda_auto_corr * grad
|
||||
|
||||
if lambda_kl > 0:
|
||||
var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True)
|
||||
|
||||
l_kld = self.kl_divergence(var)
|
||||
l_kld.backward()
|
||||
|
||||
grad = var.grad.detach()
|
||||
noise_pred = noise_pred - lambda_kl * grad
|
||||
|
||||
noise_pred = noise_pred.detach()
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or (
|
||||
(i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
inverted_latents = latents.detach().clone()
|
||||
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents.detach())
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
# 9. Convert to PIL.
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (inverted_latents, image)
|
||||
|
||||
return Pix2PixInversionPipelineOutput(latents=inverted_latents, images=image)
|
||||
|
||||
@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .scheduling_ddim import DDIMScheduler
|
||||
from .scheduling_ddim_inverse import DDIMInverseScheduler
|
||||
from .scheduling_ddpm import DDPMScheduler
|
||||
from .scheduling_deis_multistep import DEISMultistepScheduler
|
||||
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
|
||||
|
||||
227
src/diffusers/schedulers/scheduling_ddim_inverse.py
Normal file
227
src/diffusers/schedulers/scheduling_ddim_inverse.py
Normal file
@@ -0,0 +1,227 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
|
||||
# and https://github.com/hojonathanho/diffusion
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
|
||||
class DDIMSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's step function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
|
||||
`pred_original_sample` can be used to preview progress or for guidance.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
pred_original_sample: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
|
||||
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
DDIMInverseScheduler is the reverse scheduler of [`DDIMScheduler`].
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2010.02502
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
beta_start (`float`): the starting `beta` value of inference.
|
||||
beta_end (`float`): the final `beta` value.
|
||||
beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
||||
set_alpha_to_one (`bool`, default `True`):
|
||||
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
|
||||
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||||
otherwise it uses the value of alpha at step 0.
|
||||
steps_offset (`int`, default `0`):
|
||||
an offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
||||
stable diffusion.
|
||||
prediction_type (`str`, default `epsilon`, optional):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = (
|
||||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# At every step in ddim, we are looking into the previous alphas_cumprod
|
||||
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
||||
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
||||
# whether we use the final alpha of the "non-previous" one.
|
||||
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64))
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): input sample
|
||||
timestep (`int`, optional): current timestep
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
"""
|
||||
|
||||
if num_inference_steps > self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
||||
f" maximal {self.config.num_train_timesteps} timesteps."
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().copy().astype(np.int64)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
self.timesteps += self.config.steps_offset
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
variance_noise: Optional[torch.FloatTensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DDIMSchedulerOutput, Tuple]:
|
||||
e_t = model_output
|
||||
|
||||
x = sample
|
||||
prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
|
||||
|
||||
a_t = self.alphas_cumprod[timestep - 1]
|
||||
a_prev = self.alphas_cumprod[prev_timestep - 1] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
|
||||
pred_x0 = (x - (1 - a_t) ** 0.5 * e_t) / a_t.sqrt()
|
||||
|
||||
dir_xt = (1.0 - a_prev).sqrt() * e_t
|
||||
|
||||
prev_sample = a_prev.sqrt() * pred_x0 + dir_xt
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample, pred_x0)
|
||||
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -345,6 +345,21 @@ class ScoreSdeVePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class DDIMInverseScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class DDIMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -19,10 +19,12 @@ import unittest
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMInverseScheduler,
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
@@ -30,7 +32,7 @@ from diffusers import (
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import slow, torch_device
|
||||
from diffusers.utils import load_numpy, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
@@ -94,6 +96,9 @@ class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
"inverse_scheduler": None,
|
||||
"caption_generator": None,
|
||||
"caption_processor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
@@ -344,3 +349,83 @@ class StableDiffusionPix2PixZeroPipelineSlowTests(unittest.TestCase):
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 8.2 GB is allocated
|
||||
assert mem_bytes < 8.2 * 10**9
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class InversionPipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_stable_diffusion_pix2pix_inversion(self):
|
||||
img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512))
|
||||
|
||||
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.inverse_scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
caption = "a photography of a cat with flowers"
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
output = pipe.invert(caption, image=raw_image, generator=generator, num_inference_steps=10)
|
||||
inv_latents = output[0]
|
||||
|
||||
image_slice = inv_latents[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert inv_latents.shape == (1, 4, 64, 64)
|
||||
expected_slice = np.array([0.8877, 0.0587, 0.7700, -1.6035, -0.5962, 0.4827, -0.6265, 1.0498, -0.8599])
|
||||
|
||||
assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_pix2pix_full(self):
|
||||
img_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/test_images/cats/cat_6.png"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB").resize((512, 512))
|
||||
|
||||
# numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog.png
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog.npy"
|
||||
)
|
||||
|
||||
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.inverse_scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
caption = "a photography of a cat with flowers"
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
output = pipe.invert(caption, image=raw_image, generator=generator)
|
||||
inv_latents = output[0]
|
||||
|
||||
source_prompts = 4 * ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]
|
||||
target_prompts = 4 * ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"]
|
||||
|
||||
source_embeds = pipe.get_embeds(source_prompts)
|
||||
target_embeds = pipe.get_embeds(target_prompts)
|
||||
|
||||
image = pipe(
|
||||
caption,
|
||||
source_embeds=source_embeds,
|
||||
target_embeds=target_embeds,
|
||||
num_inference_steps=50,
|
||||
cross_attention_guidance_amount=0.15,
|
||||
generator=generator,
|
||||
latents=inv_latents,
|
||||
negative_prompt=caption,
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
max_diff = np.abs(expected_image - image).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
Reference in New Issue
Block a user