1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

address review comments

This commit is contained in:
Aryan
2025-05-14 12:14:24 +02:00
parent 2ed59c178d
commit 0a44380a36
2 changed files with 41 additions and 67 deletions

View File

@@ -21,7 +21,7 @@ from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
from ._helpers import TransformerBlockRegistry
from .hooks import ContextAwareState, HookRegistry, ModelHook
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -49,7 +49,7 @@ class FirstBlockCacheConfig:
threshold: float = 0.05
class FBCSharedBlockState(ContextAwareState):
class FBCSharedBlockState(BaseState):
def __init__(self) -> None:
super().__init__()
@@ -66,8 +66,8 @@ class FBCSharedBlockState(ContextAwareState):
class FBCHeadBlockHook(ModelHook):
_is_stateful = True
def __init__(self, shared_state: FBCSharedBlockState, threshold: float):
self.shared_state = shared_state
def __init__(self, state_manager: StateManager, threshold: float):
self.state_manager = state_manager
self.threshold = threshold
self._metadata = None
@@ -91,24 +91,24 @@ class FBCHeadBlockHook(ModelHook):
else:
hidden_states_residual = output - original_hidden_states
shared_state: FBCSharedBlockState = self.state_manager.get_state()
hidden_states, encoder_hidden_states = None, None
should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
self.shared_state.should_compute = should_compute
shared_state.should_compute = should_compute
if not should_compute:
# Apply caching
if is_output_tuple:
hidden_states = (
self.shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
)
else:
hidden_states = self.shared_state.tail_block_residuals[0] + output
hidden_states = shared_state.tail_block_residuals[0] + output
if self._metadata.return_encoder_hidden_states_index is not None:
assert is_output_tuple
encoder_hidden_states = (
self.shared_state.tail_block_residuals[1]
+ output[self._metadata.return_encoder_hidden_states_index]
shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index]
)
if is_output_tuple:
@@ -126,20 +126,21 @@ class FBCHeadBlockHook(ModelHook):
head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
else:
head_block_output = output
self.shared_state.head_block_output = head_block_output
self.shared_state.head_block_residual = hidden_states_residual
shared_state.head_block_output = head_block_output
shared_state.head_block_residual = hidden_states_residual
return output
def reset_state(self, module):
self.shared_state.reset()
self.state_manager.reset()
return module
@torch.compiler.disable
def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
if self.shared_state.head_block_residual is None:
shared_state = self.state_manager.get_state()
if shared_state.head_block_residual is None:
return True
prev_hidden_states_residual = self.shared_state.head_block_residual
prev_hidden_states_residual = shared_state.head_block_residual
absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
diff = (absmean / prev_hidden_states_absmean).item()
@@ -147,9 +148,9 @@ class FBCHeadBlockHook(ModelHook):
class FBCBlockHook(ModelHook):
def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False):
def __init__(self, state_manager: StateManager, is_tail: bool = False):
super().__init__()
self.shared_state = shared_state
self.state_manager = state_manager
self.is_tail = is_tail
self._metadata = None
@@ -166,21 +167,22 @@ class FBCBlockHook(ModelHook):
if self._metadata.return_encoder_hidden_states_index is not None:
original_encoder_hidden_states = outputs_if_skipped[self._metadata.return_encoder_hidden_states_index]
if self.shared_state.should_compute:
shared_state = self.state_manager.get_state()
if shared_state.should_compute:
output = self.fn_ref.original_forward(*args, **kwargs)
if self.is_tail:
hidden_states_residual = encoder_hidden_states_residual = None
if isinstance(output, tuple):
hidden_states_residual = (
output[self._metadata.return_hidden_states_index] - self.shared_state.head_block_output[0]
output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
)
encoder_hidden_states_residual = (
output[self._metadata.return_encoder_hidden_states_index]
- self.shared_state.head_block_output[1]
output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
)
else:
hidden_states_residual = output - self.shared_state.head_block_output
self.shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
hidden_states_residual = output - shared_state.head_block_output
shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
return output
output_count = len(outputs_if_skipped)
@@ -194,7 +196,7 @@ class FBCBlockHook(ModelHook):
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
shared_state = FBCSharedBlockState()
state_manager = StateManager(FBCSharedBlockState, (), {})
remaining_blocks = []
for name, submodule in module.named_children():
@@ -207,23 +209,23 @@ def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConf
tail_block_name, tail_block = remaining_blocks.pop(-1)
logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
_apply_fbc_head_block_hook(head_block, shared_state, config.threshold)
_apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
for name, block in remaining_blocks:
logger.debug(f"Applying FBCBlockHook to '{name}'")
_apply_fbc_block_hook(block, shared_state)
_apply_fbc_block_hook(block, state_manager)
logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
_apply_fbc_block_hook(tail_block, shared_state, is_tail=True)
_apply_fbc_block_hook(tail_block, state_manager, is_tail=True)
def _apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None:
def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = FBCHeadBlockHook(state, threshold)
hook = FBCHeadBlockHook(state_manager, threshold)
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
def _apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None:
def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = FBCBlockHook(state, is_tail)
hook = FBCBlockHook(state_manager, is_tail)
registry.register_hook(hook, _FBC_BLOCK_HOOK)

View File

@@ -31,23 +31,19 @@ class BaseState:
)
class ContextAwareState(BaseState):
def __init__(self, init_args=None, init_kwargs=None):
super().__init__()
class StateManager:
def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None):
self._state_cls = state_cls
self._init_args = init_args if init_args is not None else ()
self._init_kwargs = init_kwargs if init_kwargs is not None else {}
self._current_context = None
self._state_cache = {}
self._current_context = None
def get_state(self) -> "ContextAwareState":
def get_state(self):
if self._current_context is None:
# If no context is set, simply return a dummy object since we're not going to be using it
return self
raise ValueError("No context is set. Please set a context before retrieving the state.")
if self._current_context not in self._state_cache.keys():
self._state_cache[self._current_context] = ContextAwareState._create_state(
self.__class__, self._init_args, self._init_kwargs
)
self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs)
return self._state_cache[self._current_context]
def set_context(self, name: str) -> None:
@@ -59,30 +55,6 @@ class ContextAwareState(BaseState):
self._state_cache.pop(name)
self._current_context = None
@staticmethod
def _create_state(cls, init_args, init_kwargs) -> "ContextAwareState":
return cls(*init_args, **init_kwargs)
def __getattribute__(self, name):
# fmt: off
direct_attrs = ("get_state", "set_context", "reset", "_init_args", "_init_kwargs", "_current_context", "_state_cache", "_create_state")
# fmt: on
if name in direct_attrs or _is_dunder_method(name):
return object.__getattribute__(self, name)
else:
current_state = ContextAwareState.get_state(self)
return object.__getattribute__(current_state, name)
def __setattr__(self, name, value):
# fmt: off
direct_attrs = ("get_state", "set_context", "reset", "_init_args", "_init_kwargs", "_current_context", "_state_cache", "_create_state")
# fmt: on
if name in direct_attrs or _is_dunder_method(name):
object.__setattr__(self, name, value)
else:
current_state = ContextAwareState.get_state(self)
object.__setattr__(current_state, name, value)
class ModelHook:
r"""
@@ -161,10 +133,10 @@ class ModelHook:
return module
def _set_context(self, module: torch.nn.Module, name: str) -> None:
# Iterate over all attributes of the hook to see if any of them have the type `ContextAwareState`. If so, call `set_context` on them.
# Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them.
for attr_name in dir(self):
attr = getattr(self, attr_name)
if isinstance(attr, ContextAwareState):
if isinstance(attr, StateManager):
attr.set_context(name)
return module