mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix tests for equivalence of DDIM and DDPM pipelines (#1069)
* Fix equality test for ddim and ddpm * add docs for use_clipped_model_output in DDIM * fix inline comment * reorder imports in test_pipelines.py * Ignore use_clipped_model_output if scheduler doesn't take it
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import inspect
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -44,6 +44,7 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
generator: Optional[torch.Generator] = None,
|
||||
eta: float = 0.0,
|
||||
num_inference_steps: int = 50,
|
||||
use_clipped_model_output: Optional[bool] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
@@ -60,6 +61,9 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
use_clipped_model_output (`bool`, *optional*, defaults to `None`):
|
||||
if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed
|
||||
downstream to the scheduler. So use `None` for schedulers which don't support this argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -82,6 +86,14 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# Ignore use_clipped_model_output if the scheduler doesn't accept this argument
|
||||
accepts_use_clipped_model_output = "use_clipped_model_output" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
extra_kwargs = {}
|
||||
if accepts_use_clipped_model_output:
|
||||
extra_kwargs["use_clipped_model_output"] = use_clipped_model_output
|
||||
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
# 1. predict noise model_output
|
||||
model_output = self.unet(image, t).sample
|
||||
@@ -89,7 +101,7 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
# do x_t -> x_t-1
|
||||
image = self.scheduler.step(model_output, t, image, eta).prev_sample
|
||||
image = self.scheduler.step(model_output, t, image, eta, **extra_kwargs).prev_sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
@@ -220,7 +220,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
eta (`float`): weight of noise for added noise in diffusion step.
|
||||
use_clipped_model_output (`bool`): TODO
|
||||
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
|
||||
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
|
||||
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
|
||||
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
|
||||
generator: random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir
|
||||
from parameterized import parameterized
|
||||
from PIL import Image
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
@@ -445,7 +446,9 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
assert isinstance(images, list)
|
||||
assert isinstance(images[0], PIL.Image.Image)
|
||||
|
||||
def test_ddpm_ddim_equality(self):
|
||||
# Make sure the test passes for different values of random seed
|
||||
@parameterized.expand([(0,), (4,)])
|
||||
def test_ddpm_ddim_equality(self, seed):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||
@@ -459,17 +462,24 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
ddim.to(torch_device)
|
||||
ddim.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
generator = torch.manual_seed(seed)
|
||||
ddpm_image = ddpm(generator=generator, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy").images
|
||||
generator = torch.manual_seed(seed)
|
||||
ddim_image = ddim(
|
||||
generator=generator,
|
||||
num_inference_steps=1000,
|
||||
eta=1.0,
|
||||
output_type="numpy",
|
||||
use_clipped_model_output=True, # Need this to make DDIM match DDPM
|
||||
).images
|
||||
|
||||
# the values aren't exactly equal, but the images look the same visually
|
||||
assert np.abs(ddpm_image - ddim_image).max() < 1e-1
|
||||
|
||||
@unittest.skip("(Anton) The test is failing for large batch sizes, needs investigation")
|
||||
def test_ddpm_ddim_equality_batched(self):
|
||||
# Make sure the test passes for different values of random seed
|
||||
@parameterized.expand([(0,), (4,)])
|
||||
def test_ddpm_ddim_equality_batched(self, seed):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
|
||||
unet = UNet2DModel.from_pretrained(model_id, device_map="auto")
|
||||
@@ -484,12 +494,17 @@ class PipelineSlowTests(unittest.TestCase):
|
||||
ddim.to(torch_device)
|
||||
ddim.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
generator = torch.manual_seed(seed)
|
||||
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
generator = torch.manual_seed(seed)
|
||||
ddim_images = ddim(
|
||||
batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy"
|
||||
batch_size=4,
|
||||
generator=generator,
|
||||
num_inference_steps=1000,
|
||||
eta=1.0,
|
||||
output_type="numpy",
|
||||
use_clipped_model_output=True, # Need this to make DDIM match DDPM
|
||||
).images
|
||||
|
||||
# the values aren't exactly equal, but the images look the same visually
|
||||
|
||||
Reference in New Issue
Block a user