1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/scripts/daam/heatmap.py
Vladimir Mandic 7a3001170b add daam script
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-09-24 20:16:07 -04:00

186 lines
6.1 KiB
Python

import io
from collections import defaultdict
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Tuple, Set, Iterable
from matplotlib import pyplot as plt
import numpy as np
import PIL.Image
import spacy.tokens
import torch
import torch.nn.functional as F
from .evaluate import compute_ioa
from .utils import compute_token_merge_indices, cached_nlp, auto_autocast
__all__ = ['GlobalHeatMap', 'RawHeatMapCollection', 'WordHeatMap', 'ParsedHeatMap', 'SyntacticHeatMapPair']
def plot_overlay_heat_map(im, heat_map, word=None, out_file=None, crop=None, color_normalize=True, ax=None, cmap='jet'):
# type: (PIL.Image.Image | np.ndarray, torch.Tensor, str, Path, int, bool, plt.Axes) -> None
if ax is None:
plt.rcParams['font.size'] = 16
plt.rcParams['figure.facecolor'] = 'black'
plt.rcParams['text.color'] = 'white'
plt.rcParams['axes.labelcolor'] = 'white'
plt.rcParams['xtick.color'] = 'black'
plt.rcParams['ytick.color'] = 'black'
plt.clf()
plt_ = plt
else:
plt_ = ax
with auto_autocast(dtype=torch.float32):
im = np.array(im)
if crop is not None:
heat_map = heat_map.squeeze()[crop:-crop, crop:-crop]
im = im[crop:-crop, crop:-crop]
if color_normalize:
plt_.imshow(heat_map.squeeze().cpu().numpy(), cmap=cmap)
else:
heat_map = heat_map.clamp_(min=0, max=1)
plt_.imshow(heat_map.squeeze().cpu().numpy(), cmap=cmap, vmin=0.0, vmax=1.0)
im = torch.from_numpy(im).float() / 255
im = torch.cat((im, (1 - heat_map.unsqueeze(-1))), dim=-1)
plt_.imshow(im)
if word is not None:
if ax is None:
plt.title(word)
else:
ax.set_title(word)
if out_file is not None:
plt.savefig(out_file)
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
image = PIL.Image.open(buf)
return image
class WordHeatMap:
def __init__(self, heatmap: torch.Tensor, word: str = None, word_idx: int = None):
self.word = word
self.word_idx = word_idx
self.heatmap = heatmap
@property
def value(self):
return self.heatmap
def plot_overlay(self, image, out_file=None, color_normalize=True, ax=None, cmap='jet', **expand_kwargs):
# type: (PIL.Image.Image | np.ndarray, Path, bool, plt.Axes, Dict[str, Any]) -> None
return plot_overlay_heat_map(
image,
self.expand_as(image, **expand_kwargs),
word=self.word,
out_file=out_file,
color_normalize=color_normalize,
ax=ax,
cmap=cmap,
)
def expand_as(self, image, absolute=False, threshold=None, plot=False, **plot_kwargs):
# type: (PIL.Image.Image, bool, float, bool, Dict[str, Any]) -> torch.Tensor
im = self.heatmap.unsqueeze(0).unsqueeze(0)
im = F.interpolate(im.float().detach(), size=(image.size[0], image.size[1]), mode='bicubic')
if not absolute:
im = (im - im.min()) / (im.max() - im.min() + 1e-8)
if threshold:
im = (im > threshold).float()
im = im.cpu().detach().squeeze()
if plot:
self.plot_overlay(image, **plot_kwargs)
return im
def compute_ioa(self, other: 'WordHeatMap'):
return compute_ioa(self.heatmap, other.heatmap)
@dataclass
class SyntacticHeatMapPair:
head_heat_map: WordHeatMap
dep_heat_map: WordHeatMap
head_text: str
dep_text: str
relation: str
@dataclass
class ParsedHeatMap:
word_heat_map: WordHeatMap
token: spacy.tokens.Token
class GlobalHeatMap:
def __init__(self, tokenizer: Any, prompt: str, heat_maps: torch.Tensor):
self.tokenizer = tokenizer
self.heat_maps = heat_maps
self.prompt = prompt
self.compute_word_heat_map = lru_cache(maxsize=50)(self.compute_word_heat_map)
def compute_word_heat_map(self, word: str, word_idx: int = None, offset_idx: int = 0) -> WordHeatMap:
merge_idxs, word_idx = compute_token_merge_indices(self.tokenizer, self.prompt, word, word_idx, offset_idx)
return WordHeatMap(self.heat_maps[merge_idxs].mean(0), word, word_idx)
def parsed_heat_maps(self) -> Iterable[ParsedHeatMap]:
for token in cached_nlp(self.prompt):
try:
heat_map = self.compute_word_heat_map(token.text)
yield ParsedHeatMap(heat_map, token)
except ValueError:
pass
def dependency_relations(self) -> Iterable[SyntacticHeatMapPair]:
for token in cached_nlp(self.prompt):
if token.dep_ != 'ROOT':
try:
dep_heat_map = self.compute_word_heat_map(token.text)
head_heat_map = self.compute_word_heat_map(token.head.text)
yield SyntacticHeatMapPair(head_heat_map, dep_heat_map, token.head.text, token.text, token.dep_)
except ValueError:
pass
RawHeatMapKey = Tuple[int, int, int] # factor, layer, head
class RawHeatMapCollection:
def __init__(self):
self.ids_to_heatmaps: Dict[RawHeatMapKey, torch.Tensor] = defaultdict(lambda: 0.0)
self.ids_to_num_maps: Dict[RawHeatMapKey, int] = defaultdict(lambda: 0)
def update(self, factor: int, layer_idx: int, head_idx: int, heatmap: torch.Tensor):
with auto_autocast(dtype=torch.float32):
key = (factor, layer_idx, head_idx)
self.ids_to_heatmaps[key] = self.ids_to_heatmaps[key] + heatmap
def factors(self) -> Set[int]:
return set(key[0] for key in self.ids_to_heatmaps.keys())
def layers(self) -> Set[int]:
return set(key[1] for key in self.ids_to_heatmaps.keys())
def heads(self) -> Set[int]:
return set(key[2] for key in self.ids_to_heatmaps.keys())
def __iter__(self):
return iter(self.ids_to_heatmaps.items())
def clear(self):
self.ids_to_heatmaps.clear()
self.ids_to_num_maps.clear()