mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[ONNX] Improve ONNXPipeline scheduler compatibility, fix safety_checker (#1173)
* [ONNX] Improve ONNX scheduler compatibility, fix safety_checker * typo
This commit is contained in:
@@ -81,6 +81,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||
output_path = Path(output_path)
|
||||
|
||||
# TEXT ENCODER
|
||||
num_tokens = pipeline.text_encoder.config.max_position_embeddings
|
||||
text_hidden_size = pipeline.text_encoder.config.hidden_size
|
||||
text_input = pipeline.tokenizer(
|
||||
"A sample prompt",
|
||||
padding="max_length",
|
||||
@@ -103,13 +105,15 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||
del pipeline.text_encoder
|
||||
|
||||
# UNET
|
||||
unet_in_channels = pipeline.unet.config.in_channels
|
||||
unet_sample_size = pipeline.unet.config.sample_size
|
||||
unet_path = output_path / "unet" / "model.onnx"
|
||||
onnx_export(
|
||||
pipeline.unet,
|
||||
model_args=(
|
||||
torch.randn(2, pipeline.unet.in_channels, 64, 64).to(device=device, dtype=dtype),
|
||||
torch.LongTensor([0, 1]).to(device=device),
|
||||
torch.randn(2, 77, 768).to(device=device, dtype=dtype),
|
||||
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype),
|
||||
torch.randn(2).to(device=device, dtype=dtype),
|
||||
torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=unet_path,
|
||||
@@ -142,11 +146,16 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||
|
||||
# VAE ENCODER
|
||||
vae_encoder = pipeline.vae
|
||||
vae_in_channels = vae_encoder.config.in_channels
|
||||
vae_sample_size = vae_encoder.config.sample_size
|
||||
# need to get the raw tensor output (sample) from the encoder
|
||||
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample()
|
||||
onnx_export(
|
||||
vae_encoder,
|
||||
model_args=(torch.randn(1, 3, 512, 512).to(device=device, dtype=dtype), False),
|
||||
model_args=(
|
||||
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(device=device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_encoder" / "model.onnx",
|
||||
ordered_input_names=["sample", "return_dict"],
|
||||
output_names=["latent_sample"],
|
||||
@@ -158,11 +167,16 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||
|
||||
# VAE DECODER
|
||||
vae_decoder = pipeline.vae
|
||||
vae_latent_channels = vae_decoder.config.latent_channels
|
||||
vae_out_channels = vae_decoder.config.out_channels
|
||||
# forward only through the decoder part
|
||||
vae_decoder.forward = vae_encoder.decode
|
||||
onnx_export(
|
||||
vae_decoder,
|
||||
model_args=(torch.randn(1, 4, 64, 64).to(device=device, dtype=dtype), False),
|
||||
model_args=(
|
||||
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_decoder" / "model.onnx",
|
||||
ordered_input_names=["latent_sample", "return_dict"],
|
||||
output_names=["sample"],
|
||||
@@ -174,24 +188,35 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||
del pipeline.vae
|
||||
|
||||
# SAFETY CHECKER
|
||||
safety_checker = pipeline.safety_checker
|
||||
safety_checker.forward = safety_checker.forward_onnx
|
||||
onnx_export(
|
||||
pipeline.safety_checker,
|
||||
model_args=(
|
||||
torch.randn(1, 3, 224, 224).to(device=device, dtype=dtype),
|
||||
torch.randn(1, 512, 512, 3).to(device=device, dtype=dtype),
|
||||
),
|
||||
output_path=output_path / "safety_checker" / "model.onnx",
|
||||
ordered_input_names=["clip_input", "images"],
|
||||
output_names=["out_images", "has_nsfw_concepts"],
|
||||
dynamic_axes={
|
||||
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
del pipeline.safety_checker
|
||||
if pipeline.safety_checker is not None:
|
||||
safety_checker = pipeline.safety_checker
|
||||
clip_num_channels = safety_checker.config.vision_config.num_channels
|
||||
clip_image_size = safety_checker.config.vision_config.image_size
|
||||
safety_checker.forward = safety_checker.forward_onnx
|
||||
onnx_export(
|
||||
pipeline.safety_checker,
|
||||
model_args=(
|
||||
torch.randn(
|
||||
1,
|
||||
clip_num_channels,
|
||||
clip_image_size,
|
||||
clip_image_size,
|
||||
).to(device=device, dtype=dtype),
|
||||
torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(device=device, dtype=dtype),
|
||||
),
|
||||
output_path=output_path / "safety_checker" / "model.onnx",
|
||||
ordered_input_names=["clip_input", "images"],
|
||||
output_names=["out_images", "has_nsfw_concepts"],
|
||||
dynamic_axes={
|
||||
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
del pipeline.safety_checker
|
||||
safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker")
|
||||
else:
|
||||
safety_checker = None
|
||||
|
||||
onnx_pipeline = OnnxStableDiffusionPipeline(
|
||||
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
|
||||
@@ -200,7 +225,7 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
|
||||
tokenizer=pipeline.tokenizer,
|
||||
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
||||
scheduler=pipeline.scheduler,
|
||||
safety_checker=OnnxRuntimeModel.from_pretrained(output_path / "safety_checker"),
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=pipeline.feature_extractor,
|
||||
)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ import numpy as np
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from .utils import ONNX_WEIGHTS_NAME, is_onnx_available, logging
|
||||
from .utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging
|
||||
|
||||
|
||||
if is_onnx_available():
|
||||
@@ -33,13 +33,28 @@ if is_onnx_available():
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
ORT_TO_NP_TYPE = {
|
||||
"tensor(bool)": np.bool_,
|
||||
"tensor(int8)": np.int8,
|
||||
"tensor(uint8)": np.uint8,
|
||||
"tensor(int16)": np.int16,
|
||||
"tensor(uint16)": np.uint16,
|
||||
"tensor(int32)": np.int32,
|
||||
"tensor(uint32)": np.uint32,
|
||||
"tensor(int64)": np.int64,
|
||||
"tensor(uint64)": np.uint64,
|
||||
"tensor(float16)": np.float16,
|
||||
"tensor(float)": np.float32,
|
||||
"tensor(double)": np.float64,
|
||||
}
|
||||
|
||||
|
||||
class OnnxRuntimeModel:
|
||||
def __init__(self, model=None, **kwargs):
|
||||
logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.")
|
||||
self.model = model
|
||||
self.model_save_dir = kwargs.get("model_save_dir", None)
|
||||
self.latest_model_name = kwargs.get("latest_model_name", "model.onnx")
|
||||
self.latest_model_name = kwargs.get("latest_model_name", ONNX_WEIGHTS_NAME)
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
||||
@@ -84,6 +99,15 @@ class OnnxRuntimeModel:
|
||||
except shutil.SameFileError:
|
||||
pass
|
||||
|
||||
# copy external weights (for models >2GB)
|
||||
src_path = self.model_save_dir.joinpath(ONNX_EXTERNAL_WEIGHTS_NAME)
|
||||
if src_path.exists():
|
||||
dst_path = Path(save_directory).joinpath(ONNX_EXTERNAL_WEIGHTS_NAME)
|
||||
try:
|
||||
shutil.copyfile(src_path, dst_path)
|
||||
except shutil.SameFileError:
|
||||
pass
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
|
||||
@@ -541,7 +541,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
if name in passed_class_obj:
|
||||
# 1. check that passed_class_obj has correct parent class
|
||||
if not is_pipeline_module:
|
||||
if not is_pipeline_module and passed_class_obj[name] is not None:
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
|
||||
@@ -2,11 +2,12 @@ import inspect
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...onnx_utils import OnnxRuntimeModel
|
||||
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...utils import deprecate, logging
|
||||
@@ -186,7 +187,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
latents = latents * np.float(self.scheduler.init_noise_sigma)
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
@@ -197,15 +198,20 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
timestep_dtype = next(
|
||||
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
|
||||
)
|
||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
|
||||
latent_model_input = latent_model_input.cpu().numpy()
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
|
||||
)
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
noise_pred = self.unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings)
|
||||
noise_pred = noise_pred[0]
|
||||
|
||||
# perform guidance
|
||||
@@ -214,7 +220,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
|
||||
latents = np.array(latents)
|
||||
|
||||
# call the callback, if provided
|
||||
@@ -235,6 +241,9 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
||||
safety_checker_input = self.feature_extractor(
|
||||
self.numpy_to_pil(image), return_tensors="np"
|
||||
).pixel_values.astype(image.dtype)
|
||||
|
||||
image, has_nsfw_concepts = self.safety_checker(clip_input=safety_checker_input, images=image)
|
||||
|
||||
# There will throw an error if use safety_checker batchsize>1
|
||||
images, has_nsfw_concept = [], []
|
||||
for i in range(image.shape[0]):
|
||||
|
||||
@@ -8,7 +8,7 @@ import PIL
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...onnx_utils import OnnxRuntimeModel
|
||||
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...utils import deprecate, logging
|
||||
@@ -338,14 +338,21 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:].numpy()
|
||||
|
||||
timestep_dtype = next(
|
||||
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
|
||||
)
|
||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
|
||||
latent_model_input = latent_model_input.cpu().numpy()
|
||||
|
||||
# predict the noise residual
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
|
||||
sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
@@ -354,7 +361,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
|
||||
latents = latents.numpy()
|
||||
|
||||
# call the callback, if provided
|
||||
@@ -375,7 +382,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
safety_checker_input = self.feature_extractor(
|
||||
self.numpy_to_pil(image), return_tensors="np"
|
||||
).pixel_values.astype(image.dtype)
|
||||
# There will throw an error if use safety_checker batchsize>1
|
||||
# safety_checker does not support batched inputs yet
|
||||
images, has_nsfw_concept = [], []
|
||||
for i in range(image.shape[0]):
|
||||
image_i, has_nsfw_concept_i = self.safety_checker(
|
||||
|
||||
@@ -8,7 +8,7 @@ import PIL
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...onnx_utils import OnnxRuntimeModel
|
||||
from ...onnx_utils import ORT_TO_NP_TYPE, OnnxRuntimeModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...utils import deprecate, logging
|
||||
@@ -352,7 +352,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
latents = latents * np.float(self.scheduler.init_noise_sigma)
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
@@ -363,17 +363,23 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
timestep_dtype = next(
|
||||
(input.type for input in self.unet.model.get_inputs() if input.name == "timestep"), "tensor(float)"
|
||||
)
|
||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
||||
# concat latents, mask, masked_image_latnets in the channel dimension
|
||||
latent_model_input = np.concatenate([latent_model_input, mask, masked_image_latents], axis=1)
|
||||
latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
|
||||
latent_model_input = latent_model_input.numpy()
|
||||
latent_model_input = latent_model_input.cpu().numpy()
|
||||
|
||||
# predict the noise residual
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
|
||||
sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
@@ -382,7 +388,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
latents = self.scheduler.step(noise_pred, t, torch.from_numpy(latents), **extra_step_kwargs).prev_sample
|
||||
latents = latents.numpy()
|
||||
|
||||
# call the callback, if provided
|
||||
@@ -403,7 +409,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
safety_checker_input = self.feature_extractor(
|
||||
self.numpy_to_pil(image), return_tensors="np"
|
||||
).pixel_values.astype(image.dtype)
|
||||
# There will throw an error if use safety_checker batchsize>1
|
||||
# safety_checker does not support batched inputs yet
|
||||
images, has_nsfw_concept = [], []
|
||||
for i in range(image.shape[0]):
|
||||
image_i, has_nsfw_concept_i = self.safety_checker(
|
||||
|
||||
@@ -67,6 +67,7 @@ CONFIG_NAME = "config.json"
|
||||
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
|
||||
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
||||
DIFFUSERS_CACHE = default_cache_path
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
|
||||
@@ -13,11 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from diffusers import OnnxStableDiffusionPipeline
|
||||
from diffusers import DDIMScheduler, LMSDiscreteScheduler, OnnxStableDiffusionPipeline
|
||||
from diffusers.utils.testing_utils import is_onnx_available, require_onnxruntime, require_torch_gpu, slow
|
||||
|
||||
from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin
|
||||
@@ -36,32 +37,87 @@ class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.Tes
|
||||
@require_onnxruntime
|
||||
@require_torch_gpu
|
||||
class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference(self):
|
||||
provider = (
|
||||
@property
|
||||
def gpu_provider(self):
|
||||
return (
|
||||
"CUDAExecutionProvider",
|
||||
{
|
||||
"gpu_mem_limit": "17179869184", # 16GB.
|
||||
"gpu_mem_limit": "15000000000", # 15GB
|
||||
"arena_extend_strategy": "kSameAsRequested",
|
||||
},
|
||||
)
|
||||
|
||||
@property
|
||||
def gpu_options(self):
|
||||
options = ort.SessionOptions()
|
||||
options.enable_mem_pattern = False
|
||||
return options
|
||||
|
||||
def test_inference_default_pndm(self):
|
||||
# using the PNDM scheduler by default
|
||||
sd_pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
revision="onnx",
|
||||
provider=provider,
|
||||
sess_options=options,
|
||||
provider=self.gpu_provider,
|
||||
sess_options=self.gpu_options,
|
||||
)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
np.random.seed(0)
|
||||
output = sd_pipe([prompt], guidance_scale=6.0, num_inference_steps=5, output_type="np")
|
||||
output = sd_pipe([prompt], guidance_scale=6.0, num_inference_steps=10, output_type="np")
|
||||
image = output.images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.3602, 0.3688, 0.3652, 0.3895, 0.3782, 0.3747, 0.3927, 0.4241, 0.4327])
|
||||
expected_slice = np.array([0.0452, 0.0390, 0.0087, 0.0350, 0.0617, 0.0364, 0.0544, 0.0523, 0.0720])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_inference_ddim(self):
|
||||
ddim_scheduler = DDIMScheduler.from_config(
|
||||
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
|
||||
)
|
||||
sd_pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="onnx",
|
||||
scheduler=ddim_scheduler,
|
||||
provider=self.gpu_provider,
|
||||
sess_options=self.gpu_options,
|
||||
)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "open neural network exchange"
|
||||
generator = np.random.RandomState(0)
|
||||
output = sd_pipe([prompt], guidance_scale=7.5, num_inference_steps=10, generator=generator, output_type="np")
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.2867, 0.1974, 0.1481, 0.7294, 0.7251, 0.6667, 0.4194, 0.5642, 0.6486])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_inference_k_lms(self):
|
||||
lms_scheduler = LMSDiscreteScheduler.from_config(
|
||||
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
|
||||
)
|
||||
sd_pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="onnx",
|
||||
scheduler=lms_scheduler,
|
||||
provider=self.gpu_provider,
|
||||
sess_options=self.gpu_options,
|
||||
)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "open neural network exchange"
|
||||
generator = np.random.RandomState(0)
|
||||
output = sd_pipe([prompt], guidance_scale=7.5, num_inference_steps=10, generator=generator, output_type="np")
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.2306, 0.1959, 0.1593, 0.6549, 0.6394, 0.5408, 0.5065, 0.6010, 0.6161])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_intermediate_state(self):
|
||||
@@ -75,27 +131,61 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array(
|
||||
[-0.5950, -0.3039, -1.1672, 0.1594, -1.1572, 0.6719, -1.9712, -0.0403, 0.9592]
|
||||
[-0.6772, -0.3835, -1.2456, 0.1905, -1.0974, 0.6967, -1.9353, 0.0178, 1.0167]
|
||||
)
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
elif step == 5:
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array(
|
||||
[-0.4776, -0.0119, -0.8519, -0.0275, -0.9764, 0.9820, -0.3843, 0.3788, 1.2264]
|
||||
[-0.3351, 0.2241, -0.1837, -0.2325, -0.6577, 0.3393, -0.0241, 0.5899, 1.3875]
|
||||
)
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
test_callback_fn.has_been_called = False
|
||||
|
||||
pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", revision="onnx", provider="CUDAExecutionProvider"
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="onnx",
|
||||
provider=self.gpu_provider,
|
||||
sess_options=self.gpu_options,
|
||||
)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "Andromeda galaxy in a bottle"
|
||||
|
||||
np.random.seed(0)
|
||||
pipe(prompt=prompt, num_inference_steps=5, guidance_scale=7.5, callback=test_callback_fn, callback_steps=1)
|
||||
generator = np.random.RandomState(0)
|
||||
pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=5,
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
callback=test_callback_fn,
|
||||
callback_steps=1,
|
||||
)
|
||||
assert test_callback_fn.has_been_called
|
||||
assert number_of_steps == 6
|
||||
|
||||
def test_stable_diffusion_no_safety_checker(self):
|
||||
pipe = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="onnx",
|
||||
provider=self.gpu_provider,
|
||||
sess_options=self.gpu_options,
|
||||
safety_checker=None,
|
||||
)
|
||||
assert isinstance(pipe, OnnxStableDiffusionPipeline)
|
||||
assert pipe.safety_checker is None
|
||||
|
||||
image = pipe("example prompt", num_inference_steps=2).images[0]
|
||||
assert image is not None
|
||||
|
||||
# check that there's no error when saving a pipeline with one of the models being None
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
pipe = OnnxStableDiffusionPipeline.from_pretrained(tmpdirname)
|
||||
|
||||
# sanity check that the pipeline still works
|
||||
assert pipe.safety_checker is None
|
||||
image = pipe("example prompt", num_inference_steps=2).images[0]
|
||||
assert image is not None
|
||||
|
||||
@@ -17,7 +17,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from diffusers import OnnxStableDiffusionImg2ImgPipeline
|
||||
from diffusers import LMSDiscreteScheduler, OnnxStableDiffusionImg2ImgPipeline
|
||||
from diffusers.utils.testing_utils import is_onnx_available, load_image, require_onnxruntime, require_torch_gpu, slow
|
||||
|
||||
from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin
|
||||
@@ -35,45 +35,92 @@ class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.Tes
|
||||
@slow
|
||||
@require_onnxruntime
|
||||
@require_torch_gpu
|
||||
class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference(self):
|
||||
class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase):
|
||||
@property
|
||||
def gpu_provider(self):
|
||||
return (
|
||||
"CUDAExecutionProvider",
|
||||
{
|
||||
"gpu_mem_limit": "15000000000", # 15GB
|
||||
"arena_extend_strategy": "kSameAsRequested",
|
||||
},
|
||||
)
|
||||
|
||||
@property
|
||||
def gpu_options(self):
|
||||
options = ort.SessionOptions()
|
||||
options.enable_mem_pattern = False
|
||||
return options
|
||||
|
||||
def test_inference_default_pndm(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
provider = (
|
||||
"CUDAExecutionProvider",
|
||||
{
|
||||
"gpu_mem_limit": "17179869184", # 16GB.
|
||||
"arena_extend_strategy": "kSameAsRequested",
|
||||
},
|
||||
)
|
||||
options = ort.SessionOptions()
|
||||
options.enable_mem_pattern = False
|
||||
# using the PNDM scheduler by default
|
||||
pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
revision="onnx",
|
||||
provider=provider,
|
||||
sess_options=options,
|
||||
provider=self.gpu_provider,
|
||||
sess_options=self.gpu_options,
|
||||
)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
np.random.seed(0)
|
||||
generator = np.random.RandomState(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
init_image=init_image,
|
||||
strength=0.75,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=8,
|
||||
num_inference_steps=10,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
images = output.images
|
||||
image_slice = images[0, 255:258, 383:386, -1]
|
||||
|
||||
assert images.shape == (1, 512, 768, 3)
|
||||
expected_slice = np.array([0.4830, 0.5242, 0.5603, 0.5016, 0.5131, 0.5111, 0.4928, 0.5025, 0.5055])
|
||||
expected_slice = np.array([0.4909, 0.5059, 0.5372, 0.4623, 0.4876, 0.5049, 0.4820, 0.4956, 0.5019])
|
||||
# TODO: lower the tolerance after finding the cause of onnxruntime reproducibility issues
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
|
||||
|
||||
def test_inference_k_lms(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
lms_scheduler = LMSDiscreteScheduler.from_config(
|
||||
"runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx"
|
||||
)
|
||||
pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
revision="onnx",
|
||||
scheduler=lms_scheduler,
|
||||
provider=self.gpu_provider,
|
||||
sess_options=self.gpu_options,
|
||||
)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
generator = np.random.RandomState(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
init_image=init_image,
|
||||
strength=0.75,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=10,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
images = output.images
|
||||
image_slice = images[0, 255:258, 383:386, -1]
|
||||
|
||||
assert images.shape == (1, 512, 768, 3)
|
||||
expected_slice = np.array([0.7950, 0.7923, 0.7903, 0.5516, 0.5501, 0.5476, 0.4965, 0.4933, 0.4910])
|
||||
# TODO: lower the tolerance after finding the cause of onnxruntime reproducibility issues
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
|
||||
|
||||
@@ -17,7 +17,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from diffusers import OnnxStableDiffusionInpaintPipeline
|
||||
from diffusers import LMSDiscreteScheduler, OnnxStableDiffusionInpaintPipeline
|
||||
from diffusers.utils.testing_utils import is_onnx_available, load_image, require_onnxruntime, require_torch_gpu, slow
|
||||
|
||||
from ...test_pipelines_onnx_common import OnnxPipelineTesterMixin
|
||||
@@ -35,8 +35,24 @@ class OnnxStableDiffusionPipelineFastTests(OnnxPipelineTesterMixin, unittest.Tes
|
||||
@slow
|
||||
@require_onnxruntime
|
||||
@require_torch_gpu
|
||||
class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_stable_diffusion_inpaint_onnx(self):
|
||||
class OnnxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
|
||||
@property
|
||||
def gpu_provider(self):
|
||||
return (
|
||||
"CUDAExecutionProvider",
|
||||
{
|
||||
"gpu_mem_limit": "15000000000", # 15GB
|
||||
"arena_extend_strategy": "kSameAsRequested",
|
||||
},
|
||||
)
|
||||
|
||||
@property
|
||||
def gpu_options(self):
|
||||
options = ort.SessionOptions()
|
||||
options.enable_mem_pattern = False
|
||||
return options
|
||||
|
||||
def test_inference_default_pndm(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo.png"
|
||||
@@ -45,37 +61,69 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
)
|
||||
provider = (
|
||||
"CUDAExecutionProvider",
|
||||
{
|
||||
"gpu_mem_limit": "17179869184", # 16GB.
|
||||
"arena_extend_strategy": "kSameAsRequested",
|
||||
},
|
||||
)
|
||||
options = ort.SessionOptions()
|
||||
options.enable_mem_pattern = False
|
||||
pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
revision="onnx",
|
||||
provider=provider,
|
||||
sess_options=options,
|
||||
provider=self.gpu_provider,
|
||||
sess_options=self.gpu_options,
|
||||
)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A red cat sitting on a park bench"
|
||||
|
||||
np.random.seed(0)
|
||||
generator = np.random.RandomState(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=8,
|
||||
num_inference_steps=10,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
images = output.images
|
||||
image_slice = images[0, 255:258, 255:258, -1]
|
||||
|
||||
assert images.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.2951, 0.2955, 0.2922, 0.2036, 0.1977, 0.2279, 0.1716, 0.1641, 0.1799])
|
||||
expected_slice = np.array([0.2514, 0.3007, 0.3517, 0.1790, 0.2382, 0.3167, 0.1944, 0.2273, 0.2464])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_inference_k_lms(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
)
|
||||
lms_scheduler = LMSDiscreteScheduler.from_config(
|
||||
"runwayml/stable-diffusion-inpainting", subfolder="scheduler", revision="onnx"
|
||||
)
|
||||
pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
revision="onnx",
|
||||
scheduler=lms_scheduler,
|
||||
provider=self.gpu_provider,
|
||||
sess_options=self.gpu_options,
|
||||
)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A red cat sitting on a park bench"
|
||||
|
||||
generator = np.random.RandomState(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=10,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
images = output.images
|
||||
image_slice = images[0, 255:258, 255:258, -1]
|
||||
|
||||
assert images.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.2520, 0.2743, 0.2643, 0.2641, 0.2517, 0.2650, 0.2498, 0.2688, 0.2529])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
Reference in New Issue
Block a user