diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 40183cd9a0..7d33415d73 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1802,9 +1802,9 @@ class PeftLoraLoaderMixinTests: if any((re.search(pattern, name) for pattern in patterns_to_check)): dtype_to_check = compute_dtype if getattr(module, "weight", None) is not None: - self.assertEqual(module.weight.dtype, dtype_to_check) + assert module.weight.dtype == dtype_to_check if getattr(module, "bias", None) is not None: - self.assertEqual(module.bias.dtype, dtype_to_check) + assert module.bias.dtype == dtype_to_check if isinstance(module, BaseTunerLayer): assert getattr(module, "_diffusers_hook", None is not None) assert module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None