From 5d49b3e83b97b45d8745ed6fc9f06c32d5ef9286 Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 8 Apr 2025 16:47:03 +0100 Subject: [PATCH 1/3] Flux quantized with lora (#10990) * Flux quantized with lora * fix * changes * Apply suggestions from code review Co-authored-by: Sayak Paul * Apply style fixes * enable model cpu offload() * Update src/diffusers/loaders/lora_pipeline.py Co-authored-by: hlky * update * Apply suggestions from code review * update * add peft as an additional dependency for gguf --------- Co-authored-by: Sayak Paul Co-authored-by: github-actions[bot] Co-authored-by: Dhruv Nair --- .github/workflows/nightly_tests.yml | 2 +- src/diffusers/loaders/lora_pipeline.py | 63 ++++++++++++++++++++++++-- src/diffusers/quantizers/gguf/utils.py | 2 + tests/quantization/bnb/test_4bit.py | 45 +++++++++++++++++- tests/quantization/gguf/test_gguf.py | 46 +++++++++++++++++++ 5 files changed, 151 insertions(+), 7 deletions(-) diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 2b39eea2fe..88343a128b 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -417,7 +417,7 @@ jobs: additional_deps: ["peft"] - backend: "gguf" test_location: "gguf" - additional_deps: [] + additional_deps: ["peft"] - backend: "torchao" test_location: "torchao" additional_deps: [] diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 513c15ac5d..a29b77acce 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -22,6 +22,8 @@ from ..utils import ( USE_PEFT_BACKEND, deprecate, get_submodule_by_name, + is_bitsandbytes_available, + is_gguf_available, is_peft_available, is_peft_version, is_torch_version, @@ -68,6 +70,49 @@ TRANSFORMER_NAME = "transformer" _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"} +def _maybe_dequantize_weight_for_expanded_lora(model, module): + if is_bitsandbytes_available(): + from ..quantizers.bitsandbytes import dequantize_bnb_weight + + if is_gguf_available(): + from ..quantizers.gguf.utils import dequantize_gguf_tensor + + is_bnb_4bit_quantized = module.weight.__class__.__name__ == "Params4bit" + is_gguf_quantized = module.weight.__class__.__name__ == "GGUFParameter" + + if is_bnb_4bit_quantized and not is_bitsandbytes_available(): + raise ValueError( + "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints." + ) + if is_gguf_quantized and not is_gguf_available(): + raise ValueError( + "The checkpoint seems to have been quantized with `gguf`. Install `gguf` to load quantized checkpoints." + ) + + weight_on_cpu = False + if not module.weight.is_cuda: + weight_on_cpu = True + + if is_bnb_4bit_quantized: + module_weight = dequantize_bnb_weight( + module.weight.cuda() if weight_on_cpu else module.weight, + state=module.weight.quant_state, + dtype=model.dtype, + ).data + elif is_gguf_quantized: + module_weight = dequantize_gguf_tensor( + module.weight.cuda() if weight_on_cpu else module.weight, + ) + module_weight = module_weight.to(model.dtype) + else: + module_weight = module.weight.data + + if weight_on_cpu: + module_weight = module_weight.cpu() + + return module_weight + + class StableDiffusionLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and @@ -2267,6 +2312,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): overwritten_params = {} is_peft_loaded = getattr(transformer, "peft_config", None) is not None + is_quantized = hasattr(transformer, "hf_quantizer") for name, module in transformer.named_modules(): if isinstance(module, torch.nn.Linear): module_weight = module.weight.data @@ -2291,9 +2337,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): if tuple(module_weight_shape) == (out_features, in_features): continue - # TODO (sayakpaul): We still need to consider if the module we're expanding is - # quantized and handle it accordingly if that is the case. - module_out_features, module_in_features = module_weight.shape + module_out_features, module_in_features = module_weight_shape debug_message = "" if in_features > module_in_features: debug_message += ( @@ -2316,6 +2360,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin): parent_module_name, _, current_module_name = name.rpartition(".") parent_module = transformer.get_submodule(parent_module_name) + if is_quantized: + module_weight = _maybe_dequantize_weight_for_expanded_lora(transformer, module) + + # TODO: consider if this layer needs to be a quantized layer as well if `is_quantized` is True. with torch.device("meta"): expanded_module = torch.nn.Linear( in_features, out_features, bias=bias, dtype=module_weight.dtype @@ -2327,7 +2375,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): new_weight = torch.zeros_like( expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype ) - slices = tuple(slice(0, dim) for dim in module_weight.shape) + slices = tuple(slice(0, dim) for dim in module_weight_shape) new_weight[slices] = module_weight tmp_state_dict = {"weight": new_weight} if module_bias is not None: @@ -2416,7 +2464,12 @@ class FluxLoraLoaderMixin(LoraBaseMixin): base_weight_param_name: str = None, ) -> "torch.Size": def _get_weight_shape(weight: torch.Tensor): - return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape + if weight.__class__.__name__ == "Params4bit": + return weight.quant_state.shape + elif weight.__class__.__name__ == "GGUFParameter": + return weight.quant_shape + else: + return weight.shape if base_module is not None: return _get_weight_shape(base_module.weight) diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index effc39d8fe..de82dcab07 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -400,6 +400,8 @@ class GGUFParameter(torch.nn.Parameter): data = data if data is not None else torch.empty(0) self = torch.Tensor._make_subclass(cls, data, requires_grad) self.quant_type = quant_type + block_size, type_size = GGML_QUANT_SIZES[quant_type] + self.quant_shape = _quant_shape_from_byte_shape(self.shape, type_size, block_size) return self diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 29a3e212c4..fdcc5314d2 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -21,8 +21,15 @@ import numpy as np import pytest import safetensors.torch from huggingface_hub import hf_hub_download +from PIL import Image -from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel +from diffusers import ( + BitsAndBytesConfig, + DiffusionPipeline, + FluxControlPipeline, + FluxTransformer2DModel, + SD3Transformer2DModel, +) from diffusers.utils import is_accelerate_version, logging from diffusers.utils.testing_utils import ( CaptureLogger, @@ -696,6 +703,42 @@ class SlowBnb4BitFluxTests(Base4bitTests): self.assertTrue(max_diff < 1e-3) +@require_transformers_version_greater("4.44.0") +@require_peft_backend +class SlowBnb4BitFluxControlWithLoraTests(Base4bitTests): + def setUp(self) -> None: + gc.collect() + torch.cuda.empty_cache() + + self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.float16) + self.pipeline_4bit.enable_model_cpu_offload() + + def tearDown(self): + del self.pipeline_4bit + + gc.collect() + torch.cuda.empty_cache() + + def test_lora_loading(self): + self.pipeline_4bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora") + + output = self.pipeline_4bit( + prompt=self.prompt, + control_image=Image.new(mode="RGB", size=(256, 256)), + height=256, + width=256, + max_sequence_length=64, + output_type="np", + num_inference_steps=8, + generator=torch.Generator().manual_seed(42), + ).images + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.1636, 0.1675, 0.1982, 0.1743, 0.1809, 0.1936, 0.1743, 0.2095, 0.2139]) + + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}") + + @slow class BaseBnb4BitSerializationTests(Base4bitTests): def tearDown(self): diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index 5e3875c7c9..e4cf1dfee1 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -8,12 +8,14 @@ import torch.nn as nn from diffusers import ( AuraFlowPipeline, AuraFlowTransformer2DModel, + FluxControlPipeline, FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig, SD3Transformer2DModel, StableDiffusion3Pipeline, ) +from diffusers.utils import load_image from diffusers.utils.testing_utils import ( is_gguf_available, nightly, @@ -21,6 +23,7 @@ from diffusers.utils.testing_utils import ( require_accelerate, require_big_gpu_with_torch_cuda, require_gguf_version_greater_or_equal, + require_peft_backend, torch_device, ) @@ -456,3 +459,46 @@ class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): ) max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice) assert max_diff < 1e-4 + + +@require_peft_backend +@nightly +@require_big_gpu_with_torch_cuda +@require_accelerate +@require_gguf_version_greater_or_equal("0.10.0") +class FluxControlLoRAGGUFTests(unittest.TestCase): + def test_lora_loading(self): + ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf" + transformer = FluxTransformer2DModel.from_single_file( + ckpt_path, + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), + torch_dtype=torch.bfloat16, + ) + pipe = FluxControlPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + transformer=transformer, + torch_dtype=torch.bfloat16, + ).to("cuda") + pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora") + + prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts." + control_image = load_image( + "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/control_image_robot_canny.png" + ) + + output = pipe( + prompt=prompt, + control_image=control_image, + height=256, + width=256, + num_inference_steps=10, + guidance_scale=30.0, + output_type="np", + generator=torch.manual_seed(0), + ).images + + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.8047, 0.8359, 0.8711, 0.6875, 0.7070, 0.7383, 0.5469, 0.5820, 0.6641]) + + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-3) From 4b27c4a494bb07849f8a9a509b2d268bf314f7a7 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 8 Apr 2025 21:17:49 +0530 Subject: [PATCH 2/3] [feat] implement `record_stream` when using CUDA streams during group offloading (#11081) * implement record_stream for better performance. * fix * style. * merge #11097 * Update src/diffusers/hooks/group_offloading.py Co-authored-by: Aryan * fixes * docstring. * remaining todos in low_cpu_mem_usage * tests * updates to docs. --------- Co-authored-by: Aryan --- docs/source/en/optimization/memory.md | 4 ++ src/diffusers/hooks/group_offloading.py | 66 +++++++++++++++++++++++-- src/diffusers/models/modeling_utils.py | 2 + tests/models/test_modeling_common.py | 7 ++- 4 files changed, 73 insertions(+), 6 deletions(-) diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md index fd72957471..fc93947761 100644 --- a/docs/source/en/optimization/memory.md +++ b/docs/source/en/optimization/memory.md @@ -178,6 +178,9 @@ pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch # We can utilize the enable_group_offload method for Diffusers model implementations pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True) +# Uncomment the following to also allow recording the current streams. +# pipe.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True, record_stream=True) + # For any other model implementations, the apply_group_offloading function can be used apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2) apply_group_offloading(pipe.vae, onload_device=onload_device, offload_type="leaf_level") @@ -205,6 +208,7 @@ Group offloading (for CUDA devices with support for asynchronous data transfer s - The `use_stream` parameter can be used with CUDA devices to enable prefetching layers for onload. It defaults to `False`. Layer prefetching allows overlapping computation and data transfer of model weights, which drastically reduces the overall execution time compared to other offloading methods. However, it can increase the CPU RAM usage significantly. Ensure that available CPU RAM that is at least twice the size of the model when setting `use_stream=True`. You can find more information about CUDA streams [here](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html) - If specifying `use_stream=True` on VAEs with tiling enabled, make sure to do a dummy forward pass (possibly with dummy inputs) before the actual inference to avoid device-mismatch errors. This may not work on all implementations. Please open an issue if you encounter any problems. - The parameter `low_cpu_mem_usage` can be set to `True` to reduce CPU memory usage when using streams for group offloading. This is useful when the CPU memory is the bottleneck, but it may counteract the benefits of using streams and increase the overall execution time. The CPU memory savings come from creating pinned-tensors on-the-fly instead of pre-pinning them. This parameter is better suited for using `leaf_level` offloading. +- When using `use_stream=True`, users can additionally specify `record_stream=True` to get better speedups at the expense of slightly increased memory usage. Refer to the [official PyTorch docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) to know more about this. For more information about available parameters and an explanation of how group offloading works, refer to [`~hooks.group_offloading.apply_group_offloading`]. diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 4c1d354a0f..ac6cf65364 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -56,6 +56,7 @@ class ModuleGroup: buffers: Optional[List[torch.Tensor]] = None, non_blocking: bool = False, stream: Optional[torch.cuda.Stream] = None, + record_stream: Optional[bool] = False, low_cpu_mem_usage=False, onload_self: bool = True, ) -> None: @@ -68,11 +69,14 @@ class ModuleGroup: self.buffers = buffers or [] self.non_blocking = non_blocking or stream is not None self.stream = stream + self.record_stream = record_stream self.onload_self = onload_self self.low_cpu_mem_usage = low_cpu_mem_usage - self.cpu_param_dict = self._init_cpu_param_dict() + if self.stream is None and self.record_stream: + raise ValueError("`record_stream` cannot be True when `stream` is None.") + def _init_cpu_param_dict(self): cpu_param_dict = {} if self.stream is None: @@ -112,6 +116,8 @@ class ModuleGroup: def onload_(self): r"""Onloads the group of modules to the onload_device.""" context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream) + current_stream = torch.cuda.current_stream() if self.record_stream else None + if self.stream is not None: # Wait for previous Host->Device transfer to complete self.stream.synchronize() @@ -122,14 +128,22 @@ class ModuleGroup: for group_module in self.modules: for param in group_module.parameters(): param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + param.data.record_stream(current_stream) for buffer in group_module.buffers(): buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + buffer.data.record_stream(current_stream) for param in self.parameters: param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + param.data.record_stream(current_stream) for buffer in self.buffers: buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + buffer.data.record_stream(current_stream) else: for group_module in self.modules: @@ -143,11 +157,14 @@ class ModuleGroup: for buffer in self.buffers: buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) + if self.record_stream: + buffer.data.record_stream(current_stream) def offload_(self): r"""Offloads the group of modules to the offload_device.""" if self.stream is not None: - torch.cuda.current_stream().synchronize() + if not self.record_stream: + torch.cuda.current_stream().synchronize() for group_module in self.modules: for param in group_module.parameters(): param.data = self.cpu_param_dict[param] @@ -331,6 +348,7 @@ def apply_group_offloading( num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, + record_stream: bool = False, low_cpu_mem_usage: bool = False, ) -> None: r""" @@ -378,6 +396,10 @@ def apply_group_offloading( use_stream (`bool`, defaults to `False`): If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for overlapping computation and data transfer. + record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor + as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the + [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more + details. low_cpu_mem_usage (`bool`, defaults to `False`): If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when @@ -417,11 +439,24 @@ def apply_group_offloading( raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.") _apply_group_offloading_block_level( - module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage + module=module, + num_blocks_per_group=num_blocks_per_group, + offload_device=offload_device, + onload_device=onload_device, + non_blocking=non_blocking, + stream=stream, + record_stream=record_stream, + low_cpu_mem_usage=low_cpu_mem_usage, ) elif offload_type == "leaf_level": _apply_group_offloading_leaf_level( - module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage + module=module, + offload_device=offload_device, + onload_device=onload_device, + non_blocking=non_blocking, + stream=stream, + record_stream=record_stream, + low_cpu_mem_usage=low_cpu_mem_usage, ) else: raise ValueError(f"Unsupported offload_type: {offload_type}") @@ -434,6 +469,7 @@ def _apply_group_offloading_block_level( onload_device: torch.device, non_blocking: bool, stream: Optional[torch.cuda.Stream] = None, + record_stream: Optional[bool] = False, low_cpu_mem_usage: bool = False, ) -> None: r""" @@ -453,6 +489,14 @@ def _apply_group_offloading_block_level( stream (`torch.cuda.Stream`, *optional*): If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful for overlapping computation and data transfer. + record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor + as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the + [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more + details. + low_cpu_mem_usage (`bool`, defaults to `False`): + If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This + option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when + the CPU memory is a bottleneck but may counteract the benefits of using streams. """ # Create module groups for ModuleList and Sequential blocks @@ -475,6 +519,7 @@ def _apply_group_offloading_block_level( onload_leader=current_modules[0], non_blocking=non_blocking, stream=stream, + record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=stream is None, ) @@ -512,6 +557,7 @@ def _apply_group_offloading_block_level( buffers=buffers, non_blocking=False, stream=None, + record_stream=False, onload_self=True, ) next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None @@ -524,6 +570,7 @@ def _apply_group_offloading_leaf_level( onload_device: torch.device, non_blocking: bool, stream: Optional[torch.cuda.Stream] = None, + record_stream: Optional[bool] = False, low_cpu_mem_usage: bool = False, ) -> None: r""" @@ -545,6 +592,14 @@ def _apply_group_offloading_leaf_level( stream (`torch.cuda.Stream`, *optional*): If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful for overlapping computation and data transfer. + record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor + as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the + [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more + details. + low_cpu_mem_usage (`bool`, defaults to `False`): + If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This + option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when + the CPU memory is a bottleneck but may counteract the benefits of using streams. """ # Create module groups for leaf modules and apply group offloading hooks @@ -560,6 +615,7 @@ def _apply_group_offloading_leaf_level( onload_leader=submodule, non_blocking=non_blocking, stream=stream, + record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, ) @@ -605,6 +661,7 @@ def _apply_group_offloading_leaf_level( buffers=buffers, non_blocking=non_blocking, stream=stream, + record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, ) @@ -624,6 +681,7 @@ def _apply_group_offloading_leaf_level( buffers=None, non_blocking=False, stream=None, + record_stream=False, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, ) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 19ac868cda..2a22bc09ad 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -546,6 +546,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): num_blocks_per_group: Optional[int] = None, non_blocking: bool = False, use_stream: bool = False, + record_stream: bool = False, low_cpu_mem_usage=False, ) -> None: r""" @@ -594,6 +595,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): num_blocks_per_group, non_blocking, use_stream, + record_stream, low_cpu_mem_usage=low_cpu_mem_usage, ) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index d55ff6e628..847677884a 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1525,8 +1525,9 @@ class ModelTesterMixin: or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE ) + @parameterized.expand([False, True]) @require_torch_gpu - def test_group_offloading(self): + def test_group_offloading(self, record_stream): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() torch.manual_seed(0) @@ -1566,7 +1567,9 @@ class ModelTesterMixin: torch.manual_seed(0) model = self.model_class(**init_dict) - model.enable_group_offload(torch_device, offload_type="leaf_level", use_stream=True) + model.enable_group_offload( + torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream + ) output_with_group_offloading4 = run_forward(model) self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)) From 1a04812439c82a9dd318d14a800bb04e84dbbfc0 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 8 Apr 2025 21:18:34 +0530 Subject: [PATCH 3/3] [bistandbytes] improve replacement warnings for bnb (#11132) * improve replacement warnings for bnb * updates to docs. --- src/diffusers/quantizers/bitsandbytes/utils.py | 16 ++++++++++------ tests/quantization/bnb/test_4bit.py | 14 ++++++++++++++ tests/quantization/bnb/test_mixed_int8.py | 14 ++++++++++++++ 3 files changed, 38 insertions(+), 6 deletions(-) diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py index a9771b368a..e150281e81 100644 --- a/src/diffusers/quantizers/bitsandbytes/utils.py +++ b/src/diffusers/quantizers/bitsandbytes/utils.py @@ -139,10 +139,12 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name models by reducing the precision of the weights and activations, thus making models more efficient in terms of both storage and computation. """ - model, has_been_replaced = _replace_with_bnb_linear( - model, modules_to_not_convert, current_key_name, quantization_config - ) + model, _ = _replace_with_bnb_linear(model, modules_to_not_convert, current_key_name, quantization_config) + has_been_replaced = any( + isinstance(replaced_module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)) + for _, replaced_module in model.named_modules() + ) if not has_been_replaced: logger.warning( "You are loading your model in 8bit or 4bit but no linear modules were found in your model." @@ -283,16 +285,18 @@ def dequantize_and_replace( modules_to_not_convert=None, quantization_config=None, ): - model, has_been_replaced = _dequantize_and_replace( + model, _ = _dequantize_and_replace( model, dtype=model.dtype, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config, ) - + has_been_replaced = any( + isinstance(replaced_module, torch.nn.Linear) for _, replaced_module in model.named_modules() + ) if not has_been_replaced: logger.warning( - "For some reason the model has not been properly dequantized. You might see unexpected behavior." + "Some linear modules were not dequantized. This could lead to unexpected behaviour. Please check your model." ) return model diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index fdcc5314d2..096ee4c344 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -70,6 +70,8 @@ if is_torch_available(): if is_bitsandbytes_available(): import bitsandbytes as bnb + from diffusers.quantizers.bitsandbytes.utils import replace_with_bnb_linear + @require_bitsandbytes_version_greater("0.43.2") @require_accelerate @@ -371,6 +373,18 @@ class BnB4BitBasicTests(Base4bitTests): assert key_to_target in str(err_context.exception) + def test_bnb_4bit_logs_warning_for_no_quantization(self): + model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU()) + quantization_config = BitsAndBytesConfig(load_in_4bit=True) + logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils") + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + _ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config) + assert ( + "You are loading your model in 8bit or 4bit but no linear modules were found in your model." + in cap_logger.out + ) + class BnB4BitTrainingTests(Base4bitTests): def setUp(self): diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index a5e38f931e..1049bfecba 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -68,6 +68,8 @@ if is_torch_available(): if is_bitsandbytes_available(): import bitsandbytes as bnb + from diffusers.quantizers.bitsandbytes import replace_with_bnb_linear + @require_bitsandbytes_version_greater("0.43.2") @require_accelerate @@ -317,6 +319,18 @@ class BnB8bitBasicTests(Base8bitTests): # Check that this does not throw an error _ = self.model_fp16.to(torch_device) + def test_bnb_8bit_logs_warning_for_no_quantization(self): + model_with_no_linear = torch.nn.Sequential(torch.nn.Conv2d(4, 4, 3), torch.nn.ReLU()) + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + logger = logging.get_logger("diffusers.quantizers.bitsandbytes.utils") + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + _ = replace_with_bnb_linear(model_with_no_linear, quantization_config=quantization_config) + assert ( + "You are loading your model in 8bit or 4bit but no linear modules were found in your model." + in cap_logger.out + ) + class Bnb8bitDeviceTests(Base8bitTests): def setUp(self) -> None: