1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix tests for equivalence of DDIM and DDPM pipelines (#1069)

* Fix equality test for ddim and ddpm

* add docs for use_clipped_model_output in DDIM

* fix inline comment

* reorder imports in test_pipelines.py

* Ignore use_clipped_model_output if scheduler doesn't take it
This commit is contained in:
Grigory Sizov
2022-11-02 14:50:32 +01:00
committed by GitHub
parent 1216a3b122
commit 5cd29d623a
3 changed files with 42 additions and 12 deletions

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import inspect
from typing import Optional, Tuple, Union
import torch
@@ -44,6 +44,7 @@ class DDIMPipeline(DiffusionPipeline):
generator: Optional[torch.Generator] = None,
eta: float = 0.0,
num_inference_steps: int = 50,
use_clipped_model_output: Optional[bool] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
@@ -60,6 +61,9 @@ class DDIMPipeline(DiffusionPipeline):
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
use_clipped_model_output (`bool`, *optional*, defaults to `None`):
if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed
downstream to the scheduler. So use `None` for schedulers which don't support this argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -82,6 +86,14 @@ class DDIMPipeline(DiffusionPipeline):
# set step values
self.scheduler.set_timesteps(num_inference_steps)
# Ignore use_clipped_model_output if the scheduler doesn't accept this argument
accepts_use_clipped_model_output = "use_clipped_model_output" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_kwargs = {}
if accepts_use_clipped_model_output:
extra_kwargs["use_clipped_model_output"] = use_clipped_model_output
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
model_output = self.unet(image, t).sample
@@ -89,7 +101,7 @@ class DDIMPipeline(DiffusionPipeline):
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# eta corresponds to η in paper and should be between [0, 1]
# do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta).prev_sample
image = self.scheduler.step(model_output, t, image, eta, **extra_kwargs).prev_sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()

View File

@@ -220,7 +220,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): TODO
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class

View File

@@ -42,6 +42,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir
from parameterized import parameterized
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@@ -445,7 +446,9 @@ class PipelineSlowTests(unittest.TestCase):
assert isinstance(images, list)
assert isinstance(images[0], PIL.Image.Image)
def test_ddpm_ddim_equality(self):
# Make sure the test passes for different values of random seed
@parameterized.expand([(0,), (4,)])
def test_ddpm_ddim_equality(self, seed):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
@@ -459,17 +462,24 @@ class PipelineSlowTests(unittest.TestCase):
ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
generator = torch.manual_seed(seed)
ddpm_image = ddpm(generator=generator, output_type="numpy").images
generator = torch.manual_seed(0)
ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy").images
generator = torch.manual_seed(seed)
ddim_image = ddim(
generator=generator,
num_inference_steps=1000,
eta=1.0,
output_type="numpy",
use_clipped_model_output=True, # Need this to make DDIM match DDPM
).images
# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_image - ddim_image).max() < 1e-1
@unittest.skip("(Anton) The test is failing for large batch sizes, needs investigation")
def test_ddpm_ddim_equality_batched(self):
# Make sure the test passes for different values of random seed
@parameterized.expand([(0,), (4,)])
def test_ddpm_ddim_equality_batched(self, seed):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
@@ -484,12 +494,17 @@ class PipelineSlowTests(unittest.TestCase):
ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
generator = torch.manual_seed(seed)
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images
generator = torch.manual_seed(0)
generator = torch.manual_seed(seed)
ddim_images = ddim(
batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy"
batch_size=4,
generator=generator,
num_inference_steps=1000,
eta=1.0,
output_type="numpy",
use_clipped_model_output=True, # Need this to make DDIM match DDPM
).images
# the values aren't exactly equal, but the images look the same visually