From 3becd368b14d74ca361eada8408627234996e4d1 Mon Sep 17 00:00:00 2001 From: hwuebben Date: Wed, 19 Apr 2023 18:58:13 +0200 Subject: [PATCH] Update pipeline_stable_diffusion_inpaint_legacy.py (#2903) * Update pipeline_stable_diffusion_inpaint_legacy.py * fix preprocessing of Pil images with adequate batch size * revert map * add tests * reformat * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * next try to fix the style * wth is this * Update testing_utils.py * Update testing_utils.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py --------- Co-authored-by: Patrick von Platen --- ...ipeline_stable_diffusion_inpaint_legacy.py | 20 ++-- src/diffusers/utils/testing_utils.py | 10 ++ .../test_stable_diffusion_inpaint_legacy.py | 93 ++++++++++++++++++- 3 files changed, 108 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 1c8377c7e5..3ad1d5e922 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -41,17 +41,17 @@ from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) -def preprocess_image(image): +def preprocess_image(image, batch_size): w, h = image.size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) + image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size) image = torch.from_numpy(image) return 2.0 * image - 1.0 -def preprocess_mask(mask, scale_factor=8): +def preprocess_mask(mask, batch_size, scale_factor=8): if not isinstance(mask, torch.FloatTensor): mask = mask.convert("L") w, h = mask.size @@ -59,7 +59,7 @@ def preprocess_mask(mask, scale_factor=8): mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? + mask = np.vstack([mask[None]] * batch_size) mask = 1 - mask # repaint white, keep black mask = torch.from_numpy(mask) return mask @@ -521,14 +521,14 @@ class StableDiffusionInpaintPipelineLegacy( return timesteps, num_inference_steps - t_start - def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator): + def prepare_latents(self, image, timestep, num_images_per_prompt, dtype, device, generator): image = image.to(device=self.device, dtype=dtype) init_latent_dist = self.vae.encode(image).latent_dist init_latents = init_latent_dist.sample(generator=generator) init_latents = self.vae.config.scaling_factor * init_latents # Expand init_latents for batch_size and num_images_per_prompt - init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) + init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) init_latents_orig = init_latents # add noise to latents using the timesteps @@ -659,9 +659,9 @@ class StableDiffusionInpaintPipelineLegacy( # 4. Preprocess image and mask if not isinstance(image, torch.FloatTensor): - image = preprocess_image(image) + image = preprocess_image(image, batch_size) - mask_image = preprocess_mask(mask_image, self.vae_scale_factor) + mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -671,12 +671,12 @@ class StableDiffusionInpaintPipelineLegacy( # 6. Prepare latent variables # encode the init image into latents and scale the latents latents, init_latents_orig, noise = self.prepare_latents( - image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + image, latent_timestep, num_images_per_prompt, prompt_embeds.dtype, device, generator ) # 7. Prepare mask latent mask = mask_image.to(device=self.device, dtype=latents.dtype) - mask = torch.cat([mask] * batch_size * num_images_per_prompt) + mask = torch.cat([mask] * num_images_per_prompt) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index afea0540b7..d8fed5dec1 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -279,6 +279,16 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: return image +def preprocess_image(image: PIL.Image, batch_size: int): + w, h = image.size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + image = image.resize((w, h), resample=PIL.Image.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + def export_to_video(video_frames: List[np.ndarray], output_video_path: str = None) -> str: if is_opencv_available(): import cv2 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py index 15d94414ea..f56fa31a96 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py @@ -34,7 +34,7 @@ from diffusers import ( VQModel, ) from diffusers.utils import floats_tensor, load_image, nightly, slow, torch_device -from diffusers.utils.testing_utils import load_numpy, require_torch_gpu +from diffusers.utils.testing_utils import load_numpy, preprocess_image, require_torch_gpu torch.backends.cuda.matmul.allow_tf32 = False @@ -217,6 +217,55 @@ class StableDiffusionInpaintLegacyPipelineFastTests(unittest.TestCase): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_inpaint_legacy_batched(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB") + init_images_tens = preprocess_image(init_image, batch_size=2) + init_masks_tens = init_images_tens + 4 + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionInpaintPipelineLegacy( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + images = sd_pipe( + [prompt] * 2, + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + image=init_images_tens, + mask_image=init_masks_tens, + ).images + + assert images.shape == (2, 32, 32, 3) + + image_slice_0 = images[0, -3:, -3:, -1].flatten() + image_slice_1 = images[1, -3:, -3:, -1].flatten() + + expected_slice_0 = np.array([0.4697, 0.3770, 0.4096, 0.4653, 0.4497, 0.4183, 0.3950, 0.4668, 0.4672]) + expected_slice_1 = np.array([0.4105, 0.4987, 0.5771, 0.4921, 0.4237, 0.5684, 0.5496, 0.4645, 0.5272]) + + assert np.abs(expected_slice_0 - image_slice_0).max() < 1e-2 + assert np.abs(expected_slice_1 - image_slice_1).max() < 1e-2 + def test_stable_diffusion_inpaint_legacy_negative_prompt(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet @@ -349,7 +398,7 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase): gc.collect() torch.cuda.empty_cache() - def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): + def get_inputs(self, generator_device="cpu", seed=0): generator = torch.Generator(device=generator_device).manual_seed(seed) init_image = load_image( "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main" @@ -379,7 +428,7 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase): pipe.set_progress_bar_config(disable=None) pipe.enable_attention_slicing() - inputs = self.get_inputs(torch_device) + inputs = self.get_inputs() image = pipe(**inputs).images image_slice = image[0, 253:256, 253:256, -1].flatten() @@ -388,6 +437,40 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase): assert np.abs(expected_slice - image_slice).max() < 1e-4 + def test_stable_diffusion_inpaint_legacy_batched(self): + pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained( + "CompVis/stable-diffusion-v1-4", safety_checker=None + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs() + inputs["prompt"] = [inputs["prompt"]] * 2 + inputs["image"] = preprocess_image(inputs["image"], batch_size=2) + + mask = inputs["mask_image"].convert("L") + mask = np.array(mask).astype(np.float32) / 255.0 + mask = torch.from_numpy(1 - mask) + masks = torch.vstack([mask[None][None]] * 2) + inputs["mask_image"] = masks + + image = pipe(**inputs).images + assert image.shape == (2, 512, 512, 3) + + image_slice_0 = image[0, 253:256, 253:256, -1].flatten() + image_slice_1 = image[1, 253:256, 253:256, -1].flatten() + + expected_slice_0 = np.array( + [0.52093095, 0.4176447, 0.32752383, 0.6175223, 0.50563973, 0.36470804, 0.65460044, 0.5775188, 0.44332123] + ) + expected_slice_1 = np.array( + [0.3592432, 0.4233033, 0.3914635, 0.31014425, 0.3702293, 0.39412856, 0.17526966, 0.2642669, 0.37480092] + ) + + assert np.abs(expected_slice_0 - image_slice_0).max() < 1e-4 + assert np.abs(expected_slice_1 - image_slice_1).max() < 1e-4 + def test_stable_diffusion_inpaint_legacy_k_lms(self): pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained( "CompVis/stable-diffusion-v1-4", safety_checker=None @@ -397,7 +480,7 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase): pipe.set_progress_bar_config(disable=None) pipe.enable_attention_slicing() - inputs = self.get_inputs(torch_device) + inputs = self.get_inputs() image = pipe(**inputs).images image_slice = image[0, 253:256, 253:256, -1].flatten() @@ -437,7 +520,7 @@ class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase): pipe.set_progress_bar_config(disable=None) pipe.enable_attention_slicing() - inputs = self.get_inputs(torch_device, dtype=torch.float16) + inputs = self.get_inputs() pipe(**inputs, callback=callback_fn, callback_steps=1) assert callback_fn.has_been_called assert number_of_steps == 2