mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
address review comments
This commit is contained in:
@@ -21,7 +21,7 @@ from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
from ._helpers import TransformerBlockRegistry
|
||||
from .hooks import ContextAwareState, HookRegistry, ModelHook
|
||||
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -49,7 +49,7 @@ class FirstBlockCacheConfig:
|
||||
threshold: float = 0.05
|
||||
|
||||
|
||||
class FBCSharedBlockState(ContextAwareState):
|
||||
class FBCSharedBlockState(BaseState):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -66,8 +66,8 @@ class FBCSharedBlockState(ContextAwareState):
|
||||
class FBCHeadBlockHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(self, shared_state: FBCSharedBlockState, threshold: float):
|
||||
self.shared_state = shared_state
|
||||
def __init__(self, state_manager: StateManager, threshold: float):
|
||||
self.state_manager = state_manager
|
||||
self.threshold = threshold
|
||||
self._metadata = None
|
||||
|
||||
@@ -91,24 +91,24 @@ class FBCHeadBlockHook(ModelHook):
|
||||
else:
|
||||
hidden_states_residual = output - original_hidden_states
|
||||
|
||||
shared_state: FBCSharedBlockState = self.state_manager.get_state()
|
||||
hidden_states, encoder_hidden_states = None, None
|
||||
should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
|
||||
self.shared_state.should_compute = should_compute
|
||||
shared_state.should_compute = should_compute
|
||||
|
||||
if not should_compute:
|
||||
# Apply caching
|
||||
if is_output_tuple:
|
||||
hidden_states = (
|
||||
self.shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
|
||||
shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
|
||||
)
|
||||
else:
|
||||
hidden_states = self.shared_state.tail_block_residuals[0] + output
|
||||
hidden_states = shared_state.tail_block_residuals[0] + output
|
||||
|
||||
if self._metadata.return_encoder_hidden_states_index is not None:
|
||||
assert is_output_tuple
|
||||
encoder_hidden_states = (
|
||||
self.shared_state.tail_block_residuals[1]
|
||||
+ output[self._metadata.return_encoder_hidden_states_index]
|
||||
shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index]
|
||||
)
|
||||
|
||||
if is_output_tuple:
|
||||
@@ -126,20 +126,21 @@ class FBCHeadBlockHook(ModelHook):
|
||||
head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
|
||||
else:
|
||||
head_block_output = output
|
||||
self.shared_state.head_block_output = head_block_output
|
||||
self.shared_state.head_block_residual = hidden_states_residual
|
||||
shared_state.head_block_output = head_block_output
|
||||
shared_state.head_block_residual = hidden_states_residual
|
||||
|
||||
return output
|
||||
|
||||
def reset_state(self, module):
|
||||
self.shared_state.reset()
|
||||
self.state_manager.reset()
|
||||
return module
|
||||
|
||||
@torch.compiler.disable
|
||||
def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
|
||||
if self.shared_state.head_block_residual is None:
|
||||
shared_state = self.state_manager.get_state()
|
||||
if shared_state.head_block_residual is None:
|
||||
return True
|
||||
prev_hidden_states_residual = self.shared_state.head_block_residual
|
||||
prev_hidden_states_residual = shared_state.head_block_residual
|
||||
absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
|
||||
prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
|
||||
diff = (absmean / prev_hidden_states_absmean).item()
|
||||
@@ -147,9 +148,9 @@ class FBCHeadBlockHook(ModelHook):
|
||||
|
||||
|
||||
class FBCBlockHook(ModelHook):
|
||||
def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False):
|
||||
def __init__(self, state_manager: StateManager, is_tail: bool = False):
|
||||
super().__init__()
|
||||
self.shared_state = shared_state
|
||||
self.state_manager = state_manager
|
||||
self.is_tail = is_tail
|
||||
self._metadata = None
|
||||
|
||||
@@ -166,21 +167,22 @@ class FBCBlockHook(ModelHook):
|
||||
if self._metadata.return_encoder_hidden_states_index is not None:
|
||||
original_encoder_hidden_states = outputs_if_skipped[self._metadata.return_encoder_hidden_states_index]
|
||||
|
||||
if self.shared_state.should_compute:
|
||||
shared_state = self.state_manager.get_state()
|
||||
|
||||
if shared_state.should_compute:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
if self.is_tail:
|
||||
hidden_states_residual = encoder_hidden_states_residual = None
|
||||
if isinstance(output, tuple):
|
||||
hidden_states_residual = (
|
||||
output[self._metadata.return_hidden_states_index] - self.shared_state.head_block_output[0]
|
||||
output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
|
||||
)
|
||||
encoder_hidden_states_residual = (
|
||||
output[self._metadata.return_encoder_hidden_states_index]
|
||||
- self.shared_state.head_block_output[1]
|
||||
output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
|
||||
)
|
||||
else:
|
||||
hidden_states_residual = output - self.shared_state.head_block_output
|
||||
self.shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
|
||||
hidden_states_residual = output - shared_state.head_block_output
|
||||
shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
|
||||
return output
|
||||
|
||||
output_count = len(outputs_if_skipped)
|
||||
@@ -194,7 +196,7 @@ class FBCBlockHook(ModelHook):
|
||||
|
||||
|
||||
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
|
||||
shared_state = FBCSharedBlockState()
|
||||
state_manager = StateManager(FBCSharedBlockState, (), {})
|
||||
remaining_blocks = []
|
||||
|
||||
for name, submodule in module.named_children():
|
||||
@@ -207,23 +209,23 @@ def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConf
|
||||
tail_block_name, tail_block = remaining_blocks.pop(-1)
|
||||
|
||||
logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
|
||||
_apply_fbc_head_block_hook(head_block, shared_state, config.threshold)
|
||||
_apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
|
||||
|
||||
for name, block in remaining_blocks:
|
||||
logger.debug(f"Applying FBCBlockHook to '{name}'")
|
||||
_apply_fbc_block_hook(block, shared_state)
|
||||
_apply_fbc_block_hook(block, state_manager)
|
||||
|
||||
logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
|
||||
_apply_fbc_block_hook(tail_block, shared_state, is_tail=True)
|
||||
_apply_fbc_block_hook(tail_block, state_manager, is_tail=True)
|
||||
|
||||
|
||||
def _apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None:
|
||||
def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = FBCHeadBlockHook(state, threshold)
|
||||
hook = FBCHeadBlockHook(state_manager, threshold)
|
||||
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
|
||||
|
||||
|
||||
def _apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None:
|
||||
def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = FBCBlockHook(state, is_tail)
|
||||
hook = FBCBlockHook(state_manager, is_tail)
|
||||
registry.register_hook(hook, _FBC_BLOCK_HOOK)
|
||||
|
||||
@@ -31,23 +31,19 @@ class BaseState:
|
||||
)
|
||||
|
||||
|
||||
class ContextAwareState(BaseState):
|
||||
def __init__(self, init_args=None, init_kwargs=None):
|
||||
super().__init__()
|
||||
|
||||
class StateManager:
|
||||
def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None):
|
||||
self._state_cls = state_cls
|
||||
self._init_args = init_args if init_args is not None else ()
|
||||
self._init_kwargs = init_kwargs if init_kwargs is not None else {}
|
||||
self._current_context = None
|
||||
self._state_cache = {}
|
||||
self._current_context = None
|
||||
|
||||
def get_state(self) -> "ContextAwareState":
|
||||
def get_state(self):
|
||||
if self._current_context is None:
|
||||
# If no context is set, simply return a dummy object since we're not going to be using it
|
||||
return self
|
||||
raise ValueError("No context is set. Please set a context before retrieving the state.")
|
||||
if self._current_context not in self._state_cache.keys():
|
||||
self._state_cache[self._current_context] = ContextAwareState._create_state(
|
||||
self.__class__, self._init_args, self._init_kwargs
|
||||
)
|
||||
self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs)
|
||||
return self._state_cache[self._current_context]
|
||||
|
||||
def set_context(self, name: str) -> None:
|
||||
@@ -59,30 +55,6 @@ class ContextAwareState(BaseState):
|
||||
self._state_cache.pop(name)
|
||||
self._current_context = None
|
||||
|
||||
@staticmethod
|
||||
def _create_state(cls, init_args, init_kwargs) -> "ContextAwareState":
|
||||
return cls(*init_args, **init_kwargs)
|
||||
|
||||
def __getattribute__(self, name):
|
||||
# fmt: off
|
||||
direct_attrs = ("get_state", "set_context", "reset", "_init_args", "_init_kwargs", "_current_context", "_state_cache", "_create_state")
|
||||
# fmt: on
|
||||
if name in direct_attrs or _is_dunder_method(name):
|
||||
return object.__getattribute__(self, name)
|
||||
else:
|
||||
current_state = ContextAwareState.get_state(self)
|
||||
return object.__getattribute__(current_state, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
# fmt: off
|
||||
direct_attrs = ("get_state", "set_context", "reset", "_init_args", "_init_kwargs", "_current_context", "_state_cache", "_create_state")
|
||||
# fmt: on
|
||||
if name in direct_attrs or _is_dunder_method(name):
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
current_state = ContextAwareState.get_state(self)
|
||||
object.__setattr__(current_state, name, value)
|
||||
|
||||
|
||||
class ModelHook:
|
||||
r"""
|
||||
@@ -161,10 +133,10 @@ class ModelHook:
|
||||
return module
|
||||
|
||||
def _set_context(self, module: torch.nn.Module, name: str) -> None:
|
||||
# Iterate over all attributes of the hook to see if any of them have the type `ContextAwareState`. If so, call `set_context` on them.
|
||||
# Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them.
|
||||
for attr_name in dir(self):
|
||||
attr = getattr(self, attr_name)
|
||||
if isinstance(attr, ContextAwareState):
|
||||
if isinstance(attr, StateManager):
|
||||
attr.set_context(name)
|
||||
return module
|
||||
|
||||
|
||||
Reference in New Issue
Block a user