mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Postprocessing refactor img2img (#3268)
* refactor img2img VaeImageProcessor.postprocess * remove copy from for init, run_safety_checker, decode_latents Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> --------- Co-authored-by: yiyixuxu <yixu@yis-macbook-pro.lan> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
from .configuration_utils import ConfigMixin, register_to_config
|
||||
from .utils import CONFIG_NAME, PIL_INTERPOLATION
|
||||
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
|
||||
|
||||
|
||||
class VaeImageProcessor(ConfigMixin):
|
||||
@@ -82,7 +82,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
@staticmethod
|
||||
def pt_to_numpy(images):
|
||||
"""
|
||||
Convert a numpy image to a pytorch tensor
|
||||
Convert a pytorch tensor to a numpy image
|
||||
"""
|
||||
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return images
|
||||
@@ -94,6 +94,13 @@ class VaeImageProcessor(ConfigMixin):
|
||||
"""
|
||||
return 2.0 * images - 1.0
|
||||
|
||||
@staticmethod
|
||||
def denormalize(images):
|
||||
"""
|
||||
Denormalize an image array to [0,1]
|
||||
"""
|
||||
return (images / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
def resize(self, images: PIL.Image.Image) -> PIL.Image.Image:
|
||||
"""
|
||||
Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor`
|
||||
@@ -165,17 +172,39 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
image,
|
||||
image: torch.FloatTensor,
|
||||
output_type: str = "pil",
|
||||
do_denormalize: Optional[List[bool]] = None,
|
||||
):
|
||||
if isinstance(image, torch.Tensor) and output_type == "pt":
|
||||
if not isinstance(image, torch.Tensor):
|
||||
raise ValueError(
|
||||
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
||||
)
|
||||
if output_type not in ["latent", "pt", "np", "pil"]:
|
||||
deprecation_message = (
|
||||
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
||||
"`pil`, `np`, `pt`, `latent`"
|
||||
)
|
||||
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
||||
output_type = "np"
|
||||
|
||||
if output_type == "latent":
|
||||
return image
|
||||
|
||||
if do_denormalize is None:
|
||||
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
||||
|
||||
image = torch.stack(
|
||||
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
|
||||
)
|
||||
|
||||
if output_type == "pt":
|
||||
return image
|
||||
|
||||
image = self.pt_to_numpy(image)
|
||||
|
||||
if output_type == "np":
|
||||
return image
|
||||
elif output_type == "pil":
|
||||
|
||||
if output_type == "pil":
|
||||
return self.numpy_to_pil(image)
|
||||
else:
|
||||
raise ValueError(f"Unsupported output_type {output_type}.")
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -202,6 +203,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -212,11 +214,8 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
@@ -436,17 +435,32 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
return prompt_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
||||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
if torch.is_tensor(image):
|
||||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
||||
else:
|
||||
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
||||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
def decode_latents(self, latents):
|
||||
warnings.warn(
|
||||
(
|
||||
"The decode_latents method is deprecated and will be removed in a future version. Please"
|
||||
" use VaeImageProcessor instead"
|
||||
),
|
||||
FutureWarning,
|
||||
)
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
@@ -730,27 +744,19 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
if output_type not in ["latent", "pt", "np", "pil"]:
|
||||
deprecation_message = (
|
||||
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
|
||||
"`pil`, `np`, `pt`, `latent`"
|
||||
)
|
||||
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
||||
output_type = "np"
|
||||
|
||||
if output_type == "latent":
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
|
||||
if has_nsfw_concept is None:
|
||||
do_denormalize = [True] * image.shape[0]
|
||||
else:
|
||||
image = self.decode_latents(latents)
|
||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||
|
||||
if self.safety_checker is not None:
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
has_nsfw_concept = False
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -205,6 +206,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -215,11 +217,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.register_to_config(
|
||||
requires_safety_checker=requires_safety_checker,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
@@ -443,17 +442,30 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
return prompt_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
||||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
if self.safety_checker is None:
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
if torch.is_tensor(image):
|
||||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
||||
else:
|
||||
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
||||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
return image, has_nsfw_concept
|
||||
|
||||
def decode_latents(self, latents):
|
||||
warnings.warn(
|
||||
"The decode_latents method is deprecated and will be removed in a future version. Please"
|
||||
" use VaeImageProcessor instead",
|
||||
FutureWarning,
|
||||
)
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
@@ -738,27 +750,19 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
if output_type not in ["latent", "pt", "np", "pil"]:
|
||||
deprecation_message = (
|
||||
f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: "
|
||||
"`pil`, `np`, `pt`, `latent`"
|
||||
)
|
||||
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
||||
output_type = "np"
|
||||
|
||||
if output_type == "latent":
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor).sample
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
|
||||
if has_nsfw_concept is None:
|
||||
do_denormalize = [True] * image.shape[0]
|
||||
else:
|
||||
image = self.decode_latents(latents)
|
||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||
|
||||
if self.safety_checker is not None:
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
has_nsfw_concept = False
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
|
||||
@@ -42,7 +42,7 @@ class ImageProcessorTest(unittest.TestCase):
|
||||
return image
|
||||
|
||||
def test_vae_image_processor_pt(self):
|
||||
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
|
||||
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
|
||||
|
||||
input_pt = self.dummy_sample
|
||||
input_np = self.to_np(input_pt)
|
||||
@@ -59,7 +59,7 @@ class ImageProcessorTest(unittest.TestCase):
|
||||
), f"decoded output does not match input for output_type {output_type}"
|
||||
|
||||
def test_vae_image_processor_np(self):
|
||||
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
|
||||
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
|
||||
input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1)
|
||||
|
||||
for output_type in ["pt", "np", "pil"]:
|
||||
@@ -72,7 +72,7 @@ class ImageProcessorTest(unittest.TestCase):
|
||||
), f"decoded output does not match input for output_type {output_type}"
|
||||
|
||||
def test_vae_image_processor_pil(self):
|
||||
image_processor = VaeImageProcessor(do_resize=False, do_normalize=False)
|
||||
image_processor = VaeImageProcessor(do_resize=False, do_normalize=True)
|
||||
|
||||
input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1)
|
||||
input_pil = image_processor.numpy_to_pil(input_np)
|
||||
|
||||
@@ -22,6 +22,10 @@ TEXT_TO_IMAGE_PARAMS = frozenset(
|
||||
|
||||
TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([])
|
||||
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"])
|
||||
|
||||
IMAGE_VARIATION_PARAMS = frozenset(
|
||||
[
|
||||
"image",
|
||||
|
||||
@@ -35,18 +35,23 @@ from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
|
||||
|
||||
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionImg2ImgPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
@@ -96,33 +101,19 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0, input_image_type="pt", output_type="np"):
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
if input_image_type == "pt":
|
||||
input_image = image
|
||||
elif input_image_type == "np":
|
||||
input_image = image.cpu().numpy().transpose(0, 2, 3, 1)
|
||||
elif input_image_type == "pil":
|
||||
input_image = image.cpu().numpy().transpose(0, 2, 3, 1)
|
||||
input_image = VaeImageProcessor.numpy_to_pil(input_image)
|
||||
else:
|
||||
raise ValueError(f"unsupported input_image_type {input_image_type}.")
|
||||
|
||||
if output_type not in ["pt", "np", "pil"]:
|
||||
raise ValueError(f"unsupported output_type {output_type}")
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"image": input_image,
|
||||
"image": image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": output_type,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
@@ -130,11 +121,12 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False)
|
||||
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["image"] = inputs["image"] / 2 + 0.5
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
@@ -147,11 +139,12 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False)
|
||||
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["image"] = inputs["image"] / 2 + 0.5
|
||||
negative_prompt = "french fries"
|
||||
output = sd_pipe(**inputs, negative_prompt=negative_prompt)
|
||||
image = output.images
|
||||
@@ -166,13 +159,14 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False)
|
||||
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["prompt"] = [inputs["prompt"]] * 2
|
||||
inputs["image"] = inputs["image"].repeat(2, 1, 1, 1)
|
||||
inputs["image"] = inputs["image"] / 2 + 0.5
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[-1, -3:, -3:, -1]
|
||||
|
||||
@@ -188,11 +182,12 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
||||
)
|
||||
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False)
|
||||
sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=True)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["image"] = inputs["image"] / 2 + 0.5
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
@@ -217,36 +212,6 @@ class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.Test
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
return super().test_attention_slicing_forward_pass()
|
||||
|
||||
@skip_mps
|
||||
def test_pt_np_pil_outputs_equivalent(self):
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
output_pt = sd_pipe(**self.get_dummy_inputs(device, output_type="pt"))[0]
|
||||
output_np = sd_pipe(**self.get_dummy_inputs(device, output_type="np"))[0]
|
||||
output_pil = sd_pipe(**self.get_dummy_inputs(device, output_type="pil"))[0]
|
||||
|
||||
assert np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max() <= 1e-4
|
||||
assert np.abs(np.array(output_pil[0]) - (output_np * 255).round()).max() <= 1e-4
|
||||
|
||||
@skip_mps
|
||||
def test_image_types_consistent(self):
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
output_pt = sd_pipe(**self.get_dummy_inputs(device, input_image_type="pt"))[0]
|
||||
output_np = sd_pipe(**self.get_dummy_inputs(device, input_image_type="np"))[0]
|
||||
output_pil = sd_pipe(**self.get_dummy_inputs(device, input_image_type="pil"))[0]
|
||||
|
||||
assert np.abs(output_pt - output_np).max() <= 1e-4
|
||||
assert np.abs(output_pil - output_np).max() <= 1e-2
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -12,6 +12,7 @@ import torch
|
||||
|
||||
import diffusers
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
|
||||
from diffusers.utils.testing_utils import require_torch, torch_device
|
||||
@@ -27,6 +28,78 @@ def to_np(tensor):
|
||||
return tensor
|
||||
|
||||
|
||||
class PipelineLatentTesterMixin:
|
||||
"""
|
||||
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
|
||||
It provides a set of common tests for PyTorch pipeline that has vae, e.g.
|
||||
equivalence of different input and output types, etc.
|
||||
"""
|
||||
|
||||
@property
|
||||
def image_params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `image_params` in the child test class. "
|
||||
"`image_params` are tested for if all accepted input image types (i.e. `pt`,`pil`,`np`) are producing same results"
|
||||
)
|
||||
|
||||
def get_dummy_inputs_by_type(self, device, seed=0, input_image_type="pt", output_type="np"):
|
||||
inputs = self.get_dummy_inputs(device, seed)
|
||||
|
||||
def convert_pt_to_type(image, input_image_type):
|
||||
if input_image_type == "pt":
|
||||
input_image = image
|
||||
elif input_image_type == "np":
|
||||
input_image = VaeImageProcessor.pt_to_numpy(image)
|
||||
elif input_image_type == "pil":
|
||||
input_image = VaeImageProcessor.pt_to_numpy(image)
|
||||
input_image = VaeImageProcessor.numpy_to_pil(input_image)
|
||||
else:
|
||||
raise ValueError(f"unsupported input_image_type {input_image_type}.")
|
||||
return input_image
|
||||
|
||||
for image_param in self.image_params:
|
||||
if image_param in inputs.keys():
|
||||
inputs[image_param] = convert_pt_to_type(inputs[image_param], input_image_type)
|
||||
|
||||
inputs["output_type"] = output_type
|
||||
|
||||
return inputs
|
||||
|
||||
def test_pt_np_pil_outputs_equivalent(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
output_pt = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="pt"))[0]
|
||||
output_np = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="np"))[0]
|
||||
output_pil = pipe(**self.get_dummy_inputs_by_type(torch_device, output_type="pil"))[0]
|
||||
|
||||
max_diff = np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max()
|
||||
self.assertLess(max_diff, 1e-4, "`output_type=='pt'` generate different results from `output_type=='np'`")
|
||||
|
||||
max_diff = np.abs(np.array(output_pil[0]) - (output_np * 255).round()).max()
|
||||
self.assertLess(max_diff, 1e-4, "`output_type=='pil'` generate different results from `output_type=='np'`")
|
||||
|
||||
def test_pt_np_pil_inputs_equivalent(self):
|
||||
if len(self.image_params) == 0:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
out_input_pt = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0]
|
||||
out_input_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
|
||||
out_input_pil = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pil"))[0]
|
||||
|
||||
max_diff = np.abs(out_input_pt - out_input_np).max()
|
||||
self.assertLess(max_diff, 1e-4, "`input_type=='pt'` generate different result from `input_type=='np'`")
|
||||
max_diff = np.abs(out_input_pil - out_input_np).max()
|
||||
self.assertLess(max_diff, 1e-2, "`input_type=='pt'` generate different result from `input_type=='np'`")
|
||||
|
||||
|
||||
@require_torch
|
||||
class PipelineTesterMixin:
|
||||
"""
|
||||
@@ -339,9 +412,6 @@ class PipelineTesterMixin:
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
|
||||
def test_float16_inference(self):
|
||||
self._test_float16_inference()
|
||||
|
||||
def _test_float16_inference(self, expected_max_diff=1e-2):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
@@ -355,13 +425,10 @@ class PipelineTesterMixin:
|
||||
output_fp16 = pipe_fp16(**self.get_dummy_inputs(torch_device))[0]
|
||||
|
||||
max_diff = np.abs(to_np(output) - to_np(output_fp16)).max()
|
||||
self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.")
|
||||
self.assertLess(max_diff, 1e-2, "The outputs of the fp16 and fp32 pipelines are too different.")
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
|
||||
def test_save_load_float16(self):
|
||||
self._test_save_load_float16()
|
||||
|
||||
def _test_save_load_float16(self, expected_max_diff=1e-2):
|
||||
components = self.get_dummy_components()
|
||||
for name, module in components.items():
|
||||
if hasattr(module, "half"):
|
||||
@@ -390,9 +457,7 @@ class PipelineTesterMixin:
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
|
||||
self.assertLess(
|
||||
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
|
||||
)
|
||||
self.assertLess(max_diff, 1e-2, "The output of the fp16 pipeline changed after saving and loading.")
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
if not hasattr(self.pipeline_class, "_optional_components"):
|
||||
|
||||
Reference in New Issue
Block a user