1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-07-15 14:51:19 +05:30
parent 5337132c69
commit 29d8763e68

View File

@@ -172,7 +172,7 @@ class SDFunctionTesterMixin:
inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False
inputs["output_type"] = "np"
output = pipe(**inputs)[0]
output = self.get_base_output(pipe)
# FreeU-enabled inference
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
@@ -193,12 +193,12 @@ class SDFunctionTesterMixin:
inputs["output_type"] = "np"
output_no_freeu = pipe(**inputs)[0]
assert not np.allclose(output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]), (
"Enabling of FreeU should lead to different results."
)
assert np.allclose(output, output_no_freeu, atol=1e-2), (
f"Disabling of FreeU should lead to results similar to the default pipeline results but Max Abs Error={np.abs(output_no_freeu - output).max()}."
)
assert not np.allclose(
output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
), "Enabling of FreeU should lead to different results."
assert np.allclose(
output, output_no_freeu, atol=1e-2
), f"Disabling of FreeU should lead to results similar to the default pipeline results but Max Abs Error={np.abs(output_no_freeu - output).max()}."
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -219,12 +219,12 @@ class SDFunctionTesterMixin:
and hasattr(component, "original_attn_processors")
and component.original_attn_processors is not None
):
assert check_qkv_fusion_processors_exist(component), (
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
)
assert check_qkv_fusion_matches_attn_procs_length(component, component.original_attn_processors), (
"Something wrong with the attention processors concerning the fused QKV projections."
)
assert check_qkv_fusion_processors_exist(
component
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
component, component.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
@@ -237,15 +237,15 @@ class SDFunctionTesterMixin:
image_disabled = pipe(**inputs)[0]
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
"Fusion of QKV projections shouldn't affect the outputs."
)
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
)
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
"Original outputs should match when fused QKV projections are disabled."
)
assert np.allclose(
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
), "Fusion of QKV projections shouldn't affect the outputs."
assert np.allclose(
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
class IPAdapterTesterMixin:
@@ -893,7 +893,7 @@ class PipelineFromPipeTesterMixin:
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs_pipe(torch_device)
output = pipe(**inputs)[0]
output = self.get_base_output(pipe)
pipe_from_original = self.pipeline_class.from_pipe(pipe_original, **current_pipe_additional_components)
pipe_from_original.to(torch_device)
@@ -916,9 +916,9 @@ class PipelineFromPipeTesterMixin:
for component in pipe_original.components.values():
if hasattr(component, "attn_processors"):
assert all(type(proc) == AttnProcessor for proc in component.attn_processors.values()), (
"`from_pipe` changed the attention processor in original pipeline."
)
assert all(
type(proc) == AttnProcessor for proc in component.attn_processors.values()
), "`from_pipe` changed the attention processor in original pipeline."
@require_accelerator
@require_accelerate_version_greater("0.14.0")
@@ -931,7 +931,7 @@ class PipelineFromPipeTesterMixin:
pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs_pipe(torch_device)
output = pipe(**inputs)[0]
output = self.get_base_output(pipe)
original_expected_modules, _ = self.original_pipeline_class._get_signature_keys(self.original_pipeline_class)
# pipeline components that are also expected to be in the original pipeline
@@ -1009,7 +1009,7 @@ class PipelineKarrasSchedulerTesterMixin:
scheduler_cls = getattr(diffusers, scheduler_enum.name)
pipe.scheduler = scheduler_cls.from_config(pipe.scheduler.config)
output = pipe(**inputs)[0]
output = self.get_base_output(pipe)
outputs.append(output)
if "KDPM2" in scheduler_enum.name:
@@ -1053,6 +1053,19 @@ class PipelineTesterMixin:
generator = torch.Generator(device).manual_seed(seed)
return generator
def get_base_output(self, pipe):
"""
Compute and cache the base output from a pipeline to avoid redundant computation
in tests that compare against baseline results.
"""
if not hasattr(self, "_base_pipeline_output"):
inputs = self.get_dummy_inputs(torch_device)
inputs["generator"] = self.get_generator(0)
with torch.no_grad():
output = pipe(**inputs)[0]
self._base_pipeline_output = output
return self._base_pipeline_output
@property
def pipeline_class(self) -> Union[Callable, DiffusionPipeline]:
raise NotImplementedError(
@@ -1147,8 +1160,7 @@ class PipelineTesterMixin:
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
output = self.get_base_output(pipe)
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
logger.setLevel(diffusers.logging.INFO)
@@ -1386,7 +1398,7 @@ class PipelineTesterMixin:
# Reset generator in case it is used inside dummy inputs
if "generator" in inputs:
inputs["generator"] = self.get_generator(0)
output = pipe(**inputs)[0]
output = self.get_base_output(pipe)
fp16_inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is used inside dummy inputs
@@ -1416,8 +1428,7 @@ class PipelineTesterMixin:
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
output = self.get_base_output(pipe)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
@@ -1460,7 +1471,7 @@ class PipelineTesterMixin:
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output = pipe(**inputs)[0]
output = self.get_base_output(pipe)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
@@ -1516,48 +1527,6 @@ class PipelineTesterMixin:
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):
self._test_attention_slicing_forward_pass(expected_max_diff=expected_max_diff)
def _test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
if test_mean_pixel_difference:
assert_mean_pixel_difference(to_np(output_with_slicing1[0]), to_np(output_without_slicing[0]))
assert_mean_pixel_difference(to_np(output_with_slicing2[0]), to_np(output_without_slicing[0]))
@require_accelerator
@require_accelerate_version_greater("0.14.0")
def test_sequential_cpu_offload_forward_pass(self, expected_max_diff=1e-4):
@@ -2587,12 +2556,12 @@ class PyramidAttentionBroadcastTesterMixin:
image_slice_pab_disabled = output.flatten()
image_slice_pab_disabled = np.concatenate((image_slice_pab_disabled[:8], image_slice_pab_disabled[-8:]))
assert np.allclose(original_image_slice, image_slice_pab_enabled, atol=expected_atol), (
"PAB outputs should not differ much in specified timestep range."
)
assert np.allclose(original_image_slice, image_slice_pab_disabled, atol=1e-4), (
"Outputs from normal inference and after disabling cache should not differ."
)
assert np.allclose(
original_image_slice, image_slice_pab_enabled, atol=expected_atol
), "PAB outputs should not differ much in specified timestep range."
assert np.allclose(
original_image_slice, image_slice_pab_disabled, atol=1e-4
), "Outputs from normal inference and after disabling cache should not differ."
class FasterCacheTesterMixin:
@@ -2657,12 +2626,12 @@ class FasterCacheTesterMixin:
output = run_forward(pipe).flatten()
image_slice_faster_cache_disabled = np.concatenate((output[:8], output[-8:]))
assert np.allclose(original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol), (
"FasterCache outputs should not differ much in specified timestep range."
)
assert np.allclose(original_image_slice, image_slice_faster_cache_disabled, atol=1e-4), (
"Outputs from normal inference and after disabling cache should not differ."
)
assert np.allclose(
original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol
), "FasterCache outputs should not differ much in specified timestep range."
assert np.allclose(
original_image_slice, image_slice_faster_cache_disabled, atol=1e-4
), "Outputs from normal inference and after disabling cache should not differ."
def test_faster_cache_state(self):
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
@@ -2797,12 +2766,12 @@ class FirstBlockCacheTesterMixin:
output = run_forward(pipe).flatten()
image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:]))
assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), (
"FirstBlockCache outputs should not differ much."
)
assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), (
"Outputs from normal inference and after disabling cache should not differ."
)
assert np.allclose(
original_image_slice, image_slice_fbc_enabled, atol=expected_atol
), "FirstBlockCache outputs should not differ much."
assert np.allclose(
original_image_slice, image_slice_fbc_disabled, atol=1e-4
), "Outputs from normal inference and after disabling cache should not differ."
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.