mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -18,6 +18,7 @@ from typing import Tuple, Union
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
from ._helpers import TransformerBlockRegistry
|
||||
from .hooks import BaseMarkedState, HookRegistry, ModelHook
|
||||
@@ -71,7 +72,7 @@ class FBCHeadBlockHook(ModelHook):
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self._metadata = TransformerBlockRegistry.get(module.__class__)
|
||||
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
@@ -147,7 +148,7 @@ class FBCBlockHook(ModelHook):
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self._metadata = TransformerBlockRegistry.get(module.__class__)
|
||||
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import Any, Dict, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from ..utils.logging import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -47,7 +48,7 @@ class BaseMarkedState(BaseState):
|
||||
self._state_cache[self._mark_name] = self.__class__(*self._init_args, **self._init_kwargs)
|
||||
return self._state_cache[self._mark_name]
|
||||
|
||||
def mark_batch(self, name: str) -> None:
|
||||
def mark_state(self, name: str) -> None:
|
||||
self._mark_name = name
|
||||
|
||||
def reset(self, *args, **kwargs) -> None:
|
||||
@@ -59,7 +60,7 @@ class BaseMarkedState(BaseState):
|
||||
def __getattribute__(self, name):
|
||||
if name in (
|
||||
"get_current_state",
|
||||
"mark_batch",
|
||||
"mark_state",
|
||||
"reset",
|
||||
"_init_args",
|
||||
"_init_kwargs",
|
||||
@@ -74,7 +75,7 @@ class BaseMarkedState(BaseState):
|
||||
def __setattr__(self, name, value):
|
||||
if name in (
|
||||
"get_current_state",
|
||||
"mark_batch",
|
||||
"mark_state",
|
||||
"reset",
|
||||
"_init_args",
|
||||
"_init_kwargs",
|
||||
@@ -164,11 +165,11 @@ class ModelHook:
|
||||
return module
|
||||
|
||||
def _mark_state(self, module: torch.nn.Module, name: str) -> None:
|
||||
# Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_batch` on them.
|
||||
# Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_state` on them.
|
||||
for attr_name in dir(self):
|
||||
attr = getattr(self, attr_name)
|
||||
if isinstance(attr, BaseMarkedState):
|
||||
attr.mark_batch(name)
|
||||
attr.mark_state(name)
|
||||
return module
|
||||
|
||||
|
||||
@@ -283,9 +284,10 @@ class HookRegistry:
|
||||
hook.reset_state(self._module_ref)
|
||||
|
||||
if recurse:
|
||||
for module_name, module in self._module_ref.named_modules():
|
||||
for module_name, module in unwrap_module(self._module_ref).named_modules():
|
||||
if module_name == "":
|
||||
continue
|
||||
module = unwrap_module(module)
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook.reset_stateful_hooks(recurse=False)
|
||||
|
||||
@@ -301,9 +303,10 @@ class HookRegistry:
|
||||
if hook._is_stateful:
|
||||
hook._mark_state(self._module_ref, name)
|
||||
|
||||
for module_name, module in self._module_ref.named_modules():
|
||||
for module_name, module in unwrap_module(self._module_ref).named_modules():
|
||||
if module_name == "":
|
||||
continue
|
||||
module = unwrap_module(module)
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook._mark_state(name)
|
||||
|
||||
|
||||
@@ -90,6 +90,11 @@ def is_compiled_module(module) -> bool:
|
||||
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
|
||||
|
||||
|
||||
def unwrap_module(module):
|
||||
"""Unwraps a module if it was compiled with torch.compile()"""
|
||||
return module._orig_mod if is_compiled_module(module) else module
|
||||
|
||||
|
||||
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
|
||||
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
|
||||
|
||||
|
||||
Reference in New Issue
Block a user