diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py index b232e6465c..31ee08c34d 100644 --- a/src/diffusers/hooks/first_block_cache.py +++ b/src/diffusers/hooks/first_block_cache.py @@ -21,7 +21,7 @@ from ..utils import get_logger from ..utils.torch_utils import unwrap_module from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS from ._helpers import TransformerBlockRegistry -from .hooks import ContextAwareState, HookRegistry, ModelHook +from .hooks import BaseState, HookRegistry, ModelHook, StateManager logger = get_logger(__name__) # pylint: disable=invalid-name @@ -49,7 +49,7 @@ class FirstBlockCacheConfig: threshold: float = 0.05 -class FBCSharedBlockState(ContextAwareState): +class FBCSharedBlockState(BaseState): def __init__(self) -> None: super().__init__() @@ -66,8 +66,8 @@ class FBCSharedBlockState(ContextAwareState): class FBCHeadBlockHook(ModelHook): _is_stateful = True - def __init__(self, shared_state: FBCSharedBlockState, threshold: float): - self.shared_state = shared_state + def __init__(self, state_manager: StateManager, threshold: float): + self.state_manager = state_manager self.threshold = threshold self._metadata = None @@ -91,24 +91,24 @@ class FBCHeadBlockHook(ModelHook): else: hidden_states_residual = output - original_hidden_states + shared_state: FBCSharedBlockState = self.state_manager.get_state() hidden_states, encoder_hidden_states = None, None should_compute = self._should_compute_remaining_blocks(hidden_states_residual) - self.shared_state.should_compute = should_compute + shared_state.should_compute = should_compute if not should_compute: # Apply caching if is_output_tuple: hidden_states = ( - self.shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index] + shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index] ) else: - hidden_states = self.shared_state.tail_block_residuals[0] + output + hidden_states = shared_state.tail_block_residuals[0] + output if self._metadata.return_encoder_hidden_states_index is not None: assert is_output_tuple encoder_hidden_states = ( - self.shared_state.tail_block_residuals[1] - + output[self._metadata.return_encoder_hidden_states_index] + shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index] ) if is_output_tuple: @@ -126,20 +126,21 @@ class FBCHeadBlockHook(ModelHook): head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index] else: head_block_output = output - self.shared_state.head_block_output = head_block_output - self.shared_state.head_block_residual = hidden_states_residual + shared_state.head_block_output = head_block_output + shared_state.head_block_residual = hidden_states_residual return output def reset_state(self, module): - self.shared_state.reset() + self.state_manager.reset() return module @torch.compiler.disable def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool: - if self.shared_state.head_block_residual is None: + shared_state = self.state_manager.get_state() + if shared_state.head_block_residual is None: return True - prev_hidden_states_residual = self.shared_state.head_block_residual + prev_hidden_states_residual = shared_state.head_block_residual absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean() prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean() diff = (absmean / prev_hidden_states_absmean).item() @@ -147,9 +148,9 @@ class FBCHeadBlockHook(ModelHook): class FBCBlockHook(ModelHook): - def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False): + def __init__(self, state_manager: StateManager, is_tail: bool = False): super().__init__() - self.shared_state = shared_state + self.state_manager = state_manager self.is_tail = is_tail self._metadata = None @@ -166,21 +167,22 @@ class FBCBlockHook(ModelHook): if self._metadata.return_encoder_hidden_states_index is not None: original_encoder_hidden_states = outputs_if_skipped[self._metadata.return_encoder_hidden_states_index] - if self.shared_state.should_compute: + shared_state = self.state_manager.get_state() + + if shared_state.should_compute: output = self.fn_ref.original_forward(*args, **kwargs) if self.is_tail: hidden_states_residual = encoder_hidden_states_residual = None if isinstance(output, tuple): hidden_states_residual = ( - output[self._metadata.return_hidden_states_index] - self.shared_state.head_block_output[0] + output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0] ) encoder_hidden_states_residual = ( - output[self._metadata.return_encoder_hidden_states_index] - - self.shared_state.head_block_output[1] + output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1] ) else: - hidden_states_residual = output - self.shared_state.head_block_output - self.shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual) + hidden_states_residual = output - shared_state.head_block_output + shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual) return output output_count = len(outputs_if_skipped) @@ -194,7 +196,7 @@ class FBCBlockHook(ModelHook): def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None: - shared_state = FBCSharedBlockState() + state_manager = StateManager(FBCSharedBlockState, (), {}) remaining_blocks = [] for name, submodule in module.named_children(): @@ -207,23 +209,23 @@ def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConf tail_block_name, tail_block = remaining_blocks.pop(-1) logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'") - _apply_fbc_head_block_hook(head_block, shared_state, config.threshold) + _apply_fbc_head_block_hook(head_block, state_manager, config.threshold) for name, block in remaining_blocks: logger.debug(f"Applying FBCBlockHook to '{name}'") - _apply_fbc_block_hook(block, shared_state) + _apply_fbc_block_hook(block, state_manager) logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'") - _apply_fbc_block_hook(tail_block, shared_state, is_tail=True) + _apply_fbc_block_hook(tail_block, state_manager, is_tail=True) -def _apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None: +def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None: registry = HookRegistry.check_if_exists_or_initialize(block) - hook = FBCHeadBlockHook(state, threshold) + hook = FBCHeadBlockHook(state_manager, threshold) registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK) -def _apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None: +def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None: registry = HookRegistry.check_if_exists_or_initialize(block) - hook = FBCBlockHook(state, is_tail) + hook = FBCBlockHook(state_manager, is_tail) registry.register_hook(hook, _FBC_BLOCK_HOOK) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 4ca5761f75..3b39829fc5 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -31,23 +31,19 @@ class BaseState: ) -class ContextAwareState(BaseState): - def __init__(self, init_args=None, init_kwargs=None): - super().__init__() - +class StateManager: + def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None): + self._state_cls = state_cls self._init_args = init_args if init_args is not None else () self._init_kwargs = init_kwargs if init_kwargs is not None else {} - self._current_context = None self._state_cache = {} + self._current_context = None - def get_state(self) -> "ContextAwareState": + def get_state(self): if self._current_context is None: - # If no context is set, simply return a dummy object since we're not going to be using it - return self + raise ValueError("No context is set. Please set a context before retrieving the state.") if self._current_context not in self._state_cache.keys(): - self._state_cache[self._current_context] = ContextAwareState._create_state( - self.__class__, self._init_args, self._init_kwargs - ) + self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs) return self._state_cache[self._current_context] def set_context(self, name: str) -> None: @@ -59,30 +55,6 @@ class ContextAwareState(BaseState): self._state_cache.pop(name) self._current_context = None - @staticmethod - def _create_state(cls, init_args, init_kwargs) -> "ContextAwareState": - return cls(*init_args, **init_kwargs) - - def __getattribute__(self, name): - # fmt: off - direct_attrs = ("get_state", "set_context", "reset", "_init_args", "_init_kwargs", "_current_context", "_state_cache", "_create_state") - # fmt: on - if name in direct_attrs or _is_dunder_method(name): - return object.__getattribute__(self, name) - else: - current_state = ContextAwareState.get_state(self) - return object.__getattribute__(current_state, name) - - def __setattr__(self, name, value): - # fmt: off - direct_attrs = ("get_state", "set_context", "reset", "_init_args", "_init_kwargs", "_current_context", "_state_cache", "_create_state") - # fmt: on - if name in direct_attrs or _is_dunder_method(name): - object.__setattr__(self, name, value) - else: - current_state = ContextAwareState.get_state(self) - object.__setattr__(current_state, name, value) - class ModelHook: r""" @@ -161,10 +133,10 @@ class ModelHook: return module def _set_context(self, module: torch.nn.Module, name: str) -> None: - # Iterate over all attributes of the hook to see if any of them have the type `ContextAwareState`. If so, call `set_context` on them. + # Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them. for attr_name in dir(self): attr = getattr(self, attr_name) - if isinstance(attr, ContextAwareState): + if isinstance(attr, StateManager): attr.set_context(name) return module