From 5c2d30623ebbc94b30248ec2ff7d4f27ca289560 Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 13 Jan 2026 10:38:16 +0530 Subject: [PATCH] update --- tests/models/testing_utils/attention.py | 16 +++- tests/models/testing_utils/cache.py | 77 ++++++++++++------- tests/models/testing_utils/compile.py | 13 +++- tests/models/testing_utils/ip_adapter.py | 27 ++++++- tests/models/testing_utils/lora.py | 17 ++-- tests/models/testing_utils/memory.py | 31 ++++---- tests/models/testing_utils/single_file.py | 73 ++++++++++-------- tests/models/testing_utils/training.py | 25 ++++-- .../test_models_transformer_flux.py | 42 +++++----- 9 files changed, 204 insertions(+), 117 deletions(-) diff --git a/tests/models/testing_utils/attention.py b/tests/models/testing_utils/attention.py index 3f89026dfa..04e5524ab0 100644 --- a/tests/models/testing_utils/attention.py +++ b/tests/models/testing_utils/attention.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc + import pytest import torch @@ -23,6 +25,7 @@ from diffusers.models.attention_processor import ( from ...testing_utils import ( assert_tensors_close, + backend_empty_cache, is_attention, torch_device, ) @@ -38,11 +41,10 @@ class AttentionTesterMixin: - QKV projection fusion/unfusion - Attention backends (XFormers, NPU, etc.) - Expected class attributes to be set by subclasses: + Expected from config mixin: - model_class: The model class to test - - uses_custom_attn_processor: Whether model uses custom attention processors (default: False) - Expected methods to be implemented by subclasses: + Expected methods from config mixin: - get_init_dict(): Returns dict of arguments to initialize the model - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass @@ -50,6 +52,14 @@ class AttentionTesterMixin: Use `pytest -m "not attention"` to skip these tests """ + def setup_method(self): + gc.collect() + backend_empty_cache(torch_device) + + def teardown_method(self): + gc.collect() + backend_empty_cache(torch_device) + @torch.no_grad() def test_fuse_unfuse_qkv_projections(self): init_dict = self.get_init_dict() diff --git a/tests/models/testing_utils/cache.py b/tests/models/testing_utils/cache.py index 1f828ca9f8..f5025735bb 100644 --- a/tests/models/testing_utils/cache.py +++ b/tests/models/testing_utils/cache.py @@ -24,7 +24,7 @@ from diffusers.hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK from diffusers.hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK from diffusers.models.cache_utils import CacheMixin -from ...testing_utils import backend_empty_cache, is_cache, torch_device +from ...testing_utils import assert_tensors_close, backend_empty_cache, is_cache, torch_device def require_cache_mixin(func): @@ -53,8 +53,15 @@ class CacheTesterMixin: Expected methods in test classes: - get_init_dict(): Returns dict of arguments to initialize the model - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass + + Optional overrides: + - cache_input_key: Property returning the input tensor key to vary between passes (default: "hidden_states") """ + @property + def cache_input_key(self): + return "hidden_states" + def setup_method(self): gc.collect() backend_empty_cache(torch_device) @@ -161,12 +168,14 @@ class CacheTesterMixin: # First pass populates the cache _ = model(**inputs_dict, return_dict=False)[0] - # Create modified inputs for second pass (vary hidden_states to simulate denoising) + # Create modified inputs for second pass (vary input tensor to simulate denoising) inputs_dict_step2 = inputs_dict.copy() - if "hidden_states" in inputs_dict_step2: - inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.1 + if self.cache_input_key in inputs_dict_step2: + inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like( + inputs_dict_step2[self.cache_input_key] + ) - # Second pass uses cached attention with different hidden_states (produces approximated output) + # Second pass uses cached attention with different inputs (produces approximated output) output_with_cache = model(**inputs_dict_step2, return_dict=False)[0] assert output_with_cache is not None, "Model output should not be None with cache enabled." @@ -181,18 +190,32 @@ class CacheTesterMixin: "Cached output should be different from non-cached output due to cache approximation." ) + @torch.no_grad() def _test_cache_context_manager(self): - """Test the cache_context context manager.""" + """Test the cache_context context manager properly isolates cache state.""" init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() model = self.model_class(**init_dict).to(torch_device) + model.eval() config = self._get_cache_config() - model.enable_cache(config) - # Test cache_context works without error - with model.cache_context("test_context"): - pass + # Run inference in first context + with model.cache_context("context_1"): + output_ctx1 = model(**inputs_dict, return_dict=False)[0] + + # Run same inference in second context (cache should be reset) + with model.cache_context("context_2"): + output_ctx2 = model(**inputs_dict, return_dict=False)[0] + + # Both contexts should produce the same output (first pass in each) + assert_tensors_close( + output_ctx1, + output_ctx2, + atol=1e-5, + msg="First pass in different cache contexts should produce the same output.", + ) model.disable_cache() @@ -336,10 +359,12 @@ class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin): # First pass populates the cache _ = model(**inputs_dict, return_dict=False)[0] - # Create modified inputs for second pass (small perturbation keeps residuals similar) + # Create modified inputs for second pass inputs_dict_step2 = inputs_dict.copy() - if "hidden_states" in inputs_dict_step2: - inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.01 + if self.cache_input_key in inputs_dict_step2: + inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like( + inputs_dict_step2[self.cache_input_key] + ) # Second pass - FBC should skip remaining blocks and use cached residuals output_with_cache = model(**inputs_dict_step2, return_dict=False)[0] @@ -415,13 +440,11 @@ class FasterCacheConfigMixin: "tensor_format": "BCHW", } - # Store timestep for callback - use a list so it can be mutated during test - # Starts outside skip range so first pass computes; changed to inside range for subsequent passes - _current_timestep = [1000] - - def _get_cache_config(self): + def _get_cache_config(self, current_timestep_callback=None): config_kwargs = self.FASTER_CACHE_CONFIG.copy() - config_kwargs["current_timestep_callback"] = lambda: self._current_timestep[0] + if current_timestep_callback is None: + current_timestep_callback = lambda: 1000 # noqa: E731 + config_kwargs["current_timestep_callback"] = current_timestep_callback return FasterCacheConfig(**config_kwargs) def _get_hook_names(self): @@ -456,23 +479,26 @@ class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin): model = self.model_class(**init_dict).to(torch_device) model.eval() - config = self._get_cache_config() + current_timestep = [1000] + config = self._get_cache_config(current_timestep_callback=lambda: current_timestep[0]) model.enable_cache(config) # First pass with timestep outside skip range - computes and populates cache - self._current_timestep[0] = 1000 + current_timestep[0] = 1000 _ = model(**inputs_dict, return_dict=False)[0] # Move timestep inside skip range so subsequent passes use cache - self._current_timestep[0] = 500 + current_timestep[0] = 500 # Create modified inputs for second pass inputs_dict_step2 = inputs_dict.copy() - if "hidden_states" in inputs_dict_step2: - inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.1 + if self.cache_input_key in inputs_dict_step2: + inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like( + inputs_dict_step2[self.cache_input_key] + ) - # Second pass uses cached attention with different hidden_states + # Second pass uses cached attention with different inputs output_with_cache = model(**inputs_dict_step2, return_dict=False)[0] assert output_with_cache is not None, "Model output should not be None with cache enabled." @@ -498,7 +524,6 @@ class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin): config = self._get_cache_config() model.enable_cache(config) - self._current_timestep[0] = 1000 _ = model(**inputs_dict, return_dict=False)[0] model._reset_stateful_cache() diff --git a/tests/models/testing_utils/compile.py b/tests/models/testing_utils/compile.py index d7969f6330..950d4d5d1f 100644 --- a/tests/models/testing_utils/compile.py +++ b/tests/models/testing_utils/compile.py @@ -35,11 +35,13 @@ class TorchCompileTesterMixin: """ Mixin class for testing torch.compile functionality on models. - Expected class attributes to be set by subclasses: + Expected from config mixin: - model_class: The model class to test - - different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic shape testing - Expected methods to be implemented by subclasses: + Optional properties: + - different_shapes_for_compilation: List of (height, width) tuples for dynamic shape testing (default: None) + + Expected methods from config mixin: - get_init_dict(): Returns dict of arguments to initialize the model - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass @@ -47,7 +49,10 @@ class TorchCompileTesterMixin: Use `pytest -m "not compile"` to skip these tests """ - different_shapes_for_compilation = None + @property + def different_shapes_for_compilation(self) -> list[tuple[int, int]] | None: + """Optional list of (height, width) tuples for dynamic shape testing.""" + return None def setup_method(self): torch.compiler.reset() diff --git a/tests/models/testing_utils/ip_adapter.py b/tests/models/testing_utils/ip_adapter.py index 72aaf491f2..632019c874 100644 --- a/tests/models/testing_utils/ip_adapter.py +++ b/tests/models/testing_utils/ip_adapter.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import pytest import torch -from ...testing_utils import is_ip_adapter, torch_device +from ...testing_utils import backend_empty_cache, is_ip_adapter, torch_device def check_if_ip_adapter_correctly_set(model, processor_cls) -> bool: @@ -41,10 +42,17 @@ class IPAdapterTesterMixin: """ Mixin class for testing IP Adapter functionality on models. - Expected class attributes to be set by subclasses: + Expected from config mixin: - model_class: The model class to test - Expected methods to be implemented by subclasses: + Required properties (must be implemented by subclasses): + - ip_adapter_processor_cls: The IP Adapter processor class to use + + Required methods (must be implemented by subclasses): + - create_ip_adapter_state_dict(): Creates IP Adapter state dict for testing + - modify_inputs_for_ip_adapter(): Modifies inputs to include IP Adapter data + + Expected methods from config mixin: - get_init_dict(): Returns dict of arguments to initialize the model - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass @@ -52,7 +60,18 @@ class IPAdapterTesterMixin: Use `pytest -m "not ip_adapter"` to skip these tests """ - ip_adapter_processor_cls = None + def setup_method(self): + gc.collect() + backend_empty_cache(torch_device) + + def teardown_method(self): + gc.collect() + backend_empty_cache(torch_device) + + @property + def ip_adapter_processor_cls(self): + """IP Adapter processor class to use for testing. Must be implemented by subclasses.""" + raise NotImplementedError("Subclasses must implement the `ip_adapter_processor_cls` property.") def create_ip_adapter_state_dict(self, model): raise NotImplementedError("child class must implement method to create IPAdapter State Dict") diff --git a/tests/models/testing_utils/lora.py b/tests/models/testing_utils/lora.py index 5d7b98fbd3..7654495e01 100644 --- a/tests/models/testing_utils/lora.py +++ b/tests/models/testing_utils/lora.py @@ -67,10 +67,10 @@ class LoraTesterMixin: """ Mixin class for testing LoRA/PEFT functionality on models. - Expected class attributes to be set by subclasses: + Expected from config mixin: - model_class: The model class to test - Expected methods to be implemented by subclasses: + Expected methods from config mixin: - get_init_dict(): Returns dict of arguments to initialize the model - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass @@ -254,11 +254,13 @@ class LoraHotSwappingForModelTesterMixin: See https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252 for the analogous PEFT test. - Expected class attributes to be set by subclasses: + Expected from config mixin: - model_class: The model class to test - - different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic compilation tests - Expected methods to be implemented by subclasses: + Optional properties: + - different_shapes_for_compilation: List of (height, width) tuples for dynamic compilation tests (default: None) + + Expected methods from config mixin: - get_init_dict(): Returns dict of arguments to initialize the model - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass @@ -266,7 +268,10 @@ class LoraHotSwappingForModelTesterMixin: Use `pytest -m "not lora"` or `pytest -m "not torch_compile"` to skip these tests """ - different_shapes_for_compilation = None + @property + def different_shapes_for_compilation(self) -> list[tuple[int, int]] | None: + """Optional list of (height, width) tuples for dynamic compilation tests.""" + return None def setup_method(self): if not issubclass(self.model_class, PeftAdapterMixin): diff --git a/tests/models/testing_utils/memory.py b/tests/models/testing_utils/memory.py index 72b617ab76..2ff5acdc41 100644 --- a/tests/models/testing_utils/memory.py +++ b/tests/models/testing_utils/memory.py @@ -81,11 +81,13 @@ class CPUOffloadTesterMixin: """ Mixin class for testing CPU offloading functionality. - Expected class attributes to be set by subclasses: + Expected from config mixin: - model_class: The model class to test - - model_split_percents: List of percentages for splitting model across devices - Expected methods to be implemented by subclasses: + Optional properties: + - model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7]) + + Expected methods from config mixin: - get_init_dict(): Returns dict of arguments to initialize the model - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass @@ -93,7 +95,10 @@ class CPUOffloadTesterMixin: Use `pytest -m "not cpu_offload"` to skip these tests """ - model_split_percents = [0.5, 0.7] + @property + def model_split_percents(self) -> list[float]: + """List of percentages for splitting model across devices during offloading tests.""" + return [0.5, 0.7] @require_offload_support @torch.no_grad() @@ -199,10 +204,10 @@ class GroupOffloadTesterMixin: """ Mixin class for testing group offloading functionality. - Expected class attributes to be set by subclasses: + Expected from config mixin: - model_class: The model class to test - Expected methods to be implemented by subclasses: + Expected methods from config mixin: - get_init_dict(): Returns dict of arguments to initialize the model - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass @@ -385,10 +390,10 @@ class LayerwiseCastingTesterMixin: """ Mixin class for testing layerwise dtype casting for memory optimization. - Expected class attributes to be set by subclasses: + Expected from config mixin: - model_class: The model class to test - Expected methods to be implemented by subclasses: + Expected methods from config mixin: - get_init_dict(): Returns dict of arguments to initialize the model - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass """ @@ -456,7 +461,7 @@ class LayerwiseCastingTesterMixin: model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) model.train() - inputs_dict = self.get_inputs_dict() + inputs_dict = self.get_dummy_inputs() inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype) with torch.amp.autocast(device_type=torch.device(torch_device).type): output = model(**inputs_dict, return_dict=False)[0] @@ -486,16 +491,16 @@ class MemoryTesterMixin(CPUOffloadTesterMixin, GroupOffloadTesterMixin, Layerwis - GroupOffloadTesterMixin: Group offloading tests (block-level and leaf-level) - LayerwiseCastingTesterMixin: Layerwise dtype casting tests - Expected class attributes to be set by subclasses: + Expected from config mixin: - model_class: The model class to test + + Optional properties: - model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7]) - Expected methods to be implemented by subclasses: + Expected methods from config mixin: - get_init_dict(): Returns dict of arguments to initialize the model - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass Pytest mark: memory Use `pytest -m "not memory"` to skip these tests """ - - pass diff --git a/tests/models/testing_utils/single_file.py b/tests/models/testing_utils/single_file.py index 52890fc3c6..f5ae495a89 100644 --- a/tests/models/testing_utils/single_file.py +++ b/tests/models/testing_utils/single_file.py @@ -63,19 +63,40 @@ class SingleFileTesterMixin: """ Mixin class for testing single file loading for models. - Expected class attributes: + Required properties (must be implemented by subclasses): + - ckpt_path: Path or Hub path to the single file checkpoint + + Optional properties: + - torch_dtype: torch dtype to use for testing (default: None) + - alternate_ckpt_paths: List of alternate checkpoint paths for variant testing (default: None) + + Expected from config mixin: - model_class: The model class to test - pretrained_model_name_or_path: Hub repository ID for the pretrained model - - ckpt_path: Path or Hub path to the single file checkpoint - - subfolder: (Optional) Subfolder within the repo - - torch_dtype: (Optional) torch dtype to use for testing + - pretrained_model_kwargs: Additional kwargs for from_pretrained (e.g., subfolder) Pytest mark: single_file Use `pytest -m "not single_file"` to skip these tests """ - pretrained_model_name_or_path = None - ckpt_path = None + # ==================== Required Properties ==================== + + @property + def ckpt_path(self) -> str: + """Path or Hub path to the single file checkpoint. Must be implemented by subclasses.""" + raise NotImplementedError("Subclasses must implement the `ckpt_path` property.") + + # ==================== Optional Properties ==================== + + @property + def torch_dtype(self) -> torch.dtype | None: + """torch dtype to use for single file testing.""" + return None + + @property + def alternate_ckpt_paths(self) -> list[str] | None: + """List of alternate checkpoint paths for variant testing.""" + return None def setup_method(self): gc.collect() @@ -86,16 +107,10 @@ class SingleFileTesterMixin: backend_empty_cache(torch_device) def test_single_file_model_config(self): - pretrained_kwargs = {} - single_file_kwargs = {} + pretrained_kwargs = {"device": torch_device, **self.pretrained_model_kwargs} + single_file_kwargs = {"device": torch_device} - pretrained_kwargs["device"] = torch_device - single_file_kwargs["device"] = torch_device - - if hasattr(self, "subfolder") and self.subfolder: - pretrained_kwargs["subfolder"] = self.subfolder - - if hasattr(self, "torch_dtype") and self.torch_dtype: + if self.torch_dtype: pretrained_kwargs["torch_dtype"] = self.torch_dtype single_file_kwargs["torch_dtype"] = self.torch_dtype @@ -112,16 +127,10 @@ class SingleFileTesterMixin: ) def test_single_file_model_parameters(self): - pretrained_kwargs = {} - single_file_kwargs = {} + pretrained_kwargs = {"device": torch_device, **self.pretrained_model_kwargs} + single_file_kwargs = {"device": torch_device} - pretrained_kwargs["device"] = torch_device - single_file_kwargs["device"] = torch_device - - if hasattr(self, "subfolder") and self.subfolder: - pretrained_kwargs["subfolder"] = self.subfolder - - if hasattr(self, "torch_dtype") and self.torch_dtype: + if self.torch_dtype: pretrained_kwargs["torch_dtype"] = self.torch_dtype single_file_kwargs["torch_dtype"] = self.torch_dtype @@ -153,7 +162,7 @@ class SingleFileTesterMixin: def test_single_file_loading_local_files_only(self, tmp_path): single_file_kwargs = {} - if hasattr(self, "torch_dtype") and self.torch_dtype: + if self.torch_dtype: single_file_kwargs["torch_dtype"] = self.torch_dtype pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) @@ -168,7 +177,7 @@ class SingleFileTesterMixin: def test_single_file_loading_with_diffusers_config(self): single_file_kwargs = {} - if hasattr(self, "torch_dtype") and self.torch_dtype: + if self.torch_dtype: single_file_kwargs["torch_dtype"] = self.torch_dtype # Load with config parameter @@ -177,10 +186,8 @@ class SingleFileTesterMixin: ) # Load pretrained for comparison - pretrained_kwargs = {} - if hasattr(self, "subfolder") and self.subfolder: - pretrained_kwargs["subfolder"] = self.subfolder - if hasattr(self, "torch_dtype") and self.torch_dtype: + pretrained_kwargs = {**self.pretrained_model_kwargs} + if self.torch_dtype: pretrained_kwargs["torch_dtype"] = self.torch_dtype model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs) @@ -197,7 +204,7 @@ class SingleFileTesterMixin: def test_single_file_loading_with_diffusers_config_local_files_only(self, tmp_path): single_file_kwargs = {} - if hasattr(self, "torch_dtype") and self.torch_dtype: + if self.torch_dtype: single_file_kwargs["torch_dtype"] = self.torch_dtype pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) @@ -225,14 +232,14 @@ class SingleFileTesterMixin: backend_empty_cache(torch_device) def test_checkpoint_variant_loading(self): - if not hasattr(self, "alternate_ckpt_paths") or not self.alternate_ckpt_paths: + if not self.alternate_ckpt_paths: return for ckpt_path in self.alternate_ckpt_paths: backend_empty_cache(torch_device) single_file_kwargs = {} - if hasattr(self, "torch_dtype") and self.torch_dtype: + if self.torch_dtype: single_file_kwargs["torch_dtype"] = self.torch_dtype model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs) diff --git a/tests/models/testing_utils/training.py b/tests/models/testing_utils/training.py index f6612dd3be..44cce6af68 100644 --- a/tests/models/testing_utils/training.py +++ b/tests/models/testing_utils/training.py @@ -14,13 +14,20 @@ # limitations under the License. import copy +import gc import pytest import torch from diffusers.training_utils import EMAModel -from ...testing_utils import is_training, require_torch_accelerator_with_training, torch_all_close, torch_device +from ...testing_utils import ( + backend_empty_cache, + is_training, + require_torch_accelerator_with_training, + torch_all_close, + torch_device, +) @is_training @@ -29,20 +36,26 @@ class TrainingTesterMixin: """ Mixin class for testing training functionality on models. - Expected class attributes to be set by subclasses: + Expected from config mixin: - model_class: The model class to test + - output_shape: Tuple defining the expected output shape - Expected methods to be implemented by subclasses: + Expected methods from config mixin: - get_init_dict(): Returns dict of arguments to initialize the model - get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass - Expected properties to be implemented by subclasses: - - output_shape: Tuple defining the expected output shape - Pytest mark: training Use `pytest -m "not training"` to skip these tests """ + def setup_method(self): + gc.collect() + backend_empty_cache(torch_device) + + def teardown_method(self): + gc.collect() + backend_empty_cache(torch_device) + def test_training(self): init_dict = self.get_init_dict() inputs_dict = self.get_dummy_inputs() diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index b264eda22c..11da230549 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -198,31 +198,25 @@ class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin): class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin): """Memory optimization tests for Flux Transformer.""" - pass - class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin): """Training tests for Flux Transformer.""" - pass - class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin): """Attention processor tests for Flux Transformer.""" - pass - class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin): """Context Parallel inference tests for Flux Transformer""" - pass - class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin): """IP Adapter tests for Flux Transformer.""" - ip_adapter_processor_cls = FluxIPAdapterAttnProcessor + @property + def ip_adapter_processor_cls(self): + return FluxIPAdapterAttnProcessor def modify_inputs_for_ip_adapter(self, model, inputs_dict): torch.manual_seed(0) @@ -241,13 +235,13 @@ class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterM class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin): """LoRA adapter tests for Flux Transformer.""" - pass - class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin): """LoRA hot-swapping tests for Flux Transformer.""" - different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: """Override to support dynamic height/width for LoRA hotswap tests.""" @@ -268,7 +262,9 @@ class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappin class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin): - different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] + @property + def different_shapes_for_compilation(self): + return [(4, 4), (4, 8), (8, 8)] def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]: """Override to support dynamic height/width for compilation tests.""" @@ -289,11 +285,17 @@ class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTester class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin): - ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" - alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"] - pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev" - subfolder = "transformer" - pass + @property + def ckpt_path(self): + return "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" + + @property + def alternate_ckpt_paths(self): + return ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"] + + @property + def pretrained_model_name_or_path(self): + return "black-forest-labs/FLUX.1-dev" class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin): @@ -433,14 +435,10 @@ class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCo class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin): """PyramidAttentionBroadcast cache tests for Flux Transformer.""" - pass - class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin): """FirstBlockCache tests for Flux Transformer.""" - pass - class TestFluxTransformerFasterCache(FluxTransformerTesterConfig, FasterCacheTesterMixin): """FasterCache tests for Flux Transformer."""