mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
345 lines
13 KiB
Python
345 lines
13 KiB
Python
from pathlib import Path
|
|
from typing import List, Optional, Dict, Any, Union
|
|
from dataclasses import dataclass
|
|
import json
|
|
|
|
from transformers import PreTrainedTokenizer, AutoTokenizer
|
|
import PIL.Image
|
|
import numpy as np
|
|
import torch
|
|
|
|
from .utils import auto_autocast
|
|
from .evaluate import load_mask
|
|
|
|
|
|
__all__ = ['GenerationExperiment', 'COCO80_LABELS', 'COCOSTUFF27_LABELS', 'COCO80_INDICES', 'build_word_list_coco80']
|
|
|
|
|
|
COCO80_LABELS: List[str] = [
|
|
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
|
|
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
|
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
|
|
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
|
|
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
|
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
|
|
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
|
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
|
|
'hair drier', 'toothbrush'
|
|
]
|
|
|
|
COCO80_INDICES: Dict[str, int] = {x: i for i, x in enumerate(COCO80_LABELS)}
|
|
|
|
UNUSED_LABELS: List[str] = [f'__unused_{i}__' for i in range(1, 200)]
|
|
|
|
COCOSTUFF27_LABELS: List[str] = [
|
|
'electronic', 'appliance', 'food', 'furniture', 'indoor', 'kitchen', 'accessory', 'animal', 'outdoor', 'person',
|
|
'sports', 'vehicle', 'ceiling', 'floor', 'food', 'furniture', 'rawmaterial', 'textile', 'wall', 'window',
|
|
'building', 'ground', 'plant', 'sky', 'solid', 'structural', 'water'
|
|
]
|
|
|
|
COCO80_ONTOLOGY = {
|
|
'two-wheeled vehicle': ['bicycle', 'motorcycle'],
|
|
'vehicle': ['two-wheeled vehicle', 'four-wheeled vehicle'],
|
|
'four-wheeled vehicle': ['bus', 'truck', 'car'],
|
|
'four-legged animals': ['livestock', 'pets', 'wild animals'],
|
|
'livestock': ['cow', 'horse', 'sheep'],
|
|
'pets': ['cat', 'dog'],
|
|
'wild animals': ['elephant', 'bear', 'zebra', 'giraffe'],
|
|
'bags': ['backpack', 'handbag', 'suitcase'],
|
|
'sports boards': ['snowboard', 'surfboard', 'skateboard'],
|
|
'utensils': ['fork', 'knife', 'spoon'],
|
|
'receptacles': ['bowl', 'cup'],
|
|
'fruits': ['banana', 'apple', 'orange'],
|
|
'foods': ['fruits', 'meals', 'desserts'],
|
|
'meals': ['sandwich', 'hot dog', 'pizza'],
|
|
'desserts': ['cake', 'donut'],
|
|
'furniture': ['chair', 'couch', 'bench'],
|
|
'electronics': ['monitors', 'appliances'],
|
|
'monitors': ['tv', 'cell phone', 'laptop'],
|
|
'appliances': ['oven', 'toaster', 'refrigerator']
|
|
}
|
|
|
|
COCO80_TO_27 = {
|
|
'bicycle': 'vehicle', 'car': 'vehicle', 'motorcycle': 'vehicle', 'airplane': 'vehicle', 'bus': 'vehicle',
|
|
'train': 'vehicle', 'truck': 'vehicle', 'boat': 'vehicle', 'traffic light': 'accessory', 'fire hydrant': 'accessory',
|
|
'stop sign': 'accessory', 'parking meter': 'accessory', 'bench': 'furniture', 'bird': 'animal', 'cat': 'animal',
|
|
'dog': 'animal', 'horse': 'animal', 'sheep': 'animal', 'cow': 'animal', 'elephant': 'animal', 'bear': 'animal',
|
|
'zebra': 'animal', 'giraffe': 'animal', 'backpack': 'accessory', 'umbrella': 'accessory', 'handbag': 'accessory',
|
|
'tie': 'accessory', 'suitcase': 'accessory', 'frisbee': 'sports', 'skis': 'sports', 'snowboard': 'sports',
|
|
'sports ball': 'sports', 'kite': 'sports', 'baseball bat': 'sports', 'baseball glove': 'sports',
|
|
'skateboard': 'sports', 'surfboard': 'sports', 'tennis racket': 'sports', 'bottle': 'food', 'wine glass': 'food',
|
|
'cup': 'food', 'fork': 'food', 'knife': 'food', 'spoon': 'food', 'bowl': 'food', 'banana': 'food', 'apple': 'food',
|
|
'sandwich': 'food', 'orange': 'food', 'broccoli': 'food', 'carrot': 'food', 'hot dog': 'food', 'pizza': 'food',
|
|
'donut': 'food', 'cake': 'food', 'chair': 'furniture', 'couch': 'furniture', 'potted plant': 'plant',
|
|
'bed': 'furniture', 'dining table': 'furniture', 'toilet': 'furniture', 'tv': 'electronic', 'laptop': 'electronic',
|
|
'mouse': 'electronic', 'remote': 'electronic', 'keyboard': 'electronic', 'cell phone': 'electronic',
|
|
'microwave': 'appliance', 'oven': 'appliance', 'toaster': 'appliance', 'sink': 'appliance',
|
|
'refrigerator': 'appliance', 'book': 'indoor', 'clock': 'indoor', 'vase': 'indoor', 'scissors': 'indoor',
|
|
'teddy bear': 'indoor', 'hair drier': 'indoor', 'toothbrush': 'indoor'
|
|
}
|
|
|
|
|
|
def build_word_list_coco80() -> Dict[str, List[str]]:
|
|
words_map = COCO80_ONTOLOGY.copy()
|
|
words_map = {k: v for k, v in words_map.items() if not any(item in COCO80_ONTOLOGY for item in v)}
|
|
|
|
return words_map
|
|
|
|
|
|
def _add_mask(masks: Dict[str, torch.Tensor], word: str, mask: torch.Tensor, simplify80: bool = False) -> Dict[str, torch.Tensor]:
|
|
if simplify80:
|
|
word = COCO80_TO_27.get(word, word)
|
|
|
|
if word in masks:
|
|
masks[word] = masks[word.lower()] + mask
|
|
masks[word].clamp_(0, 1)
|
|
else:
|
|
masks[word] = mask
|
|
|
|
return masks
|
|
|
|
|
|
@dataclass
|
|
class GenerationExperiment:
|
|
"""Class to hold experiment parameters. Pickleable."""
|
|
image: PIL.Image.Image
|
|
global_heat_map: torch.Tensor
|
|
prompt: str
|
|
|
|
seed: int = None
|
|
id: str = '.'
|
|
path: Optional[Path] = None
|
|
|
|
truth_masks: Optional[Dict[str, torch.Tensor]] = None
|
|
prediction_masks: Optional[Dict[str, torch.Tensor]] = None
|
|
annotations: Optional[Dict[str, Any]] = None
|
|
subtype: Optional[str] = '.'
|
|
tokenizer: AutoTokenizer = None
|
|
|
|
def __post_init__(self):
|
|
if isinstance(self.path, str):
|
|
self.path = Path(self.path)
|
|
|
|
self.path = None if self.path is None else self.path / self.id
|
|
|
|
def nsfw(self) -> bool:
|
|
return np.sum(np.array(self.image)) == 0
|
|
|
|
def heat_map(self, tokenizer: AutoTokenizer = None):
|
|
if tokenizer is None:
|
|
tokenizer = self.tokenizer
|
|
|
|
from daam import GlobalHeatMap
|
|
return GlobalHeatMap(tokenizer, self.prompt, self.global_heat_map)
|
|
|
|
def clear_checkpoint(self):
|
|
path = self if isinstance(self, Path) else self.path
|
|
|
|
(path / 'generation.pt').unlink(missing_ok=True)
|
|
|
|
def save(self, path: str = None, heat_maps: bool = True, tokenizer: AutoTokenizer = None):
|
|
if path is None:
|
|
path = self.path
|
|
else:
|
|
path = Path(path) / self.id
|
|
|
|
if tokenizer is None:
|
|
tokenizer = self.tokenizer
|
|
|
|
(path / self.subtype).mkdir(parents=True, exist_ok=True)
|
|
torch.save(self, path / self.subtype / 'generation.pt')
|
|
self.image.save(path / self.subtype / 'output.png')
|
|
|
|
with (path / 'prompt.txt').open('w') as f:
|
|
f.write(self.prompt)
|
|
|
|
with (path / 'seed.txt').open('w') as f:
|
|
f.write(str(self.seed))
|
|
|
|
if self.truth_masks is not None:
|
|
for name, mask in self.truth_masks.items():
|
|
im = PIL.Image.fromarray((mask * 255).unsqueeze(-1).expand(-1, -1, 4).byte().numpy())
|
|
im.save(path / f'{name.lower()}.gt.png')
|
|
|
|
if heat_maps and tokenizer is not None:
|
|
self.save_all_heat_maps(tokenizer)
|
|
|
|
self.save_annotations()
|
|
|
|
def save_annotations(self, path: Path = None):
|
|
if path is None:
|
|
path = self.path
|
|
|
|
if self.annotations is not None:
|
|
with (path / 'annotations.json').open('w') as f:
|
|
json.dump(self.annotations, f)
|
|
|
|
def _load_truth_masks(self, simplify80: bool = False) -> Dict[str, torch.Tensor]:
|
|
masks = {}
|
|
|
|
for mask_path in self.path.glob('*.gt.png'):
|
|
word = mask_path.name.split('.gt.png')[0].lower()
|
|
mask = load_mask(str(mask_path))
|
|
_add_mask(masks, word, mask, simplify80)
|
|
|
|
return masks
|
|
|
|
def _load_pred_masks(self, pred_prefix, composite=False, simplify80=False, vocab=None):
|
|
# type: (str, bool, bool, List[str] | None) -> Dict[str, torch.Tensor]
|
|
masks = {}
|
|
|
|
if vocab is None:
|
|
vocab = UNUSED_LABELS
|
|
|
|
if composite:
|
|
try:
|
|
im = PIL.Image.open(self.path / self.subtype / f'composite.{pred_prefix}.pred.png')
|
|
im = np.array(im)
|
|
|
|
for mask_idx in np.unique(im):
|
|
mask = torch.from_numpy((im == mask_idx).astype(np.float32))
|
|
_add_mask(masks, vocab[mask_idx], mask, simplify80)
|
|
except FileNotFoundError:
|
|
pass
|
|
else:
|
|
for mask_path in (self.path / self.subtype).glob(f'*.{pred_prefix}.pred.png'):
|
|
mask = load_mask(str(mask_path))
|
|
word = mask_path.name.split(f'.{pred_prefix}.pred')[0].lower()
|
|
_add_mask(masks, word, mask, simplify80)
|
|
|
|
return masks
|
|
|
|
def clear_prediction_masks(self, name: str):
|
|
path = self if isinstance(self, Path) else self.path
|
|
path = path / self.subtype
|
|
|
|
for mask_path in path.glob(f'*.{name}.pred.png'):
|
|
mask_path.unlink()
|
|
|
|
def save_prediction_mask(self, mask: torch.Tensor, word: str, name: str):
|
|
path = self if isinstance(self, Path) else self.path
|
|
im = PIL.Image.fromarray((mask * 255).unsqueeze(-1).expand(-1, -1, 4).cpu().byte().numpy())
|
|
im.save(path / self.subtype / f'{word.lower()}.{name}.pred.png')
|
|
|
|
def save_heat_map(
|
|
self,
|
|
word: str,
|
|
tokenizer: PreTrainedTokenizer = None,
|
|
crop: int = None,
|
|
output_prefix: str = '',
|
|
absolute: bool = False
|
|
) -> Path:
|
|
from .trace import GlobalHeatMap # because of cyclical import
|
|
|
|
if tokenizer is None:
|
|
tokenizer = self.tokenizer
|
|
|
|
with auto_autocast(dtype=torch.float32):
|
|
path = self.path / self.subtype / f'{output_prefix}{word.lower()}.heat_map.png'
|
|
heat_map = GlobalHeatMap(tokenizer, self.prompt, self.global_heat_map)
|
|
heat_map.compute_word_heat_map(word).expand_as(self.image, color_normalize=not absolute, out_file=path, plot=True)
|
|
|
|
return path
|
|
|
|
def save_all_heat_maps(self, tokenizer: PreTrainedTokenizer = None, crop: int = None) -> Dict[str, Path]:
|
|
path_map = {}
|
|
|
|
if tokenizer is None:
|
|
tokenizer = self.tokenizer
|
|
|
|
for word in self.prompt.split(' '):
|
|
try:
|
|
path = self.save_heat_map(word, tokenizer, crop=crop)
|
|
path_map[word] = path
|
|
except Exception:
|
|
pass
|
|
|
|
return path_map
|
|
|
|
@staticmethod
|
|
def contains_truth_mask(path: Union[str, Path], prompt_id: str = None) -> bool:
|
|
if prompt_id is None:
|
|
return any(Path(path).glob('*.gt.png'))
|
|
else:
|
|
return any((Path(path) / prompt_id).glob('*.gt.png'))
|
|
|
|
@staticmethod
|
|
def read_seed(path: Union[str, Path], prompt_id: str = None) -> int:
|
|
if prompt_id is None:
|
|
return int(Path(path).joinpath('seed.txt').read_text())
|
|
else:
|
|
return int(Path(path).joinpath(prompt_id).joinpath('seed.txt').read_text())
|
|
|
|
@staticmethod
|
|
def has_annotations(path: Union[str, Path]) -> bool:
|
|
return Path(path).joinpath('annotations.json').exists()
|
|
|
|
@staticmethod
|
|
def has_experiment(path: Union[str, Path], prompt_id: str) -> bool:
|
|
return (Path(path) / prompt_id / 'generation.pt').exists()
|
|
|
|
@staticmethod
|
|
def read_prompt(path: Union[str, Path], prompt_id: str = None) -> str:
|
|
if prompt_id is None:
|
|
prompt_id = '.'
|
|
|
|
with (Path(path) / prompt_id / 'prompt.txt').open('r') as f:
|
|
return f.read().strip()
|
|
|
|
def _try_load_annotations(self):
|
|
if not (self.path / 'annotations.json').exists():
|
|
return None
|
|
|
|
return json.load((self.path / 'annotations.json').open())
|
|
|
|
def annotate(self, key: str, value: Any) -> 'GenerationExperiment':
|
|
if self.annotations is None:
|
|
self.annotations = {}
|
|
|
|
self.annotations[key] = value
|
|
|
|
return self
|
|
|
|
@classmethod
|
|
def load(
|
|
cls,
|
|
path,
|
|
pred_prefix='daam',
|
|
composite=False,
|
|
simplify80=False,
|
|
vocab=None,
|
|
subtype='.',
|
|
all_subtypes=False
|
|
):
|
|
# type: (str, str, bool, bool, List[str] | None, str, bool) -> GenerationExperiment | List[GenerationExperiment]
|
|
if all_subtypes:
|
|
experiments = []
|
|
|
|
for directory in Path(path).iterdir():
|
|
if not directory.is_dir():
|
|
continue
|
|
|
|
try:
|
|
experiments.append(cls.load(
|
|
path,
|
|
pred_prefix=pred_prefix,
|
|
composite=composite,
|
|
simplify80=simplify80,
|
|
vocab=vocab,
|
|
subtype=directory.name
|
|
))
|
|
except Exception:
|
|
pass
|
|
|
|
return experiments
|
|
|
|
path = Path(path)
|
|
exp = torch.load(path / subtype / 'generation.pt')
|
|
exp.subtype = subtype
|
|
exp.path = path
|
|
exp.truth_masks = exp._load_truth_masks(simplify80=simplify80)
|
|
exp.prediction_masks = exp._load_pred_masks(pred_prefix, composite=composite, simplify80=simplify80, vocab=vocab)
|
|
exp.annotations = exp._try_load_annotations()
|
|
|
|
return exp
|