mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
848 lines
35 KiB
Python
848 lines
35 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
Tagger Settings Test Suite
|
|
|
|
Tests all WaifuDiffusion and DeepBooru tagger settings to verify they're properly
|
|
mapped and affect output correctly.
|
|
|
|
Usage:
|
|
python cli/test-tagger.py [image_path]
|
|
|
|
If no image path is provided, uses a built-in test image.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import time
|
|
|
|
# Add parent directory to path for imports
|
|
script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
sys.path.insert(0, script_dir)
|
|
os.chdir(script_dir)
|
|
|
|
# Suppress installer output during import
|
|
os.environ['SD_INSTALL_QUIET'] = '1'
|
|
|
|
# Initialize cmd_args properly with all argument groups
|
|
import modules.cmd_args
|
|
import installer
|
|
|
|
# Add installer args to the parser
|
|
installer.add_args(modules.cmd_args.parser)
|
|
|
|
# Parse with empty args to get defaults
|
|
modules.cmd_args.parsed, _ = modules.cmd_args.parser.parse_known_args([])
|
|
|
|
# Now we can safely import modules that depend on cmd_args
|
|
|
|
|
|
# Default test images (in order of preference)
|
|
DEFAULT_TEST_IMAGES = [
|
|
'html/sdnext-robot-2k.jpg', # SD.Next robot mascot
|
|
'venv/lib/python3.13/site-packages/gradio/test_data/lion.jpg',
|
|
'venv/lib/python3.13/site-packages/gradio/test_data/cheetah1.jpg',
|
|
'venv/lib/python3.13/site-packages/skimage/data/astronaut.png',
|
|
'venv/lib/python3.13/site-packages/skimage/data/coffee.png',
|
|
]
|
|
|
|
|
|
def find_test_image():
|
|
"""Find a suitable test image from defaults."""
|
|
for img_path in DEFAULT_TEST_IMAGES:
|
|
full_path = os.path.join(script_dir, img_path)
|
|
if os.path.exists(full_path):
|
|
return full_path
|
|
return None
|
|
|
|
|
|
def create_test_image():
|
|
"""Create a simple test image as fallback."""
|
|
from PIL import Image, ImageDraw
|
|
img = Image.new('RGB', (512, 512), color=(200, 150, 100))
|
|
draw = ImageDraw.Draw(img)
|
|
draw.ellipse([100, 100, 400, 400], fill=(255, 200, 150), outline=(100, 50, 0))
|
|
draw.rectangle([150, 200, 350, 350], fill=(150, 100, 200))
|
|
return img
|
|
|
|
|
|
class TaggerTest:
|
|
"""Test harness for tagger settings."""
|
|
|
|
def __init__(self):
|
|
self.results = {'passed': [], 'failed': [], 'skipped': []}
|
|
self.test_image = None
|
|
self.waifudiffusion_loaded = False
|
|
self.deepbooru_loaded = False
|
|
|
|
def log_pass(self, msg):
|
|
print(f" [PASS] {msg}")
|
|
self.results['passed'].append(msg)
|
|
|
|
def log_fail(self, msg):
|
|
print(f" [FAIL] {msg}")
|
|
self.results['failed'].append(msg)
|
|
|
|
def log_skip(self, msg):
|
|
print(f" [SKIP] {msg}")
|
|
self.results['skipped'].append(msg)
|
|
|
|
def log_warn(self, msg):
|
|
print(f" [WARN] {msg}")
|
|
self.results['skipped'].append(msg)
|
|
|
|
def setup(self):
|
|
"""Load test image and models."""
|
|
from PIL import Image
|
|
|
|
print("=" * 70)
|
|
print("TAGGER SETTINGS TEST SUITE")
|
|
print("=" * 70)
|
|
|
|
# Get or create test image
|
|
if len(sys.argv) > 1 and os.path.exists(sys.argv[1]):
|
|
img_path = sys.argv[1]
|
|
print(f"\nUsing provided image: {img_path}")
|
|
self.test_image = Image.open(img_path).convert('RGB')
|
|
else:
|
|
img_path = find_test_image()
|
|
if img_path:
|
|
print(f"\nUsing default test image: {img_path}")
|
|
self.test_image = Image.open(img_path).convert('RGB')
|
|
else:
|
|
print("\nNo test image found, creating synthetic image...")
|
|
self.test_image = create_test_image()
|
|
|
|
print(f"Image size: {self.test_image.size}")
|
|
|
|
# Load models
|
|
print("\nLoading models...")
|
|
from modules.interrogate import waifudiffusion, deepbooru
|
|
|
|
t0 = time.time()
|
|
self.waifudiffusion_loaded = waifudiffusion.load_model()
|
|
print(f" WaifuDiffusion: {'loaded' if self.waifudiffusion_loaded else 'FAILED'} ({time.time()-t0:.1f}s)")
|
|
|
|
t0 = time.time()
|
|
self.deepbooru_loaded = deepbooru.load_model()
|
|
print(f" DeepBooru: {'loaded' if self.deepbooru_loaded else 'FAILED'} ({time.time()-t0:.1f}s)")
|
|
|
|
def cleanup(self):
|
|
"""Unload models and free memory."""
|
|
print("\n" + "=" * 70)
|
|
print("CLEANUP")
|
|
print("=" * 70)
|
|
|
|
from modules.interrogate import waifudiffusion, deepbooru
|
|
from modules import devices
|
|
|
|
waifudiffusion.unload_model()
|
|
deepbooru.unload_model()
|
|
devices.torch_gc(force=True)
|
|
print(" Models unloaded")
|
|
|
|
def print_summary(self):
|
|
"""Print test summary."""
|
|
print("\n" + "=" * 70)
|
|
print("TEST SUMMARY")
|
|
print("=" * 70)
|
|
|
|
print(f"\n PASSED: {len(self.results['passed'])}")
|
|
for item in self.results['passed']:
|
|
print(f" - {item}")
|
|
|
|
print(f"\n FAILED: {len(self.results['failed'])}")
|
|
for item in self.results['failed']:
|
|
print(f" - {item}")
|
|
|
|
print(f"\n SKIPPED: {len(self.results['skipped'])}")
|
|
for item in self.results['skipped']:
|
|
print(f" - {item}")
|
|
|
|
total = len(self.results['passed']) + len(self.results['failed'])
|
|
if total > 0:
|
|
success_rate = len(self.results['passed']) / total * 100
|
|
print(f"\n SUCCESS RATE: {success_rate:.1f}% ({len(self.results['passed'])}/{total})")
|
|
|
|
print("\n" + "=" * 70)
|
|
|
|
# =========================================================================
|
|
# TEST: ONNX Providers Detection
|
|
# =========================================================================
|
|
def test_onnx_providers(self):
|
|
"""Verify ONNX runtime providers are properly detected."""
|
|
print("\n" + "=" * 70)
|
|
print("TEST: ONNX Providers Detection")
|
|
print("=" * 70)
|
|
|
|
from modules import devices
|
|
|
|
# Test 1: onnxruntime can be imported
|
|
try:
|
|
import onnxruntime as ort
|
|
self.log_pass(f"onnxruntime imported: version={ort.__version__}")
|
|
except ImportError as e:
|
|
self.log_fail(f"onnxruntime import failed: {e}")
|
|
return
|
|
|
|
# Test 2: Get available providers
|
|
available = ort.get_available_providers()
|
|
if available and len(available) > 0:
|
|
self.log_pass(f"Available providers: {available}")
|
|
else:
|
|
self.log_fail("No ONNX providers available")
|
|
return
|
|
|
|
# Test 3: devices.onnx is properly configured
|
|
if devices.onnx is not None and len(devices.onnx) > 0:
|
|
self.log_pass(f"devices.onnx configured: {devices.onnx}")
|
|
else:
|
|
self.log_fail(f"devices.onnx not configured: {devices.onnx}")
|
|
|
|
# Test 4: Configured providers exist in available providers
|
|
for provider in devices.onnx:
|
|
if provider in available:
|
|
self.log_pass(f"Provider '{provider}' is available")
|
|
else:
|
|
self.log_fail(f"Provider '{provider}' configured but not available")
|
|
|
|
# Test 5: If WaifuDiffusion loaded, check session providers
|
|
if self.waifudiffusion_loaded:
|
|
from modules.interrogate import waifudiffusion
|
|
if waifudiffusion.tagger.session is not None:
|
|
session_providers = waifudiffusion.tagger.session.get_providers()
|
|
self.log_pass(f"WaifuDiffusion session providers: {session_providers}")
|
|
else:
|
|
self.log_skip("WaifuDiffusion session not initialized")
|
|
|
|
# =========================================================================
|
|
# TEST: Memory Management (Offload/Reload/Unload)
|
|
# =========================================================================
|
|
def get_memory_stats(self):
|
|
"""Get current GPU and CPU memory usage."""
|
|
import torch
|
|
|
|
stats = {}
|
|
|
|
# GPU memory (if CUDA available)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
stats['gpu_allocated'] = torch.cuda.memory_allocated() / 1024 / 1024 # MB
|
|
stats['gpu_reserved'] = torch.cuda.memory_reserved() / 1024 / 1024 # MB
|
|
else:
|
|
stats['gpu_allocated'] = 0
|
|
stats['gpu_reserved'] = 0
|
|
|
|
# CPU/RAM memory (try psutil, fallback to basic)
|
|
try:
|
|
import psutil
|
|
process = psutil.Process()
|
|
stats['ram_used'] = process.memory_info().rss / 1024 / 1024 # MB
|
|
except ImportError:
|
|
stats['ram_used'] = 0
|
|
|
|
return stats
|
|
|
|
def test_memory_management(self):
|
|
"""Test model offload to RAM, reload to GPU, and unload with memory monitoring."""
|
|
print("\n" + "=" * 70)
|
|
print("TEST: Memory Management (Offload/Reload/Unload)")
|
|
print("=" * 70)
|
|
|
|
import torch
|
|
import gc
|
|
from modules import devices
|
|
from modules.interrogate import waifudiffusion, deepbooru
|
|
|
|
# Memory leak tolerance (MB) - some variance is expected
|
|
GPU_LEAK_TOLERANCE_MB = 50
|
|
RAM_LEAK_TOLERANCE_MB = 200
|
|
|
|
# =====================================================================
|
|
# DeepBooru: Test GPU/CPU movement with memory monitoring
|
|
# =====================================================================
|
|
if self.deepbooru_loaded:
|
|
print("\n DeepBooru Memory Management:")
|
|
|
|
# Baseline memory before any operations
|
|
gc.collect()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
baseline = self.get_memory_stats()
|
|
print(f" Baseline: GPU={baseline['gpu_allocated']:.1f}MB, RAM={baseline['ram_used']:.1f}MB")
|
|
|
|
# Test 1: Check initial state (should be on CPU after load)
|
|
initial_device = next(deepbooru.model.model.parameters()).device
|
|
print(f" Initial device: {initial_device}")
|
|
if initial_device.type == 'cpu':
|
|
self.log_pass("DeepBooru: initial state on CPU")
|
|
else:
|
|
self.log_pass(f"DeepBooru: initial state on {initial_device}")
|
|
|
|
# Test 2: Move to GPU (start)
|
|
deepbooru.model.start()
|
|
gpu_device = next(deepbooru.model.model.parameters()).device
|
|
after_gpu = self.get_memory_stats()
|
|
print(f" After start(): {gpu_device} | GPU={after_gpu['gpu_allocated']:.1f}MB (+{after_gpu['gpu_allocated']-baseline['gpu_allocated']:.1f}MB)")
|
|
if gpu_device.type == devices.device.type:
|
|
self.log_pass(f"DeepBooru: moved to GPU ({gpu_device})")
|
|
else:
|
|
self.log_fail(f"DeepBooru: failed to move to GPU, got {gpu_device}")
|
|
|
|
# Test 3: Run inference while on GPU
|
|
try:
|
|
tags = deepbooru.model.tag_multi(self.test_image, max_tags=3)
|
|
after_infer = self.get_memory_stats()
|
|
print(f" After inference: GPU={after_infer['gpu_allocated']:.1f}MB")
|
|
if tags:
|
|
self.log_pass(f"DeepBooru: inference on GPU works ({tags[:30]}...)")
|
|
else:
|
|
self.log_fail("DeepBooru: inference on GPU returned empty")
|
|
except Exception as e:
|
|
self.log_fail(f"DeepBooru: inference on GPU failed: {e}")
|
|
|
|
# Test 4: Offload to CPU (stop)
|
|
deepbooru.model.stop()
|
|
gc.collect()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
after_offload = self.get_memory_stats()
|
|
cpu_device = next(deepbooru.model.model.parameters()).device
|
|
print(f" After stop(): {cpu_device} | GPU={after_offload['gpu_allocated']:.1f}MB, RAM={after_offload['ram_used']:.1f}MB")
|
|
if cpu_device.type == 'cpu':
|
|
self.log_pass("DeepBooru: offloaded to CPU")
|
|
else:
|
|
self.log_fail(f"DeepBooru: failed to offload, still on {cpu_device}")
|
|
|
|
# Check GPU memory returned to near baseline after offload
|
|
gpu_diff = after_offload['gpu_allocated'] - baseline['gpu_allocated']
|
|
if gpu_diff <= GPU_LEAK_TOLERANCE_MB:
|
|
self.log_pass(f"DeepBooru: GPU memory cleared after offload (diff={gpu_diff:.1f}MB)")
|
|
else:
|
|
self.log_fail(f"DeepBooru: GPU memory leak after offload (diff={gpu_diff:.1f}MB > {GPU_LEAK_TOLERANCE_MB}MB)")
|
|
|
|
# Test 5: Full cycle - reload and run again
|
|
deepbooru.model.start()
|
|
try:
|
|
tags = deepbooru.model.tag_multi(self.test_image, max_tags=3)
|
|
if tags:
|
|
self.log_pass("DeepBooru: reload cycle works")
|
|
else:
|
|
self.log_fail("DeepBooru: reload cycle returned empty")
|
|
except Exception as e:
|
|
self.log_fail(f"DeepBooru: reload cycle failed: {e}")
|
|
deepbooru.model.stop()
|
|
|
|
# Test 6: Full unload with memory check
|
|
deepbooru.unload_model()
|
|
gc.collect()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
after_unload = self.get_memory_stats()
|
|
print(f" After unload: GPU={after_unload['gpu_allocated']:.1f}MB, RAM={after_unload['ram_used']:.1f}MB")
|
|
|
|
if deepbooru.model.model is None:
|
|
self.log_pass("DeepBooru: unload successful")
|
|
else:
|
|
self.log_fail("DeepBooru: unload failed, model still exists")
|
|
|
|
# Check for memory leaks after full unload
|
|
gpu_leak = after_unload['gpu_allocated'] - baseline['gpu_allocated']
|
|
ram_leak = after_unload['ram_used'] - baseline['ram_used']
|
|
if gpu_leak <= GPU_LEAK_TOLERANCE_MB:
|
|
self.log_pass(f"DeepBooru: no GPU memory leak after unload (diff={gpu_leak:.1f}MB)")
|
|
else:
|
|
self.log_fail(f"DeepBooru: GPU memory leak detected (diff={gpu_leak:.1f}MB > {GPU_LEAK_TOLERANCE_MB}MB)")
|
|
|
|
if ram_leak <= RAM_LEAK_TOLERANCE_MB:
|
|
self.log_pass(f"DeepBooru: no RAM leak after unload (diff={ram_leak:.1f}MB)")
|
|
else:
|
|
self.log_warn(f"DeepBooru: RAM increased after unload (diff={ram_leak:.1f}MB) - may be caching")
|
|
|
|
# Reload for remaining tests
|
|
deepbooru.load_model()
|
|
|
|
# =====================================================================
|
|
# WaifuDiffusion: Test session lifecycle with memory monitoring
|
|
# =====================================================================
|
|
if self.waifudiffusion_loaded:
|
|
print("\n WaifuDiffusion Memory Management:")
|
|
|
|
# Baseline memory
|
|
gc.collect()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
baseline = self.get_memory_stats()
|
|
print(f" Baseline: GPU={baseline['gpu_allocated']:.1f}MB, RAM={baseline['ram_used']:.1f}MB")
|
|
|
|
# Test 1: Session exists
|
|
if waifudiffusion.tagger.session is not None:
|
|
self.log_pass("WaifuDiffusion: session loaded")
|
|
else:
|
|
self.log_fail("WaifuDiffusion: session not loaded")
|
|
return
|
|
|
|
# Test 2: Get current providers
|
|
providers = waifudiffusion.tagger.session.get_providers()
|
|
print(f" Active providers: {providers}")
|
|
self.log_pass(f"WaifuDiffusion: using providers {providers}")
|
|
|
|
# Test 3: Run inference
|
|
try:
|
|
tags = waifudiffusion.tagger.predict(self.test_image, max_tags=3)
|
|
after_infer = self.get_memory_stats()
|
|
print(f" After inference: GPU={after_infer['gpu_allocated']:.1f}MB, RAM={after_infer['ram_used']:.1f}MB")
|
|
if tags:
|
|
self.log_pass(f"WaifuDiffusion: inference works ({tags[:30]}...)")
|
|
else:
|
|
self.log_fail("WaifuDiffusion: inference returned empty")
|
|
except Exception as e:
|
|
self.log_fail(f"WaifuDiffusion: inference failed: {e}")
|
|
|
|
# Test 4: Unload session with memory check
|
|
model_name = waifudiffusion.tagger.model_name
|
|
waifudiffusion.unload_model()
|
|
gc.collect()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
after_unload = self.get_memory_stats()
|
|
print(f" After unload: GPU={after_unload['gpu_allocated']:.1f}MB, RAM={after_unload['ram_used']:.1f}MB")
|
|
|
|
if waifudiffusion.tagger.session is None:
|
|
self.log_pass("WaifuDiffusion: unload successful")
|
|
else:
|
|
self.log_fail("WaifuDiffusion: unload failed, session still exists")
|
|
|
|
# Check for memory leaks after unload
|
|
gpu_leak = after_unload['gpu_allocated'] - baseline['gpu_allocated']
|
|
ram_leak = after_unload['ram_used'] - baseline['ram_used']
|
|
if gpu_leak <= GPU_LEAK_TOLERANCE_MB:
|
|
self.log_pass(f"WaifuDiffusion: no GPU memory leak after unload (diff={gpu_leak:.1f}MB)")
|
|
else:
|
|
self.log_fail(f"WaifuDiffusion: GPU memory leak detected (diff={gpu_leak:.1f}MB > {GPU_LEAK_TOLERANCE_MB}MB)")
|
|
|
|
if ram_leak <= RAM_LEAK_TOLERANCE_MB:
|
|
self.log_pass(f"WaifuDiffusion: no RAM leak after unload (diff={ram_leak:.1f}MB)")
|
|
else:
|
|
self.log_warn(f"WaifuDiffusion: RAM increased after unload (diff={ram_leak:.1f}MB) - may be caching")
|
|
|
|
# Test 5: Reload session
|
|
waifudiffusion.load_model(model_name)
|
|
after_reload = self.get_memory_stats()
|
|
print(f" After reload: GPU={after_reload['gpu_allocated']:.1f}MB, RAM={after_reload['ram_used']:.1f}MB")
|
|
if waifudiffusion.tagger.session is not None:
|
|
self.log_pass("WaifuDiffusion: reload successful")
|
|
else:
|
|
self.log_fail("WaifuDiffusion: reload failed")
|
|
|
|
# Test 6: Inference after reload
|
|
try:
|
|
tags = waifudiffusion.tagger.predict(self.test_image, max_tags=3)
|
|
if tags:
|
|
self.log_pass("WaifuDiffusion: inference after reload works")
|
|
else:
|
|
self.log_fail("WaifuDiffusion: inference after reload returned empty")
|
|
except Exception as e:
|
|
self.log_fail(f"WaifuDiffusion: inference after reload failed: {e}")
|
|
|
|
# Final memory check after full cycle
|
|
gc.collect()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
final = self.get_memory_stats()
|
|
print(f" Final (after full cycle): GPU={final['gpu_allocated']:.1f}MB, RAM={final['ram_used']:.1f}MB")
|
|
|
|
# =========================================================================
|
|
# TEST: Settings Existence
|
|
# =========================================================================
|
|
def test_settings_exist(self):
|
|
"""Verify all tagger settings exist in shared.opts."""
|
|
print("\n" + "=" * 70)
|
|
print("TEST: Settings Existence")
|
|
print("=" * 70)
|
|
|
|
from modules import shared
|
|
|
|
settings = [
|
|
('tagger_threshold', float),
|
|
('tagger_include_rating', bool),
|
|
('tagger_max_tags', int),
|
|
('tagger_sort_alpha', bool),
|
|
('tagger_use_spaces', bool),
|
|
('tagger_escape_brackets', bool),
|
|
('tagger_exclude_tags', str),
|
|
('tagger_show_scores', bool),
|
|
('waifudiffusion_model', str),
|
|
('waifudiffusion_character_threshold', float),
|
|
('interrogate_offload', bool),
|
|
]
|
|
|
|
for setting, _expected_type in settings:
|
|
if hasattr(shared.opts, setting):
|
|
value = getattr(shared.opts, setting)
|
|
self.log_pass(f"{setting} = {value!r}")
|
|
else:
|
|
self.log_fail(f"{setting} - NOT FOUND")
|
|
|
|
# =========================================================================
|
|
# TEST: Parameter Effect - Tests a single parameter on both taggers
|
|
# =========================================================================
|
|
def test_parameter(self, param_name, test_func, waifudiffusion_supported=True, deepbooru_supported=True):
|
|
"""Test a parameter on both WaifuDiffusion and DeepBooru."""
|
|
print(f"\n Testing: {param_name}")
|
|
|
|
if waifudiffusion_supported and self.waifudiffusion_loaded:
|
|
try:
|
|
result = test_func('waifudiffusion')
|
|
if result is True:
|
|
self.log_pass(f"WaifuDiffusion: {param_name}")
|
|
elif result is False:
|
|
self.log_fail(f"WaifuDiffusion: {param_name}")
|
|
else:
|
|
self.log_skip(f"WaifuDiffusion: {param_name} - {result}")
|
|
except Exception as e:
|
|
self.log_fail(f"WaifuDiffusion: {param_name} - {e}")
|
|
elif waifudiffusion_supported:
|
|
self.log_skip(f"WaifuDiffusion: {param_name} - model not loaded")
|
|
|
|
if deepbooru_supported and self.deepbooru_loaded:
|
|
try:
|
|
result = test_func('deepbooru')
|
|
if result is True:
|
|
self.log_pass(f"DeepBooru: {param_name}")
|
|
elif result is False:
|
|
self.log_fail(f"DeepBooru: {param_name}")
|
|
else:
|
|
self.log_skip(f"DeepBooru: {param_name} - {result}")
|
|
except Exception as e:
|
|
self.log_fail(f"DeepBooru: {param_name} - {e}")
|
|
elif deepbooru_supported:
|
|
self.log_skip(f"DeepBooru: {param_name} - model not loaded")
|
|
|
|
def tag(self, tagger, **kwargs):
|
|
"""Helper to call the appropriate tagger."""
|
|
if tagger == 'waifudiffusion':
|
|
from modules.interrogate import waifudiffusion
|
|
return waifudiffusion.tagger.predict(self.test_image, **kwargs)
|
|
else:
|
|
from modules.interrogate import deepbooru
|
|
return deepbooru.model.tag(self.test_image, **kwargs)
|
|
|
|
# =========================================================================
|
|
# TEST: general_threshold
|
|
# =========================================================================
|
|
def test_threshold(self):
|
|
"""Test that threshold affects tag count."""
|
|
print("\n" + "=" * 70)
|
|
print("TEST: general_threshold effect")
|
|
print("=" * 70)
|
|
|
|
def check_threshold(tagger):
|
|
tags_high = self.tag(tagger, general_threshold=0.9)
|
|
tags_low = self.tag(tagger, general_threshold=0.1)
|
|
|
|
count_high = len(tags_high.split(', ')) if tags_high else 0
|
|
count_low = len(tags_low.split(', ')) if tags_low else 0
|
|
|
|
print(f" {tagger}: threshold=0.9 -> {count_high} tags, threshold=0.1 -> {count_low} tags")
|
|
|
|
if count_low > count_high:
|
|
return True
|
|
elif count_low == count_high == 0:
|
|
return "no tags returned"
|
|
else:
|
|
return "threshold effect unclear"
|
|
|
|
self.test_parameter('general_threshold', check_threshold)
|
|
|
|
# =========================================================================
|
|
# TEST: max_tags
|
|
# =========================================================================
|
|
def test_max_tags(self):
|
|
"""Test that max_tags limits output."""
|
|
print("\n" + "=" * 70)
|
|
print("TEST: max_tags effect")
|
|
print("=" * 70)
|
|
|
|
def check_max_tags(tagger):
|
|
tags_5 = self.tag(tagger, general_threshold=0.1, max_tags=5)
|
|
tags_50 = self.tag(tagger, general_threshold=0.1, max_tags=50)
|
|
|
|
count_5 = len(tags_5.split(', ')) if tags_5 else 0
|
|
count_50 = len(tags_50.split(', ')) if tags_50 else 0
|
|
|
|
print(f" {tagger}: max_tags=5 -> {count_5} tags, max_tags=50 -> {count_50} tags")
|
|
|
|
return count_5 <= 5
|
|
|
|
self.test_parameter('max_tags', check_max_tags)
|
|
|
|
# =========================================================================
|
|
# TEST: use_spaces
|
|
# =========================================================================
|
|
def test_use_spaces(self):
|
|
"""Test that use_spaces converts underscores to spaces."""
|
|
print("\n" + "=" * 70)
|
|
print("TEST: use_spaces effect")
|
|
print("=" * 70)
|
|
|
|
def check_use_spaces(tagger):
|
|
tags_under = self.tag(tagger, use_spaces=False, max_tags=10)
|
|
tags_space = self.tag(tagger, use_spaces=True, max_tags=10)
|
|
|
|
print(f" {tagger} use_spaces=False: {tags_under[:50]}...")
|
|
print(f" {tagger} use_spaces=True: {tags_space[:50]}...")
|
|
|
|
# Check if underscores are converted to spaces
|
|
has_underscore_before = '_' in tags_under
|
|
has_underscore_after = '_' in tags_space.replace(', ', ',') # ignore comma-space
|
|
|
|
# If there were underscores before but not after, it worked
|
|
if has_underscore_before and not has_underscore_after:
|
|
return True
|
|
# If there were never underscores, inconclusive
|
|
elif not has_underscore_before:
|
|
return "no underscores in tags to convert"
|
|
else:
|
|
return False
|
|
|
|
self.test_parameter('use_spaces', check_use_spaces)
|
|
|
|
# =========================================================================
|
|
# TEST: escape_brackets
|
|
# =========================================================================
|
|
def test_escape_brackets(self):
|
|
"""Test that escape_brackets escapes special characters."""
|
|
print("\n" + "=" * 70)
|
|
print("TEST: escape_brackets effect")
|
|
print("=" * 70)
|
|
|
|
def check_escape_brackets(tagger):
|
|
tags_escaped = self.tag(tagger, escape_brackets=True, max_tags=30, general_threshold=0.1)
|
|
tags_raw = self.tag(tagger, escape_brackets=False, max_tags=30, general_threshold=0.1)
|
|
|
|
print(f" {tagger} escape=True: {tags_escaped[:60]}...")
|
|
print(f" {tagger} escape=False: {tags_raw[:60]}...")
|
|
|
|
# Check for escaped brackets (\\( or \\))
|
|
has_escaped = '\\(' in tags_escaped or '\\)' in tags_escaped
|
|
has_unescaped = '(' in tags_raw.replace('\\(', '') or ')' in tags_raw.replace('\\)', '')
|
|
|
|
if has_escaped:
|
|
return True
|
|
elif has_unescaped:
|
|
# Has brackets but not escaped - fail
|
|
return False
|
|
else:
|
|
return "no brackets in tags to escape"
|
|
|
|
self.test_parameter('escape_brackets', check_escape_brackets)
|
|
|
|
# =========================================================================
|
|
# TEST: sort_alpha
|
|
# =========================================================================
|
|
def test_sort_alpha(self):
|
|
"""Test that sort_alpha sorts tags alphabetically."""
|
|
print("\n" + "=" * 70)
|
|
print("TEST: sort_alpha effect")
|
|
print("=" * 70)
|
|
|
|
def check_sort_alpha(tagger):
|
|
tags_conf = self.tag(tagger, sort_alpha=False, max_tags=20, general_threshold=0.1)
|
|
tags_alpha = self.tag(tagger, sort_alpha=True, max_tags=20, general_threshold=0.1)
|
|
|
|
list_conf = [t.strip() for t in tags_conf.split(',')]
|
|
list_alpha = [t.strip() for t in tags_alpha.split(',')]
|
|
|
|
print(f" {tagger} by_confidence: {', '.join(list_conf[:5])}...")
|
|
print(f" {tagger} alphabetical: {', '.join(list_alpha[:5])}...")
|
|
|
|
is_sorted = list_alpha == sorted(list_alpha)
|
|
return is_sorted
|
|
|
|
self.test_parameter('sort_alpha', check_sort_alpha)
|
|
|
|
# =========================================================================
|
|
# TEST: exclude_tags
|
|
# =========================================================================
|
|
def test_exclude_tags(self):
|
|
"""Test that exclude_tags removes specified tags."""
|
|
print("\n" + "=" * 70)
|
|
print("TEST: exclude_tags effect")
|
|
print("=" * 70)
|
|
|
|
def check_exclude_tags(tagger):
|
|
tags_all = self.tag(tagger, max_tags=50, general_threshold=0.1, exclude_tags='')
|
|
tag_list = [t.strip().replace(' ', '_') for t in tags_all.split(',')]
|
|
|
|
if len(tag_list) < 2:
|
|
return "not enough tags to test"
|
|
|
|
# Exclude the first tag
|
|
tag_to_exclude = tag_list[0]
|
|
tags_filtered = self.tag(tagger, max_tags=50, general_threshold=0.1, exclude_tags=tag_to_exclude)
|
|
|
|
print(f" {tagger} without exclusion: {tags_all[:50]}...")
|
|
print(f" {tagger} excluding '{tag_to_exclude}': {tags_filtered[:50]}...")
|
|
|
|
# Check if the exact tag was removed by parsing the filtered list
|
|
filtered_list = [t.strip().replace(' ', '_') for t in tags_filtered.split(',')]
|
|
# Also check space variant
|
|
tag_space_variant = tag_to_exclude.replace('_', ' ')
|
|
tag_present = tag_to_exclude in filtered_list or tag_space_variant in [t.strip() for t in tags_filtered.split(',')]
|
|
return not tag_present
|
|
|
|
self.test_parameter('exclude_tags', check_exclude_tags)
|
|
|
|
# =========================================================================
|
|
# TEST: tagger_show_scores (via shared.opts)
|
|
# =========================================================================
|
|
def test_show_scores(self):
|
|
"""Test that tagger_show_scores adds confidence scores."""
|
|
print("\n" + "=" * 70)
|
|
print("TEST: tagger_show_scores effect")
|
|
print("=" * 70)
|
|
|
|
from modules import shared
|
|
|
|
def check_show_scores(tagger):
|
|
original = shared.opts.tagger_show_scores
|
|
|
|
shared.opts.tagger_show_scores = False
|
|
tags_no_scores = self.tag(tagger, max_tags=5)
|
|
|
|
shared.opts.tagger_show_scores = True
|
|
tags_with_scores = self.tag(tagger, max_tags=5)
|
|
|
|
shared.opts.tagger_show_scores = original
|
|
|
|
print(f" {tagger} show_scores=False: {tags_no_scores[:50]}...")
|
|
print(f" {tagger} show_scores=True: {tags_with_scores[:50]}...")
|
|
|
|
has_scores = ':' in tags_with_scores and '(' in tags_with_scores
|
|
no_scores = ':' not in tags_no_scores
|
|
|
|
return has_scores and no_scores
|
|
|
|
self.test_parameter('tagger_show_scores', check_show_scores)
|
|
|
|
# =========================================================================
|
|
# TEST: include_rating
|
|
# =========================================================================
|
|
def test_include_rating(self):
|
|
"""Test that include_rating includes/excludes rating tags."""
|
|
print("\n" + "=" * 70)
|
|
print("TEST: include_rating effect")
|
|
print("=" * 70)
|
|
|
|
def check_include_rating(tagger):
|
|
tags_no_rating = self.tag(tagger, include_rating=False, max_tags=100, general_threshold=0.01)
|
|
tags_with_rating = self.tag(tagger, include_rating=True, max_tags=100, general_threshold=0.01)
|
|
|
|
print(f" {tagger} include_rating=False: {tags_no_rating[:60]}...")
|
|
print(f" {tagger} include_rating=True: {tags_with_rating[:60]}...")
|
|
|
|
# Rating tags typically start with "rating:" or are like "safe", "questionable", "explicit"
|
|
rating_keywords = ['rating:', 'safe', 'questionable', 'explicit', 'general', 'sensitive']
|
|
|
|
has_rating_before = any(kw in tags_no_rating.lower() for kw in rating_keywords)
|
|
has_rating_after = any(kw in tags_with_rating.lower() for kw in rating_keywords)
|
|
|
|
if has_rating_after and not has_rating_before:
|
|
return True
|
|
elif has_rating_after and has_rating_before:
|
|
return "rating tags appear in both (may need very low threshold)"
|
|
elif not has_rating_after:
|
|
return "no rating tags detected"
|
|
else:
|
|
return False
|
|
|
|
self.test_parameter('include_rating', check_include_rating)
|
|
|
|
# =========================================================================
|
|
# TEST: character_threshold (WaifuDiffusion only)
|
|
# =========================================================================
|
|
def test_character_threshold(self):
|
|
"""Test that character_threshold affects character tag count (WaifuDiffusion only)."""
|
|
print("\n" + "=" * 70)
|
|
print("TEST: character_threshold effect (WaifuDiffusion only)")
|
|
print("=" * 70)
|
|
|
|
def check_character_threshold(tagger):
|
|
if tagger != 'waifudiffusion':
|
|
return "not supported"
|
|
|
|
# Character threshold only affects character tags
|
|
# We need an image with character tags to properly test this
|
|
tags_high = self.tag(tagger, character_threshold=0.99, general_threshold=0.5)
|
|
tags_low = self.tag(tagger, character_threshold=0.1, general_threshold=0.5)
|
|
|
|
print(f" {tagger} char_threshold=0.99: {tags_high[:50]}...")
|
|
print(f" {tagger} char_threshold=0.10: {tags_low[:50]}...")
|
|
|
|
# If thresholds are different, the setting is at least being applied
|
|
# Hard to verify without an image with known character tags
|
|
return True # Setting exists and is applied (verified by code inspection)
|
|
|
|
self.test_parameter('character_threshold', check_character_threshold, deepbooru_supported=False)
|
|
|
|
# =========================================================================
|
|
# TEST: Unified Interface
|
|
# =========================================================================
|
|
def test_unified_interface(self):
|
|
"""Test that the unified tagger interface works for both backends."""
|
|
print("\n" + "=" * 70)
|
|
print("TEST: Unified tagger.tag() interface")
|
|
print("=" * 70)
|
|
|
|
from modules.interrogate import tagger
|
|
|
|
# Test WaifuDiffusion through unified interface
|
|
if self.waifudiffusion_loaded:
|
|
try:
|
|
models = tagger.get_models()
|
|
waifudiffusion_model = next((m for m in models if m != 'DeepBooru'), None)
|
|
if waifudiffusion_model:
|
|
tags = tagger.tag(self.test_image, model_name=waifudiffusion_model, max_tags=5)
|
|
print(f" WaifuDiffusion ({waifudiffusion_model}): {tags[:50]}...")
|
|
self.log_pass("Unified interface: WaifuDiffusion")
|
|
except Exception as e:
|
|
self.log_fail(f"Unified interface: WaifuDiffusion - {e}")
|
|
|
|
# Test DeepBooru through unified interface
|
|
if self.deepbooru_loaded:
|
|
try:
|
|
tags = tagger.tag(self.test_image, model_name='DeepBooru', max_tags=5)
|
|
print(f" DeepBooru: {tags[:50]}...")
|
|
self.log_pass("Unified interface: DeepBooru")
|
|
except Exception as e:
|
|
self.log_fail(f"Unified interface: DeepBooru - {e}")
|
|
|
|
def run_all_tests(self):
|
|
"""Run all tests."""
|
|
self.setup()
|
|
|
|
self.test_onnx_providers()
|
|
self.test_memory_management()
|
|
self.test_settings_exist()
|
|
self.test_threshold()
|
|
self.test_max_tags()
|
|
self.test_use_spaces()
|
|
self.test_escape_brackets()
|
|
self.test_sort_alpha()
|
|
self.test_exclude_tags()
|
|
self.test_show_scores()
|
|
self.test_include_rating()
|
|
self.test_character_threshold()
|
|
self.test_unified_interface()
|
|
|
|
self.cleanup()
|
|
self.print_summary()
|
|
|
|
return len(self.results['failed']) == 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test = TaggerTest()
|
|
success = test.run_all_tests()
|
|
sys.exit(0 if success else 1)
|