1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-07-15 15:37:46 +05:30
parent 62e2cce917
commit d87907da30

View File

@@ -17,10 +17,8 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
import diffusers
from diffusers import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderTiny,
ConsistencyDecoderVAE,
DDIMScheduler,
DiffusionPipeline,
FasterCacheConfig,
@@ -160,46 +158,6 @@ class SDFunctionTesterMixin:
zeros = torch.zeros(shape).to(torch_device)
pipe.vae.decode(zeros)
# MPS currently doesn't support ComplexFloats, which are required for FreeU - see https://github.com/huggingface/diffusers/issues/7569.
@skip_mps
def test_freeu(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# Normal inference
inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False
inputs["output_type"] = "np"
output = pipe(**inputs)[0]
# FreeU-enabled inference
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False
inputs["output_type"] = "np"
output_freeu = pipe(**inputs)[0]
# FreeU-disabled inference
pipe.disable_freeu()
freeu_keys = {"s1", "s2", "b1", "b2"}
for upsample_block in pipe.unet.up_blocks:
for key in freeu_keys:
assert getattr(upsample_block, key) is None, f"Disabling of FreeU should have set {key} to None."
inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False
inputs["output_type"] = "np"
output_no_freeu = pipe(**inputs)[0]
assert not np.allclose(output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]), (
"Enabling of FreeU should lead to different results."
)
assert np.allclose(output, output_no_freeu, atol=1e-2), (
f"Disabling of FreeU should lead to results similar to the default pipeline results but Max Abs Error={np.abs(output_no_freeu - output).max()}."
)
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -219,12 +177,12 @@ class SDFunctionTesterMixin:
and hasattr(component, "original_attn_processors")
and component.original_attn_processors is not None
):
assert check_qkv_fusion_processors_exist(component), (
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
)
assert check_qkv_fusion_matches_attn_procs_length(component, component.original_attn_processors), (
"Something wrong with the attention processors concerning the fused QKV projections."
)
assert check_qkv_fusion_processors_exist(
component
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
component, component.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
inputs["return_dict"] = False
@@ -237,15 +195,15 @@ class SDFunctionTesterMixin:
image_disabled = pipe(**inputs)[0]
image_slice_disabled = image_disabled[0, -3:, -3:, -1]
assert np.allclose(original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2), (
"Fusion of QKV projections shouldn't affect the outputs."
)
assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2), (
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
)
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
"Original outputs should match when fused QKV projections are disabled."
)
assert np.allclose(
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
), "Fusion of QKV projections shouldn't affect the outputs."
assert np.allclose(
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
class IPAdapterTesterMixin:
@@ -759,34 +717,6 @@ class PipelineLatentTesterMixin:
max_diff = np.abs(out - out_latents_inputs).max()
self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image")
def test_multi_vae(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
block_out_channels = pipe.vae.config.block_out_channels
norm_num_groups = pipe.vae.config.norm_num_groups
vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
configs = [
get_autoencoder_kl_config(block_out_channels, norm_num_groups),
get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
get_consistency_vae_config(block_out_channels, norm_num_groups),
get_autoencoder_tiny_config(block_out_channels),
]
out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
for vae_cls, config in zip(vae_classes, configs):
vae = vae_cls(**config)
vae = vae.to(torch_device)
components["vae"] = vae
vae_pipe = self.pipeline_class(**components)
out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
assert out_vae_np.shape == out_np.shape
@require_torch
class PipelineFromPipeTesterMixin:
@@ -916,9 +846,9 @@ class PipelineFromPipeTesterMixin:
for component in pipe_original.components.values():
if hasattr(component, "attn_processors"):
assert all(type(proc) == AttnProcessor for proc in component.attn_processors.values()), (
"`from_pipe` changed the attention processor in original pipeline."
)
assert all(
type(proc) == AttnProcessor for proc in component.attn_processors.values()
), "`from_pipe` changed the attention processor in original pipeline."
@require_accelerator
@require_accelerate_version_greater("0.14.0")
@@ -1137,6 +1067,15 @@ class PipelineTesterMixin:
gc.collect()
backend_empty_cache(torch_device)
def get_base_pipeline_output(self, pipe):
if not hasattr(self, "_base_pipeline_output"):
inputs = self.get_dummy_inputs(torch_device)
inputs["generator"] = self.get_generator(0)
output = pipe(**inputs)[0]
self._base_pipeline_output = output
return self._base_pipeline_output
def test_save_load_local(self, expected_max_difference=5e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -1148,7 +1087,7 @@ class PipelineTesterMixin:
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
output = self.get_base_pipeline_output(pipe)
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
logger.setLevel(diffusers.logging.INFO)
@@ -1267,7 +1206,7 @@ class PipelineTesterMixin:
output = pipe(**batched_input)
assert len(output[0]) == batch_size
def test_inference_batch_single_identical(self, batch_size=3, expected_max_diff=1e-4):
def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=1e-4):
self._test_inference_batch_single_identical(batch_size=batch_size, expected_max_diff=expected_max_diff)
def _test_inference_batch_single_identical(
@@ -1386,7 +1325,7 @@ class PipelineTesterMixin:
# Reset generator in case it is used inside dummy inputs
if "generator" in inputs:
inputs["generator"] = self.get_generator(0)
output = pipe(**inputs)[0]
output = self.get_base_pipeline_output(pipe)
fp16_inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is used inside dummy inputs
@@ -1417,7 +1356,7 @@ class PipelineTesterMixin:
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output = pipe(**inputs)[0]
output = self.get_base_pipeline_output(pipe)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
@@ -1460,7 +1399,7 @@ class PipelineTesterMixin:
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output = pipe(**inputs)[0]
output = self.get_base_pipeline_output(pipe)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
@@ -2587,12 +2526,12 @@ class PyramidAttentionBroadcastTesterMixin:
image_slice_pab_disabled = output.flatten()
image_slice_pab_disabled = np.concatenate((image_slice_pab_disabled[:8], image_slice_pab_disabled[-8:]))
assert np.allclose(original_image_slice, image_slice_pab_enabled, atol=expected_atol), (
"PAB outputs should not differ much in specified timestep range."
)
assert np.allclose(original_image_slice, image_slice_pab_disabled, atol=1e-4), (
"Outputs from normal inference and after disabling cache should not differ."
)
assert np.allclose(
original_image_slice, image_slice_pab_enabled, atol=expected_atol
), "PAB outputs should not differ much in specified timestep range."
assert np.allclose(
original_image_slice, image_slice_pab_disabled, atol=1e-4
), "Outputs from normal inference and after disabling cache should not differ."
class FasterCacheTesterMixin:
@@ -2657,12 +2596,12 @@ class FasterCacheTesterMixin:
output = run_forward(pipe).flatten()
image_slice_faster_cache_disabled = np.concatenate((output[:8], output[-8:]))
assert np.allclose(original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol), (
"FasterCache outputs should not differ much in specified timestep range."
)
assert np.allclose(original_image_slice, image_slice_faster_cache_disabled, atol=1e-4), (
"Outputs from normal inference and after disabling cache should not differ."
)
assert np.allclose(
original_image_slice, image_slice_faster_cache_enabled, atol=expected_atol
), "FasterCache outputs should not differ much in specified timestep range."
assert np.allclose(
original_image_slice, image_slice_faster_cache_disabled, atol=1e-4
), "Outputs from normal inference and after disabling cache should not differ."
def test_faster_cache_state(self):
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
@@ -2797,12 +2736,12 @@ class FirstBlockCacheTesterMixin:
output = run_forward(pipe).flatten()
image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:]))
assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), (
"FirstBlockCache outputs should not differ much."
)
assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), (
"Outputs from normal inference and after disabling cache should not differ."
)
assert np.allclose(
original_image_slice, image_slice_fbc_enabled, atol=expected_atol
), "FirstBlockCache outputs should not differ much."
assert np.allclose(
original_image_slice, image_slice_fbc_disabled, atol=1e-4
), "Outputs from normal inference and after disabling cache should not differ."
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.