mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] fix typo in attention_processor.py (#5066)
* [LoRA] fix typo in attention_processor.py fixes #5062 * make style * make fix-copies, logger comented for torch compile
This commit is contained in:
@@ -501,7 +501,7 @@ class LocalBlend:
|
||||
|
||||
alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.max_num_words)
|
||||
for i, (prompt, words_) in enumerate(zip(prompts, words)):
|
||||
if type(words_) is str:
|
||||
if isinstance(words_, str):
|
||||
words_ = [words_]
|
||||
for word in words_:
|
||||
ind = get_word_inds(prompt, word, tokenizer)
|
||||
@@ -565,7 +565,7 @@ class AttentionControlEdit(AttentionStore, abc.ABC):
|
||||
self.cross_replace_alpha = get_time_words_attention_alpha(
|
||||
prompts, num_steps, cross_replace_steps, self.tokenizer
|
||||
).to(self.device)
|
||||
if type(self_replace_steps) is float:
|
||||
if isinstance(self_replace_steps, float):
|
||||
self_replace_steps = 0, self_replace_steps
|
||||
self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
|
||||
self.local_blend = local_blend # 在外面定义后传进来
|
||||
@@ -645,7 +645,7 @@ class AttentionReweight(AttentionControlEdit):
|
||||
def update_alpha_time_word(
|
||||
alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor] = None
|
||||
):
|
||||
if type(bounds) is float:
|
||||
if isinstance(bounds, float):
|
||||
bounds = 0, bounds
|
||||
start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
|
||||
if word_inds is None:
|
||||
@@ -659,7 +659,7 @@ def update_alpha_time_word(
|
||||
def get_time_words_attention_alpha(
|
||||
prompts, num_steps, cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], tokenizer, max_num_words=77
|
||||
):
|
||||
if type(cross_replace_steps) is not dict:
|
||||
if not isinstance(cross_replace_steps, dict):
|
||||
cross_replace_steps = {"default_": cross_replace_steps}
|
||||
if "default_" not in cross_replace_steps:
|
||||
cross_replace_steps["default_"] = (0.0, 1.0)
|
||||
@@ -679,9 +679,9 @@ def get_time_words_attention_alpha(
|
||||
### util functions for LocalBlend and ReplacementEdit
|
||||
def get_word_inds(text: str, word_place: int, tokenizer):
|
||||
split_text = text.split(" ")
|
||||
if type(word_place) is str:
|
||||
if isinstance(word_place, str):
|
||||
word_place = [i for i, word in enumerate(split_text) if word_place == word]
|
||||
elif type(word_place) is int:
|
||||
elif isinstance(word_place, str):
|
||||
word_place = [word_place]
|
||||
out = []
|
||||
if len(word_place) > 0:
|
||||
@@ -750,7 +750,7 @@ def get_replacement_mapper(prompts, tokenizer, max_len=77):
|
||||
def get_equalizer(
|
||||
text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], Tuple[float, ...]], tokenizer
|
||||
):
|
||||
if type(word_select) is int or type(word_select) is str:
|
||||
if isinstance(word_select, (int, str)):
|
||||
word_select = (word_select,)
|
||||
equalizer = torch.ones(len(values), 77)
|
||||
values = torch.tensor(values, dtype=torch.float32)
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from PIL import Image
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
@@ -22,6 +21,7 @@ from diffusers.utils import (
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -11,7 +11,6 @@ import PIL.Image
|
||||
import pycuda.driver as cuda
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from PIL import Image
|
||||
from pycuda.tools import make_default_context
|
||||
from transformers import CLIPTokenizer
|
||||
@@ -26,6 +25,7 @@ from diffusers.utils import (
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
# Initialize CUDA
|
||||
|
||||
@@ -382,7 +382,7 @@ class Attention(nn.Module):
|
||||
}
|
||||
|
||||
if hasattr(self.processor, "attention_op"):
|
||||
kwargs["attention_op"] = self.prcoessor.attention_op
|
||||
kwargs["attention_op"] = self.processor.attention_op
|
||||
|
||||
lora_processor = lora_processor_cls(hidden_size, **kwargs)
|
||||
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
||||
|
||||
@@ -992,7 +992,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
upsample_size = None
|
||||
|
||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||
logger.info("Forward upsample size to force interpolation output size.")
|
||||
# Forward upsample size to force interpolation output size.
|
||||
forward_upsample_size = True
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
||||
|
||||
Reference in New Issue
Block a user