mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update
This commit is contained in:
@@ -21,11 +21,10 @@ from diffusers.models.attention_processor import (
|
||||
AttnProcessor,
|
||||
)
|
||||
|
||||
from ...testing_utils import is_attention, require_accelerator, torch_device
|
||||
from ...testing_utils import is_attention, torch_device
|
||||
|
||||
|
||||
@is_attention
|
||||
@require_accelerator
|
||||
class AttentionTesterMixin:
|
||||
"""
|
||||
Mixin class for testing attention processor and module functionality on models.
|
||||
|
||||
@@ -16,10 +16,10 @@
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
|
||||
|
||||
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant
|
||||
@@ -30,8 +30,8 @@ from ...testing_utils import torch_device
|
||||
|
||||
def compute_module_persistent_sizes(
|
||||
model: nn.Module,
|
||||
dtype: Optional[Union[str, torch.device]] = None,
|
||||
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
|
||||
dtype: str | torch.device | None = None,
|
||||
special_dtypes: dict[str, str | torch.device] | None = None,
|
||||
):
|
||||
"""
|
||||
Compute the size of each submodule of a given model (parameters + persistent buffers).
|
||||
@@ -128,6 +128,7 @@ class ModelTesterMixin:
|
||||
)
|
||||
|
||||
def test_from_save_pretrained(self, expected_max_diff=5e-5):
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -273,10 +274,10 @@ class ModelTesterMixin:
|
||||
return t.to(device)
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
if isinstance(tuple_object, (list, tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, Dict):
|
||||
elif isinstance(tuple_object, dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
|
||||
@@ -25,6 +25,7 @@ from diffusers.utils.import_utils import (
|
||||
is_gguf_available,
|
||||
is_nvidia_modelopt_available,
|
||||
is_optimum_quanto_available,
|
||||
is_torchao_available,
|
||||
)
|
||||
|
||||
from ...testing_utils import (
|
||||
@@ -41,6 +42,7 @@ from ...testing_utils import (
|
||||
require_gguf_version_greater_or_equal,
|
||||
require_quanto,
|
||||
require_torchao_version_greater_or_equal,
|
||||
require_modelopt_version_greater_or_equal,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@@ -58,7 +60,6 @@ if is_gguf_available():
|
||||
pass
|
||||
|
||||
if is_torchao_available():
|
||||
|
||||
if is_torchao_version(">=", "0.9.0"):
|
||||
pass
|
||||
|
||||
@@ -644,9 +645,7 @@ class TorchAoTesterMixin(QuantizationTesterMixin):
|
||||
if modules_to_exclude is None:
|
||||
pytest.skip("modules_to_not_convert_for_test not defined for this model")
|
||||
|
||||
self._test_quantization_modules_to_not_convert(
|
||||
self.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude
|
||||
)
|
||||
self._test_quantization_modules_to_not_convert(self.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude)
|
||||
|
||||
def test_torchao_device_map(self):
|
||||
"""Test that device_map='auto' works correctly with quantization."""
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import FluxTransformer2DModel
|
||||
@@ -46,7 +48,11 @@ class FluxTransformerTesterConfig:
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
|
||||
pretrained_model_kwargs = {"subfolder": "transformer"}
|
||||
|
||||
def get_init_dict(self):
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int]]:
|
||||
"""Return Flux model initialization arguments."""
|
||||
return {
|
||||
"patch_size": 1,
|
||||
@@ -60,30 +66,32 @@ class FluxTransformerTesterConfig:
|
||||
"axes_dims_rope": [4, 4, 8],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
height = width = 4
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 8
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels)),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim)),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim)),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels)),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels)),
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), generator=self.generator),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator
|
||||
),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim), generator=self.generator),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels), generator=self.generator),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels), generator=self.generator),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (16, 4)
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
return (1, 16, 4)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (16, 4)
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
return (1, 16, 4)
|
||||
|
||||
|
||||
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
|
||||
@@ -140,7 +148,7 @@ class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterM
|
||||
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
|
||||
"""IP Adapter tests for Flux Transformer."""
|
||||
|
||||
def create_ip_adapter_state_dict(self, model):
|
||||
def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]:
|
||||
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
|
||||
|
||||
ip_cross_attn_state_dict = {}
|
||||
@@ -202,7 +210,7 @@ class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappin
|
||||
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height=4, width=4):
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for LoRA hotswap tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
@@ -223,7 +231,7 @@ class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappin
|
||||
class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height=4, width=4):
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for compilation tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
@@ -250,7 +258,7 @@ class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
|
||||
|
||||
|
||||
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
def get_dummy_inputs(self):
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
@@ -263,7 +271,7 @@ class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesT
|
||||
|
||||
|
||||
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
|
||||
def get_dummy_inputs(self):
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
@@ -276,7 +284,7 @@ class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
|
||||
|
||||
|
||||
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
|
||||
def get_dummy_inputs(self):
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
@@ -291,7 +299,7 @@ class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin
|
||||
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
|
||||
gguf_filename = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
@@ -304,7 +312,7 @@ class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
|
||||
|
||||
|
||||
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
|
||||
def get_dummy_inputs(self):
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
return {
|
||||
"hidden_states": randn_tensor((1, 4096, 64)),
|
||||
"encoder_hidden_states": randn_tensor((1, 512, 4096)),
|
||||
|
||||
@@ -37,6 +37,7 @@ from diffusers.utils.import_utils import (
|
||||
is_flax_available,
|
||||
is_gguf_available,
|
||||
is_kernels_available,
|
||||
is_nvidia_modelopt_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
@@ -765,6 +766,19 @@ def require_kernels_version_greater_or_equal(kernels_version):
|
||||
return decorator
|
||||
|
||||
|
||||
def require_modelopt_version_greater_or_equal(modelopt_version):
|
||||
def decorator(test_case):
|
||||
correct_nvidia_modelopt_version = is_nvidia_modelopt_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("modelopt")).base_version
|
||||
) >= version.parse(modelopt_version)
|
||||
return pytest.mark.skipif(
|
||||
not correct_nvidia_modelopt_version,
|
||||
f"Test requires modelopt with version greater than {modelopt_version}.",
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecate_after_peft_backend(test_case):
|
||||
"""
|
||||
Decorator marking a test that will be skipped after PEFT backend
|
||||
|
||||
Reference in New Issue
Block a user