1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] fix typo in attention_processor.py (#5066)

* [LoRA] fix typo in attention_processor.py

fixes #5062

* make style

* make fix-copies, logger comented for torch compile
This commit is contained in:
Kashif Rasul
2023-09-16 14:43:18 +02:00
committed by GitHub
parent 38a664a3d6
commit 73bb97adfc
5 changed files with 11 additions and 11 deletions

View File

@@ -501,7 +501,7 @@ class LocalBlend:
alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.max_num_words)
for i, (prompt, words_) in enumerate(zip(prompts, words)):
if type(words_) is str:
if isinstance(words_, str):
words_ = [words_]
for word in words_:
ind = get_word_inds(prompt, word, tokenizer)
@@ -565,7 +565,7 @@ class AttentionControlEdit(AttentionStore, abc.ABC):
self.cross_replace_alpha = get_time_words_attention_alpha(
prompts, num_steps, cross_replace_steps, self.tokenizer
).to(self.device)
if type(self_replace_steps) is float:
if isinstance(self_replace_steps, float):
self_replace_steps = 0, self_replace_steps
self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
self.local_blend = local_blend # 在外面定义后传进来
@@ -645,7 +645,7 @@ class AttentionReweight(AttentionControlEdit):
def update_alpha_time_word(
alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor] = None
):
if type(bounds) is float:
if isinstance(bounds, float):
bounds = 0, bounds
start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
if word_inds is None:
@@ -659,7 +659,7 @@ def update_alpha_time_word(
def get_time_words_attention_alpha(
prompts, num_steps, cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], tokenizer, max_num_words=77
):
if type(cross_replace_steps) is not dict:
if not isinstance(cross_replace_steps, dict):
cross_replace_steps = {"default_": cross_replace_steps}
if "default_" not in cross_replace_steps:
cross_replace_steps["default_"] = (0.0, 1.0)
@@ -679,9 +679,9 @@ def get_time_words_attention_alpha(
### util functions for LocalBlend and ReplacementEdit
def get_word_inds(text: str, word_place: int, tokenizer):
split_text = text.split(" ")
if type(word_place) is str:
if isinstance(word_place, str):
word_place = [i for i, word in enumerate(split_text) if word_place == word]
elif type(word_place) is int:
elif isinstance(word_place, str):
word_place = [word_place]
out = []
if len(word_place) > 0:
@@ -750,7 +750,7 @@ def get_replacement_mapper(prompts, tokenizer, max_len=77):
def get_equalizer(
text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], Tuple[float, ...]], tokenizer
):
if type(word_select) is int or type(word_select) is str:
if isinstance(word_select, (int, str)):
word_select = (word_select,)
equalizer = torch.ones(len(values), 77)
values = torch.tensor(values, dtype=torch.float32)

View File

@@ -8,7 +8,6 @@ from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import torch
from diffusers.utils.torch_utils import randn_tensor
from PIL import Image
from transformers import CLIPTokenizer
@@ -22,6 +21,7 @@ from diffusers.utils import (
logging,
replace_example_docstring,
)
from diffusers.utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -11,7 +11,6 @@ import PIL.Image
import pycuda.driver as cuda
import tensorrt as trt
import torch
from diffusers.utils.torch_utils import randn_tensor
from PIL import Image
from pycuda.tools import make_default_context
from transformers import CLIPTokenizer
@@ -26,6 +25,7 @@ from diffusers.utils import (
logging,
replace_example_docstring,
)
from diffusers.utils.torch_utils import randn_tensor
# Initialize CUDA

View File

@@ -382,7 +382,7 @@ class Attention(nn.Module):
}
if hasattr(self.processor, "attention_op"):
kwargs["attention_op"] = self.prcoessor.attention_op
kwargs["attention_op"] = self.processor.attention_op
lora_processor = lora_processor_cls(hidden_size, **kwargs)
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())

View File

@@ -992,7 +992,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
# Forward upsample size to force interpolation output size.
forward_upsample_size = True
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension