diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index ba5b93605f..482c6d0103 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1360,12 +1360,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): # Checks if the model has been loaded in 4-bit or 8-bit with BNB if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: - if getattr(self, "is_loaded_in_8bit", False): + if getattr(self, "is_loaded_in_8bit", False) and is_bitsandbytes_version("<", "0.48.0"): raise ValueError( - "Calling `cuda()` is not supported for `8-bit` quantized models. " - " Please use the model as it is, since the model has already been set to the correct devices." + "Calling `cuda()` is not supported for `8-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.48.0." ) - elif is_bitsandbytes_version("<", "0.43.2"): + elif getattr(self, "is_loaded_in_4bit", False) and is_bitsandbytes_version("<", "0.43.2"): raise ValueError( "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." @@ -1412,17 +1412,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ) if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: - if getattr(self, "is_loaded_in_8bit", False): + if getattr(self, "is_loaded_in_8bit", False) and is_bitsandbytes_version("<", "0.48.0"): raise ValueError( - "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the" - " model has already been set to the correct devices and casted to the correct `dtype`." + "Calling `to()` is not supported for `8-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.48.0." ) - elif is_bitsandbytes_version("<", "0.43.2"): + elif getattr(self, "is_loaded_in_4bit", False) and is_bitsandbytes_version("<", "0.43.2"): raise ValueError( "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." ) - if _is_group_offload_enabled(self) and device_arg_or_kwarg_present: logger.warning( f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported." diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 5c4ac8a655..c4f118d7e1 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -60,6 +60,7 @@ from ..utils import ( deprecate, is_accelerate_available, is_accelerate_version, + is_bitsandbytes_version, is_hpu_available, is_torch_npu_available, is_torch_version, @@ -444,7 +445,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): _, _, is_loaded_in_8bit_bnb = _check_bnb_status(module) - if is_loaded_in_8bit_bnb: + # https://github.com/huggingface/accelerate/pull/3907 + if is_loaded_in_8bit_bnb and ( + is_bitsandbytes_version("<", "0.48.0") or is_accelerate_version("<", "1.13.0.dev0") + ): return False return hasattr(module, "_hf_hook") and ( @@ -523,9 +527,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision." ) - if is_loaded_in_8bit_bnb and device is not None: + if is_loaded_in_8bit_bnb and device is not None and is_bitsandbytes_version("<", "0.48.0"): logger.warning( f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}." + "You need to upgrade bitsandbytes to at least 0.48.0" ) # Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling @@ -542,6 +547,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly. if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"): module.to(device=device) + # added here https://github.com/huggingface/transformers/pull/43258 + if ( + is_loaded_in_8bit_bnb + and device is not None + and is_transformers_version(">", "4.58.0") + and is_bitsandbytes_version(">=", "0.48.0") + ): + module.to(device=device) elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded: module.to(device, dtype) @@ -1223,7 +1236,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): # This is because the model would already be placed on a CUDA device. _, _, is_loaded_in_8bit_bnb = _check_bnb_status(model) - if is_loaded_in_8bit_bnb: + if is_loaded_in_8bit_bnb and ( + is_transformers_version("<", "4.58.0") or is_bitsandbytes_version("<", "0.48.0") + ): logger.info( f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit." ) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index fde3966dec..031fdc9f9e 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -288,31 +288,29 @@ class BnB8bitBasicTests(Base8bitTests): self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) self.assertTrue(hasattr(linear.weight, "SCB")) + @require_bitsandbytes_version_greater("0.48.0") def test_device_and_dtype_assignment(self): r""" Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error. Checks also if other models are casted correctly. """ - with self.assertRaises(ValueError): - # Tries with `str` - self.model_8bit.to("cpu") with self.assertRaises(ValueError): # Tries with a `dtype`` self.model_8bit.to(torch.float16) - with self.assertRaises(ValueError): - # Tries with a `device` - self.model_8bit.to(torch.device(f"{torch_device}:0")) - with self.assertRaises(ValueError): # Tries with a `device` self.model_8bit.float() with self.assertRaises(ValueError): - # Tries with a `device` + # Tries with a `dtype` self.model_8bit.half() + # This should work with 0.48.0 + self.model_8bit.to("cpu") + self.model_8bit.to(torch.device(f"{torch_device}:0")) + # Test if we did not break anything self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device) input_dict_for_transformer = self.get_dummy_inputs() @@ -837,7 +835,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests): @require_torch_version_greater_equal("2.6.0") -@require_bitsandbytes_version_greater("0.45.5") +@require_bitsandbytes_version_greater("0.48.0") class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase): @property def quantization_config(self): @@ -848,7 +846,7 @@ class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase): ) @pytest.mark.xfail( - reason="Test fails because of an offloading problem from Accelerate with confusion in hooks." + reason="Test fails because of a type change when recompiling." " Test passes without recompilation context manager. Refer to https://github.com/huggingface/diffusers/pull/12002/files#r2240462757 for details." ) def test_torch_compile(self): @@ -858,6 +856,5 @@ class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase): def test_torch_compile_with_cpu_offload(self): super()._test_torch_compile_with_cpu_offload(torch_dtype=torch.float16) - @pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.") def test_torch_compile_with_group_offload_leaf(self): super()._test_torch_compile_with_group_offload_leaf(torch_dtype=torch.float16, use_stream=True)