mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
186 lines
6.1 KiB
Python
186 lines
6.1 KiB
Python
import io
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from functools import lru_cache
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Tuple, Set, Iterable
|
|
|
|
from matplotlib import pyplot as plt
|
|
import numpy as np
|
|
import PIL.Image
|
|
import spacy.tokens
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from .evaluate import compute_ioa
|
|
from .utils import compute_token_merge_indices, cached_nlp, auto_autocast
|
|
|
|
__all__ = ['GlobalHeatMap', 'RawHeatMapCollection', 'WordHeatMap', 'ParsedHeatMap', 'SyntacticHeatMapPair']
|
|
|
|
|
|
def plot_overlay_heat_map(im, heat_map, word=None, out_file=None, crop=None, color_normalize=True, ax=None, cmap='jet'):
|
|
# type: (PIL.Image.Image | np.ndarray, torch.Tensor, str, Path, int, bool, plt.Axes) -> None
|
|
if ax is None:
|
|
plt.rcParams['font.size'] = 16
|
|
plt.rcParams['figure.facecolor'] = 'black'
|
|
plt.rcParams['text.color'] = 'white'
|
|
plt.rcParams['axes.labelcolor'] = 'white'
|
|
plt.rcParams['xtick.color'] = 'black'
|
|
plt.rcParams['ytick.color'] = 'black'
|
|
plt.clf()
|
|
plt_ = plt
|
|
else:
|
|
plt_ = ax
|
|
|
|
with auto_autocast(dtype=torch.float32):
|
|
im = np.array(im)
|
|
|
|
if crop is not None:
|
|
heat_map = heat_map.squeeze()[crop:-crop, crop:-crop]
|
|
im = im[crop:-crop, crop:-crop]
|
|
|
|
if color_normalize:
|
|
plt_.imshow(heat_map.squeeze().cpu().numpy(), cmap=cmap)
|
|
else:
|
|
heat_map = heat_map.clamp_(min=0, max=1)
|
|
plt_.imshow(heat_map.squeeze().cpu().numpy(), cmap=cmap, vmin=0.0, vmax=1.0)
|
|
|
|
im = torch.from_numpy(im).float() / 255
|
|
im = torch.cat((im, (1 - heat_map.unsqueeze(-1))), dim=-1)
|
|
plt_.imshow(im)
|
|
|
|
if word is not None:
|
|
if ax is None:
|
|
plt.title(word)
|
|
else:
|
|
ax.set_title(word)
|
|
|
|
if out_file is not None:
|
|
plt.savefig(out_file)
|
|
|
|
buf = io.BytesIO()
|
|
plt.savefig(buf, format='png', bbox_inches='tight')
|
|
buf.seek(0)
|
|
image = PIL.Image.open(buf)
|
|
return image
|
|
|
|
|
|
class WordHeatMap:
|
|
def __init__(self, heatmap: torch.Tensor, word: str = None, word_idx: int = None):
|
|
self.word = word
|
|
self.word_idx = word_idx
|
|
self.heatmap = heatmap
|
|
|
|
@property
|
|
def value(self):
|
|
return self.heatmap
|
|
|
|
def plot_overlay(self, image, out_file=None, color_normalize=True, ax=None, cmap='jet', **expand_kwargs):
|
|
# type: (PIL.Image.Image | np.ndarray, Path, bool, plt.Axes, Dict[str, Any]) -> None
|
|
return plot_overlay_heat_map(
|
|
image,
|
|
self.expand_as(image, **expand_kwargs),
|
|
word=self.word,
|
|
out_file=out_file,
|
|
color_normalize=color_normalize,
|
|
ax=ax,
|
|
cmap=cmap,
|
|
)
|
|
|
|
def expand_as(self, image, absolute=False, threshold=None, plot=False, **plot_kwargs):
|
|
# type: (PIL.Image.Image, bool, float, bool, Dict[str, Any]) -> torch.Tensor
|
|
im = self.heatmap.unsqueeze(0).unsqueeze(0)
|
|
im = F.interpolate(im.float().detach(), size=(image.size[0], image.size[1]), mode='bicubic')
|
|
|
|
if not absolute:
|
|
im = (im - im.min()) / (im.max() - im.min() + 1e-8)
|
|
|
|
if threshold:
|
|
im = (im > threshold).float()
|
|
|
|
im = im.cpu().detach().squeeze()
|
|
|
|
if plot:
|
|
self.plot_overlay(image, **plot_kwargs)
|
|
|
|
return im
|
|
|
|
def compute_ioa(self, other: 'WordHeatMap'):
|
|
return compute_ioa(self.heatmap, other.heatmap)
|
|
|
|
|
|
@dataclass
|
|
class SyntacticHeatMapPair:
|
|
head_heat_map: WordHeatMap
|
|
dep_heat_map: WordHeatMap
|
|
head_text: str
|
|
dep_text: str
|
|
relation: str
|
|
|
|
|
|
@dataclass
|
|
class ParsedHeatMap:
|
|
word_heat_map: WordHeatMap
|
|
token: spacy.tokens.Token
|
|
|
|
|
|
class GlobalHeatMap:
|
|
def __init__(self, tokenizer: Any, prompt: str, heat_maps: torch.Tensor):
|
|
self.tokenizer = tokenizer
|
|
self.heat_maps = heat_maps
|
|
self.prompt = prompt
|
|
self.compute_word_heat_map = lru_cache(maxsize=50)(self.compute_word_heat_map)
|
|
|
|
def compute_word_heat_map(self, word: str, word_idx: int = None, offset_idx: int = 0) -> WordHeatMap:
|
|
merge_idxs, word_idx = compute_token_merge_indices(self.tokenizer, self.prompt, word, word_idx, offset_idx)
|
|
return WordHeatMap(self.heat_maps[merge_idxs].mean(0), word, word_idx)
|
|
|
|
def parsed_heat_maps(self) -> Iterable[ParsedHeatMap]:
|
|
for token in cached_nlp(self.prompt):
|
|
try:
|
|
heat_map = self.compute_word_heat_map(token.text)
|
|
yield ParsedHeatMap(heat_map, token)
|
|
except ValueError:
|
|
pass
|
|
|
|
def dependency_relations(self) -> Iterable[SyntacticHeatMapPair]:
|
|
for token in cached_nlp(self.prompt):
|
|
if token.dep_ != 'ROOT':
|
|
try:
|
|
dep_heat_map = self.compute_word_heat_map(token.text)
|
|
head_heat_map = self.compute_word_heat_map(token.head.text)
|
|
|
|
yield SyntacticHeatMapPair(head_heat_map, dep_heat_map, token.head.text, token.text, token.dep_)
|
|
except ValueError:
|
|
pass
|
|
|
|
|
|
RawHeatMapKey = Tuple[int, int, int] # factor, layer, head
|
|
|
|
|
|
class RawHeatMapCollection:
|
|
def __init__(self):
|
|
self.ids_to_heatmaps: Dict[RawHeatMapKey, torch.Tensor] = defaultdict(lambda: 0.0)
|
|
self.ids_to_num_maps: Dict[RawHeatMapKey, int] = defaultdict(lambda: 0)
|
|
|
|
def update(self, factor: int, layer_idx: int, head_idx: int, heatmap: torch.Tensor):
|
|
with auto_autocast(dtype=torch.float32):
|
|
key = (factor, layer_idx, head_idx)
|
|
self.ids_to_heatmaps[key] = self.ids_to_heatmaps[key] + heatmap
|
|
|
|
def factors(self) -> Set[int]:
|
|
return set(key[0] for key in self.ids_to_heatmaps.keys())
|
|
|
|
def layers(self) -> Set[int]:
|
|
return set(key[1] for key in self.ids_to_heatmaps.keys())
|
|
|
|
def heads(self) -> Set[int]:
|
|
return set(key[2] for key in self.ids_to_heatmaps.keys())
|
|
|
|
def __iter__(self):
|
|
return iter(self.ids_to_heatmaps.items())
|
|
|
|
def clear(self):
|
|
self.ids_to_heatmaps.clear()
|
|
self.ids_to_num_maps.clear()
|