1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Disable PEFT input autocast when using fp8 layerwise casting (#10685)

* disable peft input autocast

* use new peft method name; only disable peft input autocast if submodule layerwise casting active

* add test; reference PeftInputAutocastDisableHook in peft docs

* add load_lora_weights test

* casted -> cast

* Update tests/lora/utils.py
This commit is contained in:
Aryan
2025-02-13 23:12:54 +05:30
committed by GitHub
parent 97abdd2210
commit a0c22997fd
3 changed files with 151 additions and 2 deletions

View File

@@ -221,3 +221,7 @@ pipe.delete_adapters("toy")
pipe.get_active_adapters()
["pixel"]
```
## PeftInputAutocastDisableHook
[[autodoc]] hooks.layerwise_casting.PeftInputAutocastDisableHook

View File

@@ -17,7 +17,7 @@ from typing import Optional, Tuple, Type, Union
import torch
from ..utils import get_logger
from ..utils import get_logger, is_peft_available, is_peft_version
from .hooks import HookRegistry, ModelHook
@@ -25,6 +25,8 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
# fmt: off
_LAYERWISE_CASTING_HOOK = "layerwise_casting"
_PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable"
SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
@@ -34,6 +36,11 @@ SUPPORTED_PYTORCH_LAYERS = (
DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
# fmt: on
_SHOULD_DISABLE_PEFT_INPUT_AUTOCAST = is_peft_available() and is_peft_version(">", "0.14.0")
if _SHOULD_DISABLE_PEFT_INPUT_AUTOCAST:
from peft.helpers import disable_input_dtype_casting
from peft.tuners.tuners_utils import BaseTunerLayer
class LayerwiseCastingHook(ModelHook):
r"""
@@ -70,6 +77,32 @@ class LayerwiseCastingHook(ModelHook):
return output
class PeftInputAutocastDisableHook(ModelHook):
r"""
A hook that disables the casting of inputs to the module weight dtype during the forward pass. By default, PEFT
casts the inputs to the weight dtype of the module, which can lead to precision loss.
The reasons for needing this are:
- If we don't add PEFT layers' weight names to `skip_modules_pattern` when applying layerwise casting, the
inputs will be casted to the, possibly lower precision, storage dtype. Reference:
https://github.com/huggingface/peft/blob/0facdebf6208139cbd8f3586875acb378813dd97/src/peft/tuners/lora/layer.py#L706
- We can, on our end, use something like accelerate's `send_to_device` but for dtypes. This way, we can ensure
that the inputs are casted to the computation dtype correctly always. However, there are two goals we are
hoping to achieve:
1. Making forward implementations independent of device/dtype casting operations as much as possible.
2. Peforming inference without losing information from casting to different precisions. With the current
PEFT implementation (as linked in the reference above), and assuming running layerwise casting inference
with storage_dtype=torch.float8_e4m3fn and compute_dtype=torch.bfloat16, inputs are cast to
torch.float8_e4m3fn in the lora layer. We will then upcast back to torch.bfloat16 when we continue the
forward pass in PEFT linear forward or Diffusers layer forward, with a `send_to_dtype` operation from
LayerwiseCastingHook. This will be a lossy operation and result in poorer generation quality.
"""
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
with disable_input_dtype_casting(module):
return self.fn_ref.original_forward(*args, **kwargs)
def apply_layerwise_casting(
module: torch.nn.Module,
storage_dtype: torch.dtype,
@@ -134,6 +167,7 @@ def apply_layerwise_casting(
skip_modules_classes,
non_blocking,
)
_disable_peft_input_autocast(module)
def _apply_layerwise_casting(
@@ -188,4 +222,24 @@ def apply_layerwise_casting_hook(
"""
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = LayerwiseCastingHook(storage_dtype, compute_dtype, non_blocking)
registry.register_hook(hook, "layerwise_casting")
registry.register_hook(hook, _LAYERWISE_CASTING_HOOK)
def _is_layerwise_casting_active(module: torch.nn.Module) -> bool:
for submodule in module.modules():
if (
hasattr(submodule, "_diffusers_hook")
and submodule._diffusers_hook.get_hook(_LAYERWISE_CASTING_HOOK) is not None
):
return True
return False
def _disable_peft_input_autocast(module: torch.nn.Module) -> None:
if not _SHOULD_DISABLE_PEFT_INPUT_AUTOCAST:
return
for submodule in module.modules():
if isinstance(submodule, BaseTunerLayer) and _is_layerwise_casting_active(submodule):
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = PeftInputAutocastDisableHook()
registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)

View File

@@ -2157,3 +2157,94 @@ class PeftLoraLoaderMixinTests:
pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0]
@require_peft_version_greater("0.14.0")
def test_layerwise_casting_peft_input_autocast_denoiser(self):
r"""
A test that checks if layerwise casting works correctly with PEFT layers and forward pass does not fail. This
is different from `test_layerwise_casting_inference_denoiser` as that disables the application of layerwise
cast hooks on the PEFT layers (relevant logic in `models.modeling_utils.ModelMixin.enable_layerwise_casting`).
In this test, we enable the layerwise casting on the PEFT layers as well. If run with PEFT version <= 0.14.0,
this test will fail with the following error:
```
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Float8_e4m3fn != float
```
See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details.
"""
from diffusers.hooks.layerwise_casting import (
_PEFT_AUTOCAST_DISABLE_HOOK,
DEFAULT_SKIP_MODULES_PATTERN,
SUPPORTED_PYTORCH_LAYERS,
apply_layerwise_casting,
)
storage_dtype = torch.float8_e4m3fn
compute_dtype = torch.float32
def check_module(denoiser):
# This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser)
for name, module in denoiser.named_modules():
if not isinstance(module, SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if any(re.search(pattern, name) for pattern in patterns_to_check):
dtype_to_check = compute_dtype
if getattr(module, "weight", None) is not None:
self.assertEqual(module.weight.dtype, dtype_to_check)
if getattr(module, "bias", None) is not None:
self.assertEqual(module.bias.dtype, dtype_to_check)
if isinstance(module, BaseTunerLayer):
self.assertTrue(getattr(module, "_diffusers_hook", None) is not None)
self.assertTrue(module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None)
# 1. Test forward with add_adapter
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
if getattr(denoiser, "_skip_layerwise_casting_patterns", None) is not None:
patterns_to_check += tuple(denoiser._skip_layerwise_casting_patterns)
apply_layerwise_casting(
denoiser, storage_dtype=storage_dtype, compute_dtype=compute_dtype, skip_modules_pattern=patterns_to_check
)
check_module(denoiser)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe(**inputs, generator=torch.manual_seed(0))[0]
# 2. Test forward with load_lora_weights
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
apply_layerwise_casting(
denoiser,
storage_dtype=storage_dtype,
compute_dtype=compute_dtype,
skip_modules_pattern=patterns_to_check,
)
check_module(denoiser)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe(**inputs, generator=torch.manual_seed(0))[0]