1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/scripts/daam/utils.py
Vladimir Mandic ae25cb8880 linting
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-09-25 14:33:21 -04:00

108 lines
2.9 KiB
Python

from functools import lru_cache
from pathlib import Path
import os
import sys
import random
from typing import TypeVar
import PIL.Image
import matplotlib.pyplot as plt
import numpy as np
import spacy
import torch
__all__ = ['set_seed', 'compute_token_merge_indices', 'plot_mask_heat_map', 'cached_nlp', 'cache_dir', 'auto_device', 'auto_autocast']
T = TypeVar('T')
def auto_device(obj: T = torch.device('cpu')) -> T:
if isinstance(obj, torch.device):
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
return obj.to('cuda')
return obj
def auto_autocast(*args, **kwargs):
if not torch.cuda.is_available():
kwargs['enabled'] = False
return torch.cuda.amp.autocast(*args, **kwargs)
def plot_mask_heat_map(im: PIL.Image.Image, heat_map: torch.Tensor, threshold: float = 0.4):
im = torch.from_numpy(np.array(im)).float() / 255
mask = (heat_map.squeeze() > threshold).float()
im = im * mask.unsqueeze(-1)
plt.imshow(im)
def set_seed(seed: int) -> torch.Generator:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
gen = torch.Generator(device=auto_device())
gen.manual_seed(seed)
return gen
def cache_dir() -> Path:
# *nix
if os.name == 'posix' and sys.platform != 'darwin':
xdg = os.environ.get('XDG_CACHE_HOME', os.path.expanduser('~/.cache'))
return Path(xdg, 'daam')
elif sys.platform == 'darwin':
# Mac OS
return Path(os.path.expanduser('~'), 'Library/Caches/daam')
else:
# Windows
local = os.environ.get('LOCALAPPDATA', None) \
or os.path.expanduser('~\\AppData\\Local')
return Path(local, 'daam')
def compute_token_merge_indices(tokenizer, prompt: str, word: str, word_idx: int = None, offset_idx: int = 0):
merge_idxs = []
tokens = tokenizer.tokenize(prompt.lower())
tokens = [x.replace('</w>', '') for x in tokens] # New tokenizer uses wordpiece markers.
if word_idx is None:
word = word.lower()
search_tokens = [x.replace('</w>', '') for x in tokenizer.tokenize(word)] # New tokenizer uses wordpiece markers.
start_indices = [x + offset_idx for x in range(len(tokens)) if tokens[x:x + len(search_tokens)] == search_tokens]
for indice in start_indices:
merge_idxs += [i + indice for i in range(0, len(search_tokens))]
if not merge_idxs:
raise ValueError(f'Search word {word} not found in prompt!')
else:
merge_idxs.append(word_idx)
return [x + 1 for x in merge_idxs], word_idx # Offset by 1.
nlp = None
@lru_cache(maxsize=100000)
def cached_nlp(prompt: str, type='en_core_web_md'):
global nlp
if nlp is None:
try:
nlp = spacy.load(type)
except OSError:
os.system(f'python -m spacy download {type}')
nlp = spacy.load(type)
return nlp(prompt)