1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
DN6
2026-01-13 10:38:16 +05:30
parent 6caa0a9bf4
commit 5c2d30623e
9 changed files with 204 additions and 117 deletions

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import pytest
import torch
@@ -23,6 +25,7 @@ from diffusers.models.attention_processor import (
from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
is_attention,
torch_device,
)
@@ -38,11 +41,10 @@ class AttentionTesterMixin:
- QKV projection fusion/unfusion
- Attention backends (XFormers, NPU, etc.)
Expected class attributes to be set by subclasses:
Expected from config mixin:
- model_class: The model class to test
- uses_custom_attn_processor: Whether model uses custom attention processors (default: False)
Expected methods to be implemented by subclasses:
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
@@ -50,6 +52,14 @@ class AttentionTesterMixin:
Use `pytest -m "not attention"` to skip these tests
"""
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@torch.no_grad()
def test_fuse_unfuse_qkv_projections(self):
init_dict = self.get_init_dict()

View File

@@ -24,7 +24,7 @@ from diffusers.hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK
from diffusers.hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
from diffusers.models.cache_utils import CacheMixin
from ...testing_utils import backend_empty_cache, is_cache, torch_device
from ...testing_utils import assert_tensors_close, backend_empty_cache, is_cache, torch_device
def require_cache_mixin(func):
@@ -53,8 +53,15 @@ class CacheTesterMixin:
Expected methods in test classes:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Optional overrides:
- cache_input_key: Property returning the input tensor key to vary between passes (default: "hidden_states")
"""
@property
def cache_input_key(self):
return "hidden_states"
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -161,12 +168,14 @@ class CacheTesterMixin:
# First pass populates the cache
_ = model(**inputs_dict, return_dict=False)[0]
# Create modified inputs for second pass (vary hidden_states to simulate denoising)
# Create modified inputs for second pass (vary input tensor to simulate denoising)
inputs_dict_step2 = inputs_dict.copy()
if "hidden_states" in inputs_dict_step2:
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.1
if self.cache_input_key in inputs_dict_step2:
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
inputs_dict_step2[self.cache_input_key]
)
# Second pass uses cached attention with different hidden_states (produces approximated output)
# Second pass uses cached attention with different inputs (produces approximated output)
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
assert output_with_cache is not None, "Model output should not be None with cache enabled."
@@ -181,18 +190,32 @@ class CacheTesterMixin:
"Cached output should be different from non-cached output due to cache approximation."
)
@torch.no_grad()
def _test_cache_context_manager(self):
"""Test the cache_context context manager."""
"""Test the cache_context context manager properly isolates cache state."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# Test cache_context works without error
with model.cache_context("test_context"):
pass
# Run inference in first context
with model.cache_context("context_1"):
output_ctx1 = model(**inputs_dict, return_dict=False)[0]
# Run same inference in second context (cache should be reset)
with model.cache_context("context_2"):
output_ctx2 = model(**inputs_dict, return_dict=False)[0]
# Both contexts should produce the same output (first pass in each)
assert_tensors_close(
output_ctx1,
output_ctx2,
atol=1e-5,
msg="First pass in different cache contexts should produce the same output.",
)
model.disable_cache()
@@ -336,10 +359,12 @@ class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin):
# First pass populates the cache
_ = model(**inputs_dict, return_dict=False)[0]
# Create modified inputs for second pass (small perturbation keeps residuals similar)
# Create modified inputs for second pass
inputs_dict_step2 = inputs_dict.copy()
if "hidden_states" in inputs_dict_step2:
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.01
if self.cache_input_key in inputs_dict_step2:
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
inputs_dict_step2[self.cache_input_key]
)
# Second pass - FBC should skip remaining blocks and use cached residuals
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
@@ -415,13 +440,11 @@ class FasterCacheConfigMixin:
"tensor_format": "BCHW",
}
# Store timestep for callback - use a list so it can be mutated during test
# Starts outside skip range so first pass computes; changed to inside range for subsequent passes
_current_timestep = [1000]
def _get_cache_config(self):
def _get_cache_config(self, current_timestep_callback=None):
config_kwargs = self.FASTER_CACHE_CONFIG.copy()
config_kwargs["current_timestep_callback"] = lambda: self._current_timestep[0]
if current_timestep_callback is None:
current_timestep_callback = lambda: 1000 # noqa: E731
config_kwargs["current_timestep_callback"] = current_timestep_callback
return FasterCacheConfig(**config_kwargs)
def _get_hook_names(self):
@@ -456,23 +479,26 @@ class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
current_timestep = [1000]
config = self._get_cache_config(current_timestep_callback=lambda: current_timestep[0])
model.enable_cache(config)
# First pass with timestep outside skip range - computes and populates cache
self._current_timestep[0] = 1000
current_timestep[0] = 1000
_ = model(**inputs_dict, return_dict=False)[0]
# Move timestep inside skip range so subsequent passes use cache
self._current_timestep[0] = 500
current_timestep[0] = 500
# Create modified inputs for second pass
inputs_dict_step2 = inputs_dict.copy()
if "hidden_states" in inputs_dict_step2:
inputs_dict_step2["hidden_states"] = inputs_dict_step2["hidden_states"] + 0.1
if self.cache_input_key in inputs_dict_step2:
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
inputs_dict_step2[self.cache_input_key]
)
# Second pass uses cached attention with different hidden_states
# Second pass uses cached attention with different inputs
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
assert output_with_cache is not None, "Model output should not be None with cache enabled."
@@ -498,7 +524,6 @@ class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
config = self._get_cache_config()
model.enable_cache(config)
self._current_timestep[0] = 1000
_ = model(**inputs_dict, return_dict=False)[0]
model._reset_stateful_cache()

View File

@@ -35,11 +35,13 @@ class TorchCompileTesterMixin:
"""
Mixin class for testing torch.compile functionality on models.
Expected class attributes to be set by subclasses:
Expected from config mixin:
- model_class: The model class to test
- different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic shape testing
Expected methods to be implemented by subclasses:
Optional properties:
- different_shapes_for_compilation: List of (height, width) tuples for dynamic shape testing (default: None)
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
@@ -47,7 +49,10 @@ class TorchCompileTesterMixin:
Use `pytest -m "not compile"` to skip these tests
"""
different_shapes_for_compilation = None
@property
def different_shapes_for_compilation(self) -> list[tuple[int, int]] | None:
"""Optional list of (height, width) tuples for dynamic shape testing."""
return None
def setup_method(self):
torch.compiler.reset()

View File

@@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import pytest
import torch
from ...testing_utils import is_ip_adapter, torch_device
from ...testing_utils import backend_empty_cache, is_ip_adapter, torch_device
def check_if_ip_adapter_correctly_set(model, processor_cls) -> bool:
@@ -41,10 +42,17 @@ class IPAdapterTesterMixin:
"""
Mixin class for testing IP Adapter functionality on models.
Expected class attributes to be set by subclasses:
Expected from config mixin:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
Required properties (must be implemented by subclasses):
- ip_adapter_processor_cls: The IP Adapter processor class to use
Required methods (must be implemented by subclasses):
- create_ip_adapter_state_dict(): Creates IP Adapter state dict for testing
- modify_inputs_for_ip_adapter(): Modifies inputs to include IP Adapter data
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
@@ -52,7 +60,18 @@ class IPAdapterTesterMixin:
Use `pytest -m "not ip_adapter"` to skip these tests
"""
ip_adapter_processor_cls = None
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@property
def ip_adapter_processor_cls(self):
"""IP Adapter processor class to use for testing. Must be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement the `ip_adapter_processor_cls` property.")
def create_ip_adapter_state_dict(self, model):
raise NotImplementedError("child class must implement method to create IPAdapter State Dict")

