1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/interrogate/tagger.py
CalamitousFelicitousness 6b10f0df4f refactor(caption): address PR review feedback
Rename WD14 module and settings to WaifuDiffusion:
- Rename wd14.py to waifudiffusion.py
- Rename WD14Tagger class to WaifuDiffusionTagger
- Rename WD14_MODELS constant to WAIFUDIFFUSION_MODELS
- Rename settings: wd14_model -> waifudiffusion_model,
  wd14_character_threshold -> waifudiffusion_character_threshold
- Update all log messages from "WD14" to "WaifuDiffusion"

Code quality improvements:
- Simplify threshold parameter defaulting using `or` operator
- Extract save_output logic into _save_tags_to_file() helper with
  isolated error handling to prevent single file failures from
  impacting entire batch
- Fix timing log format consistency (remove 's' suffix)
2026-01-21 11:56:07 +00:00

80 lines
2.4 KiB
Python

# Unified Tagger Interface - Dispatches to WaifuDiffusion or DeepBooru based on model selection
# Provides a common interface for the Booru Tags tab
from modules import shared
DEEPBOORU_MODEL = "DeepBooru"
def get_models() -> list:
"""Return combined list: DeepBooru + WaifuDiffusion models."""
from modules.interrogate import waifudiffusion
return [DEEPBOORU_MODEL] + waifudiffusion.get_models()
def refresh_models() -> list:
"""Refresh and return all models."""
return get_models()
def is_deepbooru(model_name: str) -> bool:
"""Check if selected model is DeepBooru."""
return model_name == DEEPBOORU_MODEL
def load_model(model_name: str) -> bool:
"""Load appropriate backend."""
if is_deepbooru(model_name):
from modules.interrogate import deepbooru
return deepbooru.load_model()
else:
from modules.interrogate import waifudiffusion
return waifudiffusion.load_model(model_name)
def unload_model():
"""Unload both backends to ensure memory is freed."""
from modules.interrogate import deepbooru, waifudiffusion
deepbooru.unload_model()
waifudiffusion.unload_model()
def tag(image, model_name: str = None, **kwargs) -> str:
"""Unified tagging - dispatch to correct backend.
Args:
image: PIL Image to tag
model_name: Model to use (DeepBooru or WaifuDiffusion model name)
**kwargs: Additional arguments passed to the backend
Returns:
Formatted tag string
"""
if model_name is None:
model_name = shared.opts.waifudiffusion_model
if is_deepbooru(model_name):
from modules.interrogate import deepbooru
return deepbooru.tag(image, **kwargs)
else:
from modules.interrogate import waifudiffusion
return waifudiffusion.tag(image, model_name=model_name, **kwargs)
def batch(model_name: str, **kwargs) -> str:
"""Unified batch processing.
Args:
model_name: Model to use (DeepBooru or WaifuDiffusion model name)
**kwargs: Additional arguments passed to the backend
Returns:
Combined tag results
"""
if is_deepbooru(model_name):
from modules.interrogate import deepbooru
return deepbooru.batch(model_name=model_name, **kwargs)
else:
from modules.interrogate import waifudiffusion
return waifudiffusion.batch(model_name=model_name, **kwargs)