1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[tests] Fix group offloading and layerwise casting test interaction (#11796)

* update

* update

* update
This commit is contained in:
Aryan
2025-06-24 17:33:32 +05:30
committed by GitHub
parent 7392c8ff5a
commit 5df02fc171
2 changed files with 17 additions and 16 deletions

View File

@@ -110,8 +110,11 @@ class CosmosPatchEmbed3d(nn.Module):
self.patch_size = patch_size
self.patch_method = patch_method
self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False)
self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=False)
wavelets = _WAVELETS.get(patch_method).clone()
arange = torch.arange(wavelets.shape[0])
self.register_buffer("wavelets", wavelets, persistent=False)
self.register_buffer("_arange", arange, persistent=False)
def _dwt(self, hidden_states: torch.Tensor, mode: str = "reflect", rescale=False) -> torch.Tensor:
dtype = hidden_states.dtype
@@ -185,12 +188,11 @@ class CosmosUnpatcher3d(nn.Module):
self.patch_size = patch_size
self.patch_method = patch_method
self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False)
self.register_buffer(
"_arange",
torch.arange(_WAVELETS[patch_method].shape[0]),
persistent=False,
)
wavelets = _WAVELETS.get(patch_method).clone()
arange = torch.arange(wavelets.shape[0])
self.register_buffer("wavelets", wavelets, persistent=False)
self.register_buffer("_arange", arange, persistent=False)
def _idwt(self, hidden_states: torch.Tensor, rescale: bool = False) -> torch.Tensor:
device = hidden_states.device

View File

@@ -1528,14 +1528,16 @@ class ModelTesterMixin:
test_fn(torch.float8_e5m2, torch.float32)
test_fn(torch.float8_e4m3fn, torch.bfloat16)
@torch.no_grad()
def test_layerwise_casting_inference(self):
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy()
model = self.model_class(**config)
model.eval()
model.to(torch_device)
base_slice = model(**inputs_dict)[0].detach().flatten().cpu().numpy()
def check_linear_dtype(module, storage_dtype, compute_dtype):
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
@@ -1573,6 +1575,7 @@ class ModelTesterMixin:
test_layerwise_casting(torch.float8_e4m3fn, torch.bfloat16)
@require_torch_accelerator
@torch.no_grad()
def test_layerwise_casting_memory(self):
MB_TOLERANCE = 0.2
LEAST_COMPUTE_CAPABILITY = 8.0
@@ -1706,10 +1709,6 @@ class ModelTesterMixin:
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
@@ -1725,7 +1724,7 @@ class ModelTesterMixin:
**additional_kwargs,
)
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
assert has_safetensors, "No safetensors found in the directory."
self.assertTrue(len(has_safetensors) > 0, "No safetensors found in the offload directory.")
_ = model(**inputs_dict)[0]
def test_auto_model(self, expected_max_diff=5e-5):