mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Finally fix the image-based SD tests (#509)
* Finally fix the image-based SD tests * Remove autocast * Remove autocast in image tests
This commit is contained in:
@@ -2,9 +2,13 @@ import os
|
||||
import random
|
||||
import unittest
|
||||
from distutils.util import strtobool
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
import PIL.Image
|
||||
import PIL.ImageOps
|
||||
import requests
|
||||
from packaging import version
|
||||
|
||||
|
||||
@@ -59,3 +63,32 @@ def slow(test_case):
|
||||
|
||||
"""
|
||||
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|
||||
|
||||
|
||||
def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
|
||||
"""
|
||||
Args:
|
||||
Loads `image` to a PIL Image.
|
||||
image (`str` or `PIL.Image.Image`):
|
||||
The image to convert to the PIL Image format.
|
||||
Returns:
|
||||
`PIL.Image.Image`: A PIL Image.
|
||||
"""
|
||||
if isinstance(image, str):
|
||||
if image.startswith("http://") or image.startswith("https://"):
|
||||
image = PIL.Image.open(requests.get(image, stream=True).raw)
|
||||
elif os.path.isfile(image):
|
||||
image = PIL.Image.open(image)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
|
||||
)
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
image = image
|
||||
else:
|
||||
raise ValueError(
|
||||
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
|
||||
)
|
||||
image = PIL.ImageOps.exif_transpose(image)
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
@@ -22,7 +22,6 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from datasets import load_dataset
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMPipeline,
|
||||
@@ -47,7 +46,7 @@ from diffusers import (
|
||||
VQModel,
|
||||
)
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||
from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
@@ -168,7 +167,7 @@ class PipelineFastTests(unittest.TestCase):
|
||||
@property
|
||||
def dummy_safety_checker(self):
|
||||
def check(images, *args, **kwargs):
|
||||
return images, False
|
||||
return images, [False] * len(images)
|
||||
|
||||
return check
|
||||
|
||||
@@ -708,6 +707,13 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def dummy_safety_checker(self):
|
||||
def check(images, *args, **kwargs):
|
||||
return images, [False] * len(images)
|
||||
|
||||
return check
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
# 1. Load models
|
||||
model = UNet2DModel(
|
||||
@@ -1139,65 +1145,87 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
def test_stable_diffusion_img2img_pipeline(self):
|
||||
ds = load_dataset(
|
||||
"imagefolder",
|
||||
data_files={
|
||||
"input": [
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
],
|
||||
"output": [
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/fantasy_landscape.png"
|
||||
],
|
||||
},
|
||||
def test_stable_diffusion_text2img_pipeline(self):
|
||||
expected_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/text2img/astronaut_riding_a_horse.png"
|
||||
)
|
||||
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
|
||||
|
||||
init_image = ds["input"]["image"][0].resize((768, 512))
|
||||
output_image = ds["output"]["image"][0].resize((768, 512))
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
model_id,
|
||||
safety_checker=self.dummy_safety_checker,
|
||||
use_auth_token=True,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "astronaut riding a horse"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipe(prompt=prompt, strength=0.75, guidance_scale=7.5, generator=generator, output_type="np")
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
assert np.abs(expected_image - image).max() < 1e-2
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
def test_stable_diffusion_img2img_pipeline(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
expected_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/fantasy_landscape.png"
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
model_id,
|
||||
safety_checker=self.dummy_safety_checker,
|
||||
use_auth_token=True,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.enable_attention_slicing()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
with torch.autocast("cuda"):
|
||||
output = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
init_image=init_image,
|
||||
strength=0.75,
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
image = output.images[0]
|
||||
|
||||
expected_array = np.array(output_image) / 255.0
|
||||
sampled_array = np.array(image) / 255.0
|
||||
Image.fromarray((image * 255).round().astype("uint8")).save("fantasy_landscape.png")
|
||||
|
||||
assert sampled_array.shape == (512, 768, 3)
|
||||
assert np.max(np.abs(sampled_array - expected_array)) < 1e-4
|
||||
assert image.shape == (512, 768, 3)
|
||||
assert np.abs(expected_image - image).max() < 1e-2
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
def test_stable_diffusion_img2img_pipeline_k_lms(self):
|
||||
ds = load_dataset(
|
||||
"imagefolder",
|
||||
data_files={
|
||||
"input": [
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
],
|
||||
"output": [
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/fantasy_landscape_k_lms.png"
|
||||
],
|
||||
},
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
|
||||
init_image = ds["input"]["image"][0].resize((768, 512))
|
||||
output_image = ds["output"]["image"][0].resize((768, 512))
|
||||
expected_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/fantasy_landscape_k_lms.png"
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
|
||||
|
||||
lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||
|
||||
@@ -1205,78 +1233,76 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
model_id,
|
||||
scheduler=lms,
|
||||
safety_checker=self.dummy_safety_checker,
|
||||
use_auth_token=True,
|
||||
)
|
||||
pipe.enable_attention_slicing()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
with torch.autocast("cuda"):
|
||||
output = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
init_image=init_image,
|
||||
strength=0.75,
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
image = output.images[0]
|
||||
|
||||
expected_array = np.array(output_image) / 255.0
|
||||
sampled_array = np.array(image) / 255.0
|
||||
Image.fromarray((image * 255).round().astype("uint8")).save("fantasy_landscape_k_lms.png")
|
||||
|
||||
assert sampled_array.shape == (512, 768, 3)
|
||||
assert np.max(np.abs(sampled_array - expected_array)) < 1e-4
|
||||
assert image.shape == (512, 768, 3)
|
||||
assert np.abs(expected_image - image).max() < 1e-2
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
|
||||
def test_stable_diffusion_inpaint_pipeline(self):
|
||||
ds = load_dataset(
|
||||
"imagefolder",
|
||||
data_files={
|
||||
"input": [
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo.png"
|
||||
],
|
||||
"mask": [
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
],
|
||||
"output": [
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/red_cat_sitting_on_a_parking_bench.png"
|
||||
],
|
||||
},
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo.png"
|
||||
)
|
||||
|
||||
init_image = ds["input"]["image"][0].resize((768, 512))
|
||||
mask_image = ds["mask"]["image"][0].resize((768, 512))
|
||||
output_image = ds["output"]["image"][0].resize((768, 512))
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
)
|
||||
expected_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/red_cat_sitting_on_a_park_bench.png"
|
||||
)
|
||||
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
model_id,
|
||||
safety_checker=self.dummy_safety_checker,
|
||||
use_auth_token=True,
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.enable_attention_slicing()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "A red cat sitting on a parking bench"
|
||||
prompt = "A red cat sitting on a park bench"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
with torch.autocast("cuda"):
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
init_image=init_image,
|
||||
mask_image=mask_image,
|
||||
strength=0.75,
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
init_image=init_image,
|
||||
mask_image=mask_image,
|
||||
strength=0.75,
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
image = output.images[0]
|
||||
|
||||
expected_array = np.array(output_image) / 255.0
|
||||
sampled_array = np.array(image) / 255.0
|
||||
Image.fromarray((image * 255).round().astype("uint8")).save("red_cat_sitting_on_a_park_bench.png")
|
||||
|
||||
assert sampled_array.shape == (512, 768, 3)
|
||||
assert np.max(np.abs(sampled_array - expected_array)) < 1e-3
|
||||
assert image.shape == (512, 512, 3)
|
||||
assert np.abs(expected_image - image).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_stable_diffusion_onnx(self):
|
||||
|
||||
Reference in New Issue
Block a user