View File

@@ -67,10 +67,10 @@ class LoraTesterMixin:
"""
Mixin class for testing LoRA/PEFT functionality on models.
Expected class attributes to be set by subclasses:
Expected from config mixin:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
@@ -254,11 +254,13 @@ class LoraHotSwappingForModelTesterMixin:
See https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252
for the analogous PEFT test.
Expected class attributes to be set by subclasses:
Expected from config mixin:
- model_class: The model class to test
- different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic compilation tests
Expected methods to be implemented by subclasses:
Optional properties:
- different_shapes_for_compilation: List of (height, width) tuples for dynamic compilation tests (default: None)
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
@@ -266,7 +268,10 @@ class LoraHotSwappingForModelTesterMixin:
Use `pytest -m "not lora"` or `pytest -m "not torch_compile"` to skip these tests
"""
different_shapes_for_compilation = None
@property
def different_shapes_for_compilation(self) -> list[tuple[int, int]] | None:
"""Optional list of (height, width) tuples for dynamic compilation tests."""
return None
def setup_method(self):
if not issubclass(self.model_class, PeftAdapterMixin):

View File

@@ -81,11 +81,13 @@ class CPUOffloadTesterMixin:
"""
Mixin class for testing CPU offloading functionality.
Expected class attributes to be set by subclasses:
Expected from config mixin:
- model_class: The model class to test
- model_split_percents: List of percentages for splitting model across devices
Expected methods to be implemented by subclasses:
Optional properties:
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
@@ -93,7 +95,10 @@ class CPUOffloadTesterMixin:
Use `pytest -m "not cpu_offload"` to skip these tests
"""
model_split_percents = [0.5, 0.7]
@property
def model_split_percents(self) -> list[float]:
"""List of percentages for splitting model across devices during offloading tests."""
return [0.5, 0.7]
@require_offload_support
@torch.no_grad()
@@ -199,10 +204,10 @@ class GroupOffloadTesterMixin:
"""
Mixin class for testing group offloading functionality.
Expected class attributes to be set by subclasses:
Expected from config mixin:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
@@ -385,10 +390,10 @@ class LayerwiseCastingTesterMixin:
"""
Mixin class for testing layerwise dtype casting for memory optimization.
Expected class attributes to be set by subclasses:
Expected from config mixin:
- model_class: The model class to test
Expected methods to be implemented by subclasses:
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
"""
@@ -456,7 +461,7 @@ class LayerwiseCastingTesterMixin:
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
model.train()
inputs_dict = self.get_inputs_dict()
inputs_dict = self.get_dummy_inputs()
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
with torch.amp.autocast(device_type=torch.device(torch_device).type):
output = model(**inputs_dict, return_dict=False)[0]
@@ -486,16 +491,16 @@ class MemoryTesterMixin(CPUOffloadTesterMixin, GroupOffloadTesterMixin, Layerwis
- GroupOffloadTesterMixin: Group offloading tests (block-level and leaf-level)
- LayerwiseCastingTesterMixin: Layerwise dtype casting tests
Expected class attributes to be set by subclasses:
Expected from config mixin:
- model_class: The model class to test
Optional properties:
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
Expected methods to be implemented by subclasses:
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: memory
Use `pytest -m "not memory"` to skip these tests
"""
pass

