mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[tests] Fix group offloading and layerwise casting test interaction (#11796)
* update * update * update
This commit is contained in:
@@ -110,8 +110,11 @@ class CosmosPatchEmbed3d(nn.Module):
|
||||
self.patch_size = patch_size
|
||||
self.patch_method = patch_method
|
||||
|
||||
self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False)
|
||||
self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=False)
|
||||
wavelets = _WAVELETS.get(patch_method).clone()
|
||||
arange = torch.arange(wavelets.shape[0])
|
||||
|
||||
self.register_buffer("wavelets", wavelets, persistent=False)
|
||||
self.register_buffer("_arange", arange, persistent=False)
|
||||
|
||||
def _dwt(self, hidden_states: torch.Tensor, mode: str = "reflect", rescale=False) -> torch.Tensor:
|
||||
dtype = hidden_states.dtype
|
||||
@@ -185,12 +188,11 @@ class CosmosUnpatcher3d(nn.Module):
|
||||
self.patch_size = patch_size
|
||||
self.patch_method = patch_method
|
||||
|
||||
self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False)
|
||||
self.register_buffer(
|
||||
"_arange",
|
||||
torch.arange(_WAVELETS[patch_method].shape[0]),
|
||||
persistent=False,
|
||||
)
|
||||
wavelets = _WAVELETS.get(patch_method).clone()
|
||||
arange = torch.arange(wavelets.shape[0])
|
||||
|
||||
self.register_buffer("wavelets", wavelets, persistent=False)
|
||||
self.register_buffer("_arange", arange, persistent=False)
|
||||
|
||||
def _idwt(self, hidden_states: torch.Tensor, rescale: bool = False) -> torch.Tensor:
|
||||
device = hidden_states.device
|
||||
|
||||
@@ -1528,14 +1528,16 @@ class ModelTesterMixin:
|
||||
test_fn(torch.float8_e5m2, torch.float32)
|
||||
test_fn(torch.float8_e4m3fn, torch.bfloat16)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_layerwise_casting_inference(self):
|
||||
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
|
||||
|
||||
torch.manual_seed(0)
|
||||
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy()
|
||||
model = self.model_class(**config)
|
||||
model.eval()
|
||||
model.to(torch_device)
|
||||
base_slice = model(**inputs_dict)[0].detach().flatten().cpu().numpy()
|
||||
|
||||
def check_linear_dtype(module, storage_dtype, compute_dtype):
|
||||
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
|
||||
@@ -1573,6 +1575,7 @@ class ModelTesterMixin:
|
||||
test_layerwise_casting(torch.float8_e4m3fn, torch.bfloat16)
|
||||
|
||||
@require_torch_accelerator
|
||||
@torch.no_grad()
|
||||
def test_layerwise_casting_memory(self):
|
||||
MB_TOLERANCE = 0.2
|
||||
LEAST_COMPUTE_CAPABILITY = 8.0
|
||||
@@ -1706,10 +1709,6 @@ class ModelTesterMixin:
|
||||
if not self.model_class._supports_group_offloading:
|
||||
pytest.skip("Model does not support group offloading.")
|
||||
|
||||
torch.manual_seed(0)
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
torch.manual_seed(0)
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
@@ -1725,7 +1724,7 @@ class ModelTesterMixin:
|
||||
**additional_kwargs,
|
||||
)
|
||||
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
|
||||
assert has_safetensors, "No safetensors found in the directory."
|
||||
self.assertTrue(len(has_safetensors) > 0, "No safetensors found in the offload directory.")
|
||||
_ = model(**inputs_dict)[0]
|
||||
|
||||
def test_auto_model(self, expected_max_diff=5e-5):
|
||||
|
||||
Reference in New Issue
Block a user