diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 1f3a36b5d1..7186cb181a 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -96,9 +96,6 @@ class ModuleGroup: else: self.cpu_param_dict = self._init_cpu_param_dict() - if self.stream is None and self.record_stream: - raise ValueError("`record_stream` cannot be True when `stream` is None.") - def _init_cpu_param_dict(self): cpu_param_dict = {} if self.stream is None: @@ -513,6 +510,9 @@ def apply_group_offloading( else: raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.") + if not use_stream and record_stream: + raise ValueError("`record_stream` cannot be True when `use_stream=False`.") + _raise_error_if_accelerate_model_or_sequential_hook_present(module) if offload_type == "block_level": diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 16f0d48365..e6941a521d 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -1022,15 +1022,3 @@ class LoraBaseMixin: @classmethod def _optionally_disable_offloading(cls, _pipeline): return _func_optionally_disable_offloading(_pipeline=_pipeline) - - @classmethod - def _fetch_state_dict(cls, *args, **kwargs): - deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`." - deprecate("_fetch_state_dict", "0.35.0", deprecation_message) - return _fetch_state_dict(*args, **kwargs) - - @classmethod - def _best_guess_weight_name(cls, *args, **kwargs): - deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`." - deprecate("_best_guess_weight_name", "0.35.0", deprecation_message) - return _best_guess_weight_name(*args, **kwargs) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py b/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py index 15fe8e02e0..7ab79a0bb8 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py @@ -110,8 +110,11 @@ class CosmosPatchEmbed3d(nn.Module): self.patch_size = patch_size self.patch_method = patch_method - self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False) - self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=False) + wavelets = _WAVELETS.get(patch_method).clone() + arange = torch.arange(wavelets.shape[0]) + + self.register_buffer("wavelets", wavelets, persistent=False) + self.register_buffer("_arange", arange, persistent=False) def _dwt(self, hidden_states: torch.Tensor, mode: str = "reflect", rescale=False) -> torch.Tensor: dtype = hidden_states.dtype @@ -185,12 +188,11 @@ class CosmosUnpatcher3d(nn.Module): self.patch_size = patch_size self.patch_method = patch_method - self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False) - self.register_buffer( - "_arange", - torch.arange(_WAVELETS[patch_method].shape[0]), - persistent=False, - ) + wavelets = _WAVELETS.get(patch_method).clone() + arange = torch.arange(wavelets.shape[0]) + + self.register_buffer("wavelets", wavelets, persistent=False) + self.register_buffer("_arange", arange, persistent=False) def _idwt(self, hidden_states: torch.Tensor, rescale: bool = False) -> torch.Tensor: device = hidden_states.device diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index f6906074b3..1254b6725f 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -23,12 +23,14 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...models.autoencoders import AutoencoderKL from ...models.transformers import OmniGenTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils import is_torch_xla_available, is_torchvision_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from .processor_omnigen import OmniGenMultiModalProcessor +if is_torchvision_available(): + from .processor_omnigen import OmniGenMultiModalProcessor + if is_torch_xla_available(): XLA_AVAILABLE = True else: diff --git a/src/diffusers/pipelines/omnigen/processor_omnigen.py b/src/diffusers/pipelines/omnigen/processor_omnigen.py index be5ff82c4a..7ed11871bb 100644 --- a/src/diffusers/pipelines/omnigen/processor_omnigen.py +++ b/src/diffusers/pipelines/omnigen/processor_omnigen.py @@ -18,7 +18,12 @@ from typing import Dict, List import numpy as np import torch from PIL import Image -from torchvision import transforms + +from ...utils import is_torchvision_available + + +if is_torchvision_available(): + from torchvision import transforms def crop_image(pil_image, max_image_size): diff --git a/tests/lora/test_deprecated_utilities.py b/tests/lora/test_deprecated_utilities.py deleted file mode 100644 index 4275ef8089..0000000000 --- a/tests/lora/test_deprecated_utilities.py +++ /dev/null @@ -1,39 +0,0 @@ -import os -import tempfile -import unittest - -import torch - -from diffusers.loaders.lora_base import LoraBaseMixin - - -class UtilityMethodDeprecationTests(unittest.TestCase): - def test_fetch_state_dict_cls_method_raises_warning(self): - state_dict = torch.nn.Linear(3, 3).state_dict() - with self.assertWarns(FutureWarning) as warning: - _ = LoraBaseMixin._fetch_state_dict( - state_dict, - weight_name=None, - use_safetensors=False, - local_files_only=True, - cache_dir=None, - force_download=False, - proxies=None, - token=None, - revision=None, - subfolder=None, - user_agent=None, - allow_pickle=None, - ) - warning_message = str(warning.warnings[0].message) - assert "Using the `_fetch_state_dict()` method from" in warning_message - - def test_best_guess_weight_name_cls_method_raises_warning(self): - with tempfile.TemporaryDirectory() as tmpdir: - state_dict = torch.nn.Linear(3, 3).state_dict() - torch.save(state_dict, os.path.join(tmpdir, "pytorch_lora_weights.bin")) - - with self.assertWarns(FutureWarning) as warning: - _ = LoraBaseMixin._best_guess_weight_name(pretrained_model_name_or_path_or_dict=tmpdir) - warning_message = str(warning.warnings[0].message) - assert "Using the `_best_guess_weight_name()` method from" in warning_message diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 0ea7030d06..ff52ee701d 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1528,14 +1528,16 @@ class ModelTesterMixin: test_fn(torch.float8_e5m2, torch.float32) test_fn(torch.float8_e4m3fn, torch.bfloat16) + @torch.no_grad() def test_layerwise_casting_inference(self): from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS torch.manual_seed(0) config, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**config).eval() - model = model.to(torch_device) - base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy() + model = self.model_class(**config) + model.eval() + model.to(torch_device) + base_slice = model(**inputs_dict)[0].detach().flatten().cpu().numpy() def check_linear_dtype(module, storage_dtype, compute_dtype): patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN @@ -1573,6 +1575,7 @@ class ModelTesterMixin: test_layerwise_casting(torch.float8_e4m3fn, torch.bfloat16) @require_torch_accelerator + @torch.no_grad() def test_layerwise_casting_memory(self): MB_TOLERANCE = 0.2 LEAST_COMPUTE_CAPABILITY = 8.0 @@ -1706,10 +1709,6 @@ class ModelTesterMixin: if not self.model_class._supports_group_offloading: pytest.skip("Model does not support group offloading.") - torch.manual_seed(0) - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - torch.manual_seed(0) init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) @@ -1725,7 +1724,7 @@ class ModelTesterMixin: **additional_kwargs, ) has_safetensors = glob.glob(f"{tmpdir}/*.safetensors") - assert has_safetensors, "No safetensors found in the directory." + self.assertTrue(len(has_safetensors) > 0, "No safetensors found in the offload directory.") _ = model(**inputs_dict)[0] def test_auto_model(self, expected_max_diff=5e-5): @@ -2126,7 +2125,7 @@ class LoraHotSwappingForModelTesterMixin: @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa def test_hotswapping_compiled_model_conv2d(self, rank0, rank1): if "unet" not in self.model_class.__name__.lower(): - return + pytest.skip("Test only applies to UNet.") # It's important to add this context to raise an error on recompilation target_modules = ["conv", "conv1", "conv2"] @@ -2136,7 +2135,7 @@ class LoraHotSwappingForModelTesterMixin: @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1): if "unet" not in self.model_class.__name__.lower(): - return + pytest.skip("Test only applies to UNet.") # It's important to add this context to raise an error on recompilation target_modules = ["to_q", "conv"] diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py index 84085f9d7d..b2d6f0fc05 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py @@ -289,6 +289,5 @@ class KandinskyV22ControlnetPipelineIntegrationTests(unittest.TestCase): image = output.images[0] assert image.shape == (512, 512, 3) - max_diff = numpy_cosine_similarity_distance(expected_image.flatten(), image.flatten()) - assert max_diff < 1e-4 + assert max_diff < 2e-4 diff --git a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py index 342561d4f5..ab0221dc81 100644 --- a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py +++ b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py @@ -29,6 +29,7 @@ from diffusers import ( UNet2DConditionModel, ) from diffusers.utils.testing_utils import ( + Expectations, backend_empty_cache, enable_full_determinism, floats_tensor, @@ -244,7 +245,35 @@ class LEditsPPPipelineStableDiffusionSlowTests(unittest.TestCase): output_slice = reconstruction[150:153, 140:143, -1] output_slice = output_slice.flatten() - expected_slice = np.array( - [0.9453125, 0.93310547, 0.84521484, 0.94628906, 0.9111328, 0.80859375, 0.93847656, 0.9042969, 0.8144531] + expected_slices = Expectations( + { + ("xpu", 3): np.array( + [ + 0.9511719, + 0.94140625, + 0.87597656, + 0.9472656, + 0.9296875, + 0.8378906, + 0.94433594, + 0.91503906, + 0.8491211, + ] + ), + ("cuda", 7): np.array( + [ + 0.9453125, + 0.93310547, + 0.84521484, + 0.94628906, + 0.9111328, + 0.80859375, + 0.93847656, + 0.9042969, + 0.8144531, + ] + ), + } ) + expected_slice = expected_slices.get_expectation() assert np.abs(output_slice - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 207cff2a3c..4a3a9b1796 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -49,6 +49,7 @@ from diffusers.utils.source_code_parsing_utils import ReturnNameVisitor from diffusers.utils.testing_utils import ( CaptureLogger, backend_empty_cache, + numpy_cosine_similarity_distance, require_accelerate_version_greater, require_accelerator, require_hf_hub_version_greater, @@ -1394,9 +1395,8 @@ class PipelineTesterMixin: fp16_inputs["generator"] = self.get_generator(0) output_fp16 = pipe_fp16(**fp16_inputs)[0] - - max_diff = np.abs(to_np(output) - to_np(output_fp16)).max() - self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.") + max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten()) + assert max_diff < 2e-4 @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") @require_accelerator diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index ae3900459d..5d1fa4c22e 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -286,33 +286,33 @@ class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase) { ("xpu", 3): np.array( [ - 0.19335938, - 0.3125, - 0.3203125, - 0.1328125, - 0.3046875, - 0.296875, - 0.11914062, - 0.2890625, - 0.2890625, - 0.16796875, - 0.30273438, - 0.33203125, - 0.14648438, - 0.31640625, - 0.33007812, + 0.16210938, + 0.2734375, + 0.27734375, + 0.109375, + 0.27148438, + 0.2578125, + 0.1015625, + 0.2578125, + 0.2578125, + 0.14453125, + 0.26953125, + 0.29492188, 0.12890625, - 0.3046875, - 0.30859375, - 0.17773438, - 0.33789062, - 0.33203125, - 0.16796875, - 0.34570312, - 0.32421875, + 0.28710938, + 0.30078125, + 0.11132812, + 0.27734375, + 0.27929688, 0.15625, - 0.33203125, - 0.31445312, + 0.31054688, + 0.296875, + 0.15234375, + 0.3203125, + 0.29492188, + 0.140625, + 0.3046875, + 0.28515625, ] ), ("cuda", 7): np.array(