mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Pipeline] Add LEDITS++ pipelines (#6074)
* Setup LEdits++ file structure * Fix import * LEditsPP Stable Diffusion pipeline * Include variable image aspect ratios * Implement LEDITS++ for SDXL * clean up LEditsPPPipelineStableDiffusion * Adjust inversion output * Added docu, more cleanup for LEditsPPPipelineStableDiffusion * clean up LEditsPPPipelineStableDiffusionXL * Update documentation * Fix documentation import * Add skeleton IF implementation * Fix documentation typo * Add LEDTIS docu to toctree * Add missing title * Finalize SD documentation * Finalize SD-XL documentation * Fix code style and quality * Fix typo * Fix return types * added LEditsPPPipelineIF; minor changes for LEditsPPPipelineStableDiffusion and LEditsPPPipelineStableDiffusionXL * Fix copy reference * add documentation for IF * Add first tests * Fix batching for SD-XL * Fix text encoding and perfect reconstruction for SD-XL * Add tests for SD-XL, minor changes * move user_mask to correct device, use cross_attention_kwargs also for inversion * Example docstring * Fix attention resolution for non-square images * Refactoring for PR review * Safely remove ledits_utils.py * Style fixes * Replace assertions with ValueError * Remove LEditsPPPipelineIF * Remove unecessary input checks * Refactoring of CrossAttnProcessor * Revert unecessary changes to scheduler * Remove first progress-bar in inversion * Refactor scheduler usage and reset * Use imageprocessor instead of custom logic * Fix scheduler init warning * Fix error when running the pipeline in fp16 * Update documentation wrt perfect inversion * Update tests * Fix code quality and copy consistency * Update LEditsPP import * Remove enable/disable methods that are now in StableDiffusionMixin * Change import in docs * Revert import structure change * Fix ledits imports --------- Co-authored-by: Katharina Kornmeier <katharina.kornmeier@stud.tu-darmstadt.de>
This commit is contained in:
@@ -304,6 +304,8 @@
|
||||
title: Latent Consistency Models
|
||||
- local: api/pipelines/latent_diffusion
|
||||
title: Latent Diffusion
|
||||
- local: api/pipelines/ledits_pp
|
||||
title: LEDITS++
|
||||
- local: api/pipelines/panorama
|
||||
title: MultiDiffusion
|
||||
- local: api/pipelines/musicldm
|
||||
|
||||
54
docs/source/en/api/pipelines/ledits_pp.md
Normal file
54
docs/source/en/api/pipelines/ledits_pp.md
Normal file
@@ -0,0 +1,54 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# LEDITS++
|
||||
|
||||
LEDITS++ was proposed in [LEDITS++: Limitless Image Editing using Text-to-Image Models](https://huggingface.co/papers/2311.16711) by Manuel Brack, Felix Friedrich, Katharina Kornmeier, Linoy Tsaban, Patrick Schramowski, Kristian Kersting, Apolinário Passos.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Text-to-image diffusion models have recently received increasing interest for their astonishing ability to produce high-fidelity images from solely text inputs. Subsequent research efforts aim to exploit and apply their capabilities to real image editing. However, existing image-to-image methods are often inefficient, imprecise, and of limited versatility. They either require time-consuming fine-tuning, deviate unnecessarily strongly from the input image, and/or lack support for multiple, simultaneous edits. To address these issues, we introduce LEDITS++, an efficient yet versatile and precise textual image manipulation technique. LEDITS++'s novel inversion approach requires no tuning nor optimization and produces high-fidelity results with a few diffusion steps. Second, our methodology supports multiple simultaneous edits and is architecture-agnostic. Third, we use a novel implicit masking technique that limits changes to relevant image regions. We propose the novel TEdBench++ benchmark as part of our exhaustive evaluation. Our results demonstrate the capabilities of LEDITS++ and its improvements over previous methods. The project page is available at https://leditsplusplus-project.static.hf.space .*
|
||||
|
||||
<Tip>
|
||||
|
||||
You can find additional information about LEDITS++ on the [project page](https://leditsplusplus-project.static.hf.space/index.html) and try it out in a [demo](https://huggingface.co/spaces/editing-images/leditsplusplus).
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip warning={true}>
|
||||
Due to some backward compatability issues with the current diffusers implementation of [`~schedulers.DPMSolverMultistepScheduler`] this implementation of LEdits++ can no longer guarantee perfect inversion.
|
||||
This issue is unlikely to have any noticeable effects on applied use-cases. However, we provide an alternative implementation that guarantees perfect inversion in a dedicated [GitHub repo](https://github.com/ml-research/ledits_pp).
|
||||
</Tip>
|
||||
|
||||
We provide two distinct pipelines based on different pre-trained models.
|
||||
|
||||
## LEditsPPPipelineStableDiffusion
|
||||
[[autodoc]] pipelines.ledits_pp.LEditsPPPipelineStableDiffusion
|
||||
- all
|
||||
- __call__
|
||||
- invert
|
||||
|
||||
## LEditsPPPipelineStableDiffusionXL
|
||||
[[autodoc]] pipelines.ledits_pp.LEditsPPPipelineStableDiffusionXL
|
||||
- all
|
||||
- __call__
|
||||
- invert
|
||||
|
||||
|
||||
|
||||
## LEditsPPDiffusionPipelineOutput
|
||||
[[autodoc]] pipelines.ledits_pp.pipeline_output.LEditsPPDiffusionPipelineOutput
|
||||
- all
|
||||
|
||||
## LEditsPPInversionPipelineOutput
|
||||
[[autodoc]] pipelines.ledits_pp.pipeline_output.LEditsPPInversionPipelineOutput
|
||||
- all
|
||||
@@ -57,6 +57,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
|
||||
| [Latent Consistency Models](latent_consistency_models) | text2image |
|
||||
| [Latent Diffusion](latent_diffusion) | text2image, super-resolution |
|
||||
| [LDM3D](stable_diffusion/ldm3d_diffusion) | text2image, text-to-3D, text-to-pano, upscaling |
|
||||
| [LEDITS++](ledits_pp) | image editing |
|
||||
| [MultiDiffusion](panorama) | text2image |
|
||||
| [MusicLDM](musicldm) | text2audio |
|
||||
| [Paint by Example](paint_by_example) | inpainting |
|
||||
|
||||
@@ -30,6 +30,6 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionSafePipelineOutput
|
||||
## SemanticStableDiffusionPipelineOutput
|
||||
[[autodoc]] pipelines.semantic_stable_diffusion.pipeline_output.SemanticStableDiffusionPipelineOutput
|
||||
- all
|
||||
|
||||
@@ -253,6 +253,8 @@ else:
|
||||
"LatentConsistencyModelImg2ImgPipeline",
|
||||
"LatentConsistencyModelPipeline",
|
||||
"LDMTextToImagePipeline",
|
||||
"LEditsPPPipelineStableDiffusion",
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"MusicLDMPipeline",
|
||||
"PaintByExamplePipeline",
|
||||
"PIAPipeline",
|
||||
@@ -623,6 +625,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LatentConsistencyModelImg2ImgPipeline,
|
||||
LatentConsistencyModelPipeline,
|
||||
LDMTextToImagePipeline,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
MusicLDMPipeline,
|
||||
PaintByExamplePipeline,
|
||||
PIAPipeline,
|
||||
|
||||
@@ -23,6 +23,7 @@ _import_structure = {
|
||||
"controlnet_xs": [],
|
||||
"deprecated": [],
|
||||
"latent_diffusion": [],
|
||||
"ledits_pp": [],
|
||||
"stable_diffusion": [],
|
||||
"stable_diffusion_xl": [],
|
||||
}
|
||||
@@ -171,6 +172,12 @@ else:
|
||||
"LatentConsistencyModelPipeline",
|
||||
]
|
||||
_import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"])
|
||||
_import_structure["ledits_pp"].extend(
|
||||
[
|
||||
"LEditsPPPipelineStableDiffusion",
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
]
|
||||
)
|
||||
_import_structure["musicldm"] = ["MusicLDMPipeline"]
|
||||
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
|
||||
_import_structure["pia"] = ["PIAPipeline"]
|
||||
@@ -424,6 +431,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LatentConsistencyModelPipeline,
|
||||
)
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .ledits_pp import (
|
||||
LEditsPPDiffusionPipelineOutput,
|
||||
LEditsPPInversionPipelineOutput,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .musicldm import MusicLDMPipeline
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .pia import PIAPipeline
|
||||
|
||||
55
src/diffusers/pipelines/ledits_pp/__init__.py
Normal file
55
src/diffusers/pipelines/ledits_pp/__init__.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_leditspp_stable_diffusion"] = ["LEditsPPPipelineStableDiffusion"]
|
||||
_import_structure["pipeline_leditspp_stable_diffusion_xl"] = ["LEditsPPPipelineStableDiffusionXL"]
|
||||
|
||||
_import_structure["pipeline_output"] = ["LEditsPPDiffusionPipelineOutput", "LEditsPPDiffusionPipelineOutput"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_leditspp_stable_diffusion import (
|
||||
LEditsPPDiffusionPipelineOutput,
|
||||
LEditsPPInversionPipelineOutput,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
)
|
||||
from .pipeline_leditspp_stable_diffusion_xl import LEditsPPPipelineStableDiffusionXL
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
43
src/diffusers/pipelines/ledits_pp/pipeline_output.py
Normal file
43
src/diffusers/pipelines/ledits_pp/pipeline_output.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class LEditsPPDiffusionPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for LEdits++ Diffusion pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
|
||||
num_channels)`.
|
||||
nsfw_content_detected (`List[bool]`)
|
||||
List indicating whether the corresponding generated image contains “not-safe-for-work” (nsfw) content or
|
||||
`None` if safety checking could not be performed.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
nsfw_content_detected: Optional[List[bool]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LEditsPPInversionPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for LEdits++ Diffusion pipelines.
|
||||
|
||||
Args:
|
||||
input_images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of the cropped and resized input images as PIL images of length `batch_size` or NumPy array of shape `
|
||||
(batch_size, height, width, num_channels)`.
|
||||
vae_reconstruction_images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of VAE reconstruction of all input images as PIL images of length `batch_size` or NumPy array of shape `
|
||||
(batch_size, height, width, num_channels)`.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
vae_reconstruction_images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
@@ -899,6 +899,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
generator=None,
|
||||
variance_noise: Optional[torch.FloatTensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -914,6 +915,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
variance_noise (`torch.FloatTensor`):
|
||||
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
||||
itself. Useful for methods such as [`LEdits++`].
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
||||
|
||||
@@ -948,11 +952,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Upcast to avoid precision issues when computing prev_sample
|
||||
sample = sample.to(torch.float32)
|
||||
|
||||
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
|
||||
noise = randn_tensor(
|
||||
model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
|
||||
)
|
||||
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
|
||||
else:
|
||||
noise = None
|
||||
|
||||
|
||||
@@ -792,6 +792,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
generator=None,
|
||||
variance_noise: Optional[torch.FloatTensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -807,6 +808,9 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
variance_noise (`torch.FloatTensor`):
|
||||
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
||||
itself. Useful for methods such as [`CycleDiffusion`].
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
||||
|
||||
@@ -837,10 +841,12 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs[i] = self.model_outputs[i + 1]
|
||||
self.model_outputs[-1] = model_output
|
||||
|
||||
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
|
||||
noise = randn_tensor(
|
||||
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
||||
)
|
||||
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
noise = variance_noise
|
||||
else:
|
||||
noise = None
|
||||
|
||||
|
||||
@@ -647,6 +647,36 @@ class LDMTextToImagePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LEditsPPPipelineStableDiffusion(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class MusicLDMPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
0
tests/pipelines/ledits_pp/__init__.py
Normal file
0
tests/pipelines/ledits_pp/__init__.py
Normal file
244
tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py
Normal file
244
tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DPMSolverMultistepScheduler,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_image,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@skip_mps
|
||||
class LEditsPPPipelineStableDiffusionFastTests(unittest.TestCase):
|
||||
pipeline_class = LEditsPPPipelineStableDiffusion
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
scheduler = DPMSolverMultistepScheduler(algorithm_type="sde-dpmsolver++", solver_order=2)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"generator": generator,
|
||||
"editing_prompt": ["wearing glasses", "sunshine"],
|
||||
"reverse_editing_direction": [False, True],
|
||||
"edit_guidance_scale": [10.0, 5.0],
|
||||
}
|
||||
return inputs
|
||||
|
||||
def get_dummy_inversion_inputs(self, device, seed=0):
|
||||
images = floats_tensor((2, 3, 32, 32), rng=random.Random(0)).cpu().permute(0, 2, 3, 1)
|
||||
images = 255 * images
|
||||
image_1 = Image.fromarray(np.uint8(images[0])).convert("RGB")
|
||||
image_2 = Image.fromarray(np.uint8(images[1])).convert("RGB")
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"image": [image_1, image_2],
|
||||
"source_prompt": "",
|
||||
"source_guidance_scale": 3.5,
|
||||
"num_inversion_steps": 20,
|
||||
"skip": 0.15,
|
||||
"generator": generator,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_ledits_pp_inversion(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = LEditsPPPipelineStableDiffusion(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inversion_inputs(device)
|
||||
inputs["image"] = inputs["image"][0]
|
||||
sd_pipe.invert(**inputs)
|
||||
assert sd_pipe.init_latents.shape == (
|
||||
1,
|
||||
4,
|
||||
int(32 / sd_pipe.vae_scale_factor),
|
||||
int(32 / sd_pipe.vae_scale_factor),
|
||||
)
|
||||
|
||||
latent_slice = sd_pipe.init_latents[0, -1, -3:, -3:].to(device)
|
||||
print(latent_slice.flatten())
|
||||
expected_slice = np.array([-0.9084, -0.0367, 0.2940, 0.0839, 0.6890, 0.2651, -0.7104, 2.1090, -0.7822])
|
||||
assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_ledits_pp_inversion_batch(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = LEditsPPPipelineStableDiffusion(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inversion_inputs(device)
|
||||
sd_pipe.invert(**inputs)
|
||||
assert sd_pipe.init_latents.shape == (
|
||||
2,
|
||||
4,
|
||||
int(32 / sd_pipe.vae_scale_factor),
|
||||
int(32 / sd_pipe.vae_scale_factor),
|
||||
)
|
||||
|
||||
latent_slice = sd_pipe.init_latents[0, -1, -3:, -3:].to(device)
|
||||
print(latent_slice.flatten())
|
||||
expected_slice = np.array([0.2528, 0.1458, -0.2166, 0.4565, -0.5657, -1.0286, -0.9961, 0.5933, 1.1173])
|
||||
assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
latent_slice = sd_pipe.init_latents[1, -1, -3:, -3:].to(device)
|
||||
print(latent_slice.flatten())
|
||||
expected_slice = np.array([-0.0796, 2.0583, 0.5501, 0.5358, 0.0282, -0.2803, -1.0470, 0.7023, -0.0072])
|
||||
assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_ledits_pp_warmup_steps(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
pipe = LEditsPPPipelineStableDiffusion(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inversion_inputs = self.get_dummy_inversion_inputs(device)
|
||||
pipe.invert(**inversion_inputs)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
|
||||
inputs["edit_warmup_steps"] = [0, 5]
|
||||
pipe(**inputs).images
|
||||
|
||||
inputs["edit_warmup_steps"] = [5, 0]
|
||||
pipe(**inputs).images
|
||||
|
||||
inputs["edit_warmup_steps"] = [5, 10]
|
||||
pipe(**inputs).images
|
||||
|
||||
inputs["edit_warmup_steps"] = [10, 5]
|
||||
pipe(**inputs).images
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class LEditsPPPipelineStableDiffusionSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
raw_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png"
|
||||
)
|
||||
raw_image = raw_image.convert("RGB").resize((512, 512))
|
||||
cls.raw_image = raw_image
|
||||
|
||||
def test_ledits_pp_editing(self):
|
||||
pipe = LEditsPPPipelineStableDiffusion.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
_ = pipe.invert(image=self.raw_image, generator=generator)
|
||||
generator = torch.manual_seed(0)
|
||||
inputs = {
|
||||
"generator": generator,
|
||||
"editing_prompt": ["cat", "dog"],
|
||||
"reverse_editing_direction": [True, False],
|
||||
"edit_guidance_scale": [5.0, 5.0],
|
||||
"edit_threshold": [0.8, 0.8],
|
||||
}
|
||||
reconstruction = pipe(**inputs, output_type="np").images[0]
|
||||
|
||||
output_slice = reconstruction[150:153, 140:143, -1]
|
||||
output_slice = output_slice.flatten()
|
||||
expected_slice = np.array(
|
||||
[0.9453125, 0.93310547, 0.84521484, 0.94628906, 0.9111328, 0.80859375, 0.93847656, 0.9042969, 0.8144531]
|
||||
)
|
||||
assert np.abs(output_slice - expected_slice).max() < 1e-2
|
||||
289
tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py
Normal file
289
tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py
Normal file
@@ -0,0 +1,289 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextConfig,
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionConfig,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DPMSolverMultistepScheduler,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
|
||||
# from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_image,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@skip_mps
|
||||
class LEditsPPPipelineStableDiffusionXLFastTests(unittest.TestCase):
|
||||
pipeline_class = LEditsPPPipelineStableDiffusionXL
|
||||
|
||||
def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
time_cond_proj_dim=time_cond_proj_dim,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
# SD2-specific config below
|
||||
attention_head_dim=(2, 4),
|
||||
use_linear_projection=True,
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
|
||||
cross_attention_dim=64 if not skip_first_text_encoder else 32,
|
||||
)
|
||||
scheduler = DPMSolverMultistepScheduler(algorithm_type="sde-dpmsolver++", solver_order=2)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
sample_size=128,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
image_encoder_config = CLIPVisionConfig(
|
||||
hidden_size=32,
|
||||
image_size=224,
|
||||
projection_dim=32,
|
||||
intermediate_size=37,
|
||||
num_attention_heads=4,
|
||||
num_channels=3,
|
||||
num_hidden_layers=5,
|
||||
patch_size=14,
|
||||
)
|
||||
image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
|
||||
|
||||
feature_extractor = CLIPImageProcessor(
|
||||
crop_size=224,
|
||||
do_center_crop=True,
|
||||
do_normalize=True,
|
||||
do_resize=True,
|
||||
image_mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
image_std=[0.26862954, 0.26130258, 0.27577711],
|
||||
resample=3,
|
||||
size=224,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
# SD2-specific config below
|
||||
hidden_act="gelu",
|
||||
projection_dim=32,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder if not skip_first_text_encoder else None,
|
||||
"tokenizer": tokenizer if not skip_first_text_encoder else None,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"image_encoder": image_encoder,
|
||||
"feature_extractor": feature_extractor,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"generator": generator,
|
||||
"editing_prompt": ["wearing glasses", "sunshine"],
|
||||
"reverse_editing_direction": [False, True],
|
||||
"edit_guidance_scale": [10.0, 5.0],
|
||||
}
|
||||
return inputs
|
||||
|
||||
def get_dummy_inversion_inputs(self, device, seed=0):
|
||||
images = floats_tensor((2, 3, 32, 32), rng=random.Random(0)).cpu().permute(0, 2, 3, 1)
|
||||
images = 255 * images
|
||||
image_1 = Image.fromarray(np.uint8(images[0])).convert("RGB")
|
||||
image_2 = Image.fromarray(np.uint8(images[1])).convert("RGB")
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"image": [image_1, image_2],
|
||||
"source_prompt": "",
|
||||
"source_guidance_scale": 3.5,
|
||||
"num_inversion_steps": 20,
|
||||
"skip": 0.15,
|
||||
"generator": generator,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_ledits_pp_inversion(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = LEditsPPPipelineStableDiffusionXL(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inversion_inputs(device)
|
||||
inputs["image"] = inputs["image"][0]
|
||||
sd_pipe.invert(**inputs)
|
||||
assert sd_pipe.init_latents.shape == (
|
||||
1,
|
||||
4,
|
||||
int(32 / sd_pipe.vae_scale_factor),
|
||||
int(32 / sd_pipe.vae_scale_factor),
|
||||
)
|
||||
|
||||
latent_slice = sd_pipe.init_latents[0, -1, -3:, -3:].to(device)
|
||||
expected_slice = np.array([-0.9084, -0.0367, 0.2940, 0.0839, 0.6890, 0.2651, -0.7103, 2.1090, -0.7821])
|
||||
assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_ledits_pp_inversion_batch(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = LEditsPPPipelineStableDiffusionXL(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inversion_inputs(device)
|
||||
sd_pipe.invert(**inputs)
|
||||
assert sd_pipe.init_latents.shape == (
|
||||
2,
|
||||
4,
|
||||
int(32 / sd_pipe.vae_scale_factor),
|
||||
int(32 / sd_pipe.vae_scale_factor),
|
||||
)
|
||||
|
||||
latent_slice = sd_pipe.init_latents[0, -1, -3:, -3:].to(device)
|
||||
print(latent_slice.flatten())
|
||||
expected_slice = np.array([0.2528, 0.1458, -0.2166, 0.4565, -0.5656, -1.0286, -0.9961, 0.5933, 1.1172])
|
||||
assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
latent_slice = sd_pipe.init_latents[1, -1, -3:, -3:].to(device)
|
||||
print(latent_slice.flatten())
|
||||
expected_slice = np.array([-0.0796, 2.0583, 0.5500, 0.5358, 0.0282, -0.2803, -1.0470, 0.7024, -0.0072])
|
||||
print(latent_slice.flatten())
|
||||
assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_ledits_pp_warmup_steps(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
pipe = LEditsPPPipelineStableDiffusionXL(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inversion_inputs = self.get_dummy_inversion_inputs(device)
|
||||
inversion_inputs["image"] = inversion_inputs["image"][0]
|
||||
pipe.invert(**inversion_inputs)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
|
||||
inputs["edit_warmup_steps"] = [0, 5]
|
||||
pipe(**inputs).images
|
||||
|
||||
inputs["edit_warmup_steps"] = [5, 0]
|
||||
pipe(**inputs).images
|
||||
|
||||
inputs["edit_warmup_steps"] = [5, 10]
|
||||
pipe(**inputs).images
|
||||
|
||||
inputs["edit_warmup_steps"] = [10, 5]
|
||||
pipe(**inputs).images
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class LEditsPPPipelineStableDiffusionXLSlowTests(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
raw_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png"
|
||||
)
|
||||
raw_image = raw_image.convert("RGB").resize((512, 512))
|
||||
cls.raw_image = raw_image
|
||||
|
||||
def test_ledits_pp_edit(self):
|
||||
pipe = LEditsPPPipelineStableDiffusionXL.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", safety_checker=None, add_watermarker=None
|
||||
)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
_ = pipe.invert(image=self.raw_image, generator=generator, num_zero_noise_steps=0)
|
||||
inputs = {
|
||||
"generator": generator,
|
||||
"editing_prompt": ["cat", "dog"],
|
||||
"reverse_editing_direction": [True, False],
|
||||
"edit_guidance_scale": [2.0, 4.0],
|
||||
"edit_threshold": [0.8, 0.8],
|
||||
}
|
||||
reconstruction = pipe(**inputs, output_type="np").images[0]
|
||||
|
||||
output_slice = reconstruction[150:153, 140:143, -1]
|
||||
output_slice = output_slice.flatten()
|
||||
expected_slice = np.array(
|
||||
[0.56419, 0.44121838, 0.2765603, 0.5708484, 0.42763475, 0.30945742, 0.5387106, 0.4735807, 0.3547244]
|
||||
)
|
||||
assert np.abs(output_slice - expected_slice).max() < 1e-3
|
||||
Reference in New Issue
Block a user