1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
DN6
2025-12-18 13:16:50 +05:30
parent d9b73ffd51
commit e82001e40d
8 changed files with 1445 additions and 328 deletions

View File

@@ -1,4 +1,13 @@
from .attention import AttentionTesterMixin
from .cache import (
CacheTesterMixin,
FasterCacheConfigMixin,
FasterCacheTesterMixin,
FirstBlockCacheConfigMixin,
FirstBlockCacheTesterMixin,
PyramidAttentionBroadcastConfigMixin,
PyramidAttentionBroadcastTesterMixin,
)
from .common import BaseModelTesterConfig, ModelTesterMixin
from .compile import TorchCompileTesterMixin
from .ip_adapter import IPAdapterTesterMixin
@@ -6,11 +15,22 @@ from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
from .parallelism import ContextParallelTesterMixin
from .quantization import (
BitsAndBytesCompileTesterMixin,
BitsAndBytesConfigMixin,
BitsAndBytesTesterMixin,
GGUFCompileTesterMixin,
GGUFConfigMixin,
GGUFTesterMixin,
ModelOptCompileTesterMixin,
ModelOptConfigMixin,
ModelOptTesterMixin,
QuantizationCompileTesterMixin,
QuantizationTesterMixin,
QuantoCompileTesterMixin,
QuantoConfigMixin,
QuantoTesterMixin,
TorchAoCompileTesterMixin,
TorchAoConfigMixin,
TorchAoTesterMixin,
)
from .single_file import SingleFileTesterMixin
@@ -20,9 +40,18 @@ from .training import TrainingTesterMixin
__all__ = [
"AttentionTesterMixin",
"BaseModelTesterConfig",
"BitsAndBytesCompileTesterMixin",
"BitsAndBytesConfigMixin",
"BitsAndBytesTesterMixin",
"CacheTesterMixin",
"ContextParallelTesterMixin",
"CPUOffloadTesterMixin",
"FasterCacheConfigMixin",
"FasterCacheTesterMixin",
"FirstBlockCacheConfigMixin",
"FirstBlockCacheTesterMixin",
"GGUFCompileTesterMixin",
"GGUFConfigMixin",
"GGUFTesterMixin",
"GroupOffloadTesterMixin",
"IPAdapterTesterMixin",
@@ -30,11 +59,20 @@ __all__ = [
"LoraHotSwappingForModelTesterMixin",
"LoraTesterMixin",
"MemoryTesterMixin",
"ModelOptCompileTesterMixin",
"ModelOptConfigMixin",
"ModelOptTesterMixin",
"ModelTesterMixin",
"PyramidAttentionBroadcastConfigMixin",
"PyramidAttentionBroadcastTesterMixin",
"QuantizationCompileTesterMixin",
"QuantizationTesterMixin",
"QuantoCompileTesterMixin",
"QuantoConfigMixin",
"QuantoTesterMixin",
"SingleFileTesterMixin",
"TorchAoCompileTesterMixin",
"TorchAoConfigMixin",
"TorchAoTesterMixin",
"TorchCompileTesterMixin",
"TrainingTesterMixin",

View File

@@ -0,0 +1,536 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import pytest
import torch
from diffusers.hooks import FasterCacheConfig, FirstBlockCacheConfig, PyramidAttentionBroadcastConfig
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from diffusers.hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from diffusers.hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
from diffusers.models.cache_utils import CacheMixin
from ...testing_utils import assert_tensors_close, backend_empty_cache, is_cache, torch_device
def require_cache_mixin(func):
"""Decorator to skip tests if model doesn't use CacheMixin."""
def wrapper(self, *args, **kwargs):
if not issubclass(self.model_class, CacheMixin):
pytest.skip(f"{self.model_class.__name__} does not use CacheMixin.")
return func(self, *args, **kwargs)
return wrapper
class CacheTesterMixin:
"""
Base mixin class providing common test implementations for cache testing.
Cache-specific mixins should:
1. Inherit from their respective config mixin (e.g., PyramidAttentionBroadcastConfigMixin)
2. Inherit from this mixin
3. Define the cache config to use for tests
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods in test classes:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
"""
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def _get_cache_config(self):
"""
Get the cache config for testing.
Should be implemented by subclasses.
"""
raise NotImplementedError("Subclass must implement _get_cache_config")
def _get_hook_names(self):
"""
Get the hook names to check for this cache type.
Should be implemented by subclasses.
Returns a list of hook name strings.
"""
raise NotImplementedError("Subclass must implement _get_hook_names")
def _test_cache_enable_disable_state(self):
"""Test that cache enable/disable updates the is_cache_enabled state correctly."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
# Initially cache should not be enabled
assert not model.is_cache_enabled, "Cache should not be enabled initially."
config = self._get_cache_config()
# Enable cache
model.enable_cache(config)
assert model.is_cache_enabled, "Cache should be enabled after enable_cache()."
# Disable cache
model.disable_cache()
assert not model.is_cache_enabled, "Cache should not be enabled after disable_cache()."
def _test_cache_double_enable_raises_error(self):
"""Test that enabling cache twice raises an error."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
config = self._get_cache_config()
model.enable_cache(config)
# Trying to enable again should raise ValueError
with pytest.raises(ValueError, match="Caching has already been enabled"):
model.enable_cache(config)
# Cleanup
model.disable_cache()
def _test_cache_hooks_registered(self):
"""Test that cache hooks are properly registered and removed."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
config = self._get_cache_config()
hook_names = self._get_hook_names()
model.enable_cache(config)
# Check that at least one hook was registered
hook_count = 0
for module in model.modules():
if hasattr(module, "_diffusers_hook"):
for hook_name in hook_names:
hook = module._diffusers_hook.get_hook(hook_name)
if hook is not None:
hook_count += 1
assert hook_count > 0, f"At least one cache hook should be registered. Hook names: {hook_names}"
# Disable and verify hooks are removed
model.disable_cache()
hook_count_after = 0
for module in model.modules():
if hasattr(module, "_diffusers_hook"):
for hook_name in hook_names:
hook = module._diffusers_hook.get_hook(hook_name)
if hook is not None:
hook_count_after += 1
assert hook_count_after == 0, "Cache hooks should be removed after disable_cache()."
@torch.no_grad()
def _test_cache_inference(self):
"""Test that model can run inference with cache enabled."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# First pass populates the cache
_ = model(**inputs_dict, return_dict=False)[0]
# Create modified inputs for second pass (vary hidden_states to simulate denoising)
inputs_dict_step2 = inputs_dict.copy()
if "hidden_states" in inputs_dict_step2:
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.1
# Second pass uses cached attention with different hidden_states (produces approximated output)
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
assert output_with_cache is not None, "Model output should not be None with cache enabled."
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
# Run same inputs without cache to compare
model.disable_cache()
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
# Cached output should be different from non-cached output (due to approximation)
assert not torch.allclose(
output_without_cache, output_with_cache, atol=1e-5
), "Cached output should be different from non-cached output due to cache approximation."
def _test_cache_context_manager(self):
"""Test the cache_context context manager."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
config = self._get_cache_config()
model.enable_cache(config)
# Test cache_context works without error
with model.cache_context("test_context"):
pass
model.disable_cache()
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the cache state."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# Run forward to populate cache state
with torch.no_grad():
_ = model(**inputs_dict, return_dict=False)[0]
# Reset should not raise any errors
model._reset_stateful_cache()
model.disable_cache()
@is_cache
class PyramidAttentionBroadcastConfigMixin:
"""
Base mixin providing PyramidAttentionBroadcast cache config.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
"""
# Default PAB config - can be overridden by subclasses
PAB_CONFIG = {
"spatial_attention_block_skip_range": 2,
}
# Store timestep for callback (must be within default range (100, 800) for skipping to trigger)
_current_timestep = 500
def _get_cache_config(self):
config_kwargs = self.PAB_CONFIG.copy()
config_kwargs["current_timestep_callback"] = lambda: self._current_timestep
return PyramidAttentionBroadcastConfig(**config_kwargs)
def _get_hook_names(self):
return [_PYRAMID_ATTENTION_BROADCAST_HOOK]
@is_cache
class PyramidAttentionBroadcastTesterMixin(PyramidAttentionBroadcastConfigMixin, CacheTesterMixin):
"""
Mixin class for testing PyramidAttentionBroadcast caching on models.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cache
Use `pytest -m "not cache"` to skip these tests
"""
@require_cache_mixin
def test_pab_cache_enable_disable_state(self):
self._test_cache_enable_disable_state()
@require_cache_mixin
def test_pab_cache_double_enable_raises_error(self):
self._test_cache_double_enable_raises_error()
@require_cache_mixin
def test_pab_cache_hooks_registered(self):
self._test_cache_hooks_registered()
@require_cache_mixin
def test_pab_cache_inference(self):
self._test_cache_inference()
@require_cache_mixin
def test_pab_cache_context_manager(self):
self._test_cache_context_manager()
@require_cache_mixin
def test_pab_reset_stateful_cache(self):
self._test_reset_stateful_cache()
@is_cache
class FirstBlockCacheConfigMixin:
"""
Base mixin providing FirstBlockCache config.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
"""
# Default FBC config - can be overridden by subclasses
# Higher threshold makes FBC more aggressive about caching (skips more often)
FBC_CONFIG = {
"threshold": 1.0,
}
def _get_cache_config(self):
return FirstBlockCacheConfig(**self.FBC_CONFIG)
def _get_hook_names(self):
return [_FBC_LEADER_BLOCK_HOOK, _FBC_BLOCK_HOOK]
@is_cache
class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin):
"""
Mixin class for testing FirstBlockCache on models.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cache
Use `pytest -m "not cache"` to skip these tests
"""
@torch.no_grad()
def _test_cache_inference(self):
"""Test that model can run inference with FBC cache enabled (requires cache_context)."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# FBC requires cache_context to be set for inference
with model.cache_context("fbc_test"):
# First pass populates the cache
_ = model(**inputs_dict, return_dict=False)[0]
# Create modified inputs for second pass (small perturbation keeps residuals similar)
inputs_dict_step2 = inputs_dict.copy()
if "hidden_states" in inputs_dict_step2:
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.01
# Second pass - FBC should skip remaining blocks and use cached residuals
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
assert output_with_cache is not None, "Model output should not be None with cache enabled."
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
# Run same inputs without cache to compare
model.disable_cache()
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
# Cached output should be different from non-cached output (due to approximation)
assert not torch.allclose(
output_without_cache, output_with_cache, atol=1e-5
), "Cached output should be different from non-cached output due to cache approximation."
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the FBC cache state (requires cache_context)."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# FBC requires cache_context to be set for inference
with model.cache_context("fbc_test"):
with torch.no_grad():
_ = model(**inputs_dict, return_dict=False)[0]
# Reset should not raise any errors
model._reset_stateful_cache()
model.disable_cache()
@require_cache_mixin
def test_fbc_cache_enable_disable_state(self):
self._test_cache_enable_disable_state()
@require_cache_mixin
def test_fbc_cache_double_enable_raises_error(self):
self._test_cache_double_enable_raises_error()
@require_cache_mixin
def test_fbc_cache_hooks_registered(self):
self._test_cache_hooks_registered()
@require_cache_mixin
def test_fbc_cache_inference(self):
self._test_cache_inference()
@require_cache_mixin
def test_fbc_cache_context_manager(self):
self._test_cache_context_manager()
@require_cache_mixin
def test_fbc_reset_stateful_cache(self):
self._test_reset_stateful_cache()
@is_cache
class FasterCacheConfigMixin:
"""
Base mixin providing FasterCache config.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
"""
# Default FasterCache config - can be overridden by subclasses
FASTER_CACHE_CONFIG = {
"spatial_attention_block_skip_range": 2,
"spatial_attention_timestep_skip_range": (-1, 901),
"tensor_format": "BCHW",
}
# Store timestep for callback - use a list so it can be mutated during test
# Starts outside skip range so first pass computes; changed to inside range for subsequent passes
_current_timestep = [1000]
def _get_cache_config(self):
config_kwargs = self.FASTER_CACHE_CONFIG.copy()
config_kwargs["current_timestep_callback"] = lambda: self._current_timestep[0]
return FasterCacheConfig(**config_kwargs)
def _get_hook_names(self):
return [_FASTER_CACHE_DENOISER_HOOK, _FASTER_CACHE_BLOCK_HOOK]
@is_cache
class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
"""
Mixin class for testing FasterCache on models.
Note: FasterCache is designed for pipeline-level inference with proper CFG batch handling
and timestep management. Inference tests are skipped at model level - FasterCache should
be tested via pipeline tests (e.g., FluxPipeline, HunyuanVideoPipeline).
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cache
Use `pytest -m "not cache"` to skip these tests
"""
@torch.no_grad()
def _test_cache_inference(self):
"""Test that model can run inference with FasterCache enabled."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# First pass with timestep outside skip range - computes and populates cache
self._current_timestep[0] = 1000
_ = model(**inputs_dict, return_dict=False)[0]
# Move timestep inside skip range so subsequent passes use cache
self._current_timestep[0] = 500
# Create modified inputs for second pass
inputs_dict_step2 = inputs_dict.copy()
if "hidden_states" in inputs_dict_step2:
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.1
# Second pass uses cached attention with different hidden_states
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
assert output_with_cache is not None, "Model output should not be None with cache enabled."
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
# Run same inputs without cache to compare
model.disable_cache()
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
# Cached output should be different from non-cached output (due to approximation)
assert not torch.allclose(
output_without_cache, output_with_cache, atol=1e-5
), "Cached output should be different from non-cached output due to cache approximation."
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the FasterCache state."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# First pass with timestep outside skip range
self._current_timestep[0] = 1000
with torch.no_grad():
_ = model(**inputs_dict, return_dict=False)[0]
# Reset should not raise any errors
model._reset_stateful_cache()
model.disable_cache()
@require_cache_mixin
def test_faster_cache_enable_disable_state(self):
self._test_cache_enable_disable_state()
@require_cache_mixin
def test_faster_cache_double_enable_raises_error(self):
self._test_cache_double_enable_raises_error()
@require_cache_mixin
def test_faster_cache_hooks_registered(self):
self._test_cache_hooks_registered()
@require_cache_mixin
def test_faster_cache_inference(self):
self._test_cache_inference()
@require_cache_mixin
def test_faster_cache_context_manager(self):
self._test_cache_context_manager()
@require_cache_mixin
def test_faster_cache_reset_stateful_cache(self):
self._test_reset_stateful_cache()

View File

@@ -15,7 +15,6 @@
import gc
import os
import tempfile
import pytest
import torch
@@ -140,7 +139,7 @@ class TorchCompileTesterMixin:
inputs_dict = self.get_dummy_inputs(height=height, width=width)
_ = model(**inputs_dict)
def test_compile_works_with_aot(self):
def test_compile_works_with_aot(self, tmp_path):
from torch._inductor.package import load_package
init_dict = self.get_init_dict()
@@ -149,11 +148,10 @@ class TorchCompileTesterMixin:
model = self.model_class(**init_dict).to(torch_device)
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
with tempfile.TemporaryDirectory() as tmpdir:
package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2")
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
assert os.path.exists(package_path), f"Package file not created at {package_path}"
loaded_binary = load_package(package_path, run_single_threaded=True)
package_path = os.path.join(str(tmp_path), f"{self.model_class.__name__}.pt2")
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
assert os.path.exists(package_path), f"Package file not created at {package_path}"
loaded_binary = load_package(package_path, run_single_threaded=True)
model.forward = loaded_binary

View File

@@ -16,7 +16,6 @@
import gc
import glob
import inspect
import tempfile
from functools import wraps
import pytest
@@ -97,7 +96,7 @@ class CPUOffloadTesterMixin:
model_split_percents = [0.5, 0.7]
@require_offload_support
def test_cpu_offload(self):
def test_cpu_offload(self, tmp_path):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
@@ -110,25 +109,24 @@ class CPUOffloadTesterMixin:
model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
model.cpu().save_pretrained(str(tmp_path))
for max_size in max_gpu_sizes:
max_memory = {0: max_size, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU"
for max_size in max_gpu_sizes:
max_memory = {0: max_size, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU"
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert_tensors_close(
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with CPU offloading"
)
assert_tensors_close(
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with CPU offloading"
)
@require_offload_support
def test_disk_offload_without_safetensors(self):
def test_disk_offload_without_safetensors(self, tmp_path):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
@@ -143,26 +141,25 @@ class CPUOffloadTesterMixin:
# Force disk offload by setting very small CPU memory
max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
# This errors out because it's missing an offload folder
with pytest.raises(ValueError):
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
model.cpu().save_pretrained(str(tmp_path), safe_serialization=False)
# This errors out because it's missing an offload folder
with pytest.raises(ValueError):
new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory)
new_model = self.model_class.from_pretrained(
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
)
new_model = self.model_class.from_pretrained(
str(tmp_path), device_map="auto", max_memory=max_memory, offload_folder=str(tmp_path)
)
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert_tensors_close(
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with disk offloading"
)
assert_tensors_close(
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with disk offloading"
)
@require_offload_support
def test_disk_offload_with_safetensors(self):
def test_disk_offload_with_safetensors(self, tmp_path):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
@@ -173,26 +170,25 @@ class CPUOffloadTesterMixin:
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
model.cpu().save_pretrained(str(tmp_path))
max_size = int(self.model_split_percents[0] * model_size)
max_memory = {0: max_size, "cpu": max_size}
new_model = self.model_class.from_pretrained(
tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory
)
max_size = int(self.model_split_percents[0] * model_size)
max_memory = {0: max_size, "cpu": max_size}
new_model = self.model_class.from_pretrained(
str(tmp_path), device_map="auto", offload_folder=str(tmp_path), max_memory=max_memory
)
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert_tensors_close(
base_output[0],
new_output[0],
atol=1e-5,
rtol=0,
msg="Output should match with disk offloading (safetensors)",
)
assert_tensors_close(
base_output[0],
new_output[0],
atol=1e-5,
rtol=0,
msg="Output should match with disk offloading (safetensors)",
)
@is_group_offload
@@ -312,7 +308,7 @@ class GroupOffloadTesterMixin:
@require_group_offload_support
@torch.no_grad()
@torch.inference_mode()
def test_group_offloading_with_disk(self, offload_type="block_level", record_stream=False, atol=1e-5):
def test_group_offloading_with_disk(self, tmp_path, offload_type="block_level", record_stream=False, atol=1e-5):
def _has_generator_arg(model):
sig = inspect.signature(model.forward)
params = sig.parameters
@@ -340,41 +336,41 @@ class GroupOffloadTesterMixin:
num_blocks_per_group = None if offload_type == "leaf_level" else 1
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
with tempfile.TemporaryDirectory() as tmpdir:
model.enable_group_offload(
torch_device,
offload_type=offload_type,
tmpdir = str(tmp_path)
model.enable_group_offload(
torch_device,
offload_type=offload_type,
offload_to_disk_path=tmpdir,
use_stream=True,
record_stream=record_stream,
**additional_kwargs,
)
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
assert has_safetensors, "No safetensors found in the directory."
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
# in nature. So, skip it.
if offload_type != "leaf_level":
is_correct, extra_files, missing_files = _check_safetensors_serialization(
module=model,
offload_to_disk_path=tmpdir,
use_stream=True,
record_stream=record_stream,
**additional_kwargs,
offload_type=offload_type,
num_blocks_per_group=num_blocks_per_group,
)
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
assert has_safetensors, "No safetensors found in the directory."
if not is_correct:
if extra_files:
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
elif missing_files:
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
# in nature. So, skip it.
if offload_type != "leaf_level":
is_correct, extra_files, missing_files = _check_safetensors_serialization(
module=model,
offload_to_disk_path=tmpdir,
offload_type=offload_type,
num_blocks_per_group=num_blocks_per_group,
)
if not is_correct:
if extra_files:
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
elif missing_files:
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
output_with_group_offloading = _run_forward(model, inputs_dict)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading,
atol=atol,
rtol=0,
msg="Output should match with disk-based group offloading",
)
output_with_group_offloading = _run_forward(model, inputs_dict)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading,
atol=atol,
rtol=0,
msg="Output should match with disk-based group offloading",
)
class LayerwiseCastingTesterMixin:

File diff suppressed because it is too large Load Diff

View File

@@ -14,7 +14,6 @@
# limitations under the License.
import gc
import tempfile
import torch
from huggingface_hub import hf_hub_download, snapshot_download
@@ -151,21 +150,20 @@ class SingleFileTesterMixin:
param, param_single_file, atol=1e-5, rtol=1e-5, msg=f"Parameter values differ for {key}"
)
def test_single_file_loading_local_files_only(self):
def test_single_file_loading_local_files_only(self, tmp_path):
single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
with tempfile.TemporaryDirectory() as tmpdir:
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir)
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path))
model_single_file = self.model_class.from_single_file(
local_ckpt_path, local_files_only=True, **single_file_kwargs
)
model_single_file = self.model_class.from_single_file(
local_ckpt_path, local_files_only=True, **single_file_kwargs
)
assert model_single_file is not None, "Failed to load model with local_files_only=True"
assert model_single_file is not None, "Failed to load model with local_files_only=True"
def test_single_file_loading_with_diffusers_config(self):
single_file_kwargs = {}
@@ -196,22 +194,21 @@ class SingleFileTesterMixin:
f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}"
)
def test_single_file_loading_with_diffusers_config_local_files_only(self):
def test_single_file_loading_with_diffusers_config_local_files_only(self, tmp_path):
single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
with tempfile.TemporaryDirectory() as tmpdir:
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir)
local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, tmpdir)
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path))
local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, str(tmp_path))
model_single_file = self.model_class.from_single_file(
local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs
)
model_single_file = self.model_class.from_single_file(
local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs
)
assert model_single_file is not None, "Failed to load model with config and local_files_only=True"
assert model_single_file is not None, "Failed to load model with config and local_files_only=True"
def test_single_file_loading_dtype(self):
for dtype in [torch.float32, torch.float16]:

View File

@@ -26,17 +26,25 @@ from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesCompileTesterMixin,
BitsAndBytesTesterMixin,
ContextParallelTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
IPAdapterTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelOptCompileTesterMixin,
ModelOptTesterMixin,
ModelTesterMixin,
PyramidAttentionBroadcastTesterMixin,
QuantoCompileTesterMixin,
QuantoTesterMixin,
SingleFileTesterMixin,
TorchAoCompileTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
@@ -353,3 +361,94 @@ class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMix
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAndBytesCompileTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerQuantoCompile(FluxTransformerTesterConfig, QuantoCompileTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerTorchAoCompile(FluxTransformerTesterConfig, TorchAoCompileTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTesterMixin):
gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCompileTesterMixin):
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
return {
"hidden_states": randn_tensor((1, 4096, 64)),
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
"pooled_projections": randn_tensor((1, 768)),
"timestep": torch.tensor([1.0]).to(torch_device),
"img_ids": randn_tensor((4096, 3)),
"txt_ids": randn_tensor((512, 3)),
"guidance": torch.tensor([3.5]).to(torch_device),
}
class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin):
"""PyramidAttentionBroadcast cache tests for Flux Transformer."""
pass
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
"""FirstBlockCache tests for Flux Transformer."""
pass
class TestFluxTransformerFasterCache(FluxTransformerTesterConfig, FasterCacheTesterMixin):
"""FasterCache tests for Flux Transformer."""
# Flux is guidance distilled, so we can test at model level without CFG batch handling
FASTER_CACHE_CONFIG = {
"spatial_attention_block_skip_range": 2,
"spatial_attention_timestep_skip_range": (-1, 901),
"tensor_format": "BCHW",
"is_guidance_distilled": True,
}

View File

@@ -414,6 +414,15 @@ def is_group_offload(test_case):
return pytest.mark.group_offload(test_case)
def is_quantization(test_case):
"""
Decorator marking a test as a quantization test. These tests can be filtered using:
pytest -m "not quantization" to skip
pytest -m quantization to run only these tests
"""
return pytest.mark.quantization(test_case)
def is_bitsandbytes(test_case):
"""
Decorator marking a test as a BitsAndBytes quantization test. These tests can be filtered using:
@@ -468,6 +477,15 @@ def is_context_parallel(test_case):
return pytest.mark.context_parallel(test_case)
def is_cache(test_case):
"""
Decorator marking a test as a cache test. These tests can be filtered using:
pytest -m "not cache" to skip
pytest -m cache to run only these tests
"""
return pytest.mark.cache(test_case)
def require_torch(test_case):
"""
Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed.
@@ -835,7 +853,7 @@ def require_modelopt_version_greater_or_equal(modelopt_version):
) >= version.parse(modelopt_version)
return pytest.mark.skipif(
not correct_nvidia_modelopt_version,
f"Test requires modelopt with version greater than {modelopt_version}.",
reason=f"Test requires modelopt with version greater than {modelopt_version}.",
)(test_case)
return decorator