1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/interrogate/moondream3.py
vladmandic 3f161b5532 lint moondream
Signed-off-by: vladmandic <mandic00@live.com>
2025-12-08 18:16:00 +01:00

409 lines
16 KiB
Python

# Moondream 3 Preview VLM Implementation
# Source: https://huggingface.co/moondream/moondream3-preview
# Model: 9.3GB, gated (requires HuggingFace authentication)
# Architecture: Mixture-of-Experts (9B total params, 2B active)
import os
import re
import transformers
from PIL import Image
from modules import shared, devices, sd_models
from modules.interrogate import vqa_detection
# Debug logging - function-based to avoid circular import
debug_enabled = os.environ.get('SD_VQA_DEBUG', None) is not None
def debug(*args, **kwargs):
if debug_enabled:
shared.log.trace(*args, **kwargs)
# Global state
moondream3_model = None
loaded = None
image_cache = {} # Cache encoded images for reuse
def get_settings():
"""
Build settings dict for Moondream 3 API from global VQA options.
Moondream 3 accepts: temperature, top_p, max_tokens
"""
settings = {}
if shared.opts.interrogate_vlm_max_length > 0:
settings['max_tokens'] = shared.opts.interrogate_vlm_max_length
if shared.opts.interrogate_vlm_temperature > 0:
settings['temperature'] = shared.opts.interrogate_vlm_temperature
if shared.opts.interrogate_vlm_top_p > 0:
settings['top_p'] = shared.opts.interrogate_vlm_top_p
return settings if settings else None
def load_model(repo: str):
"""Load Moondream 3 model."""
global moondream3_model, loaded # pylint: disable=global-statement
if moondream3_model is None or loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
moondream3_model = None
moondream3_model = transformers.AutoModelForCausalLM.from_pretrained(
repo,
trust_remote_code=True,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
)
moondream3_model.eval()
# Initialize KV caches before moving to device (they're lazy by default)
if hasattr(moondream3_model, '_setup_caches'):
moondream3_model._setup_caches() # pylint: disable=protected-access
# Disable flex_attention decoding (can cause hangs due to torch.compile)
if hasattr(moondream3_model, 'model') and hasattr(moondream3_model.model, 'use_flex_decoding'):
moondream3_model.model.use_flex_decoding = False
loaded = repo
devices.torch_gc()
# Move model to active device
sd_models.move_model(moondream3_model, devices.device)
return moondream3_model
def encode_image(image: Image.Image, cache_key: str = None):
"""
Encode image for reuse across multiple queries.
Args:
image: PIL Image
cache_key: Optional cache key for storing encoded image
Returns:
Encoded image tensor
"""
if cache_key and cache_key in image_cache:
debug(f'VQA interrogate: handler=moondream3 using cached encoding for cache_key="{cache_key}"')
return image_cache[cache_key]
model = load_model(loaded)
with devices.inference_context():
encoded = model.encode_image(image)
if cache_key:
image_cache[cache_key] = encoded
debug(f'VQA interrogate: handler=moondream3 cached encoding cache_key="{cache_key}" cache_size={len(image_cache)}')
return encoded
def query(image: Image.Image, question: str, repo: str, stream: bool = False,
temperature: float = None, top_p: float = None, max_tokens: int = None,
use_cache: bool = False, reasoning: bool = True):
"""
Visual question answering with optional streaming.
Args:
image: PIL Image
question: Question about the image
repo: Model repository
stream: Enable streaming output (generator)
temperature: Sampling temperature (overrides global setting)
top_p: Nucleus sampling parameter (overrides global setting)
max_tokens: Maximum tokens to generate (overrides global setting)
use_cache: Use cached image encoding if available
Returns:
Answer dict or string (or generator if stream=True)
"""
model = load_model(repo)
# Build settings - per-call parameters override global settings
settings = get_settings() or {}
if temperature is not None:
settings['temperature'] = temperature
if top_p is not None:
settings['top_p'] = top_p
if max_tokens is not None:
settings['max_tokens'] = max_tokens
debug(f'VQA interrogate: handler=moondream3 method=query question="{question}" stream={stream} settings={settings}')
# Use cached encoding if requested
if use_cache:
cache_key = f"{id(image)}_{question}"
image_input = encode_image(image, cache_key)
else:
image_input = image
with devices.inference_context():
response = model.query(
image=image_input,
question=question,
stream=stream,
settings=settings if settings else None,
reasoning=reasoning
)
# Log response structure (for non-streaming)
if not stream:
if isinstance(response, dict):
debug(f'VQA interrogate: handler=moondream3 response_type=dict keys={list(response.keys())}')
if 'reasoning' in response:
reasoning_text = response['reasoning'].get('text', '')[:100] + '...' if len(response['reasoning'].get('text', '')) > 100 else response['reasoning'].get('text', '')
debug(f'VQA interrogate: handler=moondream3 reasoning="{reasoning_text}"')
if 'answer' in response:
debug(f'VQA interrogate: handler=moondream3 answer="{response["answer"]}"')
return response
def caption(image: Image.Image, repo: str, length: str = 'normal', stream: bool = False,
temperature: float = None, top_p: float = None, max_tokens: int = None):
"""
Generate image captions at different lengths.
Args:
image: PIL Image
repo: Model repository
length: Caption length - 'short', 'normal', or 'long'
stream: Enable streaming output (generator)
temperature: Sampling temperature (overrides global setting)
top_p: Nucleus sampling parameter (overrides global setting)
max_tokens: Maximum tokens to generate (overrides global setting)
Returns:
Caption dict or string (or generator if stream=True)
"""
model = load_model(repo)
# Build settings - per-call parameters override global settings
settings = get_settings() or {}
if temperature is not None:
settings['temperature'] = temperature
if top_p is not None:
settings['top_p'] = top_p
if max_tokens is not None:
settings['max_tokens'] = max_tokens
debug(f'VQA interrogate: handler=moondream3 method=caption length={length} stream={stream} settings={settings}')
with devices.inference_context():
response = model.caption(
image,
length=length,
stream=stream,
settings=settings if settings else None
)
# Log response structure (for non-streaming)
if not stream and isinstance(response, dict):
debug(f'VQA interrogate: handler=moondream3 response_type=dict keys={list(response.keys())}')
return response
def point(image: Image.Image, object_name: str, repo: str):
"""
Identify coordinates of all instances of a specific object in the image.
Args:
image: PIL Image
object_name: Name of object to locate
repo: Model repository
Returns:
List of (x, y) tuples with coordinates normalized to 0-1 range, or None if not found
Example: [(0.733, 0.442), (0.5, 0.6)] for 2 instances
"""
model = load_model(repo)
debug(f'VQA interrogate: handler=moondream3 method=point object_name="{object_name}"')
with devices.inference_context():
result = model.point(image, object_name)
debug(f'VQA interrogate: handler=moondream3 point_raw_result="{result}" type={type(result)}')
if isinstance(result, dict):
debug(f'VQA interrogate: handler=moondream3 point_raw_result_keys={list(result.keys())}')
points = vqa_detection.parse_points(result)
if points:
debug(f'VQA interrogate: handler=moondream3 point_result={len(points)} points found')
return points
debug('VQA interrogate: handler=moondream3 point_result=not found')
return None
def detect(image: Image.Image, object_name: str, repo: str, max_objects: int = 10):
"""
Detect all instances of a specific object with bounding boxes.
Args:
image: PIL Image
object_name: Name of object to detect
repo: Model repository
max_objects: Maximum number of objects to return
Returns:
List of detection dicts with keys:
- 'bbox': [x1, y1, x2, y2] normalized to 0-1
- 'label': Object label
- 'confidence': Detection confidence (0-1)
Returns empty list if no objects found.
"""
model = load_model(repo)
debug(f'VQA interrogate: handler=moondream3 method=detect object_name="{object_name}" max_objects={max_objects}')
with devices.inference_context():
result = model.detect(image, object_name)
debug(f'VQA interrogate: handler=moondream3 detect_raw_result="{result}" type={type(result)}')
if isinstance(result, dict):
debug(f'VQA interrogate: handler=moondream3 detect_raw_result_keys={list(result.keys())}')
detections = vqa_detection.parse_detections(result, object_name, max_objects)
debug(f'VQA interrogate: handler=moondream3 detect_result={len(detections)} objects found')
return detections
def predict(question: str, image: Image.Image, repo: str, model_name: str = None, thinking_mode: bool = False,
mode: str = None, stream: bool = False, use_cache: bool = False, **kwargs):
"""
Main entry point for Moondream 3 VQA - auto-detects mode from question.
Args:
question: The question/prompt (e.g., "caption", "where is the cat?", "describe this")
image: PIL Image
repo: Model repository
model_name: Display name for logging
thinking_mode: Enable reasoning mode for query
mode: Force specific mode ('query', 'caption', 'caption_short', 'caption_long', 'point', 'detect')
stream: Enable streaming output (for query/caption)
use_cache: Use cached image encoding (for query)
**kwargs: Additional parameters (max_objects for detect, etc.)
Returns:
Response string (detection data stored on VQA singleton instance.last_detection_data)
(or generator if stream=True for query/caption modes)
"""
debug(f'VQA interrogate: handler=moondream3 model_name="{model_name}" repo="{repo}" question="{question}" image_size={image.size if image else None} mode={mode} stream={stream}')
# Clean question
question = question.replace('<', '').replace('>', '').replace('_', ' ') if question else ''
# Auto-detect mode from question if not specified
if mode is None:
question_lower = question.lower()
# Caption detection
if question in ['CAPTION', 'caption'] or 'caption' in question_lower:
if 'more detailed' in question_lower or 'very long' in question_lower:
mode = 'caption_long'
elif 'detailed' in question_lower or 'long' in question_lower:
mode = 'caption_normal'
elif 'short' in question_lower or 'brief' in question_lower:
mode = 'caption_short'
else:
# Default caption mode (matches vqa.py legacy behavior)
if question == 'CAPTION':
mode = 'caption_short'
elif question == 'DETAILED CAPTION':
mode = 'caption_normal'
elif question == 'MORE DETAILED CAPTION':
mode = 'caption_long'
else:
mode = 'caption_normal'
# Point detection
elif 'where is' in question_lower or 'locate' in question_lower or 'find' in question_lower or 'point' in question_lower:
mode = 'point'
# Object detection
elif 'detect' in question_lower or 'bounding box' in question_lower or 'bbox' in question_lower:
mode = 'detect'
# Default to query
else:
mode = 'query'
debug(f'VQA interrogate: handler=moondream3 mode_selected={mode}')
# Dispatch to appropriate method
try:
if mode == 'caption_short':
response = caption(image, repo, length='short', stream=stream)
elif mode == 'caption_long':
response = caption(image, repo, length='long', stream=stream)
elif mode in ['caption', 'caption_normal']:
response = caption(image, repo, length='normal', stream=stream)
elif mode == 'point':
# Extract object name from question - case insensitive, preserve object names
object_name = question
for phrase in ['point at', 'where is', 'locate', 'find']:
object_name = re.sub(rf'\b{phrase}\b', '', object_name, flags=re.IGNORECASE)
object_name = re.sub(r'[?.!,]', '', object_name).strip()
object_name = re.sub(r'^\s*the\s+', '', object_name, flags=re.IGNORECASE)
debug(f'VQA interrogate: handler=moondream3 point_extracted_object="{object_name}"')
result = point(image, object_name, repo)
if result:
from modules.interrogate import vqa
vqa.get_instance().last_detection_data = {'points': result}
return vqa_detection.format_points_text(result)
return "Object not found"
elif mode == 'detect':
# Extract object name from question - case insensitive
object_name = question
for phrase in ['detect', 'find all', 'bounding box', 'bbox', 'find']:
object_name = re.sub(rf'\b{phrase}\b', '', object_name, flags=re.IGNORECASE)
object_name = re.sub(r'[?.!,]', '', object_name).strip()
object_name = re.sub(r'^\s*the\s+', '', object_name, flags=re.IGNORECASE)
if ' and ' in object_name.lower():
object_name = re.split(r'\s+and\s+', object_name, flags=re.IGNORECASE)[0].strip()
debug(f'VQA interrogate: handler=moondream3 detect_extracted_object="{object_name}"')
results = detect(image, object_name, repo, max_objects=kwargs.get('max_objects', 10))
if results:
from modules.interrogate import vqa
vqa.get_instance().last_detection_data = {'detections': results}
return vqa_detection.format_detections_text(results)
return "No objects detected"
else: # mode == 'query'
if len(question) < 2:
question = "Describe this image."
response = query(image, question, repo, stream=stream, use_cache=use_cache, reasoning=thinking_mode)
debug(f'VQA interrogate: handler=moondream3 response_before_clean="{response}"')
return response
except Exception as e:
from modules import errors
errors.display(e, 'Moondream3')
return f"Error: {str(e)}"
def clear_cache():
"""Clear image encoding cache."""
cache_size = len(image_cache)
image_cache.clear()
debug(f'VQA interrogate: handler=moondream3 cleared image cache cache_size_was={cache_size}')
shared.log.debug(f'Moondream3: Cleared image cache ({cache_size} entries)')
def unload():
"""Release Moondream 3 model from GPU/memory."""
global moondream3_model, loaded # pylint: disable=global-statement
if moondream3_model is not None:
shared.log.debug(f'Moondream3 unload: model="{loaded}"')
sd_models.move_model(moondream3_model, devices.cpu, force=True)
moondream3_model = None
loaded = None
clear_cache()
devices.torch_gc(force=True)
else:
shared.log.debug('Moondream3 unload: no model loaded')