1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-29 05:02:09 +03:00
Files
sdnext/modules/interrogate/deepbooru.py
vladmandic 33d4a4999d lint deepbooru
Signed-off-by: vladmandic <mandic00@live.com>
2026-01-24 18:51:11 +01:00

329 lines
11 KiB
Python

import os
import re
import threading
import torch
import numpy as np
from PIL import Image
from modules import modelloader, paths, devices, shared
re_special = re.compile(r'([\\()])')
load_lock = threading.Lock()
class DeepDanbooru:
def __init__(self):
self.model = None
def load(self):
with load_lock:
if self.model is not None:
return
model_path = os.path.join(paths.models_path, "DeepDanbooru")
shared.log.debug(f'Interrogate load: module=DeepDanbooru folder="{model_path}"')
files = modelloader.load_models(
model_path=model_path,
model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',
ext_filter=[".pt"],
download_name='model-resnet_custom_v3.pt',
)
from modules.interrogate.deepbooru_model import DeepDanbooruModel
self.model = DeepDanbooruModel()
self.model.load_state_dict(torch.load(files[0], map_location="cpu"))
self.model.eval()
self.model.to(devices.cpu, devices.dtype)
def start(self):
self.load()
self.model.to(devices.device)
def stop(self):
if shared.opts.interrogate_offload:
self.model.to(devices.cpu)
devices.torch_gc()
def tag(self, pil_image, **kwargs):
self.start()
res = self.tag_multi(pil_image, **kwargs)
self.stop()
return res
def tag_multi(
self,
pil_image,
general_threshold: float = None,
include_rating: bool = None,
exclude_tags: str = None,
max_tags: int = None,
sort_alpha: bool = None,
use_spaces: bool = None,
escape_brackets: bool = None,
):
"""Run inference and return formatted tag string.
Args:
pil_image: PIL Image to tag
general_threshold: Threshold for tag scores (0-1)
include_rating: Whether to include rating tags
exclude_tags: Comma-separated tags to exclude
max_tags: Maximum number of tags to return
sort_alpha: Sort tags alphabetically vs by confidence
use_spaces: Use spaces instead of underscores
escape_brackets: Escape parentheses/brackets in tags
Returns:
Formatted tag string
"""
# Use settings defaults if not specified
general_threshold = general_threshold or shared.opts.tagger_threshold
include_rating = include_rating if include_rating is not None else shared.opts.tagger_include_rating
exclude_tags = exclude_tags or shared.opts.tagger_exclude_tags
max_tags = max_tags or shared.opts.tagger_max_tags
sort_alpha = sort_alpha if sort_alpha is not None else shared.opts.tagger_sort_alpha
use_spaces = use_spaces if use_spaces is not None else shared.opts.tagger_use_spaces
escape_brackets = escape_brackets if escape_brackets is not None else shared.opts.tagger_escape_brackets
if isinstance(pil_image, list):
pil_image = pil_image[0] if len(pil_image) > 0 else None
if isinstance(pil_image, dict) and 'name' in pil_image:
pil_image = Image.open(pil_image['name'])
if pil_image is None:
return ''
pic = pil_image.resize((512, 512), resample=Image.Resampling.LANCZOS).convert("RGB")
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
with devices.inference_context():
x = torch.from_numpy(a).to(device=devices.device, dtype=devices.dtype)
y = self.model(x)[0].detach().float().cpu().numpy()
probability_dict = {}
for current, probability in zip(self.model.tags, y):
if probability < general_threshold:
continue
if current.startswith("rating:") and not include_rating:
continue
probability_dict[current] = probability
if sort_alpha:
tags = sorted(probability_dict)
else:
tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
res = []
filtertags = {x.strip().replace(' ', '_') for x in exclude_tags.split(",")}
for filtertag in [x for x in tags if x not in filtertags]:
probability = probability_dict[filtertag]
tag_outformat = filtertag
if use_spaces:
tag_outformat = tag_outformat.replace('_', ' ')
if escape_brackets:
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
if shared.opts.tagger_show_scores:
tag_outformat = f"({tag_outformat}:{probability:.2f})"
res.append(tag_outformat)
if max_tags > 0 and len(res) > max_tags:
res = res[:max_tags]
return ", ".join(res)
model = DeepDanbooru()
def _save_tags_to_file(img_path, tags_str: str, save_append: bool) -> bool:
"""Save tags to a text file with error handling.
Args:
img_path: Path to the image file
tags_str: Tags string to save
save_append: If True, append to existing file; otherwise overwrite
Returns:
True if save succeeded, False otherwise
"""
try:
txt_path = img_path.with_suffix('.txt')
if save_append and txt_path.exists():
with open(txt_path, 'a', encoding='utf-8') as f:
f.write(f', {tags_str}')
else:
with open(txt_path, 'w', encoding='utf-8') as f:
f.write(tags_str)
return True
except Exception as e:
shared.log.error(f'DeepBooru batch: failed to save file="{img_path}" error={e}')
return False
def get_models() -> list:
"""Return list of available DeepBooru models (just one)."""
return ["DeepBooru"]
def load_model(model_name: str = None) -> bool: # pylint: disable=unused-argument
"""Load the DeepBooru model."""
try:
model.load()
return model.model is not None
except Exception as e:
shared.log.error(f'DeepBooru load: {e}')
return False
def unload_model():
"""Unload the DeepBooru model and free memory."""
if model.model is not None:
shared.log.debug('DeepBooru unload')
model.model = None
devices.torch_gc(force=True)
def tag(image, **kwargs) -> str:
"""Tag an image using DeepBooru.
Args:
image: PIL Image to tag
**kwargs: Tagger parameters (general_threshold, include_rating, exclude_tags,
max_tags, sort_alpha, use_spaces, escape_brackets)
Returns:
Formatted tag string
"""
import time
t0 = time.time()
jobid = shared.state.begin('DeepBooru Tag')
shared.log.info(f'DeepBooru: image_size={image.size if image else None}')
try:
result = model.tag(image, **kwargs)
shared.log.debug(f'DeepBooru: complete time={time.time()-t0:.2f} tags={len(result.split(", ")) if result else 0}')
except Exception as e:
result = f"Exception {type(e)}"
shared.log.error(f'DeepBooru: {e}')
shared.state.end(jobid)
return result
def batch(
model_name: str, # pylint: disable=unused-argument
batch_files: list,
batch_folder: str,
batch_str: str,
save_output: bool = True,
save_append: bool = False,
recursive: bool = False,
**kwargs
) -> str:
"""Process multiple images in batch mode.
Args:
model_name: Model name (ignored, only DeepBooru available)
batch_files: List of file paths
batch_folder: Folder path from file picker
batch_str: Folder path as string
save_output: Save caption to .txt files
save_append: Append to existing caption files
recursive: Recursively process subfolders
**kwargs: Additional arguments (for interface compatibility)
Returns:
Combined tag results
"""
import time
from pathlib import Path
import rich.progress as rp
# Load model
model.load()
# Collect image files
image_files = []
image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif'}
# From file picker
if batch_files:
for f in batch_files:
if isinstance(f, dict):
image_files.append(Path(f['name']))
elif hasattr(f, 'name'):
image_files.append(Path(f.name))
else:
image_files.append(Path(f))
# From folder picker
if batch_folder:
folder_path = None
if isinstance(batch_folder, list) and len(batch_folder) > 0:
f = batch_folder[0]
if isinstance(f, dict):
folder_path = Path(f['name']).parent
elif hasattr(f, 'name'):
folder_path = Path(f.name).parent
if folder_path and folder_path.is_dir():
if recursive:
for ext in image_extensions:
image_files.extend(folder_path.rglob(f'*{ext}'))
else:
for ext in image_extensions:
image_files.extend(folder_path.glob(f'*{ext}'))
# From string path
if batch_str and batch_str.strip():
folder_path = Path(batch_str.strip())
if folder_path.is_dir():
if recursive:
for ext in image_extensions:
image_files.extend(folder_path.rglob(f'*{ext}'))
else:
for ext in image_extensions:
image_files.extend(folder_path.glob(f'*{ext}'))
# Remove duplicates while preserving order
seen = set()
unique_files = []
for f in image_files:
f_resolved = f.resolve()
if f_resolved not in seen:
seen.add(f_resolved)
unique_files.append(f)
image_files = unique_files
if not image_files:
shared.log.warning('DeepBooru batch: no images found')
return ''
t0 = time.time()
jobid = shared.state.begin('DeepBooru Batch')
shared.log.info(f'DeepBooru batch: images={len(image_files)} write={save_output} append={save_append} recursive={recursive}')
results = []
model.start()
# Progress bar
pbar = rp.Progress(rp.TextColumn('[cyan]DeepBooru:'), rp.BarColumn(), rp.MofNCompleteColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
with pbar:
task = pbar.add_task(total=len(image_files), description='starting...')
for img_path in image_files:
pbar.update(task, advance=1, description=str(img_path.name))
try:
if shared.state.interrupted:
shared.log.info('DeepBooru batch: interrupted')
break
image = Image.open(img_path)
tags_str = model.tag_multi(image, **kwargs)
if save_output:
_save_tags_to_file(img_path, tags_str, save_append)
results.append(f'{img_path.name}: {tags_str[:100]}...' if len(tags_str) > 100 else f'{img_path.name}: {tags_str}')
except Exception as e:
shared.log.error(f'DeepBooru batch: file="{img_path}" error={e}')
results.append(f'{img_path.name}: ERROR - {e}')
model.stop()
elapsed = time.time() - t0
shared.log.info(f'DeepBooru batch: complete images={len(results)} time={elapsed:.1f}s')
shared.state.end(jobid)
return '\n'.join(results)