1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[modular] add tests for qwen modular (#12585)

* add tests for qwenimage modular.

* qwenimage edit.

* qwenimage edit plus.

* empty

* align with the latest structure

* up

* up

* reason

* up

* fix multiple issues.

* up

* up

* fix

* up

* make it similar to the original pipeline.
This commit is contained in:
Sayak Paul
2025-11-12 17:37:42 +05:30
committed by GitHub
parent 093cd3f040
commit f5e5f34823
9 changed files with 194 additions and 123 deletions

View File

@@ -132,6 +132,7 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("latents"),
InputParam(name="height"),
InputParam(name="width"),
InputParam(name="num_images_per_prompt", default=1),
@@ -196,11 +197,11 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
block_state.latents = randn_tensor(
shape, generator=block_state.generator, device=device, dtype=block_state.dtype
)
block_state.latents = components.pachifier.pack_latents(block_state.latents)
if block_state.latents is None:
block_state.latents = randn_tensor(
shape, generator=block_state.generator, device=device, dtype=block_state.dtype
)
block_state.latents = components.pachifier.pack_latents(block_state.latents)
self.set_block_state(state, block_state)
return components, state
@@ -549,8 +550,7 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
block_state.width // components.vae_scale_factor // 2,
)
]
* block_state.batch_size
]
] * block_state.batch_size
block_state.txt_seq_lens = (
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
)

View File

@@ -74,8 +74,9 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
block_state = self.get_block_state(state)
# YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular
vae_scale_factor = components.vae_scale_factor
block_state.latents = components.pachifier.unpack_latents(
block_state.latents, block_state.height, block_state.width
block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor
)
block_state.latents = block_state.latents.to(components.vae.dtype)

View File

@@ -503,6 +503,8 @@ class QwenImageTextEncoderStep(ModularPipelineBlocks):
block_state.prompt_embeds = block_state.prompt_embeds[:, : block_state.max_sequence_length]
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask[:, : block_state.max_sequence_length]
block_state.negative_prompt_embeds = None
block_state.negative_prompt_embeds_mask = None
if components.requires_unconditional_embeds:
negative_prompt = block_state.negative_prompt or ""
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds(
@@ -627,6 +629,8 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
device=device,
)
block_state.negative_prompt_embeds = None
block_state.negative_prompt_embeds_mask = None
if components.requires_unconditional_embeds:
negative_prompt = block_state.negative_prompt or " "
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit(
@@ -679,6 +683,8 @@ class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep):
device=device,
)
block_state.negative_prompt_embeds = None
block_state.negative_prompt_embeds_mask = None
if components.requires_unconditional_embeds:
negative_prompt = block_state.negative_prompt or " "
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = (

View File

@@ -26,10 +26,7 @@ class QwenImagePachifier(ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
patch_size: int = 2,
):
def __init__(self, patch_size: int = 2):
super().__init__()
def pack_latents(self, latents):

View File

@@ -55,6 +55,9 @@ class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
}
return inputs
def test_float16_inference(self):
super().test_float16_inference(9e-2)
class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxModularPipeline
@@ -118,6 +121,9 @@ class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_float16_inference(self):
super().test_float16_inference(8e-2)
class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = FluxKontextModularPipeline
@@ -170,3 +176,6 @@ class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_float16_inference(self):
super().test_float16_inference(9e-2)

View File

View File

@@ -0,0 +1,120 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import PIL
import pytest
from diffusers.modular_pipelines import (
QwenImageAutoBlocks,
QwenImageEditAutoBlocks,
QwenImageEditModularPipeline,
QwenImageEditPlusAutoBlocks,
QwenImageEditPlusModularPipeline,
QwenImageModularPipeline,
)
from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin
class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
pipeline_class = QwenImageModularPipeline
pipeline_blocks_class = QwenImageAutoBlocks
repo = "hf-internal-testing/tiny-qwenimage-modular"
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
def get_dummy_inputs(self):
generator = self.get_generator()
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=5e-4)
class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
pipeline_class = QwenImageEditModularPipeline
pipeline_blocks_class = QwenImageEditAutoBlocks
repo = "hf-internal-testing/tiny-qwenimage-edit-modular"
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
def get_dummy_inputs(self):
generator = self.get_generator()
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"output_type": "pt",
}
inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
return inputs
def test_guider_cfg(self):
super().test_guider_cfg(7e-5)
class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
pipeline_class = QwenImageEditPlusModularPipeline
pipeline_blocks_class = QwenImageEditPlusAutoBlocks
repo = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"
# No `mask_image` yet.
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"])
batch_params = frozenset(["prompt", "negative_prompt", "image"])
def get_dummy_inputs(self):
generator = self.get_generator()
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"output_type": "pt",
}
inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
return inputs
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_num_images_per_prompt(self):
super().test_num_images_per_prompt()
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_inference_batch_consistent():
super().test_inference_batch_consistent()
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_inference_batch_single_identical():
super().test_inference_batch_single_identical()
def test_guider_cfg(self):
super().test_guider_cfg(1e-3)

