mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[modular] add tests for qwen modular (#12585)
* add tests for qwenimage modular. * qwenimage edit. * qwenimage edit plus. * empty * align with the latest structure * up * up * reason * up * fix multiple issues. * up * up * fix * up * make it similar to the original pipeline.
This commit is contained in:
@@ -132,6 +132,7 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("latents"),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
@@ -196,11 +197,11 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
|
||||
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
block_state.latents = randn_tensor(
|
||||
shape, generator=block_state.generator, device=device, dtype=block_state.dtype
|
||||
)
|
||||
block_state.latents = components.pachifier.pack_latents(block_state.latents)
|
||||
if block_state.latents is None:
|
||||
block_state.latents = randn_tensor(
|
||||
shape, generator=block_state.generator, device=device, dtype=block_state.dtype
|
||||
)
|
||||
block_state.latents = components.pachifier.pack_latents(block_state.latents)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
@@ -549,8 +550,7 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
|
||||
block_state.width // components.vae_scale_factor // 2,
|
||||
)
|
||||
]
|
||||
* block_state.batch_size
|
||||
]
|
||||
] * block_state.batch_size
|
||||
block_state.txt_seq_lens = (
|
||||
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
@@ -74,8 +74,9 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular
|
||||
vae_scale_factor = components.vae_scale_factor
|
||||
block_state.latents = components.pachifier.unpack_latents(
|
||||
block_state.latents, block_state.height, block_state.width
|
||||
block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor
|
||||
)
|
||||
block_state.latents = block_state.latents.to(components.vae.dtype)
|
||||
|
||||
|
||||
@@ -503,6 +503,8 @@ class QwenImageTextEncoderStep(ModularPipelineBlocks):
|
||||
block_state.prompt_embeds = block_state.prompt_embeds[:, : block_state.max_sequence_length]
|
||||
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask[:, : block_state.max_sequence_length]
|
||||
|
||||
block_state.negative_prompt_embeds = None
|
||||
block_state.negative_prompt_embeds_mask = None
|
||||
if components.requires_unconditional_embeds:
|
||||
negative_prompt = block_state.negative_prompt or ""
|
||||
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds(
|
||||
@@ -627,6 +629,8 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
|
||||
device=device,
|
||||
)
|
||||
|
||||
block_state.negative_prompt_embeds = None
|
||||
block_state.negative_prompt_embeds_mask = None
|
||||
if components.requires_unconditional_embeds:
|
||||
negative_prompt = block_state.negative_prompt or " "
|
||||
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit(
|
||||
@@ -679,6 +683,8 @@ class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep):
|
||||
device=device,
|
||||
)
|
||||
|
||||
block_state.negative_prompt_embeds = None
|
||||
block_state.negative_prompt_embeds_mask = None
|
||||
if components.requires_unconditional_embeds:
|
||||
negative_prompt = block_state.negative_prompt or " "
|
||||
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = (
|
||||
|
||||
@@ -26,10 +26,7 @@ class QwenImagePachifier(ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
):
|
||||
def __init__(self, patch_size: int = 2):
|
||||
super().__init__()
|
||||
|
||||
def pack_latents(self, latents):
|
||||
|
||||
@@ -55,6 +55,9 @@ class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
|
||||
class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxModularPipeline
|
||||
@@ -118,6 +121,9 @@ class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(8e-2)
|
||||
|
||||
|
||||
class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxKontextModularPipeline
|
||||
@@ -170,3 +176,6 @@ class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
0
tests/modular_pipelines/qwen/__init__.py
Normal file
0
tests/modular_pipelines/qwen/__init__.py
Normal file
120
tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py
Normal file
120
tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import PIL
|
||||
import pytest
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
QwenImageAutoBlocks,
|
||||
QwenImageEditAutoBlocks,
|
||||
QwenImageEditModularPipeline,
|
||||
QwenImageEditPlusAutoBlocks,
|
||||
QwenImageEditPlusModularPipeline,
|
||||
QwenImageModularPipeline,
|
||||
)
|
||||
|
||||
from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin
|
||||
|
||||
|
||||
class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
|
||||
pipeline_class = QwenImageModularPipeline
|
||||
pipeline_blocks_class = QwenImageAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-qwenimage-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
generator = self.get_generator()
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "bad quality",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=5e-4)
|
||||
|
||||
|
||||
class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
|
||||
pipeline_class = QwenImageEditModularPipeline
|
||||
pipeline_blocks_class = QwenImageEditAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-qwenimage-edit-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
generator = self.get_generator()
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "bad quality",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
|
||||
return inputs
|
||||
|
||||
def test_guider_cfg(self):
|
||||
super().test_guider_cfg(7e-5)
|
||||
|
||||
|
||||
class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
|
||||
pipeline_class = QwenImageEditPlusModularPipeline
|
||||
pipeline_blocks_class = QwenImageEditPlusAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"
|
||||
|
||||
# No `mask_image` yet.
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"])
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image"])
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
generator = self.get_generator()
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "bad quality",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
|
||||
return inputs
|
||||
|
||||
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
|
||||
def test_num_images_per_prompt(self):
|
||||
super().test_num_images_per_prompt()
|
||||
|
||||
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
|
||||
def test_inference_batch_consistent():
|
||||
super().test_inference_batch_consistent()
|
||||
|
||||
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
|
||||
def test_inference_batch_single_identical():
|
||||
super().test_inference_batch_single_identical()
|
||||
|
||||
def test_guider_cfg(self):
|
||||
super().test_guider_cfg(1e-3)
|
||||
@@ -25,7 +25,7 @@ from diffusers.loaders import ModularIPAdapterMixin
|
||||
|
||||
from ...models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -37,13 +37,11 @@ class SDXLModularTesterMixin:
|
||||
"""
|
||||
|
||||
def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2):
|
||||
sd_pipe = self.get_pipeline()
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
sd_pipe = self.get_pipeline().to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = sd_pipe(**inputs, output="images")
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_slice = image[0, -3:, -3:, -1].cpu()
|
||||
|
||||
assert image.shape == expected_image_shape
|
||||
max_diff = torch.abs(image_slice.flatten() - expected_slice).max()
|
||||
@@ -110,7 +108,7 @@ class SDXLModularIPAdapterTesterMixin:
|
||||
pipe = blocks.init_pipeline(self.repo)
|
||||
pipe.load_components(torch_dtype=torch.float32)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
cross_attention_dim = pipe.unet.config.get("cross_attention_dim")
|
||||
|
||||
# forward pass without ip adapter
|
||||
@@ -219,9 +217,7 @@ class SDXLModularControlNetTesterMixin:
|
||||
# compare against static slices and that can be shaky (with a VVVV low probability).
|
||||
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
|
||||
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe = self.get_pipeline().to(torch_device)
|
||||
|
||||
# forward pass without controlnet
|
||||
inputs = self.get_dummy_inputs()
|
||||
@@ -251,9 +247,7 @@ class SDXLModularControlNetTesterMixin:
|
||||
assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference"
|
||||
|
||||
def test_controlnet_cfg(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe = self.get_pipeline().to(torch_device)
|
||||
|
||||
# forward pass with CFG not applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=1.0)
|
||||
@@ -273,35 +267,11 @@ class SDXLModularControlNetTesterMixin:
|
||||
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class SDXLModularGuiderTesterMixin:
|
||||
def test_guider_cfg(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward pass with CFG not applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=1.0)
|
||||
pipe.update_components(guider=guider)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
out_no_cfg = pipe(**inputs, output="images")
|
||||
|
||||
# forward pass with CFG applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=7.5)
|
||||
pipe.update_components(guider=guider)
|
||||
inputs = self.get_dummy_inputs()
|
||||
out_cfg = pipe(**inputs, output="images")
|
||||
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
max_diff = np.abs(out_cfg - out_no_cfg).max()
|
||||
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class TestSDXLModularPipelineFast(
|
||||
SDXLModularTesterMixin,
|
||||
SDXLModularIPAdapterTesterMixin,
|
||||
SDXLModularControlNetTesterMixin,
|
||||
SDXLModularGuiderTesterMixin,
|
||||
ModularGuiderTesterMixin,
|
||||
ModularPipelineTesterMixin,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL modular pipeline fast tests."""
|
||||
@@ -335,18 +305,7 @@ class TestSDXLModularPipelineFast(
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=self.expected_image_output_shape,
|
||||
expected_slice=torch.tensor(
|
||||
[
|
||||
0.5966781,
|
||||
0.62939394,
|
||||
0.48465094,
|
||||
0.51573336,
|
||||
0.57593524,
|
||||
0.47035995,
|
||||
0.53410417,
|
||||
0.51436996,
|
||||
0.47313565,
|
||||
],
|
||||
device=torch_device,
|
||||
[0.3886, 0.4685, 0.4953, 0.4217, 0.4317, 0.3945, 0.4847, 0.4704, 0.4731],
|
||||
),
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
@@ -359,7 +318,7 @@ class TestSDXLImg2ImgModularPipelineFast(
|
||||
SDXLModularTesterMixin,
|
||||
SDXLModularIPAdapterTesterMixin,
|
||||
SDXLModularControlNetTesterMixin,
|
||||
SDXLModularGuiderTesterMixin,
|
||||
ModularGuiderTesterMixin,
|
||||
ModularPipelineTesterMixin,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests."""
|
||||
@@ -400,20 +359,7 @@ class TestSDXLImg2ImgModularPipelineFast(
|
||||
def test_stable_diffusion_xl_euler(self):
|
||||
self._test_stable_diffusion_xl_euler(
|
||||
expected_image_shape=self.expected_image_output_shape,
|
||||
expected_slice=torch.tensor(
|
||||
[
|
||||
0.56943184,
|
||||
0.4702148,
|
||||
0.48048905,
|
||||
0.6235963,
|
||||
0.551138,
|
||||
0.49629188,
|
||||
0.60031277,
|
||||
0.5688907,
|
||||
0.43996853,
|
||||
],
|
||||
device=torch_device,
|
||||
),
|
||||
expected_slice=torch.tensor([0.5246, 0.4466, 0.444, 0.3246, 0.4443, 0.5108, 0.5225, 0.559, 0.5147]),
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
@@ -425,7 +371,7 @@ class SDXLInpaintingModularPipelineFastTests(
|
||||
SDXLModularTesterMixin,
|
||||
SDXLModularIPAdapterTesterMixin,
|
||||
SDXLModularControlNetTesterMixin,
|
||||
SDXLModularGuiderTesterMixin,
|
||||
ModularGuiderTesterMixin,
|
||||
ModularPipelineTesterMixin,
|
||||
):
|
||||
"""Test cases for Stable Diffusion XL inpainting modular pipeline fast tests."""
|
||||
|
||||
@@ -2,22 +2,17 @@ import gc
|
||||
import tempfile
|
||||
from typing import Callable, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import diffusers
|
||||
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers.guiders import ClassifierFreeGuidance
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ..testing_utils import (
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerator,
|
||||
require_torch,
|
||||
torch_device,
|
||||
)
|
||||
from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModularPipelineTesterMixin:
|
||||
"""
|
||||
It provides a set of common tests for each modular pipeline,
|
||||
@@ -32,20 +27,9 @@ class ModularPipelineTesterMixin:
|
||||
# Canonical parameters that are passed to `__call__` regardless
|
||||
# of the type of pipeline. They are always optional and have common
|
||||
# sense default values.
|
||||
optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"num_images_per_prompt",
|
||||
"latents",
|
||||
"output_type",
|
||||
]
|
||||
)
|
||||
optional_params = frozenset(["num_inference_steps", "num_images_per_prompt", "latents", "output_type"])
|
||||
# this is modular specific: generator needs to be a intermediate input because it's mutable
|
||||
intermediate_params = frozenset(
|
||||
[
|
||||
"generator",
|
||||
]
|
||||
)
|
||||
intermediate_params = frozenset(["generator"])
|
||||
|
||||
def get_generator(self, seed=0):
|
||||
generator = torch.Generator("cpu").manual_seed(seed)
|
||||
@@ -121,6 +105,7 @@ class ModularPipelineTesterMixin:
|
||||
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
|
||||
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
|
||||
pipeline.load_components(torch_dtype=torch_dtype)
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
return pipeline
|
||||
|
||||
def test_pipeline_call_signature(self):
|
||||
@@ -138,9 +123,7 @@ class ModularPipelineTesterMixin:
|
||||
_check_for_parameters(self.optional_params, optional_parameters, "optional")
|
||||
|
||||
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe = self.get_pipeline().to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
@@ -179,9 +162,8 @@ class ModularPipelineTesterMixin:
|
||||
batch_size=2,
|
||||
expected_max_diff=1e-4,
|
||||
):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe = self.get_pipeline().to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
|
||||
# Reset generator in case it is has been used in self.get_dummy_inputs
|
||||
@@ -219,11 +201,9 @@ class ModularPipelineTesterMixin:
|
||||
def test_float16_inference(self, expected_max_diff=5e-2):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device, torch.float32)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe_fp16 = self.get_pipeline()
|
||||
pipe_fp16.to(torch_device, torch.float16)
|
||||
pipe_fp16.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
# Reset generator in case it is used inside dummy inputs
|
||||
@@ -237,19 +217,16 @@ class ModularPipelineTesterMixin:
|
||||
fp16_inputs["generator"] = self.get_generator(0)
|
||||
output_fp16 = pipe_fp16(**fp16_inputs, output="images")
|
||||
|
||||
if isinstance(output, torch.Tensor):
|
||||
output = output.cpu()
|
||||
output_fp16 = output_fp16.cpu()
|
||||
output = output.cpu()
|
||||
output_fp16 = output_fp16.cpu()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
|
||||
assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"
|
||||
|
||||
@require_accelerator
|
||||
def test_to_device(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe = self.get_pipeline().to("cpu")
|
||||
|
||||
pipe.to("cpu")
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
@@ -264,30 +241,23 @@ class ModularPipelineTesterMixin:
|
||||
)
|
||||
|
||||
def test_inference_is_not_nan_cpu(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to("cpu")
|
||||
pipe = self.get_pipeline().to("cpu")
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(), output="images")
|
||||
assert torch.isnan(output).sum() == 0, "CPU Inference returns NaN"
|
||||
|
||||
@require_accelerator
|
||||
def test_inference_is_not_nan(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
pipe = self.get_pipeline().to(torch_device)
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(), output="images")
|
||||
assert torch.isnan(output).sum() == 0, "Accelerator Inference returns NaN"
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
pipe = self.get_pipeline()
|
||||
pipe = self.get_pipeline().to(torch_device)
|
||||
|
||||
if "num_images_per_prompt" not in pipe.blocks.input_names:
|
||||
return
|
||||
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pytest.mark.skip("Skipping test as `num_images_per_prompt` is not present in input names.")
|
||||
|
||||
batch_sizes = [1, 2]
|
||||
num_images_per_prompts = [1, 2]
|
||||
@@ -342,3 +312,25 @@ class ModularPipelineTesterMixin:
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
|
||||
class ModularGuiderTesterMixin:
|
||||
def test_guider_cfg(self, expected_max_diff=1e-2):
|
||||
pipe = self.get_pipeline().to(torch_device)
|
||||
|
||||
# forward pass with CFG not applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=1.0)
|
||||
pipe.update_components(guider=guider)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
out_no_cfg = pipe(**inputs, output="images")
|
||||
|
||||
# forward pass with CFG applied
|
||||
guider = ClassifierFreeGuidance(guidance_scale=7.5)
|
||||
pipe.update_components(guider=guider)
|
||||
inputs = self.get_dummy_inputs()
|
||||
out_cfg = pipe(**inputs, output="images")
|
||||
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
max_diff = torch.abs(out_cfg - out_no_cfg).max()
|
||||
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
|
||||
|
||||
Reference in New Issue
Block a user