1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/scripts/daam/experiment.py
Vladimir Mandic 4b95d72d45 video tab layout
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-10-18 14:07:52 -04:00

345 lines
13 KiB
Python

from pathlib import Path
from typing import List, Optional, Dict, Any, Union
from dataclasses import dataclass
import json
from transformers import PreTrainedTokenizer, AutoTokenizer
import PIL.Image
import numpy as np
import torch
from .utils import auto_autocast
from .evaluate import load_mask
__all__ = ['GenerationExperiment', 'COCO80_LABELS', 'COCOSTUFF27_LABELS', 'COCO80_INDICES', 'build_word_list_coco80']
COCO80_LABELS: List[str] = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush'
]
COCO80_INDICES: Dict[str, int] = {x: i for i, x in enumerate(COCO80_LABELS)}
UNUSED_LABELS: List[str] = [f'__unused_{i}__' for i in range(1, 200)]
COCOSTUFF27_LABELS: List[str] = [
'electronic', 'appliance', 'food', 'furniture', 'indoor', 'kitchen', 'accessory', 'animal', 'outdoor', 'person',
'sports', 'vehicle', 'ceiling', 'floor', 'food', 'furniture', 'rawmaterial', 'textile', 'wall', 'window',
'building', 'ground', 'plant', 'sky', 'solid', 'structural', 'water'
]
COCO80_ONTOLOGY = {
'two-wheeled vehicle': ['bicycle', 'motorcycle'],
'vehicle': ['two-wheeled vehicle', 'four-wheeled vehicle'],
'four-wheeled vehicle': ['bus', 'truck', 'car'],
'four-legged animals': ['livestock', 'pets', 'wild animals'],
'livestock': ['cow', 'horse', 'sheep'],
'pets': ['cat', 'dog'],
'wild animals': ['elephant', 'bear', 'zebra', 'giraffe'],
'bags': ['backpack', 'handbag', 'suitcase'],
'sports boards': ['snowboard', 'surfboard', 'skateboard'],
'utensils': ['fork', 'knife', 'spoon'],
'receptacles': ['bowl', 'cup'],
'fruits': ['banana', 'apple', 'orange'],
'foods': ['fruits', 'meals', 'desserts'],
'meals': ['sandwich', 'hot dog', 'pizza'],
'desserts': ['cake', 'donut'],
'furniture': ['chair', 'couch', 'bench'],
'electronics': ['monitors', 'appliances'],
'monitors': ['tv', 'cell phone', 'laptop'],
'appliances': ['oven', 'toaster', 'refrigerator']
}
COCO80_TO_27 = {
'bicycle': 'vehicle', 'car': 'vehicle', 'motorcycle': 'vehicle', 'airplane': 'vehicle', 'bus': 'vehicle',
'train': 'vehicle', 'truck': 'vehicle', 'boat': 'vehicle', 'traffic light': 'accessory', 'fire hydrant': 'accessory',
'stop sign': 'accessory', 'parking meter': 'accessory', 'bench': 'furniture', 'bird': 'animal', 'cat': 'animal',
'dog': 'animal', 'horse': 'animal', 'sheep': 'animal', 'cow': 'animal', 'elephant': 'animal', 'bear': 'animal',
'zebra': 'animal', 'giraffe': 'animal', 'backpack': 'accessory', 'umbrella': 'accessory', 'handbag': 'accessory',
'tie': 'accessory', 'suitcase': 'accessory', 'frisbee': 'sports', 'skis': 'sports', 'snowboard': 'sports',
'sports ball': 'sports', 'kite': 'sports', 'baseball bat': 'sports', 'baseball glove': 'sports',
'skateboard': 'sports', 'surfboard': 'sports', 'tennis racket': 'sports', 'bottle': 'food', 'wine glass': 'food',
'cup': 'food', 'fork': 'food', 'knife': 'food', 'spoon': 'food', 'bowl': 'food', 'banana': 'food', 'apple': 'food',
'sandwich': 'food', 'orange': 'food', 'broccoli': 'food', 'carrot': 'food', 'hot dog': 'food', 'pizza': 'food',
'donut': 'food', 'cake': 'food', 'chair': 'furniture', 'couch': 'furniture', 'potted plant': 'plant',
'bed': 'furniture', 'dining table': 'furniture', 'toilet': 'furniture', 'tv': 'electronic', 'laptop': 'electronic',
'mouse': 'electronic', 'remote': 'electronic', 'keyboard': 'electronic', 'cell phone': 'electronic',
'microwave': 'appliance', 'oven': 'appliance', 'toaster': 'appliance', 'sink': 'appliance',
'refrigerator': 'appliance', 'book': 'indoor', 'clock': 'indoor', 'vase': 'indoor', 'scissors': 'indoor',
'teddy bear': 'indoor', 'hair drier': 'indoor', 'toothbrush': 'indoor'
}
def build_word_list_coco80() -> Dict[str, List[str]]:
words_map = COCO80_ONTOLOGY.copy()
words_map = {k: v for k, v in words_map.items() if not any(item in COCO80_ONTOLOGY for item in v)}
return words_map
def _add_mask(masks: Dict[str, torch.Tensor], word: str, mask: torch.Tensor, simplify80: bool = False) -> Dict[str, torch.Tensor]:
if simplify80:
word = COCO80_TO_27.get(word, word)
if word in masks:
masks[word] = masks[word.lower()] + mask
masks[word].clamp_(0, 1)
else:
masks[word] = mask
return masks
@dataclass
class GenerationExperiment:
"""Class to hold experiment parameters. Pickleable."""
image: PIL.Image.Image
global_heat_map: torch.Tensor
prompt: str
seed: int = None
id: str = '.'
path: Optional[Path] = None
truth_masks: Optional[Dict[str, torch.Tensor]] = None
prediction_masks: Optional[Dict[str, torch.Tensor]] = None
annotations: Optional[Dict[str, Any]] = None
subtype: Optional[str] = '.'
tokenizer: AutoTokenizer = None
def __post_init__(self):
if isinstance(self.path, str):
self.path = Path(self.path)
self.path = None if self.path is None else self.path / self.id
def nsfw(self) -> bool:
return np.sum(np.array(self.image)) == 0
def heat_map(self, tokenizer: AutoTokenizer = None):
if tokenizer is None:
tokenizer = self.tokenizer
from daam import GlobalHeatMap
return GlobalHeatMap(tokenizer, self.prompt, self.global_heat_map)
def clear_checkpoint(self):
path = self if isinstance(self, Path) else self.path
(path / 'generation.pt').unlink(missing_ok=True)
def save(self, path: str = None, heat_maps: bool = True, tokenizer: AutoTokenizer = None):
if path is None:
path = self.path
else:
path = Path(path) / self.id
if tokenizer is None:
tokenizer = self.tokenizer
(path / self.subtype).mkdir(parents=True, exist_ok=True)
torch.save(self, path / self.subtype / 'generation.pt')
self.image.save(path / self.subtype / 'output.png')
with (path / 'prompt.txt').open('w') as f:
f.write(self.prompt)
with (path / 'seed.txt').open('w') as f:
f.write(str(self.seed))
if self.truth_masks is not None:
for name, mask in self.truth_masks.items():
im = PIL.Image.fromarray((mask * 255).unsqueeze(-1).expand(-1, -1, 4).byte().numpy())
im.save(path / f'{name.lower()}.gt.png')
if heat_maps and tokenizer is not None:
self.save_all_heat_maps(tokenizer)
self.save_annotations()
def save_annotations(self, path: Path = None):
if path is None:
path = self.path
if self.annotations is not None:
with (path / 'annotations.json').open('w') as f:
json.dump(self.annotations, f)
def _load_truth_masks(self, simplify80: bool = False) -> Dict[str, torch.Tensor]:
masks = {}
for mask_path in self.path.glob('*.gt.png'):
word = mask_path.name.split('.gt.png')[0].lower()
mask = load_mask(str(mask_path))
_add_mask(masks, word, mask, simplify80)
return masks
def _load_pred_masks(self, pred_prefix, composite=False, simplify80=False, vocab=None):
# type: (str, bool, bool, List[str] | None) -> Dict[str, torch.Tensor]
masks = {}
if vocab is None:
vocab = UNUSED_LABELS
if composite:
try:
im = PIL.Image.open(self.path / self.subtype / f'composite.{pred_prefix}.pred.png')
im = np.array(im)
for mask_idx in np.unique(im):
mask = torch.from_numpy((im == mask_idx).astype(np.float32))
_add_mask(masks, vocab[mask_idx], mask, simplify80)
except FileNotFoundError:
pass
else:
for mask_path in (self.path / self.subtype).glob(f'*.{pred_prefix}.pred.png'):
mask = load_mask(str(mask_path))
word = mask_path.name.split(f'.{pred_prefix}.pred')[0].lower()
_add_mask(masks, word, mask, simplify80)
return masks
def clear_prediction_masks(self, name: str):
path = self if isinstance(self, Path) else self.path
path = path / self.subtype
for mask_path in path.glob(f'*.{name}.pred.png'):
mask_path.unlink()
def save_prediction_mask(self, mask: torch.Tensor, word: str, name: str):
path = self if isinstance(self, Path) else self.path
im = PIL.Image.fromarray((mask * 255).unsqueeze(-1).expand(-1, -1, 4).cpu().byte().numpy())
im.save(path / self.subtype / f'{word.lower()}.{name}.pred.png')
def save_heat_map(
self,
word: str,
tokenizer: PreTrainedTokenizer = None,
crop: int = None,
output_prefix: str = '',
absolute: bool = False
) -> Path:
from .trace import GlobalHeatMap # because of cyclical import
if tokenizer is None:
tokenizer = self.tokenizer
with auto_autocast(dtype=torch.float32):
path = self.path / self.subtype / f'{output_prefix}{word.lower()}.heat_map.png'
heat_map = GlobalHeatMap(tokenizer, self.prompt, self.global_heat_map)
heat_map.compute_word_heat_map(word).expand_as(self.image, color_normalize=not absolute, out_file=path, plot=True)
return path
def save_all_heat_maps(self, tokenizer: PreTrainedTokenizer = None, crop: int = None) -> Dict[str, Path]:
path_map = {}
if tokenizer is None:
tokenizer = self.tokenizer
for word in self.prompt.split(' '):
try:
path = self.save_heat_map(word, tokenizer, crop=crop)
path_map[word] = path
except Exception:
pass
return path_map
@staticmethod
def contains_truth_mask(path: Union[str, Path], prompt_id: str = None) -> bool:
if prompt_id is None:
return any(Path(path).glob('*.gt.png'))
else:
return any((Path(path) / prompt_id).glob('*.gt.png'))
@staticmethod
def read_seed(path: Union[str, Path], prompt_id: str = None) -> int:
if prompt_id is None:
return int(Path(path).joinpath('seed.txt').read_text())
else:
return int(Path(path).joinpath(prompt_id).joinpath('seed.txt').read_text())
@staticmethod
def has_annotations(path: Union[str, Path]) -> bool:
return Path(path).joinpath('annotations.json').exists()
@staticmethod
def has_experiment(path: Union[str, Path], prompt_id: str) -> bool:
return (Path(path) / prompt_id / 'generation.pt').exists()
@staticmethod
def read_prompt(path: Union[str, Path], prompt_id: str = None) -> str:
if prompt_id is None:
prompt_id = '.'
with (Path(path) / prompt_id / 'prompt.txt').open('r') as f:
return f.read().strip()
def _try_load_annotations(self):
if not (self.path / 'annotations.json').exists():
return None
return json.load((self.path / 'annotations.json').open())
def annotate(self, key: str, value: Any) -> 'GenerationExperiment':
if self.annotations is None:
self.annotations = {}
self.annotations[key] = value
return self
@classmethod
def load(
cls,
path,
pred_prefix='daam',
composite=False,
simplify80=False,
vocab=None,
subtype='.',
all_subtypes=False
):
# type: (str, str, bool, bool, List[str] | None, str, bool) -> GenerationExperiment | List[GenerationExperiment]
if all_subtypes:
experiments = []
for directory in Path(path).iterdir():
if not directory.is_dir():
continue
try:
experiments.append(cls.load(
path,
pred_prefix=pred_prefix,
composite=composite,
simplify80=simplify80,
vocab=vocab,
subtype=directory.name
))
except Exception:
pass
return experiments
path = Path(path)
exp = torch.load(path / subtype / 'generation.pt')
exp.subtype = subtype
exp.path = path
exp.truth_masks = exp._load_truth_masks(simplify80=simplify80)
exp.prediction_masks = exp._load_pred_masks(pred_prefix, composite=composite, simplify80=simplify80, vocab=vocab)
exp.annotations = exp._try_load_annotations()
return exp