mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update
This commit is contained in:
@@ -1,4 +1,13 @@
|
||||
from .attention import AttentionTesterMixin
|
||||
from .cache import (
|
||||
CacheTesterMixin,
|
||||
FasterCacheConfigMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheConfigMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
PyramidAttentionBroadcastConfigMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
)
|
||||
from .common import BaseModelTesterConfig, ModelTesterMixin
|
||||
from .compile import TorchCompileTesterMixin
|
||||
from .ip_adapter import IPAdapterTesterMixin
|
||||
@@ -6,11 +15,22 @@ from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
|
||||
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
|
||||
from .parallelism import ContextParallelTesterMixin
|
||||
from .quantization import (
|
||||
BitsAndBytesCompileTesterMixin,
|
||||
BitsAndBytesConfigMixin,
|
||||
BitsAndBytesTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFConfigMixin,
|
||||
GGUFTesterMixin,
|
||||
ModelOptCompileTesterMixin,
|
||||
ModelOptConfigMixin,
|
||||
ModelOptTesterMixin,
|
||||
QuantizationCompileTesterMixin,
|
||||
QuantizationTesterMixin,
|
||||
QuantoCompileTesterMixin,
|
||||
QuantoConfigMixin,
|
||||
QuantoTesterMixin,
|
||||
TorchAoCompileTesterMixin,
|
||||
TorchAoConfigMixin,
|
||||
TorchAoTesterMixin,
|
||||
)
|
||||
from .single_file import SingleFileTesterMixin
|
||||
@@ -20,9 +40,18 @@ from .training import TrainingTesterMixin
|
||||
__all__ = [
|
||||
"AttentionTesterMixin",
|
||||
"BaseModelTesterConfig",
|
||||
"BitsAndBytesCompileTesterMixin",
|
||||
"BitsAndBytesConfigMixin",
|
||||
"BitsAndBytesTesterMixin",
|
||||
"CacheTesterMixin",
|
||||
"ContextParallelTesterMixin",
|
||||
"CPUOffloadTesterMixin",
|
||||
"FasterCacheConfigMixin",
|
||||
"FasterCacheTesterMixin",
|
||||
"FirstBlockCacheConfigMixin",
|
||||
"FirstBlockCacheTesterMixin",
|
||||
"GGUFCompileTesterMixin",
|
||||
"GGUFConfigMixin",
|
||||
"GGUFTesterMixin",
|
||||
"GroupOffloadTesterMixin",
|
||||
"IPAdapterTesterMixin",
|
||||
@@ -30,11 +59,20 @@ __all__ = [
|
||||
"LoraHotSwappingForModelTesterMixin",
|
||||
"LoraTesterMixin",
|
||||
"MemoryTesterMixin",
|
||||
"ModelOptCompileTesterMixin",
|
||||
"ModelOptConfigMixin",
|
||||
"ModelOptTesterMixin",
|
||||
"ModelTesterMixin",
|
||||
"PyramidAttentionBroadcastConfigMixin",
|
||||
"PyramidAttentionBroadcastTesterMixin",
|
||||
"QuantizationCompileTesterMixin",
|
||||
"QuantizationTesterMixin",
|
||||
"QuantoCompileTesterMixin",
|
||||
"QuantoConfigMixin",
|
||||
"QuantoTesterMixin",
|
||||
"SingleFileTesterMixin",
|
||||
"TorchAoCompileTesterMixin",
|
||||
"TorchAoConfigMixin",
|
||||
"TorchAoTesterMixin",
|
||||
"TorchCompileTesterMixin",
|
||||
"TrainingTesterMixin",
|
||||
|
||||
536
tests/models/testing_utils/cache.py
Normal file
536
tests/models/testing_utils/cache.py
Normal file
@@ -0,0 +1,536 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.hooks import FasterCacheConfig, FirstBlockCacheConfig, PyramidAttentionBroadcastConfig
|
||||
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from diffusers.hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
|
||||
from diffusers.hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
from diffusers.models.cache_utils import CacheMixin
|
||||
|
||||
from ...testing_utils import assert_tensors_close, backend_empty_cache, is_cache, torch_device
|
||||
|
||||
|
||||
def require_cache_mixin(func):
|
||||
"""Decorator to skip tests if model doesn't use CacheMixin."""
|
||||
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if not issubclass(self.model_class, CacheMixin):
|
||||
pytest.skip(f"{self.model_class.__name__} does not use CacheMixin.")
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class CacheTesterMixin:
|
||||
"""
|
||||
Base mixin class providing common test implementations for cache testing.
|
||||
|
||||
Cache-specific mixins should:
|
||||
1. Inherit from their respective config mixin (e.g., PyramidAttentionBroadcastConfigMixin)
|
||||
2. Inherit from this mixin
|
||||
3. Define the cache config to use for tests
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
|
||||
Expected methods in test classes:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def _get_cache_config(self):
|
||||
"""
|
||||
Get the cache config for testing.
|
||||
Should be implemented by subclasses.
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement _get_cache_config")
|
||||
|
||||
def _get_hook_names(self):
|
||||
"""
|
||||
Get the hook names to check for this cache type.
|
||||
Should be implemented by subclasses.
|
||||
Returns a list of hook name strings.
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement _get_hook_names")
|
||||
|
||||
def _test_cache_enable_disable_state(self):
|
||||
"""Test that cache enable/disable updates the is_cache_enabled state correctly."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Initially cache should not be enabled
|
||||
assert not model.is_cache_enabled, "Cache should not be enabled initially."
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
# Enable cache
|
||||
model.enable_cache(config)
|
||||
assert model.is_cache_enabled, "Cache should be enabled after enable_cache()."
|
||||
|
||||
# Disable cache
|
||||
model.disable_cache()
|
||||
assert not model.is_cache_enabled, "Cache should not be enabled after disable_cache()."
|
||||
|
||||
def _test_cache_double_enable_raises_error(self):
|
||||
"""Test that enabling cache twice raises an error."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# Trying to enable again should raise ValueError
|
||||
with pytest.raises(ValueError, match="Caching has already been enabled"):
|
||||
model.enable_cache(config)
|
||||
|
||||
# Cleanup
|
||||
model.disable_cache()
|
||||
|
||||
def _test_cache_hooks_registered(self):
|
||||
"""Test that cache hooks are properly registered and removed."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
config = self._get_cache_config()
|
||||
hook_names = self._get_hook_names()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# Check that at least one hook was registered
|
||||
hook_count = 0
|
||||
for module in model.modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
for hook_name in hook_names:
|
||||
hook = module._diffusers_hook.get_hook(hook_name)
|
||||
if hook is not None:
|
||||
hook_count += 1
|
||||
|
||||
assert hook_count > 0, f"At least one cache hook should be registered. Hook names: {hook_names}"
|
||||
|
||||
# Disable and verify hooks are removed
|
||||
model.disable_cache()
|
||||
|
||||
hook_count_after = 0
|
||||
for module in model.modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
for hook_name in hook_names:
|
||||
hook = module._diffusers_hook.get_hook(hook_name)
|
||||
if hook is not None:
|
||||
hook_count_after += 1
|
||||
|
||||
assert hook_count_after == 0, "Cache hooks should be removed after disable_cache()."
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_cache_inference(self):
|
||||
"""Test that model can run inference with cache enabled."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# First pass populates the cache
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Create modified inputs for second pass (vary hidden_states to simulate denoising)
|
||||
inputs_dict_step2 = inputs_dict.copy()
|
||||
if "hidden_states" in inputs_dict_step2:
|
||||
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.1
|
||||
|
||||
# Second pass uses cached attention with different hidden_states (produces approximated output)
|
||||
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
assert output_with_cache is not None, "Model output should not be None with cache enabled."
|
||||
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
|
||||
|
||||
# Run same inputs without cache to compare
|
||||
model.disable_cache()
|
||||
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
# Cached output should be different from non-cached output (due to approximation)
|
||||
assert not torch.allclose(
|
||||
output_without_cache, output_with_cache, atol=1e-5
|
||||
), "Cached output should be different from non-cached output due to cache approximation."
|
||||
|
||||
def _test_cache_context_manager(self):
|
||||
"""Test the cache_context context manager."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# Test cache_context works without error
|
||||
with model.cache_context("test_context"):
|
||||
pass
|
||||
|
||||
model.disable_cache()
|
||||
|
||||
def _test_reset_stateful_cache(self):
|
||||
"""Test that _reset_stateful_cache resets the cache state."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# Run forward to populate cache state
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Reset should not raise any errors
|
||||
model._reset_stateful_cache()
|
||||
|
||||
model.disable_cache()
|
||||
|
||||
|
||||
@is_cache
|
||||
class PyramidAttentionBroadcastConfigMixin:
|
||||
"""
|
||||
Base mixin providing PyramidAttentionBroadcast cache config.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
"""
|
||||
|
||||
# Default PAB config - can be overridden by subclasses
|
||||
PAB_CONFIG = {
|
||||
"spatial_attention_block_skip_range": 2,
|
||||
}
|
||||
|
||||
# Store timestep for callback (must be within default range (100, 800) for skipping to trigger)
|
||||
_current_timestep = 500
|
||||
|
||||
def _get_cache_config(self):
|
||||
config_kwargs = self.PAB_CONFIG.copy()
|
||||
config_kwargs["current_timestep_callback"] = lambda: self._current_timestep
|
||||
return PyramidAttentionBroadcastConfig(**config_kwargs)
|
||||
|
||||
def _get_hook_names(self):
|
||||
return [_PYRAMID_ATTENTION_BROADCAST_HOOK]
|
||||
|
||||
|
||||
@is_cache
|
||||
class PyramidAttentionBroadcastTesterMixin(PyramidAttentionBroadcastConfigMixin, CacheTesterMixin):
|
||||
"""
|
||||
Mixin class for testing PyramidAttentionBroadcast caching on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: cache
|
||||
Use `pytest -m "not cache"` to skip these tests
|
||||
"""
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_cache_enable_disable_state(self):
|
||||
self._test_cache_enable_disable_state()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_cache_double_enable_raises_error(self):
|
||||
self._test_cache_double_enable_raises_error()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_cache_hooks_registered(self):
|
||||
self._test_cache_hooks_registered()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_cache_inference(self):
|
||||
self._test_cache_inference()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_cache_context_manager(self):
|
||||
self._test_cache_context_manager()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_reset_stateful_cache(self):
|
||||
self._test_reset_stateful_cache()
|
||||
|
||||
|
||||
@is_cache
|
||||
class FirstBlockCacheConfigMixin:
|
||||
"""
|
||||
Base mixin providing FirstBlockCache config.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
"""
|
||||
|
||||
# Default FBC config - can be overridden by subclasses
|
||||
# Higher threshold makes FBC more aggressive about caching (skips more often)
|
||||
FBC_CONFIG = {
|
||||
"threshold": 1.0,
|
||||
}
|
||||
|
||||
def _get_cache_config(self):
|
||||
return FirstBlockCacheConfig(**self.FBC_CONFIG)
|
||||
|
||||
def _get_hook_names(self):
|
||||
return [_FBC_LEADER_BLOCK_HOOK, _FBC_BLOCK_HOOK]
|
||||
|
||||
|
||||
@is_cache
|
||||
class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin):
|
||||
"""
|
||||
Mixin class for testing FirstBlockCache on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: cache
|
||||
Use `pytest -m "not cache"` to skip these tests
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_cache_inference(self):
|
||||
"""Test that model can run inference with FBC cache enabled (requires cache_context)."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
model.enable_cache(config)
|
||||
|
||||
# FBC requires cache_context to be set for inference
|
||||
with model.cache_context("fbc_test"):
|
||||
# First pass populates the cache
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Create modified inputs for second pass (small perturbation keeps residuals similar)
|
||||
inputs_dict_step2 = inputs_dict.copy()
|
||||
if "hidden_states" in inputs_dict_step2:
|
||||
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.01
|
||||
|
||||
# Second pass - FBC should skip remaining blocks and use cached residuals
|
||||
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
assert output_with_cache is not None, "Model output should not be None with cache enabled."
|
||||
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
|
||||
|
||||
# Run same inputs without cache to compare
|
||||
model.disable_cache()
|
||||
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
# Cached output should be different from non-cached output (due to approximation)
|
||||
assert not torch.allclose(
|
||||
output_without_cache, output_with_cache, atol=1e-5
|
||||
), "Cached output should be different from non-cached output due to cache approximation."
|
||||
|
||||
def _test_reset_stateful_cache(self):
|
||||
"""Test that _reset_stateful_cache resets the FBC cache state (requires cache_context)."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
model.enable_cache(config)
|
||||
|
||||
# FBC requires cache_context to be set for inference
|
||||
with model.cache_context("fbc_test"):
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Reset should not raise any errors
|
||||
model._reset_stateful_cache()
|
||||
|
||||
model.disable_cache()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_cache_enable_disable_state(self):
|
||||
self._test_cache_enable_disable_state()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_cache_double_enable_raises_error(self):
|
||||
self._test_cache_double_enable_raises_error()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_cache_hooks_registered(self):
|
||||
self._test_cache_hooks_registered()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_cache_inference(self):
|
||||
self._test_cache_inference()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_cache_context_manager(self):
|
||||
self._test_cache_context_manager()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_reset_stateful_cache(self):
|
||||
self._test_reset_stateful_cache()
|
||||
|
||||
|
||||
@is_cache
|
||||
class FasterCacheConfigMixin:
|
||||
"""
|
||||
Base mixin providing FasterCache config.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
"""
|
||||
|
||||
# Default FasterCache config - can be overridden by subclasses
|
||||
FASTER_CACHE_CONFIG = {
|
||||
"spatial_attention_block_skip_range": 2,
|
||||
"spatial_attention_timestep_skip_range": (-1, 901),
|
||||
"tensor_format": "BCHW",
|
||||
}
|
||||
|
||||
# Store timestep for callback - use a list so it can be mutated during test
|
||||
# Starts outside skip range so first pass computes; changed to inside range for subsequent passes
|
||||
_current_timestep = [1000]
|
||||
|
||||
def _get_cache_config(self):
|
||||
config_kwargs = self.FASTER_CACHE_CONFIG.copy()
|
||||
config_kwargs["current_timestep_callback"] = lambda: self._current_timestep[0]
|
||||
return FasterCacheConfig(**config_kwargs)
|
||||
|
||||
def _get_hook_names(self):
|
||||
return [_FASTER_CACHE_DENOISER_HOOK, _FASTER_CACHE_BLOCK_HOOK]
|
||||
|
||||
|
||||
@is_cache
|
||||
class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
|
||||
"""
|
||||
Mixin class for testing FasterCache on models.
|
||||
|
||||
Note: FasterCache is designed for pipeline-level inference with proper CFG batch handling
|
||||
and timestep management. Inference tests are skipped at model level - FasterCache should
|
||||
be tested via pipeline tests (e.g., FluxPipeline, HunyuanVideoPipeline).
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: cache
|
||||
Use `pytest -m "not cache"` to skip these tests
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_cache_inference(self):
|
||||
"""Test that model can run inference with FasterCache enabled."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# First pass with timestep outside skip range - computes and populates cache
|
||||
self._current_timestep[0] = 1000
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Move timestep inside skip range so subsequent passes use cache
|
||||
self._current_timestep[0] = 500
|
||||
|
||||
# Create modified inputs for second pass
|
||||
inputs_dict_step2 = inputs_dict.copy()
|
||||
if "hidden_states" in inputs_dict_step2:
|
||||
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.1
|
||||
|
||||
# Second pass uses cached attention with different hidden_states
|
||||
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
assert output_with_cache is not None, "Model output should not be None with cache enabled."
|
||||
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
|
||||
|
||||
# Run same inputs without cache to compare
|
||||
model.disable_cache()
|
||||
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
# Cached output should be different from non-cached output (due to approximation)
|
||||
assert not torch.allclose(
|
||||
output_without_cache, output_with_cache, atol=1e-5
|
||||
), "Cached output should be different from non-cached output due to cache approximation."
|
||||
|
||||
def _test_reset_stateful_cache(self):
|
||||
"""Test that _reset_stateful_cache resets the FasterCache state."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
model.enable_cache(config)
|
||||
|
||||
# First pass with timestep outside skip range
|
||||
self._current_timestep[0] = 1000
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Reset should not raise any errors
|
||||
model._reset_stateful_cache()
|
||||
|
||||
model.disable_cache()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_enable_disable_state(self):
|
||||
self._test_cache_enable_disable_state()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_double_enable_raises_error(self):
|
||||
self._test_cache_double_enable_raises_error()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_hooks_registered(self):
|
||||
self._test_cache_hooks_registered()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_inference(self):
|
||||
self._test_cache_inference()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_context_manager(self):
|
||||
self._test_cache_context_manager()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_reset_stateful_cache(self):
|
||||
self._test_reset_stateful_cache()
|
||||
@@ -15,7 +15,6 @@
|
||||
|
||||
import gc
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -140,7 +139,7 @@ class TorchCompileTesterMixin:
|
||||
inputs_dict = self.get_dummy_inputs(height=height, width=width)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_compile_works_with_aot(self):
|
||||
def test_compile_works_with_aot(self, tmp_path):
|
||||
from torch._inductor.package import load_package
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
@@ -149,11 +148,10 @@ class TorchCompileTesterMixin:
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
package_path = os.path.join(tmpdir, f"{self.model_class.__name__}.pt2")
|
||||
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
|
||||
assert os.path.exists(package_path), f"Package file not created at {package_path}"
|
||||
loaded_binary = load_package(package_path, run_single_threaded=True)
|
||||
package_path = os.path.join(str(tmp_path), f"{self.model_class.__name__}.pt2")
|
||||
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
|
||||
assert os.path.exists(package_path), f"Package file not created at {package_path}"
|
||||
loaded_binary = load_package(package_path, run_single_threaded=True)
|
||||
|
||||
model.forward = loaded_binary
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
import gc
|
||||
import glob
|
||||
import inspect
|
||||
import tempfile
|
||||
from functools import wraps
|
||||
|
||||
import pytest
|
||||
@@ -97,7 +96,7 @@ class CPUOffloadTesterMixin:
|
||||
model_split_percents = [0.5, 0.7]
|
||||
|
||||
@require_offload_support
|
||||
def test_cpu_offload(self):
|
||||
def test_cpu_offload(self, tmp_path):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
@@ -110,25 +109,24 @@ class CPUOffloadTesterMixin:
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
# We test several splits of sizes to make sure it works
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
model.cpu().save_pretrained(str(tmp_path))
|
||||
|
||||
for max_size in max_gpu_sizes:
|
||||
max_memory = {0: max_size, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will actually end up offloaded
|
||||
assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU"
|
||||
for max_size in max_gpu_sizes:
|
||||
max_memory = {0: max_size, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will actually end up offloaded
|
||||
assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU"
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert_tensors_close(
|
||||
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with CPU offloading"
|
||||
)
|
||||
assert_tensors_close(
|
||||
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with CPU offloading"
|
||||
)
|
||||
|
||||
@require_offload_support
|
||||
def test_disk_offload_without_safetensors(self):
|
||||
def test_disk_offload_without_safetensors(self, tmp_path):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
@@ -143,26 +141,25 @@ class CPUOffloadTesterMixin:
|
||||
# Force disk offload by setting very small CPU memory
|
||||
max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
|
||||
# This errors out because it's missing an offload folder
|
||||
with pytest.raises(ValueError):
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
model.cpu().save_pretrained(str(tmp_path), safe_serialization=False)
|
||||
# This errors out because it's missing an offload folder
|
||||
with pytest.raises(ValueError):
|
||||
new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory)
|
||||
|
||||
new_model = self.model_class.from_pretrained(
|
||||
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
|
||||
)
|
||||
new_model = self.model_class.from_pretrained(
|
||||
str(tmp_path), device_map="auto", max_memory=max_memory, offload_folder=str(tmp_path)
|
||||
)
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert_tensors_close(
|
||||
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with disk offloading"
|
||||
)
|
||||
assert_tensors_close(
|
||||
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with disk offloading"
|
||||
)
|
||||
|
||||
@require_offload_support
|
||||
def test_disk_offload_with_safetensors(self):
|
||||
def test_disk_offload_with_safetensors(self, tmp_path):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
@@ -173,26 +170,25 @@ class CPUOffloadTesterMixin:
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
model.cpu().save_pretrained(str(tmp_path))
|
||||
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
max_memory = {0: max_size, "cpu": max_size}
|
||||
new_model = self.model_class.from_pretrained(
|
||||
tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory
|
||||
)
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
max_memory = {0: max_size, "cpu": max_size}
|
||||
new_model = self.model_class.from_pretrained(
|
||||
str(tmp_path), device_map="auto", offload_folder=str(tmp_path), max_memory=max_memory
|
||||
)
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert_tensors_close(
|
||||
base_output[0],
|
||||
new_output[0],
|
||||
atol=1e-5,
|
||||
rtol=0,
|
||||
msg="Output should match with disk offloading (safetensors)",
|
||||
)
|
||||
assert_tensors_close(
|
||||
base_output[0],
|
||||
new_output[0],
|
||||
atol=1e-5,
|
||||
rtol=0,
|
||||
msg="Output should match with disk offloading (safetensors)",
|
||||
)
|
||||
|
||||
|
||||
@is_group_offload
|
||||
@@ -312,7 +308,7 @@ class GroupOffloadTesterMixin:
|
||||
@require_group_offload_support
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def test_group_offloading_with_disk(self, offload_type="block_level", record_stream=False, atol=1e-5):
|
||||
def test_group_offloading_with_disk(self, tmp_path, offload_type="block_level", record_stream=False, atol=1e-5):
|
||||
def _has_generator_arg(model):
|
||||
sig = inspect.signature(model.forward)
|
||||
params = sig.parameters
|
||||
@@ -340,41 +336,41 @@ class GroupOffloadTesterMixin:
|
||||
|
||||
num_blocks_per_group = None if offload_type == "leaf_level" else 1
|
||||
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.enable_group_offload(
|
||||
torch_device,
|
||||
offload_type=offload_type,
|
||||
tmpdir = str(tmp_path)
|
||||
model.enable_group_offload(
|
||||
torch_device,
|
||||
offload_type=offload_type,
|
||||
offload_to_disk_path=tmpdir,
|
||||
use_stream=True,
|
||||
record_stream=record_stream,
|
||||
**additional_kwargs,
|
||||
)
|
||||
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
|
||||
assert has_safetensors, "No safetensors found in the directory."
|
||||
|
||||
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
|
||||
# in nature. So, skip it.
|
||||
if offload_type != "leaf_level":
|
||||
is_correct, extra_files, missing_files = _check_safetensors_serialization(
|
||||
module=model,
|
||||
offload_to_disk_path=tmpdir,
|
||||
use_stream=True,
|
||||
record_stream=record_stream,
|
||||
**additional_kwargs,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
)
|
||||
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
|
||||
assert has_safetensors, "No safetensors found in the directory."
|
||||
if not is_correct:
|
||||
if extra_files:
|
||||
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
|
||||
elif missing_files:
|
||||
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
|
||||
|
||||
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
|
||||
# in nature. So, skip it.
|
||||
if offload_type != "leaf_level":
|
||||
is_correct, extra_files, missing_files = _check_safetensors_serialization(
|
||||
module=model,
|
||||
offload_to_disk_path=tmpdir,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
)
|
||||
if not is_correct:
|
||||
if extra_files:
|
||||
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
|
||||
elif missing_files:
|
||||
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
|
||||
|
||||
output_with_group_offloading = _run_forward(model, inputs_dict)
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading,
|
||||
atol=atol,
|
||||
rtol=0,
|
||||
msg="Output should match with disk-based group offloading",
|
||||
)
|
||||
output_with_group_offloading = _run_forward(model, inputs_dict)
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading,
|
||||
atol=atol,
|
||||
rtol=0,
|
||||
msg="Output should match with disk-based group offloading",
|
||||
)
|
||||
|
||||
|
||||
class LayerwiseCastingTesterMixin:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
@@ -151,21 +150,20 @@ class SingleFileTesterMixin:
|
||||
param, param_single_file, atol=1e-5, rtol=1e-5, msg=f"Parameter values differ for {key}"
|
||||
)
|
||||
|
||||
def test_single_file_loading_local_files_only(self):
|
||||
def test_single_file_loading_local_files_only(self, tmp_path):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir)
|
||||
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path))
|
||||
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
local_ckpt_path, local_files_only=True, **single_file_kwargs
|
||||
)
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
local_ckpt_path, local_files_only=True, **single_file_kwargs
|
||||
)
|
||||
|
||||
assert model_single_file is not None, "Failed to load model with local_files_only=True"
|
||||
assert model_single_file is not None, "Failed to load model with local_files_only=True"
|
||||
|
||||
def test_single_file_loading_with_diffusers_config(self):
|
||||
single_file_kwargs = {}
|
||||
@@ -196,22 +194,21 @@ class SingleFileTesterMixin:
|
||||
f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}"
|
||||
)
|
||||
|
||||
def test_single_file_loading_with_diffusers_config_local_files_only(self):
|
||||
def test_single_file_loading_with_diffusers_config_local_files_only(self, tmp_path):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, tmpdir)
|
||||
local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, tmpdir)
|
||||
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path))
|
||||
local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, str(tmp_path))
|
||||
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs
|
||||
)
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs
|
||||
)
|
||||
|
||||
assert model_single_file is not None, "Failed to load model with config and local_files_only=True"
|
||||
assert model_single_file is not None, "Failed to load model with config and local_files_only=True"
|
||||
|
||||
def test_single_file_loading_dtype(self):
|
||||
for dtype in [torch.float32, torch.float16]:
|
||||
|
||||
@@ -26,17 +26,25 @@ from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesCompileTesterMixin,
|
||||
BitsAndBytesTesterMixin,
|
||||
ContextParallelTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
IPAdapterTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelOptCompileTesterMixin,
|
||||
ModelOptTesterMixin,
|
||||
ModelTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
QuantoCompileTesterMixin,
|
||||
QuantoTesterMixin,
|
||||
SingleFileTesterMixin,
|
||||
TorchAoCompileTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
@@ -353,3 +361,94 @@ class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMix
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAndBytesCompileTesterMixin):
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerQuantoCompile(FluxTransformerTesterConfig, QuantoCompileTesterMixin):
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerTorchAoCompile(FluxTransformerTesterConfig, TorchAoCompileTesterMixin):
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTesterMixin):
|
||||
gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCompileTesterMixin):
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
"pooled_projections": randn_tensor((1, 768)),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device),
|
||||
"img_ids": randn_tensor((4096, 3)),
|
||||
"txt_ids": randn_tensor((512, 3)),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin):
|
||||
"""PyramidAttentionBroadcast cache tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
|
||||
"""FirstBlockCache tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerFasterCache(FluxTransformerTesterConfig, FasterCacheTesterMixin):
|
||||
"""FasterCache tests for Flux Transformer."""
|
||||
|
||||
# Flux is guidance distilled, so we can test at model level without CFG batch handling
|
||||
FASTER_CACHE_CONFIG = {
|
||||
"spatial_attention_block_skip_range": 2,
|
||||
"spatial_attention_timestep_skip_range": (-1, 901),
|
||||
"tensor_format": "BCHW",
|
||||
"is_guidance_distilled": True,
|
||||
}
|
||||
|
||||
@@ -414,6 +414,15 @@ def is_group_offload(test_case):
|
||||
return pytest.mark.group_offload(test_case)
|
||||
|
||||
|
||||
def is_quantization(test_case):
|
||||
"""
|
||||
Decorator marking a test as a quantization test. These tests can be filtered using:
|
||||
pytest -m "not quantization" to skip
|
||||
pytest -m quantization to run only these tests
|
||||
"""
|
||||
return pytest.mark.quantization(test_case)
|
||||
|
||||
|
||||
def is_bitsandbytes(test_case):
|
||||
"""
|
||||
Decorator marking a test as a BitsAndBytes quantization test. These tests can be filtered using:
|
||||
@@ -468,6 +477,15 @@ def is_context_parallel(test_case):
|
||||
return pytest.mark.context_parallel(test_case)
|
||||
|
||||
|
||||
def is_cache(test_case):
|
||||
"""
|
||||
Decorator marking a test as a cache test. These tests can be filtered using:
|
||||
pytest -m "not cache" to skip
|
||||
pytest -m cache to run only these tests
|
||||
"""
|
||||
return pytest.mark.cache(test_case)
|
||||
|
||||
|
||||
def require_torch(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed.
|
||||
@@ -835,7 +853,7 @@ def require_modelopt_version_greater_or_equal(modelopt_version):
|
||||
) >= version.parse(modelopt_version)
|
||||
return pytest.mark.skipif(
|
||||
not correct_nvidia_modelopt_version,
|
||||
f"Test requires modelopt with version greater than {modelopt_version}.",
|
||||
reason=f"Test requires modelopt with version greater than {modelopt_version}.",
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
Reference in New Issue
Block a user