From c727a6a5fb94deb05d3fec25d54dc42a174c9be6 Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Fri, 16 Sep 2022 14:37:12 +0200 Subject: [PATCH] Finally fix the image-based SD tests (#509) * Finally fix the image-based SD tests * Remove autocast * Remove autocast in image tests --- src/diffusers/testing_utils.py | 33 ++++++ tests/test_pipelines.py | 190 +++++++++++++++++++-------------- 2 files changed, 141 insertions(+), 82 deletions(-) diff --git a/src/diffusers/testing_utils.py b/src/diffusers/testing_utils.py index ff8b6aa9b4..7daf2bc633 100644 --- a/src/diffusers/testing_utils.py +++ b/src/diffusers/testing_utils.py @@ -2,9 +2,13 @@ import os import random import unittest from distutils.util import strtobool +from typing import Union import torch +import PIL.Image +import PIL.ImageOps +import requests from packaging import version @@ -59,3 +63,32 @@ def slow(test_case): """ return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) + + +def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: + """ + Args: + Loads `image` to a PIL Image. + image (`str` or `PIL.Image.Image`): + The image to convert to the PIL Image format. + Returns: + `PIL.Image.Image`: A PIL Image. + """ + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + image = PIL.Image.open(requests.get(image, stream=True).raw) + elif os.path.isfile(image): + image = PIL.Image.open(image) + else: + raise ValueError( + f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path" + ) + elif isinstance(image, PIL.Image.Image): + image = image + else: + raise ValueError( + "Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image." + ) + image = PIL.ImageOps.exif_transpose(image) + image = image.convert("RGB") + return image diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 3d69136810..805ac116bb 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -22,7 +22,6 @@ import numpy as np import torch import PIL -from datasets import load_dataset from diffusers import ( AutoencoderKL, DDIMPipeline, @@ -47,7 +46,7 @@ from diffusers import ( VQModel, ) from diffusers.pipeline_utils import DiffusionPipeline -from diffusers.testing_utils import floats_tensor, slow, torch_device +from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -168,7 +167,7 @@ class PipelineFastTests(unittest.TestCase): @property def dummy_safety_checker(self): def check(images, *args, **kwargs): - return images, False + return images, [False] * len(images) return check @@ -708,6 +707,13 @@ class PipelineTesterMixin(unittest.TestCase): gc.collect() torch.cuda.empty_cache() + @property + def dummy_safety_checker(self): + def check(images, *args, **kwargs): + return images, [False] * len(images) + + return check + def test_from_pretrained_save_pretrained(self): # 1. Load models model = UNet2DModel( @@ -1139,65 +1145,87 @@ class PipelineTesterMixin(unittest.TestCase): @slow @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") - def test_stable_diffusion_img2img_pipeline(self): - ds = load_dataset( - "imagefolder", - data_files={ - "input": [ - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/img2img/sketch-mountains-input.jpg" - ], - "output": [ - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/img2img/fantasy_landscape.png" - ], - }, + def test_stable_diffusion_text2img_pipeline(self): + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/text2img/astronaut_riding_a_horse.png" ) + expected_image = np.array(expected_image, dtype=np.float32) / 255.0 - init_image = ds["input"]["image"][0].resize((768, 512)) - output_image = ds["output"]["image"][0].resize((768, 512)) + model_id = "CompVis/stable-diffusion-v1-4" + pipe = StableDiffusionPipeline.from_pretrained( + model_id, + safety_checker=self.dummy_safety_checker, + use_auth_token=True, + ) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + prompt = "astronaut riding a horse" + + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe(prompt=prompt, strength=0.75, guidance_scale=7.5, generator=generator, output_type="np") + image = output.images[0] + + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 1e-2 + + @slow + @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") + def test_stable_diffusion_img2img_pipeline(self): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" + ) + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/fantasy_landscape.png" + ) + init_image = init_image.resize((768, 512)) + expected_image = np.array(expected_image, dtype=np.float32) / 255.0 model_id = "CompVis/stable-diffusion-v1-4" pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_id, + safety_checker=self.dummy_safety_checker, use_auth_token=True, ) pipe.to(torch_device) - pipe.enable_attention_slicing() pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() prompt = "A fantasy landscape, trending on artstation" generator = torch.Generator(device=torch_device).manual_seed(0) - with torch.autocast("cuda"): - output = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator) + output = pipe( + prompt=prompt, + init_image=init_image, + strength=0.75, + guidance_scale=7.5, + generator=generator, + output_type="np", + ) image = output.images[0] - expected_array = np.array(output_image) / 255.0 - sampled_array = np.array(image) / 255.0 + Image.fromarray((image * 255).round().astype("uint8")).save("fantasy_landscape.png") - assert sampled_array.shape == (512, 768, 3) - assert np.max(np.abs(sampled_array - expected_array)) < 1e-4 + assert image.shape == (512, 768, 3) + assert np.abs(expected_image - image).max() < 1e-2 @slow @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") def test_stable_diffusion_img2img_pipeline_k_lms(self): - ds = load_dataset( - "imagefolder", - data_files={ - "input": [ - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/img2img/sketch-mountains-input.jpg" - ], - "output": [ - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/img2img/fantasy_landscape_k_lms.png" - ], - }, + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/sketch-mountains-input.jpg" ) - - init_image = ds["input"]["image"][0].resize((768, 512)) - output_image = ds["output"]["image"][0].resize((768, 512)) + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/img2img/fantasy_landscape_k_lms.png" + ) + init_image = init_image.resize((768, 512)) + expected_image = np.array(expected_image, dtype=np.float32) / 255.0 lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") @@ -1205,78 +1233,76 @@ class PipelineTesterMixin(unittest.TestCase): pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_id, scheduler=lms, + safety_checker=self.dummy_safety_checker, use_auth_token=True, ) - pipe.enable_attention_slicing() pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() prompt = "A fantasy landscape, trending on artstation" generator = torch.Generator(device=torch_device).manual_seed(0) - with torch.autocast("cuda"): - output = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator) + output = pipe( + prompt=prompt, + init_image=init_image, + strength=0.75, + guidance_scale=7.5, + generator=generator, + output_type="np", + ) image = output.images[0] - expected_array = np.array(output_image) / 255.0 - sampled_array = np.array(image) / 255.0 + Image.fromarray((image * 255).round().astype("uint8")).save("fantasy_landscape_k_lms.png") - assert sampled_array.shape == (512, 768, 3) - assert np.max(np.abs(sampled_array - expected_array)) < 1e-4 + assert image.shape == (512, 768, 3) + assert np.abs(expected_image - image).max() < 1e-2 @slow @unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") def test_stable_diffusion_inpaint_pipeline(self): - ds = load_dataset( - "imagefolder", - data_files={ - "input": [ - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo.png" - ], - "mask": [ - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" - ], - "output": [ - "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" - "/in_paint/red_cat_sitting_on_a_parking_bench.png" - ], - }, + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" ) - - init_image = ds["input"]["image"][0].resize((768, 512)) - mask_image = ds["mask"]["image"][0].resize((768, 512)) - output_image = ds["output"]["image"][0].resize((768, 512)) + mask_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" + ) + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/red_cat_sitting_on_a_park_bench.png" + ) + expected_image = np.array(expected_image, dtype=np.float32) / 255.0 model_id = "CompVis/stable-diffusion-v1-4" pipe = StableDiffusionInpaintPipeline.from_pretrained( model_id, + safety_checker=self.dummy_safety_checker, use_auth_token=True, ) pipe.to(torch_device) - pipe.enable_attention_slicing() pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() - prompt = "A red cat sitting on a parking bench" + prompt = "A red cat sitting on a park bench" generator = torch.Generator(device=torch_device).manual_seed(0) - with torch.autocast("cuda"): - output = pipe( - prompt=prompt, - init_image=init_image, - mask_image=mask_image, - strength=0.75, - guidance_scale=7.5, - generator=generator, - ) + output = pipe( + prompt=prompt, + init_image=init_image, + mask_image=mask_image, + strength=0.75, + guidance_scale=7.5, + generator=generator, + output_type="np", + ) image = output.images[0] - expected_array = np.array(output_image) / 255.0 - sampled_array = np.array(image) / 255.0 + Image.fromarray((image * 255).round().astype("uint8")).save("red_cat_sitting_on_a_park_bench.png") - assert sampled_array.shape == (512, 768, 3) - assert np.max(np.abs(sampled_array - expected_array)) < 1e-3 + assert image.shape == (512, 512, 3) + assert np.abs(expected_image - image).max() < 1e-2 @slow def test_stable_diffusion_onnx(self):