1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

apply ruff

This commit is contained in:
toilaluan
2025-11-17 13:24:20 +07:00
parent 9290b5895f
commit d929ab28a7

View File

@@ -1,6 +1,6 @@
import torch
from dataclasses import dataclass
from typing import Callable, Optional, List, Dict, Tuple
from typing import Optional, List, Dict, Tuple
from .hooks import ModelHook
import math
from ..models.attention import Attention
@@ -11,11 +11,9 @@ from ._common import (
from ..hooks import HookRegistry
from ..utils import logging
import re
from collections import defaultdict
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
logger = logging.get_logger(__name__)
_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
@@ -70,6 +68,7 @@ class TaylorSeerCacheConfig:
def get_identifiers_template(self) -> Dict[str, Dict[str, List[str]]]:
return _CACHE_TEMPLATES
class TaylorSeerOutputState:
"""
Manages the state for Taylor series-based prediction of a single attention output.
@@ -219,9 +218,7 @@ class TaylorSeerAttentionCacheHook(ModelHook):
module_dtype = attention_outputs[0].dtype
self.num_outputs = len(attention_outputs)
self.states = [
TaylorSeerOutputState(
self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip
)
TaylorSeerOutputState(self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip)
for _ in range(self.num_outputs)
]
for i, features in enumerate(attention_outputs):
@@ -249,7 +246,9 @@ class TaylorSeerAttentionCacheHook(ModelHook):
attention_outputs = list(attention_outputs)
is_first_update = self.step_counter == 0 # Only True for the very first step
for i, features in enumerate(attention_outputs):
self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update)
self.states[i].update(
features, self.step_counter, self.max_order, self.predict_steps, is_first_update
)
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
else:
# Predict using Taylor series
@@ -330,4 +329,4 @@ def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSee
is_skip=is_skip,
)
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)