From c2a38ef9df350c6e01a8d1e299bc082da26beb7e Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Sun, 18 Dec 2022 11:49:53 +0100 Subject: [PATCH] Fix/update the LDM pipeline and tests (#1743) * Fix/update LDM tests * batched generators --- .../pipeline_latent_diffusion.py | 31 ++- .../latent_diffusion/test_latent_diffusion.py | 201 ++++++++++-------- 2 files changed, 135 insertions(+), 97 deletions(-) diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 83f8d355a7..ec0c71af4f 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -128,29 +128,42 @@ class LDMTextToImagePipeline(DiffusionPipeline): # get unconditional embeddings for classifier free guidance if guidance_scale != 1.0: - uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt" + ) uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0] # get prompt text embeddings - text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") + text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt") text_embeddings = self.bert(text_input.input_ids.to(self.device))[0] # get the initial random noise unless the user supplied it latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: - if self.device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(latents_shape, generator=generator, device="cpu").to(self.device) + rand_device = "cpu" if self.device.type == "mps" else self.device + + if isinstance(generator, list): + latents_shape = (1,) + latents_shape[1:] + latents = [ + torch.randn(latents_shape, generator=generator[i], device=rand_device, dtype=text_embeddings.dtype) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0) else: latents = torch.randn( - latents_shape, - generator=generator, - device=self.device, + latents_shape, generator=generator, device=rand_device, dtype=text_embeddings.dtype ) + latents = latents.to(self.device) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) + latents = latents.to(self.device) self.scheduler.set_timesteps(num_inference_steps) diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion.py b/tests/pipelines/latent_diffusion/test_latent_diffusion.py index 4b7c89977d..ef790f28e4 100644 --- a/tests/pipelines/latent_diffusion/test_latent_diffusion.py +++ b/tests/pipelines/latent_diffusion/test_latent_diffusion.py @@ -13,24 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import unittest import numpy as np import torch from diffusers import AutoencoderKL, DDIMScheduler, LDMTextToImagePipeline, UNet2DConditionModel -from diffusers.utils.testing_utils import require_torch, slow, torch_device +from diffusers.utils.testing_utils import load_numpy, nightly, require_torch_gpu, slow, torch_device from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer +from ...test_pipelines_common import PipelineTesterMixin + torch.backends.cuda.matmul.allow_tf32 = False -class LDMTextToImagePipelineFastTests(unittest.TestCase): - @property - def dummy_cond_unet(self): +class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = LDMTextToImagePipeline + test_cpu_offload = False + + def get_dummy_components(self): torch.manual_seed(0) - model = UNet2DConditionModel( + unet = UNet2DConditionModel( block_out_channels=(32, 64), layers_per_block=2, sample_size=32, @@ -40,25 +45,24 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase): up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), cross_attention_dim=32, ) - return model - - @property - def dummy_vae(self): + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) torch.manual_seed(0) - model = AutoencoderKL( - block_out_channels=[32, 64], + vae = AutoencoderKL( + block_out_channels=(32, 64), in_channels=3, out_channels=3, - down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], - up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D"), + up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D"), latent_channels=4, ) - return model - - @property - def dummy_text_encoder(self): torch.manual_seed(0) - config = CLIPTextConfig( + text_encoder_config = CLIPTextConfig( bos_token_id=0, eos_token_id=2, hidden_size=32, @@ -69,96 +73,117 @@ class LDMTextToImagePipelineFastTests(unittest.TestCase): pad_token_id=1, vocab_size=1000, ) - return CLIPTextModel(config) - - def test_inference_text2img(self): - if torch_device != "cpu": - return - - unet = self.dummy_cond_unet - scheduler = DDIMScheduler() - vae = self.dummy_vae - bert = self.dummy_text_encoder + text_encoder = CLIPTextModel(text_encoder_config) tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") - ldm = LDMTextToImagePipeline(vqvae=vae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) - ldm.to(torch_device) - ldm.set_progress_bar_config(disable=None) + components = { + "unet": unet, + "scheduler": scheduler, + "vqvae": vae, + "bert": text_encoder, + "tokenizer": tokenizer, + } + return components - prompt = "A painting of a squirrel eating a burger" + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + } + return inputs - # Warmup pass when using mps (see #372) - if torch_device == "mps": - generator = torch.manual_seed(0) - _ = ldm( - [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy" - ).images + def test_inference_text2img(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator - device = torch_device if torch_device != "mps" else "cpu" - generator = torch.Generator(device=device).manual_seed(0) - - image = ldm( - [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy" - ).images - - device = torch_device if torch_device != "mps" else "cpu" - generator = torch.Generator(device=device).manual_seed(0) - - image_from_tuple = ldm( - [prompt], - generator=generator, - guidance_scale=6.0, - num_inference_steps=2, - output_type="numpy", - return_dict=False, - )[0] + components = self.get_dummy_components() + pipe = LDMTextToImagePipeline(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] - image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 16, 16, 3) - expected_slice = np.array([0.6806, 0.5454, 0.5638, 0.4893, 0.4656, 0.4257, 0.6248, 0.5217, 0.5498]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + expected_slice = np.array([0.59450, 0.64078, 0.55509, 0.51229, 0.69640, 0.36960, 0.59296, 0.60801, 0.49332]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @slow -@require_torch -class LDMTextToImagePipelineIntegrationTests(unittest.TestCase): - def test_inference_text2img(self): - ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") - ldm.to(torch_device) - ldm.set_progress_bar_config(disable=None) +@require_torch_gpu +class LDMTextToImagePipelineSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() - prompt = "A painting of a squirrel eating a burger" + def get_inputs(self, device, dtype=torch.float32, seed=0): + generator = torch.Generator(device=device).manual_seed(seed) + latents = np.random.RandomState(seed).standard_normal((1, 4, 32, 32)) + latents = torch.from_numpy(latents).to(device=device, dtype=dtype) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "latents": latents, + "generator": generator, + "num_inference_steps": 3, + "guidance_scale": 6.0, + "output_type": "numpy", + } + return inputs - device = torch_device if torch_device != "mps" else "cpu" - generator = torch.Generator(device=device).manual_seed(0) + def test_ldm_default_ddim(self): + pipe = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256").to(torch_device) + pipe.set_progress_bar_config(disable=None) - image = ldm( - [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy" - ).images - - image_slice = image[0, -3:, -3:, -1] + inputs = self.get_inputs(torch_device) + image = pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1].flatten() assert image.shape == (1, 256, 256, 3) - expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + expected_slice = np.array([0.51825, 0.52850, 0.52543, 0.54258, 0.52304, 0.52569, 0.54363, 0.55276, 0.56878]) + max_diff = np.abs(expected_slice - image_slice).max() + assert max_diff < 1e-3 - def test_inference_text2img_fast(self): - ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") - ldm.to(torch_device) - ldm.set_progress_bar_config(disable=None) - prompt = "A painting of a squirrel eating a burger" +@nightly +@require_torch_gpu +class LDMTextToImagePipelineNightlyTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() - device = torch_device if torch_device != "mps" else "cpu" - generator = torch.Generator(device=device).manual_seed(0) + def get_inputs(self, device, dtype=torch.float32, seed=0): + generator = torch.Generator(device=device).manual_seed(seed) + latents = np.random.RandomState(seed).standard_normal((1, 4, 32, 32)) + latents = torch.from_numpy(latents).to(device=device, dtype=dtype) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "latents": latents, + "generator": generator, + "num_inference_steps": 50, + "guidance_scale": 6.0, + "output_type": "numpy", + } + return inputs - image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images + def test_ldm_default_ddim(self): + pipe = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256").to(torch_device) + pipe.set_progress_bar_config(disable=None) - image_slice = image[0, -3:, -3:, -1] + inputs = self.get_inputs(torch_device) + image = pipe(**inputs).images[0] - assert image.shape == (1, 256, 256, 3) - expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + expected_image = load_numpy( + "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/ldm_text2img/ldm_large_256_ddim.npy" + ) + max_diff = np.abs(expected_image - image).max() + assert max_diff < 1e-3