1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Postprocessing refactor img2img (#3268)

* refactor img2img VaeImageProcessor.postprocess

* remove copy from for init, run_safety_checker, decode_latents

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

---------

Co-authored-by: yiyixuxu <yixu@yis-macbook-pro.lan>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
YiYi Xu
2023-05-01 07:54:09 -10:00
committed by Daniel Gu
parent 6a84a7439d
commit 863bb75ea9
7 changed files with 198 additions and 125 deletions

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import warnings
from typing import Union
from typing import List, Optional, Union
import numpy as np
import PIL
@@ -21,7 +21,7 @@ import torch
from PIL import Image
from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME, PIL_INTERPOLATION
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
class VaeImageProcessor(ConfigMixin):
@@ -82,7 +82,7 @@ class VaeImageProcessor(ConfigMixin):
@staticmethod
def pt_to_numpy(images):
"""
Convert a numpy image to a pytorch tensor
Convert a pytorch tensor to a numpy image
"""
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
return images
@@ -94,6 +94,13 @@ class VaeImageProcessor(ConfigMixin):
"""
return 2.0 * images - 1.0
@staticmethod
def denormalize(images):
"""
Denormalize an image array to [0,1]
"""
return (images / 2 + 0.5).clamp(0, 1)
def resize(self, images: PIL.Image.Image) -> PIL.Image.Image:
"""
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
@@ -165,17 +172,39 @@ class VaeImageProcessor(ConfigMixin):
def postprocess(
self,
image,
image: torch.FloatTensor,
output_type: str = "pil",
do_denormalize: Optional[List[bool]] = None,
):
if isinstance(image, torch.Tensor) and output_type == "pt":
if not isinstance(image, torch.Tensor):
raise ValueError(
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
)
if output_type not in ["latent", "pt", "np", "pil"]:
deprecation_message = (
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
"`pil`, `np`, `pt`, `latent`"
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"
if output_type == "latent":
return image
if do_denormalize is None:
do_denormalize = [self.config.do_normalize] * image.shape[0]
image = torch.stack(
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
)
if output_type == "pt":
return image
image = self.pt_to_numpy(image)
if output_type == "np":
return image
elif output_type == "pil":
if output_type == "pil":
return self.numpy_to_pil(image)
else:
raise ValueError(f"Unsupported output_type {output_type}.")

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
@@ -202,6 +203,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -212,11 +214,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(
requires_safety_checker=requires_safety_checker,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
@@ -436,17 +435,32 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
return prompt_embeds
def run_safety_checker(self, image, device, dtype):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
def decode_latents(self, latents):
warnings.warn(
(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead"
),
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def prepare_extra_step_kwargs(self, generator, eta):
@@ -730,27 +744,19 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if output_type not in ["latent", "pt", "np", "pil"]:
deprecation_message = (
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
"`pil`, `np`, `pt`, `latent`"
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"
if output_type == "latent":
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
image = self.decode_latents(latents)
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
if self.safety_checker is not None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
has_nsfw_concept = False
image = self.image_processor.postprocess(image, output_type=output_type)
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import inspect
import warnings
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
@@ -205,6 +206,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -215,11 +217,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(
requires_safety_checker=requires_safety_checker,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, gpu_id=0):
@@ -443,17 +442,30 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
return prompt_embeds
def run_safety_checker(self, image, device, dtype):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
def decode_latents(self, latents):
warnings.warn(
"The decode_latents method is deprecated and will be removed in a future version. Please"
" use VaeImageProcessor instead",
FutureWarning,
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
@@ -738,27 +750,19 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if output_type not in ["latent", "pt", "np", "pil"]:
deprecation_message = (
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
"`pil`, `np`, `pt`, `latent`"
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"
if output_type == "latent":
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
image = self.decode_latents(latents)
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
if self.safety_checker is not None:
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
has_nsfw_concept = False
image = self.image_processor.postprocess(image, output_type=output_type)
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:

View File

@@ -42,7 +42,7 @@ class ImageProcessorTest(unittest.TestCase):
return image
def test_vae_image_processor_pt(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
input_pt = self.dummy_sample
input_np = self.to_np(input_pt)
@@ -59,7 +59,7 @@ class ImageProcessorTest(unittest.TestCase):
), f"decoded output does not match input for output_type {output_type}"
def test_vae_image_processor_np(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1)
for output_type in ["pt", "np", "pil"]:
@@ -72,7 +72,7 @@ class ImageProcessorTest(unittest.TestCase):
), f"decoded output does not match input for output_type {output_type}"
def test_vae_image_processor_pil(self):
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1)
input_pil = image_processor.numpy_to_pil(input_np)

View File

@@ -22,6 +22,10 @@ TEXT_TO_IMAGE_PARAMS = frozenset(
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
IMAGE_VARIATION_PARAMS = frozenset(
[
"image",

View File

@@ -35,18 +35,23 @@ from diffusers.image_processor import VaeImageProcessor
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
)
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionImg2ImgPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
@@ -96,33 +101,19 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
}
return components
def get_dummy_inputs(self, device, seed=0, input_image_type="pt", output_type="np"):
def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
if input_image_type == "pt":
input_image = image
elif input_image_type == "np":
input_image = image.cpu().numpy().transpose(0, 2, 3, 1)
elif input_image_type == "pil":
input_image = image.cpu().numpy().transpose(0, 2, 3, 1)
input_image = VaeImageProcessor.numpy_to_pil(input_image)
else:
raise ValueError(f"unsupported input_image_type {input_image_type}.")
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"unsupported output_type {output_type}")
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"image": input_image,
"image": image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": output_type,
"output_type": "numpy",
}
return inputs
@@ -130,11 +121,12 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["image"] = inputs["image"] / 2 + 0.5
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
@@ -147,11 +139,12 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["image"] = inputs["image"] / 2 + 0.5
negative_prompt = "french fries"
output = sd_pipe(**inputs, negative_prompt=negative_prompt)
image = output.images
@@ -166,13 +159,14 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"]] * 2
inputs["image"] = inputs["image"].repeat(2, 1, 1, 1)
inputs["image"] = inputs["image"] / 2 + 0.5
image = sd_pipe(**inputs).images
image_slice = image[-1, -3:, -3:, -1]
@@ -188,11 +182,12 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False)
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["image"] = inputs["image"] / 2 + 0.5
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
@@ -217,36 +212,6 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
def test_attention_slicing_forward_pass(self):
return super().test_attention_slicing_forward_pass()
@skip_mps
def test_pt_np_pil_outputs_equivalent(self):
device = "cpu"
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
output_pt = sd_pipe(**self.get_dummy_inputs(device, output_type="pt"))[0]
output_np = sd_pipe(**self.get_dummy_inputs(device, output_type="np"))[0]
output_pil = sd_pipe(**self.get_dummy_inputs(device, output_type="pil"))[0]
assert np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max() <= 1e-4
assert np.abs(np.array(output_pil[0]) - (output_np * 255).round()).max() <= 1e-4
@skip_mps
def test_image_types_consistent(self):
device = "cpu"
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
output_pt = sd_pipe(**self.get_dummy_inputs(device, input_image_type="pt"))[0]
output_np = sd_pipe(**self.get_dummy_inputs(device, input_image_type="np"))[0]
output_pil = sd_pipe(**self.get_dummy_inputs(device, input_image_type="pil"))[0]
assert np.abs(output_pt - output_np).max() <= 1e-4
assert np.abs(output_pil - output_np).max() <= 1e-2
@slow
@require_torch_gpu

View File

@@ -12,6 +12,7 @@ import torch
import diffusers
from diffusers import DiffusionPipeline
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils import logging
from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
from diffusers.utils.testing_utils import require_torch, torch_device
@@ -27,6 +28,78 @@ def to_np(tensor):
return tensor
class PipelineLatentTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
It provides a set of common tests for PyTorch pipeline that has vae, e.g.
equivalence of different input and output types, etc.
"""
@property
def image_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `image_params` in the child test class. "
"`image_params` are tested for if all accepted input image types (i.e. `pt`,`pil`,`np`) are producing same results"
)
def get_dummy_inputs_by_type(self, device, seed=0, input_image_type="pt", output_type="np"):
inputs = self.get_dummy_inputs(device, seed)
def convert_pt_to_type(image, input_image_type):
if input_image_type == "pt":
input_image = image
elif input_image_type == "np":
input_image = VaeImageProcessor.pt_to_numpy(image)
elif input_image_type == "pil":
input_image = VaeImageProcessor.pt_to_numpy(image)
input_image = VaeImageProcessor.numpy_to_pil(input_image)
else:
raise ValueError(f"unsupported input_image_type {input_image_type}.")
return input_image
for image_param in self.image_params:
if image_param in inputs.keys():
inputs[image_param] = convert_pt_to_type(inputs[image_param], input_image_type)
inputs["output_type"] = output_type
return inputs
def test_pt_np_pil_outputs_equivalent(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
output_pt = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="pt"))[0]
output_np = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="np"))[0]
output_pil = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="pil"))[0]
max_diff = np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max()
self.assertLess(max_diff, 1e-4, "`output_type=='pt'` generate different results from `output_type=='np'`")
max_diff = np.abs(np.array(output_pil[0]) - (output_np * 255).round()).max()
self.assertLess(max_diff, 1e-4, "`output_type=='pil'` generate different results from `output_type=='np'`")
def test_pt_np_pil_inputs_equivalent(self):
if len(self.image_params) == 0:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
out_input_pt = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0]
out_input_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
out_input_pil = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pil"))[0]
max_diff = np.abs(out_input_pt - out_input_np).max()
self.assertLess(max_diff, 1e-4, "`input_type=='pt'` generate different result from `input_type=='np'`")
max_diff = np.abs(out_input_pil - out_input_np).max()
self.assertLess(max_diff, 1e-2, "`input_type=='pt'` generate different result from `input_type=='np'`")
@require_torch
class PipelineTesterMixin:
"""
@@ -339,9 +412,6 @@ class PipelineTesterMixin:
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
def test_float16_inference(self):
self._test_float16_inference()
def _test_float16_inference(self, expected_max_diff=1e-2):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
@@ -355,13 +425,10 @@ class PipelineTesterMixin:
output_fp16 = pipe_fp16(**self.get_dummy_inputs(torch_device))[0]
max_diff = np.abs(to_np(output) - to_np(output_fp16)).max()
self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.")
self.assertLess(max_diff, 1e-2, "The outputs of the fp16 and fp32 pipelines are too different.")
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
def test_save_load_float16(self):
self._test_save_load_float16()
def _test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components()
for name, module in components.items():
if hasattr(module, "half"):
@@ -390,9 +457,7 @@ class PipelineTesterMixin:
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
self.assertLess(
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
)
self.assertLess(max_diff, 1e-2, "The output of the fp16 pipeline changed after saving and loading.")
def test_save_load_optional_components(self):
if not hasattr(self.pipeline_class, "_optional_components"):