diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index 1293ded558..7863a12688 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -18,6 +18,7 @@ from typing import Tuple, Union import torch from ..utils import get_logger +from ..utils.torch_utils import unwrap_module from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS from ._helpers import TransformerBlockRegistry from .hooks import BaseMarkedState, HookRegistry, ModelHook @@ -71,7 +72,7 @@ class FBCHeadBlockHook(ModelHook): self._metadata = None def initialize_hook(self, module): - self._metadata = TransformerBlockRegistry.get(module.__class__) + self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): @@ -147,7 +148,7 @@ class FBCBlockHook(ModelHook): self._metadata = None def initialize_hook(self, module): - self._metadata = TransformerBlockRegistry.get(module.__class__) + self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__) return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 9e8128d0bb..c42592783d 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -18,6 +18,7 @@ from typing import Any, Dict, Optional, Tuple import torch from ..utils.logging import get_logger +from ..utils.torch_utils import unwrap_module logger = get_logger(__name__) # pylint: disable=invalid-name @@ -47,7 +48,7 @@ class BaseMarkedState(BaseState): self._state_cache[self._mark_name] = self.__class__(*self._init_args, **self._init_kwargs) return self._state_cache[self._mark_name] - def mark_batch(self, name: str) -> None: + def mark_state(self, name: str) -> None: self._mark_name = name def reset(self, *args, **kwargs) -> None: @@ -59,7 +60,7 @@ class BaseMarkedState(BaseState): def __getattribute__(self, name): if name in ( "get_current_state", - "mark_batch", + "mark_state", "reset", "_init_args", "_init_kwargs", @@ -74,7 +75,7 @@ class BaseMarkedState(BaseState): def __setattr__(self, name, value): if name in ( "get_current_state", - "mark_batch", + "mark_state", "reset", "_init_args", "_init_kwargs", @@ -164,11 +165,11 @@ class ModelHook: return module def _mark_state(self, module: torch.nn.Module, name: str) -> None: - # Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_batch` on them. + # Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_state` on them. for attr_name in dir(self): attr = getattr(self, attr_name) if isinstance(attr, BaseMarkedState): - attr.mark_batch(name) + attr.mark_state(name) return module @@ -283,9 +284,10 @@ class HookRegistry: hook.reset_state(self._module_ref) if recurse: - for module_name, module in self._module_ref.named_modules(): + for module_name, module in unwrap_module(self._module_ref).named_modules(): if module_name == "": continue + module = unwrap_module(module) if hasattr(module, "_diffusers_hook"): module._diffusers_hook.reset_stateful_hooks(recurse=False) @@ -301,9 +303,10 @@ class HookRegistry: if hook._is_stateful: hook._mark_state(self._module_ref, name) - for module_name, module in self._module_ref.named_modules(): + for module_name, module in unwrap_module(self._module_ref).named_modules(): if module_name == "": continue + module = unwrap_module(module) if hasattr(module, "_diffusers_hook"): module._diffusers_hook._mark_state(name) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 3c8911773e..06f9981f01 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -90,6 +90,11 @@ def is_compiled_module(module) -> bool: return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) +def unwrap_module(module): + """Unwraps a module if it was compiled with torch.compile()""" + return module._orig_mod if is_compiled_module(module) else module + + def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).