mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
108 lines
2.9 KiB
Python
108 lines
2.9 KiB
Python
from functools import lru_cache
|
|
from pathlib import Path
|
|
import os
|
|
import sys
|
|
import random
|
|
from typing import TypeVar
|
|
|
|
import PIL.Image
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import spacy
|
|
import torch
|
|
|
|
|
|
__all__ = ['set_seed', 'compute_token_merge_indices', 'plot_mask_heat_map', 'cached_nlp', 'cache_dir', 'auto_device', 'auto_autocast']
|
|
|
|
|
|
T = TypeVar('T')
|
|
|
|
|
|
def auto_device(obj: T = torch.device('cpu')) -> T:
|
|
if isinstance(obj, torch.device):
|
|
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
if torch.cuda.is_available():
|
|
return obj.to('cuda')
|
|
|
|
return obj
|
|
|
|
|
|
def auto_autocast(*args, **kwargs):
|
|
if not torch.cuda.is_available():
|
|
kwargs['enabled'] = False
|
|
|
|
return torch.cuda.amp.autocast(*args, **kwargs)
|
|
|
|
|
|
def plot_mask_heat_map(im: PIL.Image.Image, heat_map: torch.Tensor, threshold: float = 0.4):
|
|
im = torch.from_numpy(np.array(im)).float() / 255
|
|
mask = (heat_map.squeeze() > threshold).float()
|
|
im = im * mask.unsqueeze(-1)
|
|
plt.imshow(im)
|
|
|
|
|
|
def set_seed(seed: int) -> torch.Generator:
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
gen = torch.Generator(device=auto_device())
|
|
gen.manual_seed(seed)
|
|
|
|
return gen
|
|
|
|
|
|
def cache_dir() -> Path:
|
|
# *nix
|
|
if os.name == 'posix' and sys.platform != 'darwin':
|
|
xdg = os.environ.get('XDG_CACHE_HOME', os.path.expanduser('~/.cache'))
|
|
return Path(xdg, 'daam')
|
|
elif sys.platform == 'darwin':
|
|
# Mac OS
|
|
return Path(os.path.expanduser('~'), 'Library/Caches/daam')
|
|
else:
|
|
# Windows
|
|
local = os.environ.get('LOCALAPPDATA', None) \
|
|
or os.path.expanduser('~\\AppData\\Local')
|
|
return Path(local, 'daam')
|
|
|
|
|
|
def compute_token_merge_indices(tokenizer, prompt: str, word: str, word_idx: int = None, offset_idx: int = 0):
|
|
merge_idxs = []
|
|
tokens = tokenizer.tokenize(prompt.lower())
|
|
tokens = [x.replace('</w>', '') for x in tokens] # New tokenizer uses wordpiece markers.
|
|
|
|
if word_idx is None:
|
|
word = word.lower()
|
|
search_tokens = [x.replace('</w>', '') for x in tokenizer.tokenize(word)] # New tokenizer uses wordpiece markers.
|
|
start_indices = [x + offset_idx for x in range(len(tokens)) if tokens[x:x + len(search_tokens)] == search_tokens]
|
|
|
|
for indice in start_indices:
|
|
merge_idxs += [i + indice for i in range(0, len(search_tokens))]
|
|
|
|
if not merge_idxs:
|
|
raise ValueError(f'Search word {word} not found in prompt!')
|
|
else:
|
|
merge_idxs.append(word_idx)
|
|
|
|
return [x + 1 for x in merge_idxs], word_idx # Offset by 1.
|
|
|
|
|
|
nlp = None
|
|
|
|
|
|
@lru_cache(maxsize=100000)
|
|
def cached_nlp(prompt: str, type='en_core_web_md'):
|
|
global nlp
|
|
|
|
if nlp is None:
|
|
try:
|
|
nlp = spacy.load(type)
|
|
except OSError:
|
|
os.system(f'python -m spacy download {type}')
|
|
nlp = spacy.load(type)
|
|
|
|
return nlp(prompt)
|