1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Finally fix the image-based SD tests (#509)

* Finally fix the image-based SD tests

* Remove autocast

* Remove autocast in image tests
This commit is contained in:
Anton Lozhkov
2022-09-16 14:37:12 +02:00
committed by GitHub
parent f73ca908e5
commit c727a6a5fb
2 changed files with 141 additions and 82 deletions

View File

@@ -2,9 +2,13 @@ import os
import random
import unittest
from distutils.util import strtobool
from typing import Union
import torch
import PIL.Image
import PIL.ImageOps
import requests
from packaging import version
@@ -59,3 +63,32 @@ def slow(test_case):
"""
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
"""
Args:
Loads `image` to a PIL Image.
image (`str` or `PIL.Image.Image`):
The image to convert to the PIL Image format.
Returns:
`PIL.Image.Image`: A PIL Image.
"""
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
image = PIL.Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
image = PIL.Image.open(image)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
elif isinstance(image, PIL.Image.Image):
image = image
else:
raise ValueError(
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image

View File

@@ -22,7 +22,6 @@ import numpy as np
import torch
import PIL
from datasets import load_dataset
from diffusers import (
AutoencoderKL,
DDIMPipeline,
@@ -47,7 +46,7 @@ from diffusers import (
VQModel,
)
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.testing_utils import floats_tensor, slow, torch_device
from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@@ -168,7 +167,7 @@ class PipelineFastTests(unittest.TestCase):
@property
def dummy_safety_checker(self):
def check(images, *args, **kwargs):
return images, False
return images, [False] * len(images)
return check
@@ -708,6 +707,13 @@ class PipelineTesterMixin(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
@property
def dummy_safety_checker(self):
def check(images, *args, **kwargs):
return images, [False] * len(images)
return check
def test_from_pretrained_save_pretrained(self):
# 1. Load models
model = UNet2DModel(
@@ -1139,65 +1145,87 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_img2img_pipeline(self):
ds = load_dataset(
"imagefolder",
data_files={
"input": [
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
],
"output": [
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/fantasy_landscape.png"
],
},
def test_stable_diffusion_text2img_pipeline(self):
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/text2img/astronaut_riding_a_horse.png"
)
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
init_image = ds["input"]["image"][0].resize((768, 512))
output_image = ds["output"]["image"][0].resize((768, 512))
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
safety_checker=self.dummy_safety_checker,
use_auth_token=True,
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "astronaut riding a horse"
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(prompt=prompt, strength=0.75, guidance_scale=7.5, generator=generator, output_type="np")
image = output.images[0]
assert image.shape == (512, 512, 3)
assert np.abs(expected_image - image).max() < 1e-2
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_img2img_pipeline(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
)
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/fantasy_landscape.png"
)
init_image = init_image.resize((768, 512))
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id,
safety_checker=self.dummy_safety_checker,
use_auth_token=True,
)
pipe.to(torch_device)
pipe.enable_attention_slicing()
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "A fantasy landscape, trending on artstation"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast("cuda"):
output = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator)
output = pipe(
prompt=prompt,
init_image=init_image,
strength=0.75,
guidance_scale=7.5,
generator=generator,
output_type="np",
)
image = output.images[0]
expected_array = np.array(output_image) / 255.0
sampled_array = np.array(image) / 255.0
Image.fromarray((image * 255).round().astype("uint8")).save("fantasy_landscape.png")
assert sampled_array.shape == (512, 768, 3)
assert np.max(np.abs(sampled_array - expected_array)) < 1e-4
assert image.shape == (512, 768, 3)
assert np.abs(expected_image - image).max() < 1e-2
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_img2img_pipeline_k_lms(self):
ds = load_dataset(
"imagefolder",
data_files={
"input": [
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
],
"output": [
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/fantasy_landscape_k_lms.png"
],
},
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
)
init_image = ds["input"]["image"][0].resize((768, 512))
output_image = ds["output"]["image"][0].resize((768, 512))
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/fantasy_landscape_k_lms.png"
)
init_image = init_image.resize((768, 512))
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
@@ -1205,78 +1233,76 @@ class PipelineTesterMixin(unittest.TestCase):
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id,
scheduler=lms,
safety_checker=self.dummy_safety_checker,
use_auth_token=True,
)
pipe.enable_attention_slicing()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "A fantasy landscape, trending on artstation"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast("cuda"):
output = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator)
output = pipe(
prompt=prompt,
init_image=init_image,
strength=0.75,
guidance_scale=7.5,
generator=generator,
output_type="np",
)
image = output.images[0]
expected_array = np.array(output_image) / 255.0
sampled_array = np.array(image) / 255.0
Image.fromarray((image * 255).round().astype("uint8")).save("fantasy_landscape_k_lms.png")
assert sampled_array.shape == (512, 768, 3)
assert np.max(np.abs(sampled_array - expected_array)) < 1e-4
assert image.shape == (512, 768, 3)
assert np.abs(expected_image - image).max() < 1e-2
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_inpaint_pipeline(self):
ds = load_dataset(
"imagefolder",
data_files={
"input": [
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
],
"mask": [
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
],
"output": [
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/red_cat_sitting_on_a_parking_bench.png"
],
},
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo.png"
)
init_image = ds["input"]["image"][0].resize((768, 512))
mask_image = ds["mask"]["image"][0].resize((768, 512))
output_image = ds["output"]["image"][0].resize((768, 512))
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
)
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/in_paint/red_cat_sitting_on_a_park_bench.png"
)
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
safety_checker=self.dummy_safety_checker,
use_auth_token=True,
)
pipe.to(torch_device)
pipe.enable_attention_slicing()
pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing()
prompt = "A red cat sitting on a parking bench"
prompt = "A red cat sitting on a park bench"
generator = torch.Generator(device=torch_device).manual_seed(0)
with torch.autocast("cuda"):
output = pipe(
prompt=prompt,
init_image=init_image,
mask_image=mask_image,
strength=0.75,
guidance_scale=7.5,
generator=generator,
)
output = pipe(
prompt=prompt,
init_image=init_image,
mask_image=mask_image,
strength=0.75,
guidance_scale=7.5,
generator=generator,
output_type="np",
)
image = output.images[0]
expected_array = np.array(output_image) / 255.0
sampled_array = np.array(image) / 255.0
Image.fromarray((image * 255).round().astype("uint8")).save("red_cat_sitting_on_a_park_bench.png")
assert sampled_array.shape == (512, 768, 3)
assert np.max(np.abs(sampled_array - expected_array)) < 1e-3
assert image.shape == (512, 512, 3)
assert np.abs(expected_image - image).max() < 1e-2
@slow
def test_stable_diffusion_onnx(self):