mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
apply ruff
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional, List, Dict, Tuple
|
||||
from typing import Optional, List, Dict, Tuple
|
||||
from .hooks import ModelHook
|
||||
import math
|
||||
from ..models.attention import Attention
|
||||
@@ -11,11 +11,9 @@ from ._common import (
|
||||
from ..hooks import HookRegistry
|
||||
from ..utils import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
|
||||
|
||||
@@ -70,6 +68,7 @@ class TaylorSeerCacheConfig:
|
||||
def get_identifiers_template(self) -> Dict[str, Dict[str, List[str]]]:
|
||||
return _CACHE_TEMPLATES
|
||||
|
||||
|
||||
class TaylorSeerOutputState:
|
||||
"""
|
||||
Manages the state for Taylor series-based prediction of a single attention output.
|
||||
@@ -219,9 +218,7 @@ class TaylorSeerAttentionCacheHook(ModelHook):
|
||||
module_dtype = attention_outputs[0].dtype
|
||||
self.num_outputs = len(attention_outputs)
|
||||
self.states = [
|
||||
TaylorSeerOutputState(
|
||||
self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip
|
||||
)
|
||||
TaylorSeerOutputState(self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip)
|
||||
for _ in range(self.num_outputs)
|
||||
]
|
||||
for i, features in enumerate(attention_outputs):
|
||||
@@ -249,7 +246,9 @@ class TaylorSeerAttentionCacheHook(ModelHook):
|
||||
attention_outputs = list(attention_outputs)
|
||||
is_first_update = self.step_counter == 0 # Only True for the very first step
|
||||
for i, features in enumerate(attention_outputs):
|
||||
self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update)
|
||||
self.states[i].update(
|
||||
features, self.step_counter, self.max_order, self.predict_steps, is_first_update
|
||||
)
|
||||
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
|
||||
else:
|
||||
# Predict using Taylor series
|
||||
@@ -330,4 +329,4 @@ def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSee
|
||||
is_skip=is_skip,
|
||||
)
|
||||
|
||||
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)
|
||||
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)
|
||||
|
||||
Reference in New Issue
Block a user