From 5cd29d623ac38ccb3bdd8a5f654b85d4765d9751 Mon Sep 17 00:00:00 2001 From: Grigory Sizov Date: Wed, 2 Nov 2022 14:50:32 +0100 Subject: [PATCH] Fix tests for equivalence of DDIM and DDPM pipelines (#1069) * Fix equality test for ddim and ddpm * add docs for use_clipped_model_output in DDIM * fix inline comment * reorder imports in test_pipelines.py * Ignore use_clipped_model_output if scheduler doesn't take it --- src/diffusers/pipelines/ddim/pipeline_ddim.py | 16 +++++++-- src/diffusers/schedulers/scheduling_ddim.py | 5 ++- tests/test_pipelines.py | 33 ++++++++++++++----- 3 files changed, 42 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 74607fe87a..733a28c9f3 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -13,7 +13,7 @@ # limitations under the License. - +import inspect from typing import Optional, Tuple, Union import torch @@ -44,6 +44,7 @@ class DDIMPipeline(DiffusionPipeline): generator: Optional[torch.Generator] = None, eta: float = 0.0, num_inference_steps: int = 50, + use_clipped_model_output: Optional[bool] = None, output_type: Optional[str] = "pil", return_dict: bool = True, **kwargs, @@ -60,6 +61,9 @@ class DDIMPipeline(DiffusionPipeline): num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + use_clipped_model_output (`bool`, *optional*, defaults to `None`): + if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed + downstream to the scheduler. So use `None` for schedulers which don't support this argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -82,6 +86,14 @@ class DDIMPipeline(DiffusionPipeline): # set step values self.scheduler.set_timesteps(num_inference_steps) + # Ignore use_clipped_model_output if the scheduler doesn't accept this argument + accepts_use_clipped_model_output = "use_clipped_model_output" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_kwargs = {} + if accepts_use_clipped_model_output: + extra_kwargs["use_clipped_model_output"] = use_clipped_model_output + for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output model_output = self.unet(image, t).sample @@ -89,7 +101,7 @@ class DDIMPipeline(DiffusionPipeline): # 2. predict previous mean of image x_t-1 and add variance depending on eta # eta corresponds to η in paper and should be between [0, 1] # do x_t -> x_t-1 - image = self.scheduler.step(model_output, t, image, eta).prev_sample + image = self.scheduler.step(model_output, t, image, eta, **extra_kwargs).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index f95c18d9fa..23648d1bc3 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -220,7 +220,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. eta (`float`): weight of noise for added noise in diffusion step. - use_clipped_model_output (`bool`): TODO + use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped + predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would + coincide with the one provided as input and `use_clipped_model_output` will have not effect. generator: random number generator. return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index e355a19493..c11287339a 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -42,6 +42,7 @@ from diffusers.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir +from parameterized import parameterized from PIL import Image from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -445,7 +446,9 @@ class PipelineSlowTests(unittest.TestCase): assert isinstance(images, list) assert isinstance(images[0], PIL.Image.Image) - def test_ddpm_ddim_equality(self): + # Make sure the test passes for different values of random seed + @parameterized.expand([(0,), (4,)]) + def test_ddpm_ddim_equality(self, seed): model_id = "google/ddpm-cifar10-32" unet = UNet2DModel.from_pretrained(model_id, device_map="auto") @@ -459,17 +462,24 @@ class PipelineSlowTests(unittest.TestCase): ddim.to(torch_device) ddim.set_progress_bar_config(disable=None) - generator = torch.manual_seed(0) + generator = torch.manual_seed(seed) ddpm_image = ddpm(generator=generator, output_type="numpy").images - generator = torch.manual_seed(0) - ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy").images + generator = torch.manual_seed(seed) + ddim_image = ddim( + generator=generator, + num_inference_steps=1000, + eta=1.0, + output_type="numpy", + use_clipped_model_output=True, # Need this to make DDIM match DDPM + ).images # the values aren't exactly equal, but the images look the same visually assert np.abs(ddpm_image - ddim_image).max() < 1e-1 - @unittest.skip("(Anton) The test is failing for large batch sizes, needs investigation") - def test_ddpm_ddim_equality_batched(self): + # Make sure the test passes for different values of random seed + @parameterized.expand([(0,), (4,)]) + def test_ddpm_ddim_equality_batched(self, seed): model_id = "google/ddpm-cifar10-32" unet = UNet2DModel.from_pretrained(model_id, device_map="auto") @@ -484,12 +494,17 @@ class PipelineSlowTests(unittest.TestCase): ddim.to(torch_device) ddim.set_progress_bar_config(disable=None) - generator = torch.manual_seed(0) + generator = torch.manual_seed(seed) ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images - generator = torch.manual_seed(0) + generator = torch.manual_seed(seed) ddim_images = ddim( - batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy" + batch_size=4, + generator=generator, + num_inference_steps=1000, + eta=1.0, + output_type="numpy", + use_clipped_model_output=True, # Need this to make DDIM match DDPM ).images # the values aren't exactly equal, but the images look the same visually