View File

@@ -25,7 +25,7 @@ from diffusers.loaders import ModularIPAdapterMixin
from ...models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin
enable_full_determinism()
@@ -37,13 +37,11 @@ class SDXLModularTesterMixin:
"""
def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2):
sd_pipe = self.get_pipeline()
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
sd_pipe = self.get_pipeline().to(torch_device)
inputs = self.get_dummy_inputs()
image = sd_pipe(**inputs, output="images")
image_slice = image[0, -3:, -3:, -1]
image_slice = image[0, -3:, -3:, -1].cpu()
assert image.shape == expected_image_shape
max_diff = torch.abs(image_slice.flatten() - expected_slice).max()
@@ -110,7 +108,7 @@ class SDXLModularIPAdapterTesterMixin:
pipe = blocks.init_pipeline(self.repo)
pipe.load_components(torch_dtype=torch.float32)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim")
# forward pass without ip adapter
@@ -219,9 +217,7 @@ class SDXLModularControlNetTesterMixin:
# compare against static slices and that can be shaky (with a VVVV low probability).
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe = self.get_pipeline().to(torch_device)
# forward pass without controlnet
inputs = self.get_dummy_inputs()
@@ -251,9 +247,7 @@ class SDXLModularControlNetTesterMixin:
assert max_diff_with_controlnet_scale > 1e-2, "Output with controlnet must be different from normal inference"
def test_controlnet_cfg(self):
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe = self.get_pipeline().to(torch_device)
# forward pass with CFG not applied
guider = ClassifierFreeGuidance(guidance_scale=1.0)
@@ -273,35 +267,11 @@ class SDXLModularControlNetTesterMixin:
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
class SDXLModularGuiderTesterMixin:
def test_guider_cfg(self):
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# forward pass with CFG not applied
guider = ClassifierFreeGuidance(guidance_scale=1.0)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs()
out_no_cfg = pipe(**inputs, output="images")
# forward pass with CFG applied
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs()
out_cfg = pipe(**inputs, output="images")
assert out_cfg.shape == out_no_cfg.shape
max_diff = np.abs(out_cfg - out_no_cfg).max()
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
class TestSDXLModularPipelineFast(
SDXLModularTesterMixin,
SDXLModularIPAdapterTesterMixin,
SDXLModularControlNetTesterMixin,
SDXLModularGuiderTesterMixin,
ModularGuiderTesterMixin,
ModularPipelineTesterMixin,
):
"""Test cases for Stable Diffusion XL modular pipeline fast tests."""
@@ -335,18 +305,7 @@ class TestSDXLModularPipelineFast(
self._test_stable_diffusion_xl_euler(
expected_image_shape=self.expected_image_output_shape,
expected_slice=torch.tensor(
[
0.5966781,
0.62939394,
0.48465094,
0.51573336,
0.57593524,
0.47035995,
0.53410417,
0.51436996,
0.47313565,
],
device=torch_device,
[0.3886, 0.4685, 0.4953, 0.4217, 0.4317, 0.3945, 0.4847, 0.4704, 0.4731],
),
expected_max_diff=1e-2,
)
@@ -359,7 +318,7 @@ class TestSDXLImg2ImgModularPipelineFast(
SDXLModularTesterMixin,
SDXLModularIPAdapterTesterMixin,
SDXLModularControlNetTesterMixin,
SDXLModularGuiderTesterMixin,
ModularGuiderTesterMixin,
ModularPipelineTesterMixin,
):
"""Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests."""
@@ -400,20 +359,7 @@ class TestSDXLImg2ImgModularPipelineFast(
def test_stable_diffusion_xl_euler(self):
self._test_stable_diffusion_xl_euler(
expected_image_shape=self.expected_image_output_shape,
expected_slice=torch.tensor(
[
0.56943184,
0.4702148,
0.48048905,
0.6235963,
0.551138,
0.49629188,
0.60031277,
0.5688907,
0.43996853,
],
device=torch_device,
),
expected_slice=torch.tensor([0.5246, 0.4466, 0.444, 0.3246, 0.4443, 0.5108, 0.5225, 0.559, 0.5147]),
expected_max_diff=1e-2,
)
@@ -425,7 +371,7 @@ class SDXLInpaintingModularPipelineFastTests(
SDXLModularTesterMixin,
SDXLModularIPAdapterTesterMixin,
SDXLModularControlNetTesterMixin,
SDXLModularGuiderTesterMixin,
ModularGuiderTesterMixin,
ModularPipelineTesterMixin,
):
"""Test cases for Stable Diffusion XL inpainting modular pipeline fast tests."""

View File

@@ -2,22 +2,17 @@ import gc
import tempfile
from typing import Callable, Union
import pytest
import torch
import diffusers
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.utils import logging
from ..testing_utils import (
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
require_torch,
torch_device,
)
from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device
@require_torch
class ModularPipelineTesterMixin:
"""
It provides a set of common tests for each modular pipeline,
@@ -32,20 +27,9 @@ class ModularPipelineTesterMixin:
# Canonical parameters that are passed to `__call__` regardless
# of the type of pipeline. They are always optional and have common
# sense default values.
optional_params = frozenset(
[
"num_inference_steps",
"num_images_per_prompt",
"latents",
"output_type",
]
)
optional_params = frozenset(["num_inference_steps", "num_images_per_prompt", "latents", "output_type"])
# this is modular specific: generator needs to be a intermediate input because it's mutable
intermediate_params = frozenset(
[
"generator",
]
)
intermediate_params = frozenset(["generator"])
def get_generator(self, seed=0):
generator = torch.Generator("cpu").manual_seed(seed)
@@ -121,6 +105,7 @@ class ModularPipelineTesterMixin:
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
pipeline.load_components(torch_dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=None)
return pipeline
def test_pipeline_call_signature(self):
@@ -138,9 +123,7 @@ class ModularPipelineTesterMixin:
_check_for_parameters(self.optional_params, optional_parameters, "optional")
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
pipe = self.get_pipeline()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe = self.get_pipeline().to(torch_device)
inputs = self.get_dummy_inputs()
inputs["generator"] = self.get_generator(0)
@@ -179,9 +162,8 @@ class ModularPipelineTesterMixin:
batch_size=2,
expected_max_diff=1e-4,
):
pipe = self.get_pipeline()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe = self.get_pipeline().to(torch_device)
inputs = self.get_dummy_inputs()
# Reset generator in case it is has been used in self.get_dummy_inputs
@@ -219,11 +201,9 @@ class ModularPipelineTesterMixin:
def test_float16_inference(self, expected_max_diff=5e-2):
pipe = self.get_pipeline()
pipe.to(torch_device, torch.float32)
pipe.set_progress_bar_config(disable=None)
pipe_fp16 = self.get_pipeline()
pipe_fp16.to(torch_device, torch.float16)
pipe_fp16.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs()
# Reset generator in case it is used inside dummy inputs
@@ -237,19 +217,16 @@ class ModularPipelineTesterMixin:
fp16_inputs["generator"] = self.get_generator(0)
output_fp16 = pipe_fp16(**fp16_inputs, output="images")
if isinstance(output, torch.Tensor):
output = output.cpu()
output_fp16 = output_fp16.cpu()
output = output.cpu()
output_fp16 = output_fp16.cpu()
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"
@require_accelerator
def test_to_device(self):
pipe = self.get_pipeline()
pipe.set_progress_bar_config(disable=None)
pipe = self.get_pipeline().to("cpu")
pipe.to("cpu")
model_devices = [
component.device.type for component in pipe.components.values() if hasattr(component, "device")
]
@@ -264,30 +241,23 @@ class ModularPipelineTesterMixin:
)
def test_inference_is_not_nan_cpu(self):
pipe = self.get_pipeline()
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
pipe = self.get_pipeline().to("cpu")
output = pipe(**self.get_dummy_inputs(), output="images")
assert torch.isnan(output).sum() == 0, "CPU Inference returns NaN"
@require_accelerator
def test_inference_is_not_nan(self):
pipe = self.get_pipeline()
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
pipe = self.get_pipeline().to(torch_device)
output = pipe(**self.get_dummy_inputs(), output="images")
assert torch.isnan(output).sum() == 0, "Accelerator Inference returns NaN"
def test_num_images_per_prompt(self):
pipe = self.get_pipeline()
pipe = self.get_pipeline().to(torch_device)
if "num_images_per_prompt" not in pipe.blocks.input_names:
return
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pytest.mark.skip("Skipping test as `num_images_per_prompt` is not present in input names.")
batch_sizes = [1, 2]
num_images_per_prompts = [1, 2]
@@ -342,3 +312,25 @@ class ModularPipelineTesterMixin:
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
class ModularGuiderTesterMixin:
def test_guider_cfg(self, expected_max_diff=1e-2):
pipe = self.get_pipeline().to(torch_device)
# forward pass with CFG not applied
guider = ClassifierFreeGuidance(guidance_scale=1.0)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs()
out_no_cfg = pipe(**inputs, output="images")
# forward pass with CFG applied
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs()
out_cfg = pipe(**inputs, output="images")
assert out_cfg.shape == out_no_cfg.shape
max_diff = torch.abs(out_cfg - out_no_cfg).max()
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"