1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Aryan
2025-04-02 21:52:33 +02:00
parent 315e357a18
commit c76e1cc17e
3 changed files with 18 additions and 9 deletions

View File

@@ -18,6 +18,7 @@ from typing import Tuple, Union
import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
from ._helpers import TransformerBlockRegistry
from .hooks import BaseMarkedState, HookRegistry, ModelHook
@@ -71,7 +72,7 @@ class FBCHeadBlockHook(ModelHook):
self._metadata = None
def initialize_hook(self, module):
self._metadata = TransformerBlockRegistry.get(module.__class__)
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
@@ -147,7 +148,7 @@ class FBCBlockHook(ModelHook):
self._metadata = None
def initialize_hook(self, module):
self._metadata = TransformerBlockRegistry.get(module.__class__)
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):

View File

@@ -18,6 +18,7 @@ from typing import Any, Dict, Optional, Tuple
import torch
from ..utils.logging import get_logger
from ..utils.torch_utils import unwrap_module
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -47,7 +48,7 @@ class BaseMarkedState(BaseState):
self._state_cache[self._mark_name] = self.__class__(*self._init_args, **self._init_kwargs)
return self._state_cache[self._mark_name]
def mark_batch(self, name: str) -> None:
def mark_state(self, name: str) -> None:
self._mark_name = name
def reset(self, *args, **kwargs) -> None:
@@ -59,7 +60,7 @@ class BaseMarkedState(BaseState):
def __getattribute__(self, name):
if name in (
"get_current_state",
"mark_batch",
"mark_state",
"reset",
"_init_args",
"_init_kwargs",
@@ -74,7 +75,7 @@ class BaseMarkedState(BaseState):
def __setattr__(self, name, value):
if name in (
"get_current_state",
"mark_batch",
"mark_state",
"reset",
"_init_args",
"_init_kwargs",
@@ -164,11 +165,11 @@ class ModelHook:
return module
def _mark_state(self, module: torch.nn.Module, name: str) -> None:
# Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_batch` on them.
# Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_state` on them.
for attr_name in dir(self):
attr = getattr(self, attr_name)
if isinstance(attr, BaseMarkedState):
attr.mark_batch(name)
attr.mark_state(name)
return module
@@ -283,9 +284,10 @@ class HookRegistry:
hook.reset_state(self._module_ref)
if recurse:
for module_name, module in self._module_ref.named_modules():
for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "":
continue
module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.reset_stateful_hooks(recurse=False)
@@ -301,9 +303,10 @@ class HookRegistry:
if hook._is_stateful:
hook._mark_state(self._module_ref, name)
for module_name, module in self._module_ref.named_modules():
for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "":
continue
module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook._mark_state(name)

View File

@@ -90,6 +90,11 @@ def is_compiled_module(module) -> bool:
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
def unwrap_module(module):
"""Unwraps a module if it was compiled with torch.compile()"""
return module._orig_mod if is_compiled_module(module) else module
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).