from functools import lru_cache from pathlib import Path import os import sys import random from typing import TypeVar import PIL.Image import matplotlib.pyplot as plt import numpy as np import spacy import torch __all__ = ['set_seed', 'compute_token_merge_indices', 'plot_mask_heat_map', 'cached_nlp', 'cache_dir', 'auto_device', 'auto_autocast'] T = TypeVar('T') def auto_device(obj: T = torch.device('cpu')) -> T: if isinstance(obj, torch.device): return torch.device('cuda' if torch.cuda.is_available() else 'cpu') if torch.cuda.is_available(): return obj.to('cuda') return obj def auto_autocast(*args, **kwargs): if not torch.cuda.is_available(): kwargs['enabled'] = False return torch.cuda.amp.autocast(*args, **kwargs) def plot_mask_heat_map(im: PIL.Image.Image, heat_map: torch.Tensor, threshold: float = 0.4): im = torch.from_numpy(np.array(im)).float() / 255 mask = (heat_map.squeeze() > threshold).float() im = im * mask.unsqueeze(-1) plt.imshow(im) def set_seed(seed: int) -> torch.Generator: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) gen = torch.Generator(device=auto_device()) gen.manual_seed(seed) return gen def cache_dir() -> Path: # *nix if os.name == 'posix' and sys.platform != 'darwin': xdg = os.environ.get('XDG_CACHE_HOME', os.path.expanduser('~/.cache')) return Path(xdg, 'daam') elif sys.platform == 'darwin': # Mac OS return Path(os.path.expanduser('~'), 'Library/Caches/daam') else: # Windows local = os.environ.get('LOCALAPPDATA', None) \ or os.path.expanduser('~\\AppData\\Local') return Path(local, 'daam') def compute_token_merge_indices(tokenizer, prompt: str, word: str, word_idx: int = None, offset_idx: int = 0): merge_idxs = [] tokens = tokenizer.tokenize(prompt.lower()) tokens = [x.replace('', '') for x in tokens] # New tokenizer uses wordpiece markers. if word_idx is None: word = word.lower() search_tokens = [x.replace('', '') for x in tokenizer.tokenize(word)] # New tokenizer uses wordpiece markers. start_indices = [x + offset_idx for x in range(len(tokens)) if tokens[x:x + len(search_tokens)] == search_tokens] for indice in start_indices: merge_idxs += [i + indice for i in range(0, len(search_tokens))] if not merge_idxs: raise ValueError(f'Search word {word} not found in prompt!') else: merge_idxs.append(word_idx) return [x + 1 for x in merge_idxs], word_idx # Offset by 1. nlp = None @lru_cache(maxsize=100000) def cached_nlp(prompt: str, type='en_core_web_md'): global nlp if nlp is None: try: nlp = spacy.load(type) except OSError: os.system(f'python -m spacy download {type}') nlp = spacy.load(type) return nlp(prompt)