mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -77,7 +77,11 @@ class FBCHeadBlockHook(ModelHook):
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs)
|
||||
original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index]
|
||||
|
||||
if isinstance(outputs_if_skipped, tuple):
|
||||
original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index]
|
||||
else:
|
||||
original_hs = outputs_if_skipped
|
||||
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
is_output_tuple = isinstance(output, tuple)
|
||||
@@ -200,14 +204,14 @@ def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConf
|
||||
head_block_name, head_block = remaining_blocks.pop(0)
|
||||
tail_block_name, tail_block = remaining_blocks.pop(-1)
|
||||
|
||||
logger.debug(f"Apply FBCHeadBlockHook to '{head_block_name}'")
|
||||
logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
|
||||
apply_fbc_head_block_hook(head_block, shared_state, config.threshold)
|
||||
|
||||
for name, block in remaining_blocks:
|
||||
logger.debug(f"Apply FBCBlockHook to '{name}'")
|
||||
logger.debug(f"Applying FBCBlockHook to '{name}'")
|
||||
apply_fbc_block_hook(block, shared_state)
|
||||
|
||||
logger.debug(f"Apply FBCBlockHook to tail block '{tail_block_name}'")
|
||||
logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
|
||||
apply_fbc_block_hook(tail_block, shared_state, is_tail=True)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user