1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Aryan
2025-04-05 00:23:10 +02:00
parent c76e1cc17e
commit 46619ea717

View File

@@ -77,7 +77,11 @@ class FBCHeadBlockHook(ModelHook):
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs)
original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index]
if isinstance(outputs_if_skipped, tuple):
original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index]
else:
original_hs = outputs_if_skipped
output = self.fn_ref.original_forward(*args, **kwargs)
is_output_tuple = isinstance(output, tuple)
@@ -200,14 +204,14 @@ def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConf
head_block_name, head_block = remaining_blocks.pop(0)
tail_block_name, tail_block = remaining_blocks.pop(-1)
logger.debug(f"Apply FBCHeadBlockHook to '{head_block_name}'")
logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
apply_fbc_head_block_hook(head_block, shared_state, config.threshold)
for name, block in remaining_blocks:
logger.debug(f"Apply FBCBlockHook to '{name}'")
logger.debug(f"Applying FBCBlockHook to '{name}'")
apply_fbc_block_hook(block, shared_state)
logger.debug(f"Apply FBCBlockHook to tail block '{tail_block_name}'")
logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
apply_fbc_block_hook(tail_block, shared_state, is_tail=True)