mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix ONNX img2img preprocessing and add fast tests coverage (#1727)
* Fix ONNX img2img preprocessing and add fast tests coverage * revert * disable progressbars
This commit is contained in:
@@ -90,7 +90,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
safety_checker: OnnxRuntimeModel
|
||||
feature_extractor: CLIPFeatureExtractor
|
||||
|
||||
_optional_components = ["safety_checker"]
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -332,13 +332,10 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if generator is None:
|
||||
generator = np.random
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
image = preprocess(image)
|
||||
image = preprocess(image).cpu().numpy()
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
|
||||
@@ -18,7 +18,15 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from diffusers import DDIMScheduler, LMSDiscreteScheduler, OnnxStableDiffusionPipeline
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
OnnxStableDiffusionPipeline,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers.utils.testing_utils import is_onnx_available, require_onnxruntime, require_torch_gpu, slow
|
||||
|
||||
from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin
|
||||
@@ -29,8 +37,95 @@ if is_onnx_available():
|
||||
|
||||
|
||||
class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.TestCase):
|
||||
# FIXME: add fast tests
|
||||
pass
|
||||
hub_checkpoint = "hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline"
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = np.random.RandomState(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_pipeline_default_ddim(self):
|
||||
pipe = OnnxStableDiffusionPipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.65072, 0.58492, 0.48219, 0.55521, 0.53180, 0.55939, 0.50697, 0.39800, 0.46455])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_pipeline_pndm(self):
|
||||
pipe = OnnxStableDiffusionPipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
|
||||
pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config, skip_prk_steps=True)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.65863, 0.59425, 0.49326, 0.56313, 0.53875, 0.56627, 0.51065, 0.39777, 0.46330])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_pipeline_lms(self):
|
||||
pipe = OnnxStableDiffusionPipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
|
||||
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.53755, 0.60786, 0.47402, 0.49488, 0.51869, 0.49819, 0.47985, 0.38957, 0.44279])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_pipeline_euler(self):
|
||||
pipe = OnnxStableDiffusionPipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
|
||||
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.53755, 0.60786, 0.47402, 0.49488, 0.51869, 0.49819, 0.47985, 0.38957, 0.44279])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_pipeline_euler_ancestral(self):
|
||||
pipe = OnnxStableDiffusionPipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
|
||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.53817, 0.60812, 0.47384, 0.49530, 0.51894, 0.49814, 0.47984, 0.38958, 0.44271])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_pipeline_dpm_multistep(self):
|
||||
pipe = OnnxStableDiffusionPipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.53895, 0.60808, 0.47933, 0.49608, 0.51886, 0.49950, 0.48053, 0.38957, 0.44200])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -13,11 +13,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from diffusers import LMSDiscreteScheduler, OnnxStableDiffusionImg2ImgPipeline
|
||||
from diffusers import (
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
OnnxStableDiffusionImg2ImgPipeline,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers.utils import floats_tensor
|
||||
from diffusers.utils.testing_utils import is_onnx_available, load_image, require_onnxruntime, require_torch_gpu, slow
|
||||
|
||||
from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin
|
||||
@@ -27,9 +36,102 @@ if is_onnx_available():
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.TestCase):
|
||||
# FIXME: add fast tests
|
||||
pass
|
||||
class OnnxStableDiffusionImg2ImgPipelineFastTests(OnnxPipelineTesterMixin, unittest.TestCase):
|
||||
hub_checkpoint = "hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline"
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
image = floats_tensor((1, 3, 128, 128), rng=random.Random(seed))
|
||||
generator = np.random.RandomState(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"image": image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"strength": 0.75,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_pipeline_default_ddim(self):
|
||||
pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.69643, 0.58484, 0.50314, 0.58760, 0.55368, 0.59643, 0.51529, 0.41217, 0.49087])
|
||||
assert np.abs(image_slice - expected_slice).max() < 1e-1
|
||||
|
||||
def test_pipeline_pndm(self):
|
||||
pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
|
||||
pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config, skip_prk_steps=True)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.61710, 0.53390, 0.49310, 0.55622, 0.50982, 0.58240, 0.50716, 0.38629, 0.46856])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
||||
|
||||
def test_pipeline_lms(self):
|
||||
pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
|
||||
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# warmup pass to apply optimizations
|
||||
_ = pipe(**self.get_dummy_inputs())
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.52761, 0.59977, 0.49033, 0.49619, 0.54282, 0.50311, 0.47600, 0.40918, 0.45203])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
||||
|
||||
def test_pipeline_euler(self):
|
||||
pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
|
||||
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.52911, 0.60004, 0.49229, 0.49805, 0.54502, 0.50680, 0.47777, 0.41028, 0.45304])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
||||
|
||||
def test_pipeline_euler_ancestral(self):
|
||||
pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
|
||||
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.52911, 0.60004, 0.49229, 0.49805, 0.54502, 0.50680, 0.47777, 0.41028, 0.45304])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
||||
|
||||
def test_pipeline_dpm_multistep(self):
|
||||
pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(self.hub_checkpoint, provider="CPUExecutionProvider")
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.65331, 0.58277, 0.48204, 0.56059, 0.53665, 0.56235, 0.50969, 0.40009, 0.46552])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
Reference in New Issue
Block a user