From a0c22997fd45770fffd9b454625e9ab525fa2b16 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 13 Feb 2025 23:12:54 +0530 Subject: [PATCH] Disable PEFT input autocast when using fp8 layerwise casting (#10685) * disable peft input autocast * use new peft method name; only disable peft input autocast if submodule layerwise casting active * add test; reference PeftInputAutocastDisableHook in peft docs * add load_lora_weights test * casted -> cast * Update tests/lora/utils.py --- .../en/tutorials/using_peft_for_inference.md | 4 + src/diffusers/hooks/layerwise_casting.py | 58 +++++++++++- tests/lora/utils.py | 91 +++++++++++++++++++ 3 files changed, 151 insertions(+), 2 deletions(-) diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md index 9cf8a73395..33414a331e 100644 --- a/docs/source/en/tutorials/using_peft_for_inference.md +++ b/docs/source/en/tutorials/using_peft_for_inference.md @@ -221,3 +221,7 @@ pipe.delete_adapters("toy") pipe.get_active_adapters() ["pixel"] ``` + +## PeftInputAutocastDisableHook + +[[autodoc]] hooks.layerwise_casting.PeftInputAutocastDisableHook diff --git a/src/diffusers/hooks/layerwise_casting.py b/src/diffusers/hooks/layerwise_casting.py index 038625e21f..6f2cfdc348 100644 --- a/src/diffusers/hooks/layerwise_casting.py +++ b/src/diffusers/hooks/layerwise_casting.py @@ -17,7 +17,7 @@ from typing import Optional, Tuple, Type, Union import torch -from ..utils import get_logger +from ..utils import get_logger, is_peft_available, is_peft_version from .hooks import HookRegistry, ModelHook @@ -25,6 +25,8 @@ logger = get_logger(__name__) # pylint: disable=invalid-name # fmt: off +_LAYERWISE_CASTING_HOOK = "layerwise_casting" +_PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable" SUPPORTED_PYTORCH_LAYERS = ( torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, @@ -34,6 +36,11 @@ SUPPORTED_PYTORCH_LAYERS = ( DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$") # fmt: on +_SHOULD_DISABLE_PEFT_INPUT_AUTOCAST = is_peft_available() and is_peft_version(">", "0.14.0") +if _SHOULD_DISABLE_PEFT_INPUT_AUTOCAST: + from peft.helpers import disable_input_dtype_casting + from peft.tuners.tuners_utils import BaseTunerLayer + class LayerwiseCastingHook(ModelHook): r""" @@ -70,6 +77,32 @@ class LayerwiseCastingHook(ModelHook): return output +class PeftInputAutocastDisableHook(ModelHook): + r""" + A hook that disables the casting of inputs to the module weight dtype during the forward pass. By default, PEFT + casts the inputs to the weight dtype of the module, which can lead to precision loss. + + The reasons for needing this are: + - If we don't add PEFT layers' weight names to `skip_modules_pattern` when applying layerwise casting, the + inputs will be casted to the, possibly lower precision, storage dtype. Reference: + https://github.com/huggingface/peft/blob/0facdebf6208139cbd8f3586875acb378813dd97/src/peft/tuners/lora/layer.py#L706 + - We can, on our end, use something like accelerate's `send_to_device` but for dtypes. This way, we can ensure + that the inputs are casted to the computation dtype correctly always. However, there are two goals we are + hoping to achieve: + 1. Making forward implementations independent of device/dtype casting operations as much as possible. + 2. Peforming inference without losing information from casting to different precisions. With the current + PEFT implementation (as linked in the reference above), and assuming running layerwise casting inference + with storage_dtype=torch.float8_e4m3fn and compute_dtype=torch.bfloat16, inputs are cast to + torch.float8_e4m3fn in the lora layer. We will then upcast back to torch.bfloat16 when we continue the + forward pass in PEFT linear forward or Diffusers layer forward, with a `send_to_dtype` operation from + LayerwiseCastingHook. This will be a lossy operation and result in poorer generation quality. + """ + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + with disable_input_dtype_casting(module): + return self.fn_ref.original_forward(*args, **kwargs) + + def apply_layerwise_casting( module: torch.nn.Module, storage_dtype: torch.dtype, @@ -134,6 +167,7 @@ def apply_layerwise_casting( skip_modules_classes, non_blocking, ) + _disable_peft_input_autocast(module) def _apply_layerwise_casting( @@ -188,4 +222,24 @@ def apply_layerwise_casting_hook( """ registry = HookRegistry.check_if_exists_or_initialize(module) hook = LayerwiseCastingHook(storage_dtype, compute_dtype, non_blocking) - registry.register_hook(hook, "layerwise_casting") + registry.register_hook(hook, _LAYERWISE_CASTING_HOOK) + + +def _is_layerwise_casting_active(module: torch.nn.Module) -> bool: + for submodule in module.modules(): + if ( + hasattr(submodule, "_diffusers_hook") + and submodule._diffusers_hook.get_hook(_LAYERWISE_CASTING_HOOK) is not None + ): + return True + return False + + +def _disable_peft_input_autocast(module: torch.nn.Module) -> None: + if not _SHOULD_DISABLE_PEFT_INPUT_AUTOCAST: + return + for submodule in module.modules(): + if isinstance(submodule, BaseTunerLayer) and _is_layerwise_casting_active(submodule): + registry = HookRegistry.check_if_exists_or_initialize(submodule) + hook = PeftInputAutocastDisableHook() + registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index d0d39d05b0..b56d729207 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -2157,3 +2157,94 @@ class PeftLoraLoaderMixinTests: pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0] + + @require_peft_version_greater("0.14.0") + def test_layerwise_casting_peft_input_autocast_denoiser(self): + r""" + A test that checks if layerwise casting works correctly with PEFT layers and forward pass does not fail. This + is different from `test_layerwise_casting_inference_denoiser` as that disables the application of layerwise + cast hooks on the PEFT layers (relevant logic in `models.modeling_utils.ModelMixin.enable_layerwise_casting`). + In this test, we enable the layerwise casting on the PEFT layers as well. If run with PEFT version <= 0.14.0, + this test will fail with the following error: + + ``` + RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Float8_e4m3fn != float + ``` + + See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details. + """ + + from diffusers.hooks.layerwise_casting import ( + _PEFT_AUTOCAST_DISABLE_HOOK, + DEFAULT_SKIP_MODULES_PATTERN, + SUPPORTED_PYTORCH_LAYERS, + apply_layerwise_casting, + ) + + storage_dtype = torch.float8_e4m3fn + compute_dtype = torch.float32 + + def check_module(denoiser): + # This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser) + for name, module in denoiser.named_modules(): + if not isinstance(module, SUPPORTED_PYTORCH_LAYERS): + continue + dtype_to_check = storage_dtype + if any(re.search(pattern, name) for pattern in patterns_to_check): + dtype_to_check = compute_dtype + if getattr(module, "weight", None) is not None: + self.assertEqual(module.weight.dtype, dtype_to_check) + if getattr(module, "bias", None) is not None: + self.assertEqual(module.bias.dtype, dtype_to_check) + if isinstance(module, BaseTunerLayer): + self.assertTrue(getattr(module, "_diffusers_hook", None) is not None) + self.assertTrue(module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None) + + # 1. Test forward with add_adapter + components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device, dtype=compute_dtype) + pipe.set_progress_bar_config(disable=None) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + + patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN + if getattr(denoiser, "_skip_layerwise_casting_patterns", None) is not None: + patterns_to_check += tuple(denoiser._skip_layerwise_casting_patterns) + + apply_layerwise_casting( + denoiser, storage_dtype=storage_dtype, compute_dtype=compute_dtype, skip_modules_pattern=patterns_to_check + ) + check_module(denoiser) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe(**inputs, generator=torch.manual_seed(0))[0] + + # 2. Test forward with load_lora_weights + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts + ) + + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + components, _, _ = self.get_dummy_components(self.scheduler_classes[0]) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device, dtype=compute_dtype) + pipe.set_progress_bar_config(disable=None) + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + apply_layerwise_casting( + denoiser, + storage_dtype=storage_dtype, + compute_dtype=compute_dtype, + skip_modules_pattern=patterns_to_check, + ) + check_module(denoiser) + + _, _, inputs = self.get_dummy_inputs(with_generator=False) + pipe(**inputs, generator=torch.manual_seed(0))[0]