mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Disable PEFT input autocast when using fp8 layerwise casting (#10685)
* disable peft input autocast * use new peft method name; only disable peft input autocast if submodule layerwise casting active * add test; reference PeftInputAutocastDisableHook in peft docs * add load_lora_weights test * casted -> cast * Update tests/lora/utils.py
This commit is contained in:
@@ -221,3 +221,7 @@ pipe.delete_adapters("toy")
|
||||
pipe.get_active_adapters()
|
||||
["pixel"]
|
||||
```
|
||||
|
||||
## PeftInputAutocastDisableHook
|
||||
|
||||
[[autodoc]] hooks.layerwise_casting.PeftInputAutocastDisableHook
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ..utils import get_logger, is_peft_available, is_peft_version
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
@@ -25,6 +25,8 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# fmt: off
|
||||
_LAYERWISE_CASTING_HOOK = "layerwise_casting"
|
||||
_PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable"
|
||||
SUPPORTED_PYTORCH_LAYERS = (
|
||||
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
|
||||
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
|
||||
@@ -34,6 +36,11 @@ SUPPORTED_PYTORCH_LAYERS = (
|
||||
DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
|
||||
# fmt: on
|
||||
|
||||
_SHOULD_DISABLE_PEFT_INPUT_AUTOCAST = is_peft_available() and is_peft_version(">", "0.14.0")
|
||||
if _SHOULD_DISABLE_PEFT_INPUT_AUTOCAST:
|
||||
from peft.helpers import disable_input_dtype_casting
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
|
||||
class LayerwiseCastingHook(ModelHook):
|
||||
r"""
|
||||
@@ -70,6 +77,32 @@ class LayerwiseCastingHook(ModelHook):
|
||||
return output
|
||||
|
||||
|
||||
class PeftInputAutocastDisableHook(ModelHook):
|
||||
r"""
|
||||
A hook that disables the casting of inputs to the module weight dtype during the forward pass. By default, PEFT
|
||||
casts the inputs to the weight dtype of the module, which can lead to precision loss.
|
||||
|
||||
The reasons for needing this are:
|
||||
- If we don't add PEFT layers' weight names to `skip_modules_pattern` when applying layerwise casting, the
|
||||
inputs will be casted to the, possibly lower precision, storage dtype. Reference:
|
||||
https://github.com/huggingface/peft/blob/0facdebf6208139cbd8f3586875acb378813dd97/src/peft/tuners/lora/layer.py#L706
|
||||
- We can, on our end, use something like accelerate's `send_to_device` but for dtypes. This way, we can ensure
|
||||
that the inputs are casted to the computation dtype correctly always. However, there are two goals we are
|
||||
hoping to achieve:
|
||||
1. Making forward implementations independent of device/dtype casting operations as much as possible.
|
||||
2. Peforming inference without losing information from casting to different precisions. With the current
|
||||
PEFT implementation (as linked in the reference above), and assuming running layerwise casting inference
|
||||
with storage_dtype=torch.float8_e4m3fn and compute_dtype=torch.bfloat16, inputs are cast to
|
||||
torch.float8_e4m3fn in the lora layer. We will then upcast back to torch.bfloat16 when we continue the
|
||||
forward pass in PEFT linear forward or Diffusers layer forward, with a `send_to_dtype` operation from
|
||||
LayerwiseCastingHook. This will be a lossy operation and result in poorer generation quality.
|
||||
"""
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
with disable_input_dtype_casting(module):
|
||||
return self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
|
||||
def apply_layerwise_casting(
|
||||
module: torch.nn.Module,
|
||||
storage_dtype: torch.dtype,
|
||||
@@ -134,6 +167,7 @@ def apply_layerwise_casting(
|
||||
skip_modules_classes,
|
||||
non_blocking,
|
||||
)
|
||||
_disable_peft_input_autocast(module)
|
||||
|
||||
|
||||
def _apply_layerwise_casting(
|
||||
@@ -188,4 +222,24 @@ def apply_layerwise_casting_hook(
|
||||
"""
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
hook = LayerwiseCastingHook(storage_dtype, compute_dtype, non_blocking)
|
||||
registry.register_hook(hook, "layerwise_casting")
|
||||
registry.register_hook(hook, _LAYERWISE_CASTING_HOOK)
|
||||
|
||||
|
||||
def _is_layerwise_casting_active(module: torch.nn.Module) -> bool:
|
||||
for submodule in module.modules():
|
||||
if (
|
||||
hasattr(submodule, "_diffusers_hook")
|
||||
and submodule._diffusers_hook.get_hook(_LAYERWISE_CASTING_HOOK) is not None
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _disable_peft_input_autocast(module: torch.nn.Module) -> None:
|
||||
if not _SHOULD_DISABLE_PEFT_INPUT_AUTOCAST:
|
||||
return
|
||||
for submodule in module.modules():
|
||||
if isinstance(submodule, BaseTunerLayer) and _is_layerwise_casting_active(submodule):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||||
hook = PeftInputAutocastDisableHook()
|
||||
registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)
|
||||
|
||||
@@ -2157,3 +2157,94 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
|
||||
pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@require_peft_version_greater("0.14.0")
|
||||
def test_layerwise_casting_peft_input_autocast_denoiser(self):
|
||||
r"""
|
||||
A test that checks if layerwise casting works correctly with PEFT layers and forward pass does not fail. This
|
||||
is different from `test_layerwise_casting_inference_denoiser` as that disables the application of layerwise
|
||||
cast hooks on the PEFT layers (relevant logic in `models.modeling_utils.ModelMixin.enable_layerwise_casting`).
|
||||
In this test, we enable the layerwise casting on the PEFT layers as well. If run with PEFT version <= 0.14.0,
|
||||
this test will fail with the following error:
|
||||
|
||||
```
|
||||
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Float8_e4m3fn != float
|
||||
```
|
||||
|
||||
See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details.
|
||||
"""
|
||||
|
||||
from diffusers.hooks.layerwise_casting import (
|
||||
_PEFT_AUTOCAST_DISABLE_HOOK,
|
||||
DEFAULT_SKIP_MODULES_PATTERN,
|
||||
SUPPORTED_PYTORCH_LAYERS,
|
||||
apply_layerwise_casting,
|
||||
)
|
||||
|
||||
storage_dtype = torch.float8_e4m3fn
|
||||
compute_dtype = torch.float32
|
||||
|
||||
def check_module(denoiser):
|
||||
# This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser)
|
||||
for name, module in denoiser.named_modules():
|
||||
if not isinstance(module, SUPPORTED_PYTORCH_LAYERS):
|
||||
continue
|
||||
dtype_to_check = storage_dtype
|
||||
if any(re.search(pattern, name) for pattern in patterns_to_check):
|
||||
dtype_to_check = compute_dtype
|
||||
if getattr(module, "weight", None) is not None:
|
||||
self.assertEqual(module.weight.dtype, dtype_to_check)
|
||||
if getattr(module, "bias", None) is not None:
|
||||
self.assertEqual(module.bias.dtype, dtype_to_check)
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
self.assertTrue(getattr(module, "_diffusers_hook", None) is not None)
|
||||
self.assertTrue(module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None)
|
||||
|
||||
# 1. Test forward with add_adapter
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device, dtype=compute_dtype)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
|
||||
if getattr(denoiser, "_skip_layerwise_casting_patterns", None) is not None:
|
||||
patterns_to_check += tuple(denoiser._skip_layerwise_casting_patterns)
|
||||
|
||||
apply_layerwise_casting(
|
||||
denoiser, storage_dtype=storage_dtype, compute_dtype=compute_dtype, skip_modules_pattern=patterns_to_check
|
||||
)
|
||||
check_module(denoiser)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
# 2. Test forward with load_lora_weights
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
|
||||
)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device, dtype=compute_dtype)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
apply_layerwise_casting(
|
||||
denoiser,
|
||||
storage_dtype=storage_dtype,
|
||||
compute_dtype=compute_dtype,
|
||||
skip_modules_pattern=patterns_to_check,
|
||||
)
|
||||
check_module(denoiser)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
Reference in New Issue
Block a user