mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update
This commit is contained in:
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -23,6 +25,7 @@ from diffusers.models.attention_processor import (
|
||||
|
||||
from ...testing_utils import (
|
||||
assert_tensors_close,
|
||||
backend_empty_cache,
|
||||
is_attention,
|
||||
torch_device,
|
||||
)
|
||||
@@ -38,11 +41,10 @@ class AttentionTesterMixin:
|
||||
- QKV projection fusion/unfusion
|
||||
- Attention backends (XFormers, NPU, etc.)
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
- uses_custom_attn_processor: Whether model uses custom attention processors (default: False)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
@@ -50,6 +52,14 @@ class AttentionTesterMixin:
|
||||
Use `pytest -m "not attention"` to skip these tests
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_fuse_unfuse_qkv_projections(self):
|
||||
init_dict = self.get_init_dict()
|
||||
|
||||
@@ -24,7 +24,7 @@ from diffusers.hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK
|
||||
from diffusers.hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
from diffusers.models.cache_utils import CacheMixin
|
||||
|
||||
from ...testing_utils import backend_empty_cache, is_cache, torch_device
|
||||
from ...testing_utils import assert_tensors_close, backend_empty_cache, is_cache, torch_device
|
||||
|
||||
|
||||
def require_cache_mixin(func):
|
||||
@@ -53,8 +53,15 @@ class CacheTesterMixin:
|
||||
Expected methods in test classes:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Optional overrides:
|
||||
- cache_input_key: Property returning the input tensor key to vary between passes (default: "hidden_states")
|
||||
"""
|
||||
|
||||
@property
|
||||
def cache_input_key(self):
|
||||
return "hidden_states"
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
@@ -161,12 +168,14 @@ class CacheTesterMixin:
|
||||
# First pass populates the cache
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Create modified inputs for second pass (vary hidden_states to simulate denoising)
|
||||
# Create modified inputs for second pass (vary input tensor to simulate denoising)
|
||||
inputs_dict_step2 = inputs_dict.copy()
|
||||
if "hidden_states" in inputs_dict_step2:
|
||||
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.1
|
||||
if self.cache_input_key in inputs_dict_step2:
|
||||
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
|
||||
inputs_dict_step2[self.cache_input_key]
|
||||
)
|
||||
|
||||
# Second pass uses cached attention with different hidden_states (produces approximated output)
|
||||
# Second pass uses cached attention with different inputs (produces approximated output)
|
||||
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
assert output_with_cache is not None, "Model output should not be None with cache enabled."
|
||||
@@ -181,18 +190,32 @@ class CacheTesterMixin:
|
||||
"Cached output should be different from non-cached output due to cache approximation."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_cache_context_manager(self):
|
||||
"""Test the cache_context context manager."""
|
||||
"""Test the cache_context context manager properly isolates cache state."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# Test cache_context works without error
|
||||
with model.cache_context("test_context"):
|
||||
pass
|
||||
# Run inference in first context
|
||||
with model.cache_context("context_1"):
|
||||
output_ctx1 = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Run same inference in second context (cache should be reset)
|
||||
with model.cache_context("context_2"):
|
||||
output_ctx2 = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Both contexts should produce the same output (first pass in each)
|
||||
assert_tensors_close(
|
||||
output_ctx1,
|
||||
output_ctx2,
|
||||
atol=1e-5,
|
||||
msg="First pass in different cache contexts should produce the same output.",
|
||||
)
|
||||
|
||||
model.disable_cache()
|
||||
|
||||
@@ -336,10 +359,12 @@ class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin):
|
||||
# First pass populates the cache
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Create modified inputs for second pass (small perturbation keeps residuals similar)
|
||||
# Create modified inputs for second pass
|
||||
inputs_dict_step2 = inputs_dict.copy()
|
||||
if "hidden_states" in inputs_dict_step2:
|
||||
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.01
|
||||
if self.cache_input_key in inputs_dict_step2:
|
||||
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
|
||||
inputs_dict_step2[self.cache_input_key]
|
||||
)
|
||||
|
||||
# Second pass - FBC should skip remaining blocks and use cached residuals
|
||||
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
@@ -415,13 +440,11 @@ class FasterCacheConfigMixin:
|
||||
"tensor_format": "BCHW",
|
||||
}
|
||||
|
||||
# Store timestep for callback - use a list so it can be mutated during test
|
||||
# Starts outside skip range so first pass computes; changed to inside range for subsequent passes
|
||||
_current_timestep = [1000]
|
||||
|
||||
def _get_cache_config(self):
|
||||
def _get_cache_config(self, current_timestep_callback=None):
|
||||
config_kwargs = self.FASTER_CACHE_CONFIG.copy()
|
||||
config_kwargs["current_timestep_callback"] = lambda: self._current_timestep[0]
|
||||
if current_timestep_callback is None:
|
||||
current_timestep_callback = lambda: 1000 # noqa: E731
|
||||
config_kwargs["current_timestep_callback"] = current_timestep_callback
|
||||
return FasterCacheConfig(**config_kwargs)
|
||||
|
||||
def _get_hook_names(self):
|
||||
@@ -456,23 +479,26 @@ class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
current_timestep = [1000]
|
||||
config = self._get_cache_config(current_timestep_callback=lambda: current_timestep[0])
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# First pass with timestep outside skip range - computes and populates cache
|
||||
self._current_timestep[0] = 1000
|
||||
current_timestep[0] = 1000
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Move timestep inside skip range so subsequent passes use cache
|
||||
self._current_timestep[0] = 500
|
||||
current_timestep[0] = 500
|
||||
|
||||
# Create modified inputs for second pass
|
||||
inputs_dict_step2 = inputs_dict.copy()
|
||||
if "hidden_states" in inputs_dict_step2:
|
||||
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.1
|
||||
if self.cache_input_key in inputs_dict_step2:
|
||||
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
|
||||
inputs_dict_step2[self.cache_input_key]
|
||||
)
|
||||
|
||||
# Second pass uses cached attention with different hidden_states
|
||||
# Second pass uses cached attention with different inputs
|
||||
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
assert output_with_cache is not None, "Model output should not be None with cache enabled."
|
||||
@@ -498,7 +524,6 @@ class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
|
||||
config = self._get_cache_config()
|
||||
model.enable_cache(config)
|
||||
|
||||
self._current_timestep[0] = 1000
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model._reset_stateful_cache()
|
||||
|
||||
@@ -35,11 +35,13 @@ class TorchCompileTesterMixin:
|
||||
"""
|
||||
Mixin class for testing torch.compile functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
- different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic shape testing
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
Optional properties:
|
||||
- different_shapes_for_compilation: List of (height, width) tuples for dynamic shape testing (default: None)
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
@@ -47,7 +49,10 @@ class TorchCompileTesterMixin:
|
||||
Use `pytest -m "not compile"` to skip these tests
|
||||
"""
|
||||
|
||||
different_shapes_for_compilation = None
|
||||
@property
|
||||
def different_shapes_for_compilation(self) -> list[tuple[int, int]] | None:
|
||||
"""Optional list of (height, width) tuples for dynamic shape testing."""
|
||||
return None
|
||||
|
||||
def setup_method(self):
|
||||
torch.compiler.reset()
|
||||
|
||||
@@ -13,11 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from ...testing_utils import is_ip_adapter, torch_device
|
||||
from ...testing_utils import backend_empty_cache, is_ip_adapter, torch_device
|
||||
|
||||
|
||||
def check_if_ip_adapter_correctly_set(model, processor_cls) -> bool:
|
||||
@@ -41,10 +42,17 @@ class IPAdapterTesterMixin:
|
||||
"""
|
||||
Mixin class for testing IP Adapter functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
Required properties (must be implemented by subclasses):
|
||||
- ip_adapter_processor_cls: The IP Adapter processor class to use
|
||||
|
||||
Required methods (must be implemented by subclasses):
|
||||
- create_ip_adapter_state_dict(): Creates IP Adapter state dict for testing
|
||||
- modify_inputs_for_ip_adapter(): Modifies inputs to include IP Adapter data
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
@@ -52,7 +60,18 @@ class IPAdapterTesterMixin:
|
||||
Use `pytest -m "not ip_adapter"` to skip these tests
|
||||
"""
|
||||
|
||||
ip_adapter_processor_cls = None
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@property
|
||||
def ip_adapter_processor_cls(self):
|
||||
"""IP Adapter processor class to use for testing. Must be implemented by subclasses."""
|
||||
raise NotImplementedError("Subclasses must implement the `ip_adapter_processor_cls` property.")
|
||||
|
||||
def create_ip_adapter_state_dict(self, model):
|
||||
raise NotImplementedError("child class must implement method to create IPAdapter State Dict")
|
||||
|
||||
@@ -67,10 +67,10 @@ class LoraTesterMixin:
|
||||
"""
|
||||
Mixin class for testing LoRA/PEFT functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
@@ -254,11 +254,13 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
See https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252
|
||||
for the analogous PEFT test.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
- different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic compilation tests
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
Optional properties:
|
||||
- different_shapes_for_compilation: List of (height, width) tuples for dynamic compilation tests (default: None)
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
@@ -266,7 +268,10 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
Use `pytest -m "not lora"` or `pytest -m "not torch_compile"` to skip these tests
|
||||
"""
|
||||
|
||||
different_shapes_for_compilation = None
|
||||
@property
|
||||
def different_shapes_for_compilation(self) -> list[tuple[int, int]] | None:
|
||||
"""Optional list of (height, width) tuples for dynamic compilation tests."""
|
||||
return None
|
||||
|
||||
def setup_method(self):
|
||||
if not issubclass(self.model_class, PeftAdapterMixin):
|
||||
|
||||
@@ -81,11 +81,13 @@ class CPUOffloadTesterMixin:
|
||||
"""
|
||||
Mixin class for testing CPU offloading functionality.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
- model_split_percents: List of percentages for splitting model across devices
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
Optional properties:
|
||||
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
@@ -93,7 +95,10 @@ class CPUOffloadTesterMixin:
|
||||
Use `pytest -m "not cpu_offload"` to skip these tests
|
||||
"""
|
||||
|
||||
model_split_percents = [0.5, 0.7]
|
||||
@property
|
||||
def model_split_percents(self) -> list[float]:
|
||||
"""List of percentages for splitting model across devices during offloading tests."""
|
||||
return [0.5, 0.7]
|
||||
|
||||
@require_offload_support
|
||||
@torch.no_grad()
|
||||
@@ -199,10 +204,10 @@ class GroupOffloadTesterMixin:
|
||||
"""
|
||||
Mixin class for testing group offloading functionality.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
@@ -385,10 +390,10 @@ class LayerwiseCastingTesterMixin:
|
||||
"""
|
||||
Mixin class for testing layerwise dtype casting for memory optimization.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
"""
|
||||
@@ -456,7 +461,7 @@ class LayerwiseCastingTesterMixin:
|
||||
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
model.train()
|
||||
|
||||
inputs_dict = self.get_inputs_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||
with torch.amp.autocast(device_type=torch.device(torch_device).type):
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
@@ -486,16 +491,16 @@ class MemoryTesterMixin(CPUOffloadTesterMixin, GroupOffloadTesterMixin, Layerwis
|
||||
- GroupOffloadTesterMixin: Group offloading tests (block-level and leaf-level)
|
||||
- LayerwiseCastingTesterMixin: Layerwise dtype casting tests
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Optional properties:
|
||||
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: memory
|
||||
Use `pytest -m "not memory"` to skip these tests
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -63,19 +63,40 @@ class SingleFileTesterMixin:
|
||||
"""
|
||||
Mixin class for testing single file loading for models.
|
||||
|
||||
Expected class attributes:
|
||||
Required properties (must be implemented by subclasses):
|
||||
- ckpt_path: Path or Hub path to the single file checkpoint
|
||||
|
||||
Optional properties:
|
||||
- torch_dtype: torch dtype to use for testing (default: None)
|
||||
- alternate_ckpt_paths: List of alternate checkpoint paths for variant testing (default: None)
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- ckpt_path: Path or Hub path to the single file checkpoint
|
||||
- subfolder: (Optional) Subfolder within the repo
|
||||
- torch_dtype: (Optional) torch dtype to use for testing
|
||||
- pretrained_model_kwargs: Additional kwargs for from_pretrained (e.g., subfolder)
|
||||
|
||||
Pytest mark: single_file
|
||||
Use `pytest -m "not single_file"` to skip these tests
|
||||
"""
|
||||
|
||||
pretrained_model_name_or_path = None
|
||||
ckpt_path = None
|
||||
# ==================== Required Properties ====================
|
||||
|
||||
@property
|
||||
def ckpt_path(self) -> str:
|
||||
"""Path or Hub path to the single file checkpoint. Must be implemented by subclasses."""
|
||||
raise NotImplementedError("Subclasses must implement the `ckpt_path` property.")
|
||||
|
||||
# ==================== Optional Properties ====================
|
||||
|
||||
@property
|
||||
def torch_dtype(self) -> torch.dtype | None:
|
||||
"""torch dtype to use for single file testing."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def alternate_ckpt_paths(self) -> list[str] | None:
|
||||
"""List of alternate checkpoint paths for variant testing."""
|
||||
return None
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
@@ -86,16 +107,10 @@ class SingleFileTesterMixin:
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_model_config(self):
|
||||
pretrained_kwargs = {}
|
||||
single_file_kwargs = {}
|
||||
pretrained_kwargs = {"device": torch_device, **self.pretrained_model_kwargs}
|
||||
single_file_kwargs = {"device": torch_device}
|
||||
|
||||
pretrained_kwargs["device"] = torch_device
|
||||
single_file_kwargs["device"] = torch_device
|
||||
|
||||
if hasattr(self, "subfolder") and self.subfolder:
|
||||
pretrained_kwargs["subfolder"] = self.subfolder
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
if self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
@@ -112,16 +127,10 @@ class SingleFileTesterMixin:
|
||||
)
|
||||
|
||||
def test_single_file_model_parameters(self):
|
||||
pretrained_kwargs = {}
|
||||
single_file_kwargs = {}
|
||||
pretrained_kwargs = {"device": torch_device, **self.pretrained_model_kwargs}
|
||||
single_file_kwargs = {"device": torch_device}
|
||||
|
||||
pretrained_kwargs["device"] = torch_device
|
||||
single_file_kwargs["device"] = torch_device
|
||||
|
||||
if hasattr(self, "subfolder") and self.subfolder:
|
||||
pretrained_kwargs["subfolder"] = self.subfolder
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
if self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
@@ -153,7 +162,7 @@ class SingleFileTesterMixin:
|
||||
def test_single_file_loading_local_files_only(self, tmp_path):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
if self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||
@@ -168,7 +177,7 @@ class SingleFileTesterMixin:
|
||||
def test_single_file_loading_with_diffusers_config(self):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
if self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
# Load with config parameter
|
||||
@@ -177,10 +186,8 @@ class SingleFileTesterMixin:
|
||||
)
|
||||
|
||||
# Load pretrained for comparison
|
||||
pretrained_kwargs = {}
|
||||
if hasattr(self, "subfolder") and self.subfolder:
|
||||
pretrained_kwargs["subfolder"] = self.subfolder
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
pretrained_kwargs = {**self.pretrained_model_kwargs}
|
||||
if self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
@@ -197,7 +204,7 @@ class SingleFileTesterMixin:
|
||||
def test_single_file_loading_with_diffusers_config_local_files_only(self, tmp_path):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
if self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||
@@ -225,14 +232,14 @@ class SingleFileTesterMixin:
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_checkpoint_variant_loading(self):
|
||||
if not hasattr(self, "alternate_ckpt_paths") or not self.alternate_ckpt_paths:
|
||||
if not self.alternate_ckpt_paths:
|
||||
return
|
||||
|
||||
for ckpt_path in self.alternate_ckpt_paths:
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
single_file_kwargs = {}
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
if self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
|
||||
|
||||
@@ -14,13 +14,20 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.training_utils import EMAModel
|
||||
|
||||
from ...testing_utils import is_training, require_torch_accelerator_with_training, torch_all_close, torch_device
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_training,
|
||||
require_torch_accelerator_with_training,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@is_training
|
||||
@@ -29,20 +36,26 @@ class TrainingTesterMixin:
|
||||
"""
|
||||
Mixin class for testing training functionality on models.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
- output_shape: Tuple defining the expected output shape
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Expected properties to be implemented by subclasses:
|
||||
- output_shape: Tuple defining the expected output shape
|
||||
|
||||
Pytest mark: training
|
||||
Use `pytest -m "not training"` to skip these tests
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_training(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
@@ -198,31 +198,25 @@ class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
|
||||
class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
|
||||
"""Context Parallel inference tests for Flux Transformer"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
|
||||
"""IP Adapter tests for Flux Transformer."""
|
||||
|
||||
ip_adapter_processor_cls = FluxIPAdapterAttnProcessor
|
||||
@property
|
||||
def ip_adapter_processor_cls(self):
|
||||
return FluxIPAdapterAttnProcessor
|
||||
|
||||
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
|
||||
torch.manual_seed(0)
|
||||
@@ -241,13 +235,13 @@ class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterM
|
||||
class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for Flux Transformer."""
|
||||
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for LoRA hotswap tests."""
|
||||
@@ -268,7 +262,9 @@ class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappin
|
||||
|
||||
|
||||
class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for compilation tests."""
|
||||
@@ -289,11 +285,17 @@ class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTester
|
||||
|
||||
|
||||
class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
|
||||
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
|
||||
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
|
||||
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
|
||||
subfolder = "transformer"
|
||||
pass
|
||||
@property
|
||||
def ckpt_path(self):
|
||||
return "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
|
||||
|
||||
@property
|
||||
def alternate_ckpt_paths(self):
|
||||
return ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "black-forest-labs/FLUX.1-dev"
|
||||
|
||||
|
||||
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
@@ -433,14 +435,10 @@ class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCo
|
||||
class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin):
|
||||
"""PyramidAttentionBroadcast cache tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
|
||||
"""FirstBlockCache tests for Flux Transformer."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestFluxTransformerFasterCache(FluxTransformerTesterConfig, FasterCacheTesterMixin):
|
||||
"""FasterCache tests for Flux Transformer."""
|
||||
|
||||
Reference in New Issue
Block a user