View File

@@ -63,19 +63,40 @@ class SingleFileTesterMixin:
"""
Mixin class for testing single file loading for models.
Expected class attributes:
Required properties (must be implemented by subclasses):
- ckpt_path: Path or Hub path to the single file checkpoint
Optional properties:
- torch_dtype: torch dtype to use for testing (default: None)
- alternate_ckpt_paths: List of alternate checkpoint paths for variant testing (default: None)
Expected from config mixin:
- model_class: The model class to test
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
- ckpt_path: Path or Hub path to the single file checkpoint
- subfolder: (Optional) Subfolder within the repo
- torch_dtype: (Optional) torch dtype to use for testing
- pretrained_model_kwargs: Additional kwargs for from_pretrained (e.g., subfolder)
Pytest mark: single_file
Use `pytest -m "not single_file"` to skip these tests
"""
pretrained_model_name_or_path = None
ckpt_path = None
# ==================== Required Properties ====================
@property
def ckpt_path(self) -> str:
"""Path or Hub path to the single file checkpoint. Must be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement the `ckpt_path` property.")
# ==================== Optional Properties ====================
@property
def torch_dtype(self) -> torch.dtype | None:
"""torch dtype to use for single file testing."""
return None
@property
def alternate_ckpt_paths(self) -> list[str] | None:
"""List of alternate checkpoint paths for variant testing."""
return None
def setup_method(self):
gc.collect()
@@ -86,16 +107,10 @@ class SingleFileTesterMixin:
backend_empty_cache(torch_device)
def test_single_file_model_config(self):
pretrained_kwargs = {}
single_file_kwargs = {}
pretrained_kwargs = {"device": torch_device, **self.pretrained_model_kwargs}
single_file_kwargs = {"device": torch_device}
pretrained_kwargs["device"] = torch_device
single_file_kwargs["device"] = torch_device
if hasattr(self, "subfolder") and self.subfolder:
pretrained_kwargs["subfolder"] = self.subfolder
if hasattr(self, "torch_dtype") and self.torch_dtype:
if self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
single_file_kwargs["torch_dtype"] = self.torch_dtype
@@ -112,16 +127,10 @@ class SingleFileTesterMixin:
)
def test_single_file_model_parameters(self):
pretrained_kwargs = {}
single_file_kwargs = {}
pretrained_kwargs = {"device": torch_device, **self.pretrained_model_kwargs}
single_file_kwargs = {"device": torch_device}
pretrained_kwargs["device"] = torch_device
single_file_kwargs["device"] = torch_device
if hasattr(self, "subfolder") and self.subfolder:
pretrained_kwargs["subfolder"] = self.subfolder
if hasattr(self, "torch_dtype") and self.torch_dtype:
if self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
single_file_kwargs["torch_dtype"] = self.torch_dtype
@@ -153,7 +162,7 @@ class SingleFileTesterMixin:
def test_single_file_loading_local_files_only(self, tmp_path):
single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
if self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
@@ -168,7 +177,7 @@ class SingleFileTesterMixin:
def test_single_file_loading_with_diffusers_config(self):
single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
if self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
# Load with config parameter
@@ -177,10 +186,8 @@ class SingleFileTesterMixin:
)
# Load pretrained for comparison
pretrained_kwargs = {}
if hasattr(self, "subfolder") and self.subfolder:
pretrained_kwargs["subfolder"] = self.subfolder
if hasattr(self, "torch_dtype") and self.torch_dtype:
pretrained_kwargs = {**self.pretrained_model_kwargs}
if self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
@@ -197,7 +204,7 @@ class SingleFileTesterMixin:
def test_single_file_loading_with_diffusers_config_local_files_only(self, tmp_path):
single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
if self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
@@ -225,14 +232,14 @@ class SingleFileTesterMixin:
backend_empty_cache(torch_device)
def test_checkpoint_variant_loading(self):
if not hasattr(self, "alternate_ckpt_paths") or not self.alternate_ckpt_paths:
if not self.alternate_ckpt_paths:
return
for ckpt_path in self.alternate_ckpt_paths:
backend_empty_cache(torch_device)
single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
if self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)

View File

@@ -14,13 +14,20 @@
# limitations under the License.
import copy
import gc
import pytest
import torch
from diffusers.training_utils import EMAModel
from ...testing_utils import is_training, require_torch_accelerator_with_training, torch_all_close, torch_device
from ...testing_utils import (
backend_empty_cache,
is_training,
require_torch_accelerator_with_training,
torch_all_close,
torch_device,
)
@is_training
@@ -29,20 +36,26 @@ class TrainingTesterMixin:
"""
Mixin class for testing training functionality on models.
Expected class attributes to be set by subclasses:
Expected from config mixin:
- model_class: The model class to test
- output_shape: Tuple defining the expected output shape
Expected methods to be implemented by subclasses:
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Expected properties to be implemented by subclasses:
- output_shape: Tuple defining the expected output shape
Pytest mark: training
Use `pytest -m "not training"` to skip these tests
"""
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def test_training(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()

View File

@@ -198,31 +198,25 @@ class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Flux Transformer."""
pass
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Flux Transformer."""
pass
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Flux Transformer."""
pass
class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for Flux Transformer"""
pass
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
"""IP Adapter tests for Flux Transformer."""
ip_adapter_processor_cls = FluxIPAdapterAttnProcessor
@property
def ip_adapter_processor_cls(self):
return FluxIPAdapterAttnProcessor
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
torch.manual_seed(0)
@@ -241,13 +235,13 @@ class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterM
class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for Flux Transformer."""
pass
class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for Flux Transformer."""
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for LoRA hotswap tests."""
@@ -268,7 +262,9 @@ class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappin
class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin):
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for compilation tests."""
@@ -289,11 +285,17 @@ class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTester
class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
pretrained_model_name_or_path = "black-forest-labs/FLUX.1-dev"
subfolder = "transformer"
pass
@property
def ckpt_path(self):
return "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
@property
def alternate_ckpt_paths(self):
return ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
@property
def pretrained_model_name_or_path(self):
return "black-forest-labs/FLUX.1-dev"
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
@@ -433,14 +435,10 @@ class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCo
class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin):
"""PyramidAttentionBroadcast cache tests for Flux Transformer."""
pass
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
"""FirstBlockCache tests for Flux Transformer."""
pass
class TestFluxTransformerFasterCache(FluxTransformerTesterConfig, FasterCacheTesterMixin):
"""FasterCache tests for Flux Transformer."""