1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[Pipelines] Adds pix2pix zero (#2334)

* add: support for BLIP generation.

* add: support for editing synthetic images.

* remove unnecessary comments.

* add inits and run make fix-copies.

* version change of diffusers.

* fix: condition for loading the captioner.

* default conditions_input_image to False.

* guidance_amount -> cross_attention_guidance_amount

* fix inputs to check_inputs()

* fix: attribute.

* fix: prepare_attention_mask() call.

* debugging.

* better placement of references.

* remove torch.no_grad() decorations.

* put torch.no_grad() context before the first denoising loop.

* detach() latents before decoding them.

* put deocding in a torch.no_grad() context.

* add reconstructed image for debugging.

* no_grad(0

* apply formatting.

* address one-off suggestions from the draft PR.

* back to torch.no_grad() and add more elaborate comments.

* refactor prepare_unet() per Patrick's suggestions.

* more elaborate description for .

* formatting.

* add docstrings to the methods specific to pix2pix zero.

* suspecting a redundant noise prediction.

* needed for gradient computation chain.

* less hacks.

* fix: attention mask handling within the processor.

* remove attention reference map computation.

* fix: cross attn args.

* fix: prcoessor.

* store attention maps.

* fix: attention processor.

* update docs and better treatment to xa args.

* update the final noise computation call.

* change xa args call.

* remove xa args option from the pipeline.

* add: docs.

* first test.

* fix: url call.

* fix: argument call.

* remove image conditioning for now.

* 🚨 add: fast tests.

* explicit placement of the xa attn weights.

* add: slow tests 🐢

* fix: tests.

* edited direction embedding should be on the same device as prompt_embeds.

* debugging message.

* debugging.

* add pix2pix zero pipeline for a non-deterministic test.

* debugging/

* remove debugging message.

* make caption generation _

* address comments (part I).

* address PR comments (part II)

* fix: DDPM test assertion.

* refactor doc.

* address PR comments (part III).

* fix: type annotation for the scheduler.

* apply styling.

* skip_mps and add note on embeddings in the docs.
This commit is contained in:
Sayak Paul
2023-02-16 15:50:38 +05:30
committed by GitHub
parent e5810e686e
commit fd3d5502d4
11 changed files with 1320 additions and 4 deletions

View File

@@ -151,6 +151,8 @@
title: Stable-Diffusion-Latent-Upscaler
- local: api/pipelines/stable_diffusion/pix2pix
title: InstructPix2Pix
- local: api/pipelines/stable_diffusion/pix2pix_zero
title: Pix2Pix Zero
title: Stable Diffusion
- local: api/pipelines/stable_diffusion_2
title: Stable Diffusion 2

View File

@@ -33,6 +33,7 @@ For more details about how Stable Diffusion works and how it differs from the ba
| [StableDiffusionUpscalePipeline](./upscale) | **Experimental** *Text-Guided Image Super-Resolution * | | Coming soon
| [StableDiffusionLatentUpscalePipeline](./latent_upscale) | **Experimental** *Text-Guided Image Super-Resolution * | | Coming soon
| [StableDiffusionInstructPix2PixPipeline](./pix2pix) | **Experimental** *Text-Based Image Editing * | | [InstructPix2Pix: Learning to Follow Image Editing Instructions](https://huggingface.co/spaces/timbrooks/instruct-pix2pix)
| [StableDiffusionPix2PixZeroPipeline](./pix2pix_zero) | **Experimental** *Text-Based Image Editing * | | [Zero-shot Image-to-Image Translation](https://arxiv.org/abs/2302.03027)

View File

@@ -0,0 +1,103 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Zero-shot Image-to-Image Translation
## Overview
[Zero-shot Image-to-Image Translation](https://arxiv.org/abs/2302.03027) by Gaurav Parmar, Krishna Kumar Singh, Richard Zhang, Yijun Li, Jingwan Lu, and Jun-Yan Zhu.
The abstract of the paper is the following:
*Large-scale text-to-image generative models have shown their remarkable ability to synthesize diverse and high-quality images. However, it is still challenging to directly apply these models for editing real images for two reasons. First, it is hard for users to come up with a perfect text prompt that accurately describes every visual detail in the input image. Second, while existing models can introduce desirable changes in certain regions, they often dramatically alter the input content and introduce unexpected changes in unwanted regions. In this work, we propose pix2pix-zero, an image-to-image translation method that can preserve the content of the original image without manual prompting. We first automatically discover editing directions that reflect desired edits in the text embedding space. To preserve the general content structure after editing, we further propose cross-attention guidance, which aims to retain the cross-attention maps of the input image throughout the diffusion process. In addition, our method does not need additional training for these edits and can directly use the existing pre-trained text-to-image diffusion model. We conduct extensive experiments and show that our method outperforms existing and concurrent works for both real and synthetic image editing.*
Resources:
* [Project Page](https://pix2pixzero.github.io/).
* [Paper](https://arxiv.org/abs/2302.03027).
* [Original Code](https://github.com/pix2pixzero/pix2pix-zero).
## Tips
* The pipeline exposes two arguments namely `source_embeds` and `target_embeds`
that let you control the direction of the semantic edits in the final image to be generated. Let's say,
you wanted to translate from "cat" to "dog". In this case, the edit direction will be "cat -> dog". To reflect
this in the pipeline, you simply have to set the embeddings related to the phrases including "cat" to
`source_embeds` and "dog" to `target_embeds`. Refer to the code example below for more details.
* When you're using this pipeline from a prompt, specify the _source_ concept in the prompt. Taking
the above example, a valid input prompt would be: "a high resolution painting of a **cat** in the style of van gough".
* If you wanted to reverse the direction in the example above, i.e., "dog -> cat", then it's recommended to:
* Swap the `source_embeds` and `target_embeds`.
* Change the input prompt to include "dog".
* To learn more about how the source and target embeddings are generated, refer to the [original
paper](https://arxiv.org/abs/2302.03027).
## Available Pipelines:
| Pipeline | Tasks | Demo
|---|---|:---:|
| [StableDiffusionPix2PixZeroPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py) | *Text-Based Image Editing* | [🤗 Space] (soon) |
<!-- TODO: add Colab -->
## Usage example
**Based on an image generated with the input prompt**
```python
import requests
import torch
from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline
def download(embedding_url, local_filepath):
r = requests.get(embedding_url)
with open(local_filepath, "wb") as f:
f.write(r.content)
model_ckpt = "CompVis/stable-diffusion-v1-4"
pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(
model_ckpt, conditions_input_image=False, torch_dtype=torch.float16
)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
pipeline.to("cuda")
prompt = "a high resolution painting of a cat in the style of van gough"
src_embs_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/cat.pt"
target_embs_url = "https://github.com/pix2pixzero/pix2pix-zero/raw/main/assets/embeddings_sd_1.4/dog.pt"
for url in [src_embs_url, target_embs_url]:
download(url, url.split("/")[-1])
src_embeds = torch.load(src_embs_url.split("/")[-1])
target_embeds = torch.load(target_embs_url.split("/")[-1])
images = pipeline(
prompt,
source_embeds=src_embeds,
target_embeds=target_embeds,
num_inference_steps=50,
cross_attention_guidance_amount=0.15,
).images
images[0].save("edited_image_dog.png")
```
**Based on an input image**
_Coming soon_
## StableDiffusionPix2PixZeroPipeline
[[autodoc]] StableDiffusionPix2PixZeroPipeline
- __call__
- all

View File

@@ -118,6 +118,7 @@ else:
StableDiffusionLatentUpscalePipeline,
StableDiffusionPipeline,
StableDiffusionPipelineSafe,
StableDiffusionPix2PixZeroPipeline,
StableDiffusionUpscalePipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,

View File

@@ -54,6 +54,7 @@ else:
StableDiffusionInstructPix2PixPipeline,
StableDiffusionLatentUpscalePipeline,
StableDiffusionPipeline,
StableDiffusionPix2PixZeroPipeline,
StableDiffusionUpscalePipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,

View File

@@ -66,6 +66,7 @@ except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionDepth2ImgPipeline
else:
from .pipeline_stable_diffusion_depth2img import StableDiffusionDepth2ImgPipeline
from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline
try:

View File

@@ -0,0 +1,836 @@
# Copyright 2023 Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import PIL
import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.cross_attention import CrossAttention
from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler
from ...utils import is_accelerate_available, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import requests
>>> import torch
>>> from diffusers import DDIMScheduler, StableDiffusionPix2PixZeroPipeline
>>> def download(embedding_url, local_filepath):
... r = requests.get(embedding_url)
... with open(local_filepath, "wb") as f:
... f.write(r.content)
>>> model_ckpt = "CompVis/stable-diffusion-v1-4"
>>> pipeline = StableDiffusionPix2PixZeroPipeline.from_pretrained(
... model_ckpt, conditions_input_image=False, torch_dtype=torch.float16
... )
>>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
>>> pipeline.to("cuda")
>>> prompt = "a high resolution painting of a cat in the style of van gough"
>>> source_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt"
>>> target_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt"
>>> for url in [source_emb_url, target_emb_url]:
... download(url, url.split("/")[-1])
>>> src_embeds = torch.load(source_emb_url.split("/")[-1])
>>> target_embeds = torch.load(target_emb_url.split("/")[-1])
>>> images = pipeline(
... prompt,
... source_embeds=src_embeds,
... target_embeds=target_embeds,
... num_inference_steps=50,
... cross_attention_guidance_amount=0.15,
... ).images
>>> images[0].save("edited_image_dog.png")
```
"""
def prepare_unet(unet: UNet2DConditionModel):
"""Modifies the UNet (`unet`) to perform Pix2Pix Zero optimizations."""
pix2pix_zero_attn_procs = {}
for name in unet.attn_processors.keys():
module_name = name.replace(".processor", "")
module = unet.get_submodule(module_name)
if "attn2" in name:
pix2pix_zero_attn_procs[name] = Pix2PixZeroCrossAttnProcessor(is_pix2pix_zero=True)
module.requires_grad_(True)
else:
pix2pix_zero_attn_procs[name] = Pix2PixZeroCrossAttnProcessor(is_pix2pix_zero=False)
module.requires_grad_(False)
unet.set_attn_processor(pix2pix_zero_attn_procs)
return unet
class Pix2PixZeroL2Loss:
def __init__(self):
self.loss = 0.0
def compute_loss(self, predictions, targets):
self.loss += ((predictions - targets) ** 2).sum((1, 2)).mean(0)
class Pix2PixZeroCrossAttnProcessor:
"""An attention processor class to store the attention weights.
In Pix2Pix Zero, it happens during computations in the cross-attention blocks."""
def __init__(self, is_pix2pix_zero=False):
self.is_pix2pix_zero = is_pix2pix_zero
if self.is_pix2pix_zero:
self.reference_cross_attn_map = {}
def __call__(
self,
attn: CrossAttention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
timestep=None,
loss=None,
):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
if self.is_pix2pix_zero and timestep is not None:
# new bookkeeping to save the attention weights.
if loss is None:
self.reference_cross_attn_map[timestep.item()] = attention_probs.detach().cpu()
# compute loss
elif loss is not None:
prev_attn_probs = self.reference_cross_attn_map.pop(timestep.item())
loss.compute_loss(attention_probs, prev_attn_probs.to(attention_probs.device))
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
r"""
Pipeline for pixel-levl image editing using Pix2Pix Zero. Based on Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], or [`DDPMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
conditions_input_image (bool):
Whether to condition the pipeline with an input image to compute an inverted noise latent.
requires_safety_checker (bool):
Whether the pipeline requires a safety checker. We recommend setting it to True if you're using the
pipeline publicly.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDPMScheduler, DDIMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
conditions_input_image: bool = False,
requires_safety_checker: bool = True,
):
super().__init__()
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if conditions_input_image:
raise NotImplementedError
# logger.info("Loading caption generator since `conditions_input_image` is True.")
# checkpoint = "Salesforce/blip-image-captioning-base"
# captioner_processor = AutoProcessor.from_pretrained(checkpoint)
# captioner = BlipForConditionalGeneration.from_pretrained(checkpoint, dtype=unet.dtype)
else:
captioner_processor = None
captioner = None
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
_captioner_processor=captioner_processor,
_captioner=captioner,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.conditions_input_image = conditions_input_image
self.register_to_config(
_captioner=captioner,
_captioner_processor=captioner_processor,
requires_safety_checker=requires_safety_checker,
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
"""
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None:
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is not None:
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
else:
has_nsfw_concept = None
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
conditions_input_image,
image,
source_embeds,
target_embeds,
callback_steps,
prompt_embeds=None,
):
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if source_embeds is None and target_embeds is None:
raise ValueError("`source_embeds` and `target_embeds` cannot be undefined.")
if prompt is None and not conditions_input_image:
raise ValueError(f"`prompt` cannot be None when `conditions_input_image` is {conditions_input_image}")
elif prompt is not None and conditions_input_image:
raise ValueError(
f"`prompt` should not be provided when `conditions_input_image` is {conditions_input_image}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if conditions_input_image:
if image is None:
raise ValueError("`image` cannot be None when `conditions_input_image` is True.")
elif isinstance(image, (torch.FloatTensor, PIL.Image.Image)):
raise ValueError("Invalid image provided. Supported formats: torch.FloatTensor, PIL.Image.Image.}")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def generate_caption(self, image, return_image=True):
"""Generates caption for a given image."""
inputs = self._captioner_processor(images=image, return_tensors="pt")
outputs = self._captioner.generate(inputs)
caption = self._captioner_processor.batch_deocde(outputs, skip_special_tokens=True)[0]
if return_image:
return caption, inputs["pixel_values"]
else:
return caption
def construct_direction(self, embs_source: torch.Tensor, embs_target: torch.Tensor):
"""Constructs the edit direction to steer the image generation process semantically."""
return (embs_target.mean(0) - embs_source.mean(0)).unsqueeze(0)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
source_embeds: torch.Tensor = None,
target_embeds: torch.Tensor = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
cross_attention_guidance_amount: float = 0.1,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`PIL.Image.Image`, *optional*):
`Image`, or tensor representing an image batch which will be used for conditioning.
source_embeds (`torch.Tensor`):
Source concept embeddings. Generation of the embeddings as per the [original
paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction.
target_embeds (`torch.Tensor`):
Target concept embeddings. Generation of the embeddings as per the [original
paper](https://arxiv.org/abs/2302.03027). Used in discovering the edit direction.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
cross_attention_guidance_amount (`float`, defaults to 0.1):
Amount of guidance needed from the reference cross-attention maps.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Define the spatial resolutions.
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
self.conditions_input_image,
image,
source_embeds,
target_embeds,
callback_steps,
prompt_embeds,
)
if self.conditions_input_image and prompt_embeds:
logger.warning(
f"You have set `conditions_input_image` to {self.conditions_input_image} and"
" passed `prompt_embeds`. `prompt_embeds` will be ignored. "
)
# 2. Generate a caption for the input image if we are conditioning the
# pipeline based on some input image.
if self.conditions_input_image:
prompt, preprocessed_image = self.generate_caption(image)
height, width = preprocessed_image.shape[-2:]
logger.info(f"Generated prompt for the input image: {prompt}.")
# 3. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if cross_attention_kwargs is None:
cross_attention_kwargs = {}
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Generate the inverted noise from the input image or any other image
# generated from the input prompt.
if self.conditions_input_image:
# TODO (sayakpaul): Generate this using DDIM inversion.
# We need to get the inverted noise from the input image and this requires
# us to do a sort of `inverse_step()` in DDIM and then regularize the
# noise to enforce the statistical properties of Gaussian.
pass
else:
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
latents_init = latents.clone()
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Rejig the UNet so that we can obtain the cross-attenion maps and
# use them for guiding the subsequent image generation.
self.unet = prepare_unet(self.unet)
# 7. Denoising loop where we obtain the cross-attention maps.
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs={"timestep": t},
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 8. Compute the edit directions.
edit_direction = self.construct_direction(source_embeds, target_embeds).to(prompt_embeds.device)
# 9. Edit the prompt embeddings as per the edit directions discovered.
prompt_embeds_edit = prompt_embeds.clone()
prompt_embeds_edit[1:2] += edit_direction
# 10. Second denoising loop to generate the edited image.
latents = latents_init
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# we want to learn the latent such that it steers the generation
# process towards the edited direction, so make the make initial
# noise learnable
x_in = latent_model_input.detach().clone()
x_in.requires_grad = True
# optimizer
opt = torch.optim.SGD([x_in], lr=cross_attention_guidance_amount)
with torch.enable_grad():
# initialize loss
loss = Pix2PixZeroL2Loss()
# predict the noise residual
noise_pred = self.unet(
x_in,
t,
encoder_hidden_states=prompt_embeds_edit.detach(),
cross_attention_kwargs={"timestep": t, "loss": loss},
).sample
loss.loss.backward(retain_graph=False)
opt.step()
# recompute the noise
noise_pred = self.unet(
x_in.detach(),
t,
encoder_hidden_states=prompt_embeds_edit,
cross_attention_kwargs={"timestep": None},
).sample
latents = x_in.detach().chunk(2)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# 11. Post-process the latents.
edited_image = self.decode_latents(latents)
# 12. Run the safety checker.
edited_image, has_nsfw_concept = self.run_safety_checker(edited_image, device, prompt_embeds.dtype)
# 13. Convert to PIL.
if output_type == "pil":
edited_image = self.numpy_to_pil(edited_image)
if not return_dict:
return (edited_image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=edited_image, nsfw_content_detected=has_nsfw_concept)

View File

@@ -212,6 +212,21 @@ class StableDiffusionPipelineSafe(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionPix2PixZeroPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionUpscalePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -408,9 +408,9 @@ def requires_backends(obj, backends):
" --upgrade transformers \n```"
)
if name in [
"StableDiffusionDepth2ImgPipeline",
] and is_transformers_version("<", "4.26.0"):
if name in ["StableDiffusionDepth2ImgPipeline", "StableDiffusionPix2PixZeroPipeline"] and is_transformers_version(
"<", "4.26.0"
):
raise ImportError(
f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install"
" --upgrade transformers \n```"

View File

@@ -0,0 +1,350 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import requests
import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DDPMScheduler,
EulerAncestralDiscreteScheduler,
LMSDiscreteScheduler,
StableDiffusionPix2PixZeroPipeline,
UNet2DConditionModel,
)
from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
def download_from_url(embedding_url, local_filepath):
r = requests.get(embedding_url)
with open(local_filepath, "wb") as f:
f.write(r.content)
@skip_mps
class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionPix2PixZeroPipeline
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=32,
)
scheduler = DDIMScheduler()
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
}
return components
def get_dummy_inputs(self, device, seed=0):
src_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/src_emb_0.pt"
tgt_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/tgt_emb_0.pt"
for url in [src_emb_url, tgt_emb_url]:
download_from_url(url, url.split("/")[-1])
src_embeds = torch.load(src_emb_url.split("/")[-1])
target_embeds = torch.load(tgt_emb_url.split("/")[-1])
generator = torch.manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"cross_attention_guidance_amount": 0.15,
"source_embeds": src_embeds,
"target_embeds": target_embeds,
"output_type": "numpy",
}
return inputs
def test_stable_diffusion_pix2pix_zero_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5184, 0.503, 0.4917, 0.4022, 0.3455, 0.464, 0.5324, 0.5323, 0.4894])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_pix2pix_zero_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
negative_prompt = "french fries"
output = sd_pipe(**inputs, negative_prompt=negative_prompt)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5464, 0.5072, 0.5012, 0.4124, 0.3624, 0.466, 0.5413, 0.5468, 0.4927])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_pix2pix_zero_euler(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
components["scheduler"] = EulerAncestralDiscreteScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5114, 0.5051, 0.5222, 0.5279, 0.5037, 0.5156, 0.4604, 0.4966, 0.504])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_pix2pix_zero_ddpm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
components["scheduler"] = DDPMScheduler()
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5185, 0.5027, 0.492, 0.401, 0.3445, 0.464, 0.5321, 0.5327, 0.4892])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_stable_diffusion_pix2pix_zero_num_images_per_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
# test num_images_per_prompt=1 (default)
inputs = self.get_dummy_inputs(device)
images = sd_pipe(**inputs).images
assert images.shape == (1, 64, 64, 3)
# test num_images_per_prompt=2 for a single prompt
num_images_per_prompt = 2
inputs = self.get_dummy_inputs(device)
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
assert images.shape == (num_images_per_prompt, 64, 64, 3)
# test num_images_per_prompt for batch of prompts
batch_size = 2
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * batch_size
images = sd_pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images
assert images.shape == (batch_size * num_images_per_prompt, 64, 64, 3)
@slow
@require_torch_gpu
class StableDiffusionPix2PixZeroPipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, seed=0):
generator = torch.manual_seed(seed)
src_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/cat.pt"
tgt_emb_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/dog.pt"
for url in [src_emb_url, tgt_emb_url]:
download_from_url(url, url.split("/")[-1])
src_embeds = torch.load(src_emb_url.split("/1")[-1])
target_embeds = torch.load(tgt_emb_url.split("/1")[-1])
inputs = {
"prompt": "turn him into a cyborg",
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 7.5,
"cross_attention_guidance_amount": 0.15,
"source_embeds": src_embeds,
"target_embeds": target_embeds,
"output_type": "numpy",
}
return inputs
def test_stable_diffusion_pix2pix_zero_default(self):
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
inputs = self.get_inputs()
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.4705, 0.4771, 0.4832, 0.4783, 0.4495, 0.447, 0.4658, 0.4568, 0.438])
assert np.abs(expected_slice - image_slice).max() < 1e-3
def test_stable_diffusion_pix2pix_zero_k_lms(self):
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
)
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
inputs = self.get_inputs()
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.6514, 0.5571, 0.5244, 0.5591, 0.4998, 0.4834, 0.502, 0.468, 0.4663])
assert np.abs(expected_slice - image_slice).max() < 1e-3
def test_stable_diffusion_pix2pix_zero_intermediate_state(self):
number_of_steps = 0
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
callback_fn.has_been_called = True
nonlocal number_of_steps
number_of_steps += 1
if step == 1:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array(
[-0.5176, 0.0669, -0.1963, -0.1653, -0.7856, -0.2871, -0.5562, -0.0096, -0.012]
)
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
elif step == 2:
latents = latents.detach().cpu().numpy()
assert latents.shape == (1, 4, 64, 64)
latents_slice = latents[0, -3:, -3:, -1]
expected_slice = np.array(
[-0.5127, 0.0613, -0.1937, -0.1622, -0.7856, -0.2849, -0.5601, -0.0111, -0.0137]
)
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
callback_fn.has_been_called = False
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
inputs = self.get_inputs()
pipe(**inputs, callback=callback_fn, callback_steps=1)
assert callback_fn.has_been_called
assert number_of_steps == 3
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing(1)
pipe.enable_sequential_cpu_offload()
inputs = self.get_inputs()
_ = pipe(**inputs)
mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 8.2 GB is allocated
assert mem_bytes < 8.2 * 10**9

View File

@@ -191,10 +191,16 @@ class PipelineTesterMixin:
def _test_inference_batch_single_identical(
self, test_max_difference=None, test_mean_pixel_difference=None, relax_max_difference=False
):
if self.pipeline_class.__name__ in ["CycleDiffusionPipeline", "RePaintPipeline"]:
if self.pipeline_class.__name__ in [
"CycleDiffusionPipeline",
"RePaintPipeline",
"StableDiffusionPix2PixZeroPipeline",
]:
# RePaint can hardly be made deterministic since the scheduler is currently always
# nondeterministic
# CycleDiffusion is also slightly nondeterministic
# There's a training loop inside Pix2PixZero and is guided by edit directions. This is
# why the slight non-determinism.
return
if test_max_difference is None: