1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/interrogate/vqa.py
CalamitousFelicitousness 0659759e90 fix(vqa): improve unload logging consistency
Add before/after debug messages when unloading VQA model to match
the pattern used in prompt enhance for better debugging visibility.
2026-01-12 00:17:20 +00:00

1459 lines
68 KiB
Python

import io
import os
import time
import json
import base64
import copy
import torch
import transformers
import transformers.dynamic_module_utils
from PIL import Image
from modules import shared, devices, errors, model_quant, sd_models, sd_models_compile, ui_symbols
from modules.interrogate import vqa_detection
# Debug logging - function-based to avoid circular import
debug_enabled = os.environ.get('SD_VQA_DEBUG', None) is not None
def debug(*args, **kwargs):
if debug_enabled:
shared.log.trace(*args, **kwargs)
vlm_default = "Alibaba Qwen 2.5 VL 3B"
vlm_models = {
"Google Gemma 3 4B": "google/gemma-3-4b-it",
"Google Gemma 3n E2B": "google/gemma-3n-E2B-it", # 1.5GB
"Google Gemma 3n E4B": "google/gemma-3n-E4B-it", # 1.5GB
"Nidum Gemma 3 4B Uncensored": "nidum/Nidum-Gemma-3-4B-it-Uncensored",
"Allura Gemma 3 Glitter 4B": "allura-org/Gemma-3-Glitter-4B",
"Alibaba Qwen 2.0 VL 2B": "Qwen/Qwen2-VL-2B-Instruct",
"Alibaba Qwen 2.5 Omni 3B": "Qwen/Qwen2.5-Omni-3B",
"Alibaba Qwen 2.5 VL 3B": "Qwen/Qwen2.5-VL-3B-Instruct",
"Alibaba Qwen 3 VL 2B": "Qwen/Qwen3-VL-2B-Instruct",
f"Alibaba Qwen 3 VL 2B Thinking {ui_symbols.reasoning}": "Qwen/Qwen3-VL-2B-Thinking",
"Alibaba Qwen 3 VL 4B": "Qwen/Qwen3-VL-4B-Instruct",
f"Alibaba Qwen 3 VL 4B Thinking {ui_symbols.reasoning}": "Qwen/Qwen3-VL-4B-Thinking",
"Alibaba Qwen 3 VL 8B": "Qwen/Qwen3-VL-8B-Instruct",
f"Alibaba Qwen 3 VL 8B Thinking {ui_symbols.reasoning}": "Qwen/Qwen3-VL-8B-Thinking",
"XiaomiMiMo MiMo VL 7B RL": "XiaomiMiMo/MiMo-VL-7B-RL-2508", # 8.3GB
"Huggingface Smol VL2 0.5B": "HuggingFaceTB/SmolVLM-500M-Instruct",
"Huggingface Smol VL2 2B": "HuggingFaceTB/SmolVLM-Instruct",
"Apple FastVLM 0.5B": "apple/FastVLM-0.5B",
"Apple FastVLM 1.5B": "apple/FastVLM-1.5B",
"Apple FastVLM 7B": "apple/FastVLM-7B",
"Microsoft Florence 2 Base": "florence-community/Florence-2-base-ft", # 0.5GB
"Microsoft Florence 2 Large": "florence-community/Florence-2-large-ft", # 1.5GB
"MiaoshouAI PromptGen 1.5 Base": "Disty0/Florence-2-base-PromptGen-v1.5", # 0.5GB
"MiaoshouAI PromptGen 1.5 Large": "Disty0/Florence-2-large-PromptGen-v1.5", # 1.5GB
"MiaoshouAI PromptGen 2.0 Base": "Disty0/Florence-2-base-PromptGen-v2.0", # 0.5GB
"MiaoshouAI PromptGen 2.0 Large": "Disty0/Florence-2-large-PromptGen-v2.0", # 1.5GB
"CogFlorence 2.0 Large": "thwri/CogFlorence-2-Large-Freeze", # 1.6GB
"CogFlorence 2.2 Large": "thwri/CogFlorence-2.2-Large", # 1.6GB
f"Moondream 2 {ui_symbols.reasoning}": "vikhyatk/moondream2", # 3.7GB
f"Moondream 3 Preview {ui_symbols.reasoning}": "moondream/moondream3-preview", # 9.3GB (gated)
"Google Pix Textcaps": "google/pix2struct-textcaps-base", # 1.1GB
"Google PaliGemma 2 3B": "google/paligemma2-3b-pt-224",
"Salesforce BLIP Base": "Salesforce/blip-vqa-base", # 1.5GB
"Salesforce BLIP Large": "Salesforce/blip-vqa-capfilt-large", # 1.5GB
"Microsoft GIT TextCaps Base": "microsoft/git-base-textcaps", # 0.7GB
"Microsoft GIT VQA Base": "microsoft/git-base-vqav2", # 0.7GB
"Microsoft GIT VQA Large": "microsoft/git-large-vqav2", # 1.6GB
"ToriiGate 0.4 2B": "Minthy/ToriiGate-v0.4-2B",
"ToriiGate 0.4 7B": "Minthy/ToriiGate-v0.4-7B",
"ViLT Base": "dandelin/vilt-b32-finetuned-vqa", # 0.5GB
"JoyCaption Alpha": "fancyfeast/llama-joycaption-alpha-two-hf-llava", # 17.4GB
"JoyCaption Beta": "fancyfeast/llama-joycaption-beta-one-hf-llava", # 17.4GB
"JoyTag": "fancyfeast/joytag", # 0.7GB
"AIDC Ovis2 1B": "AIDC-AI/Ovis2-1B",
"AIDC Ovis2 2B": "AIDC-AI/Ovis2-2B",
"AIDC Ovis2 4B": "AIDC-AI/Ovis2-4B",
"ByteDance Sa2VA 1B": "ByteDance/Sa2VA-1B",
"ByteDance Sa2VA 4B": "ByteDance/Sa2VA-4B",
}
vlm_system = 'You are image captioning expert, creative, unbiased and uncensored.'
# Common prompts (work with all VLM models)
vlm_prompts_common = [
"Use Prompt",
"Short Caption",
"Normal Caption",
"Long Caption",
]
# Florence-2 specific prompts (only shown for Florence/PromptGen models)
vlm_prompts_florence = [
"Phrase Grounding",
"Object Detection",
"Dense Region Caption",
"Region Proposal",
"OCR (Read Text)",
"OCR with Regions",
"Analyze",
"Generate Tags",
"Mixed Caption",
"Mixed Caption+",
]
# Moondream specific prompts (shared by Moondream 2 and 3)
vlm_prompts_moondream = [
"Point at...",
"Detect all...",
]
# Moondream 2 only prompts (gaze detection not available in Moondream 3)
vlm_prompts_moondream2 = [
"Detect Gaze",
]
# Mapping from friendly names to internal tokens/commands
vlm_prompt_mapping = {
"Use Prompt": "Use Prompt",
"Short Caption": "<CAPTION>",
"Normal Caption": "<DETAILED_CAPTION>",
"Long Caption": "<MORE_DETAILED_CAPTION>",
"Phrase Grounding": "<CAPTION_TO_PHRASE_GROUNDING>",
"Object Detection": "<OD>",
"Dense Region Caption": "<DENSE_REGION_CAPTION>",
"Region Proposal": "<REGION_PROPOSAL>",
"OCR (Read Text)": "<OCR>",
"OCR with Regions": "<OCR_WITH_REGION>",
"Analyze": "<ANALYZE>",
"Generate Tags": "<GENERATE_TAGS>",
"Mixed Caption": "<MIXED_CAPTION>",
"Mixed Caption+": "<MIXED_CAPTION_PLUS>",
"Point at...": "POINT_MODE",
"Detect all...": "DETECT_MODE",
"Detect Gaze": "DETECT_GAZE",
}
# Placeholder hints for prompt field based on selected question
vlm_prompt_placeholders = {
"Use Prompt": "Enter your question or instruction for the model",
"Short Caption": "Optional: add specific focus or style instructions",
"Normal Caption": "Optional: add specific focus or style instructions",
"Long Caption": "Optional: add specific focus or style instructions",
"Phrase Grounding": "Optional: specify phrases to ground in the image",
"Object Detection": "Optional: specify object types to detect",
"Dense Region Caption": "Optional: add specific instructions",
"Region Proposal": "Optional: add specific instructions",
"OCR (Read Text)": "Optional: add specific instructions",
"OCR with Regions": "Optional: add specific instructions",
"Analyze": "Optional: add specific analysis instructions",
"Generate Tags": "Optional: add specific tagging instructions",
"Mixed Caption": "Optional: add specific instructions",
"Mixed Caption+": "Optional: add specific instructions",
"Point at...": "Enter objects to locate, e.g., 'the red car' or 'all the eyes'",
"Detect all...": "Enter object type to detect, e.g., 'cars' or 'faces'",
"Detect Gaze": "No input needed - auto-detects face and gaze direction",
}
# Legacy list for backwards compatibility
vlm_prompts = vlm_prompts_common + vlm_prompts_florence + vlm_prompts_moondream + vlm_prompts_moondream2
vlm_prefill = 'Answer: the image shows'
def get_prompts_for_model(model_name: str) -> list:
"""Get available prompts based on selected model."""
if model_name is None:
return vlm_prompts_common
model_lower = model_name.lower()
# Check for Florence-2 / PromptGen models
if 'florence' in model_lower or 'promptgen' in model_lower:
return vlm_prompts_common + vlm_prompts_florence
# Check for Moondream models (Moondream 2 has gaze detection, Moondream 3 does not)
if 'moondream' in model_lower:
if 'moondream3' in model_lower or 'moondream 3' in model_lower:
return vlm_prompts_common + vlm_prompts_moondream
else: # Moondream 2 includes gaze detection
return vlm_prompts_common + vlm_prompts_moondream + vlm_prompts_moondream2
# Default: common prompts only
return vlm_prompts_common
def get_internal_prompt(friendly_name: str, user_prompt: str = None) -> str:
"""Convert friendly prompt name to internal token/command."""
internal = vlm_prompt_mapping.get(friendly_name, friendly_name)
# Handle Moondream point/detect modes - prepend trigger phrase
if internal == "POINT_MODE" and user_prompt:
return f"Point at {user_prompt}"
elif internal == "DETECT_MODE" and user_prompt:
return f"Detect {user_prompt}"
return internal
def get_prompt_placeholder(friendly_name: str) -> str:
"""Get placeholder text for the prompt field based on selected question."""
return vlm_prompt_placeholders.get(friendly_name, "Enter your question or instruction")
def is_florence_task(question: str) -> bool:
"""Check if the question is a Florence-2 task token (either friendly name or internal token)."""
if not question:
return False
# Check if it's a Florence-specific friendly name
if question in vlm_prompts_florence:
return True
# Check if it's an internal Florence-2 task token (for backwards compatibility)
florence_tokens = ['<CAPTION>', '<DETAILED_CAPTION>', '<MORE_DETAILED_CAPTION>', '<CAPTION_TO_PHRASE_GROUNDING>',
'<OD>', '<DENSE_REGION_CAPTION>', '<REGION_PROPOSAL>', '<OCR>', '<OCR_WITH_REGION>',
'<ANALYZE>', '<GENERATE_TAGS>', '<MIXED_CAPTION>', '<MIXED_CAPTION_PLUS>']
return question in florence_tokens
def is_thinking_model(model_name: str) -> bool:
"""Check if the model supports thinking mode based on its name."""
if not model_name:
return False
model_lower = model_name.lower()
# Check for known thinking models
thinking_indicators = [
'thinking', # Qwen3-VL-*-Thinking models
'moondream3', # Moondream 3 supports thinking
'moondream 3',
'moondream2', # Moondream 2 supports reasoning mode
'moondream 2',
'mimo',
]
return any(indicator in model_lower for indicator in thinking_indicators)
def truncate_b64_in_conversation(conversation, front_chars=50, tail_chars=50, threshold=200):
"""
Deep copy a conversation structure and truncate long base64 image strings for logging.
Preserves front and tail of base64 strings with truncation indicator.
"""
conv_copy = copy.deepcopy(conversation)
def truncate_recursive(obj):
if isinstance(obj, dict):
for key, value in obj.items():
if key == "image" and isinstance(value, str) and len(value) > threshold:
# Truncate the base64 image string
truncated_count = len(value) - front_chars - tail_chars
obj[key] = f"{value[:front_chars]}...[{truncated_count} chars truncated]...{value[-tail_chars:]}"
elif isinstance(value, (dict, list)):
truncate_recursive(value)
elif isinstance(obj, list):
for item in obj:
truncate_recursive(item)
truncate_recursive(conv_copy)
return conv_copy
def keep_think_block_open(text_prompt: str) -> str:
"""Remove the closing </think> of the final assistant message so the model can continue reasoning."""
think_open = "<think>"
think_close = "</think>"
last_open = text_prompt.rfind(think_open)
if last_open == -1:
return text_prompt
close_index = text_prompt.find(think_close, last_open)
if close_index == -1:
return text_prompt
# Skip any whitespace immediately following the closing tag
end_close = close_index + len(think_close)
while end_close < len(text_prompt) and text_prompt[end_close] in (' ', '\t'):
end_close += 1
while end_close < len(text_prompt) and text_prompt[end_close] in ('\r', '\n'):
end_close += 1
trimmed_prompt = text_prompt[:close_index] + text_prompt[end_close:]
debug('VQA interrogate: keep_think_block_open applied to prompt segment near assistant reply')
return trimmed_prompt
def b64(image):
if image is None:
return ''
with io.BytesIO() as stream:
image.save(stream, 'JPEG')
values = stream.getvalue()
encoded = base64.b64encode(values).decode()
return encoded
def clean(response, question, prefill=None):
strip = ['---', '\r', '\t', '**', '"', '"', '"', 'Assistant:', 'Caption:', '<|im_end|>', '<pad>']
if isinstance(response, str):
response = response.strip()
elif isinstance(response, dict):
text_response = ""
if 'reasoning' in response and shared.opts.interrogate_vlm_keep_thinking:
r_text = response['reasoning']
if isinstance(r_text, dict) and 'text' in r_text:
r_text = r_text['text']
text_response += f"Reasoning:\n{r_text}\n\nAnswer:\n"
if 'answer' in response:
text_response += response['answer']
elif 'caption' in response:
text_response += response['caption']
elif 'task' in response:
text_response += response['task']
else:
if not text_response:
text_response = json.dumps(response)
response = text_response
elif isinstance(response, list):
response = response[0]
else:
response = str(response)
# Determine prefill text
prefill_text = vlm_prefill if prefill is None else prefill
if prefill_text is None:
prefill_text = ""
prefill_text = prefill_text.strip()
question = question.replace('<', '').replace('>', '').replace('_', ' ')
if question in response:
response = response.split(question, 1)[1]
while any(s in response for s in strip):
for s in strip:
response = response.replace(s, '')
response = response.replace(' ', ' ').replace('* ', '- ').strip()
# Handle prefill retention/removal
if shared.opts.interrogate_vlm_keep_prefill:
# Add prefill if it's missing from the cleaned response
if len(prefill_text) > 0 and not response.startswith(prefill_text):
sep = " "
if not response or response[0] in ".,!?;:":
sep = ""
response = f"{prefill_text}{sep}{response}"
else:
# Remove prefill if it's present in the cleaned response
if len(prefill_text) > 0 and response.startswith(prefill_text):
response = response[len(prefill_text):].strip()
return response
def get_kwargs():
kwargs = {
'max_new_tokens': shared.opts.interrogate_vlm_max_length,
'do_sample': shared.opts.interrogate_vlm_do_sample,
}
if shared.opts.interrogate_vlm_num_beams > 0:
kwargs['num_beams'] = shared.opts.interrogate_vlm_num_beams
if shared.opts.interrogate_vlm_temperature > 0:
kwargs['temperature'] = shared.opts.interrogate_vlm_temperature
if shared.opts.interrogate_vlm_top_k > 0:
kwargs['top_k'] = shared.opts.interrogate_vlm_top_k
if shared.opts.interrogate_vlm_top_p > 0:
kwargs['top_p'] = shared.opts.interrogate_vlm_top_p
return kwargs
class VQA:
"""Vision-Language Model interrogation class with per-model self-contained loading."""
def __init__(self):
self.processor = None
self.model = None
self.loaded: str = None
self.last_annotated_image = None
self.last_detection_data = None
def unload(self):
"""Release VLM model from GPU/memory."""
if self.model is not None:
model_name = self.loaded
shared.log.debug(f'VQA unload: unloading model="{model_name}"')
sd_models.move_model(self.model, devices.cpu, force=True)
self.model = None
self.processor = None
self.loaded = None
devices.torch_gc(force=True, reason='vqa unload')
shared.log.debug(f'VQA unload: model="{model_name}" unloaded')
else:
shared.log.debug('VQA unload: no model loaded')
def load(self, model_name: str = None):
"""Load VLM model into memory for the specified model name."""
model_name = model_name or shared.opts.interrogate_vlm_model
if not model_name:
shared.log.warning('VQA load: no model specified')
return
repo = vlm_models.get(model_name)
if repo is None:
shared.log.error(f'VQA load: unknown model="{model_name}"')
return
shared.log.debug(f'VQA load: pre-loading model="{model_name}" repo="{repo}"')
# Dispatch to appropriate loader (same logic as interrogate)
repo_lower = repo.lower()
if 'qwen' in repo_lower or 'torii' in repo_lower or 'mimo' in repo_lower:
self._load_qwen(repo)
elif 'gemma' in repo_lower and 'pali' not in repo_lower:
self._load_gemma(repo)
elif 'smol' in repo_lower:
self._load_smol(repo)
elif 'florence' in repo_lower:
self._load_florence(repo)
elif 'moondream2' in repo_lower:
self._load_moondream(repo)
elif 'git' in repo_lower:
self._load_git(repo)
elif 'blip' in repo_lower:
self._load_blip(repo)
elif 'vilt' in repo_lower:
self._load_vilt(repo)
elif 'pix' in repo_lower:
self._load_pix(repo)
elif 'paligemma' in repo_lower:
self._load_paligemma(repo)
elif 'ovis' in repo_lower:
self._load_ovis(repo)
elif 'sa2' in repo_lower:
self._load_sa2(repo)
elif 'fastvlm' in repo_lower:
self._load_fastvlm(repo)
elif 'moondream3' in repo_lower:
from modules.interrogate import moondream3
moondream3.load_model(repo)
shared.log.info(f'VQA load: model="{model_name}" loaded (external handler)')
return
elif 'joytag' in repo_lower:
from modules.interrogate import joytag
joytag.load()
shared.log.info(f'VQA load: model="{model_name}" loaded (external handler)')
return
elif 'joycaption' in repo_lower:
from modules.interrogate import joycaption
joycaption.load(repo)
shared.log.info(f'VQA load: model="{model_name}" loaded (external handler)')
return
elif 'deepseek' in repo_lower:
from modules.interrogate import deepseek
deepseek.load(repo)
shared.log.info(f'VQA load: model="{model_name}" loaded (external handler)')
return
else:
shared.log.warning(f'VQA load: no pre-loader for model="{model_name}"')
return
sd_models.move_model(self.model, devices.device)
shared.log.info(f'VQA load: model="{model_name}" loaded')
def _load_fastvlm(self, repo: str):
"""Load FastVLM model and tokenizer."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
quant_args = model_quant.create_config(module='LLM')
self.model = None
self.processor = transformers.AutoTokenizer.from_pretrained(repo, trust_remote_code=True, cache_dir=shared.opts.hfcache_dir)
self.model = transformers.AutoModelForCausalLM.from_pretrained(
repo,
torch_dtype=devices.dtype,
trust_remote_code=True,
cache_dir=shared.opts.hfcache_dir,
**quant_args,
)
self.loaded = repo
devices.torch_gc()
def _fastvlm(self, question: str, image: Image.Image, repo: str, model_name: str = None):
debug(f'VQA interrogate: handler=fastvlm model_name="{model_name}" repo="{repo}" question="{question}" image_size={image.size if image else None}')
self._load_fastvlm(repo)
sd_models.move_model(self.model, devices.device)
if len(question) < 2:
question = "Describe the image."
question = question.replace('<', '').replace('>', '')
IMAGE_TOKEN_INDEX = -200 # what the model code looks for
messages = [{"role": "user", "content": f"<image>\n{question}"}]
rendered = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
pre, post = rendered.split("<image>", 1)
pre_ids = self.processor(pre, return_tensors="pt", add_special_tokens=False).input_ids
post_ids = self.processor(post, return_tensors="pt", add_special_tokens=False).input_ids
img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1)
input_ids = input_ids.to(devices.device)
attention_mask = torch.ones_like(input_ids, device=devices.device)
px = self.model.get_vision_tower().image_processor(images=image, return_tensors="pt")
px = px["pixel_values"].to(self.model.device, dtype=self.model.dtype)
with devices.inference_context():
outputs = self.model.generate(
inputs=input_ids,
attention_mask=attention_mask,
images=px,
max_new_tokens=128,
)
answer = self.processor.decode(outputs[0], skip_special_tokens=True)
return answer
def _load_qwen(self, repo: str):
"""Load Qwen VL model and processor."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
self.model = None
if 'Qwen3-VL' in repo or 'Qwen3VL' in repo:
cls_name = transformers.Qwen3VLForConditionalGeneration
elif 'Qwen2.5-VL' in repo or 'Qwen2_5_VL' in repo or 'MiMo-VL' in repo:
cls_name = transformers.Qwen2_5_VLForConditionalGeneration
elif 'Qwen2-VL' in repo or 'Qwen2VL' in repo:
cls_name = transformers.Qwen2VLForConditionalGeneration
else:
cls_name = transformers.AutoModelForCausalLM
quant_args = model_quant.create_config(module='LLM')
self.model = cls_name.from_pretrained(
repo,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
**quant_args,
)
self.processor = transformers.AutoProcessor.from_pretrained(repo, max_pixels=1024*1024, cache_dir=shared.opts.hfcache_dir)
if 'LLM' in shared.opts.cuda_compile:
self.model = sd_models_compile.compile_torch(self.model)
self.loaded = repo
devices.torch_gc()
def _qwen(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False):
self._load_qwen(repo)
sd_models.move_model(self.model, devices.device)
# Get model class name for logging
cls_name = self.model.__class__.__name__
debug(f'VQA interrogate: handler=qwen model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
question = question.replace('<', '').replace('>', '').replace('_', ' ')
system_prompt = system_prompt or shared.opts.interrogate_vlm_system
conversation = [
{
"role": "system",
"content": [{"type": "text", "text": system_prompt}],
},
{
"role": "user",
"content": [
{"type": "image", "image": b64(image)},
{"type": "text", "text": question},
],
}
]
# Add prefill if provided
prefill_value = vlm_prefill if prefill is None else prefill
prefill_text = prefill_value.strip()
# Thinking models emit their own <think> tags via the chat template
# Only models with thinking capability can use thinking mode
is_thinking = is_thinking_model(model_name)
# Standardize prefill
prefill_value = vlm_prefill if prefill is None else prefill
prefill_text = prefill_value.strip()
use_prefill = len(prefill_text) > 0
if debug_enabled:
debug(f'VQA interrogate: handler=qwen conversation_roles={[msg["role"] for msg in conversation]}')
debug(f'VQA interrogate: handler=qwen full_conversation={truncate_b64_in_conversation(conversation)}')
debug(f'VQA interrogate: handler=qwen is_thinking={is_thinking} thinking_mode={thinking_mode} prefill="{prefill_text}"')
# Generate base prompt using template
# Qwen-Thinking template automatically adds "<|im_start|>assistant\n<think>\n" when add_generation_prompt=True
try:
text_prompt = self.processor.apply_chat_template(
conversation,
add_generation_prompt=True,
)
except (TypeError, ValueError) as e:
debug(f'VQA interrogate: handler=qwen chat_template fallback add_generation_prompt=True: {e}')
text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
# Manually handle thinking tags and prefill
if is_thinking:
if not thinking_mode:
# User wants to SKIP thinking.
# Since template opened the block with <think>, we close it immediately.
text_prompt += "</think>\n"
if use_prefill:
text_prompt += prefill_text
else:
# User wants thinking. Prompt already ends in <think>.
# If prefill is provided, it becomes part of the thought process.
if use_prefill:
text_prompt += prefill_text
else:
# Standard model (not forcing <think>)
if use_prefill:
text_prompt += prefill_text
if debug_enabled:
debug(f'VQA interrogate: handler=qwen text_prompt="{text_prompt}"')
inputs = self.processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt")
inputs = inputs.to(devices.device, devices.dtype)
gen_kwargs = get_kwargs()
debug(f'VQA interrogate: handler=qwen generation_kwargs={gen_kwargs} input_ids_shape={inputs.input_ids.shape}')
output_ids = self.model.generate(
**inputs,
**gen_kwargs,
)
debug(f'VQA interrogate: handler=qwen output_ids_shape={output_ids.shape}')
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
]
response = self.processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
if debug_enabled:
debug(f'VQA interrogate: handler=qwen response_before_clean="{response}"')
# Clean up thinking tags
# Note: <think> is in the prompt, not the response - only </think> appears in generated output
if len(response) > 0:
text = response[0]
if shared.opts.interrogate_vlm_keep_thinking:
# Handle case where <think> is in prompt (not response) but </think> is in response
if '</think>' in text and '<think>' not in text:
text = 'Reasoning:\n' + text.replace('</think>', '\n\nAnswer:')
else:
text = text.replace('<think>', 'Reasoning:\n').replace('</think>', '\n\nAnswer:')
else:
while '</think>' in text:
start = text.find('<think>')
end = text.find('</think>')
if start != -1 and start < end:
# Standard <think>...content...</think> block
text = text[:start] + text[end+8:]
else:
# Missing <think> (implied at start) or malformed
# Remove from start up to </think>
text = text[end+8:]
response[0] = text
return response
def _load_gemma(self, repo: str):
"""Load Gemma 3 model and processor."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
self.model = None
if '3n' in repo:
cls = transformers.Gemma3nForConditionalGeneration # pylint: disable=no-member
else:
cls = transformers.Gemma3ForConditionalGeneration
quant_args = model_quant.create_config(module='LLM')
self.model = cls.from_pretrained(
repo,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
**quant_args,
)
if 'LLM' in shared.opts.cuda_compile:
self.model = sd_models_compile.compile_torch(self.model)
self.processor = transformers.AutoProcessor.from_pretrained(repo, max_pixels=1024*1024, cache_dir=shared.opts.hfcache_dir)
self.loaded = repo
devices.torch_gc()
def _gemma(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False):
self._load_gemma(repo)
sd_models.move_model(self.model, devices.device)
# Get model class name for logging
cls_name = self.model.__class__.__name__
debug(f'VQA interrogate: handler=gemma model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
question = question.replace('<', '').replace('>', '').replace('_', ' ')
system_prompt = system_prompt or shared.opts.interrogate_vlm_system
system_content = []
if system_prompt is not None and len(system_prompt) > 4:
system_content.append({"type": "text", "text": system_prompt})
user_content = []
if question is not None and len(question) > 4:
user_content.append({"type": "text", "text": question})
if image is not None:
user_content.append({"type": "image", "image": b64(image)})
conversation = [
{"role": "system", "content": system_content},
{"role": "user", "content": user_content},
]
# Add prefill if provided
prefill_value = vlm_prefill if prefill is None else prefill
prefill_text = prefill_value.strip()
use_prefill = len(prefill_text) > 0
# Thinking models emit their own <think> tags via the chat template
# Use manual toggle OR auto-detection based on model name
use_thinking = thinking_mode or is_thinking_model(model_name)
if use_prefill:
conversation.append({
"role": "assistant",
"content": [{"type": "text", "text": prefill_text}],
})
debug(f'VQA interrogate: handler=gemma prefill="{prefill_text}"')
else:
debug('VQA interrogate: handler=gemma prefill disabled (empty), relying on add_generation_prompt')
if debug_enabled:
debug(f'VQA interrogate: handler=gemma conversation_roles={[msg["role"] for msg in conversation]}')
debug(f'VQA interrogate: handler=gemma full_conversation={truncate_b64_in_conversation(conversation)}')
debug_prefill_mode = 'add_generation_prompt=False continue_final_message=True' if use_prefill else 'add_generation_prompt=True'
debug(f'VQA interrogate: handler=gemma template_mode={debug_prefill_mode}')
try:
if use_prefill:
text_prompt = self.processor.apply_chat_template(
conversation,
add_generation_prompt=False,
continue_final_message=True,
tokenize=False,
)
else:
text_prompt = self.processor.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=False,
)
except (TypeError, ValueError) as e:
debug(f'VQA interrogate: handler=gemma chat_template fallback add_generation_prompt=True: {e}')
text_prompt = self.processor.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=False,
)
if use_prefill and use_thinking:
text_prompt = keep_think_block_open(text_prompt)
if debug_enabled:
debug(f'VQA interrogate: handler=gemma text_prompt="{text_prompt}"')
inputs = self.processor(
text=[text_prompt],
images=[image],
padding=True,
return_tensors="pt",
).to(device=devices.device, dtype=devices.dtype)
input_len = inputs["input_ids"].shape[-1]
gen_kwargs = get_kwargs()
debug(f'VQA interrogate: handler=gemma generation_kwargs={gen_kwargs} input_len={input_len}')
with devices.inference_context():
generation = self.model.generate(
**inputs,
**gen_kwargs,
)
debug(f'VQA interrogate: handler=gemma output_ids_shape={generation.shape}')
generation = generation[0][input_len:]
response = self.processor.decode(generation, skip_special_tokens=True)
if debug_enabled:
debug(f'VQA interrogate: handler=gemma response_before_clean="{response}"')
# Clean up thinking tags (if any remain)
if shared.opts.interrogate_vlm_keep_thinking:
response = response.replace('<think>', 'Reasoning:\n').replace('</think>', '\n\nAnswer:')
else:
text = response
while '</think>' in text:
start = text.find('<think>')
end = text.find('</think>')
if start != -1 and start < end:
text = text[:start] + text[end+8:]
else:
text = text[end+8:]
response = text
return response
def _load_paligemma(self, repo: str):
"""Load PaliGemma model and processor."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
self.processor = transformers.PaliGemmaProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
self.model = None
self.model = transformers.PaliGemmaForConditionalGeneration.from_pretrained(
repo,
cache_dir=shared.opts.hfcache_dir,
torch_dtype=devices.dtype,
)
self.loaded = repo
devices.torch_gc()
def _paligemma(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
self._load_paligemma(repo)
sd_models.move_model(self.model, devices.device)
question = question.replace('<', '').replace('>', '').replace('_', ' ')
model_inputs = self.processor(text=question, images=image, return_tensors="pt").to(devices.device, devices.dtype)
input_len = model_inputs["input_ids"].shape[-1]
with devices.inference_context():
generation = self.model.generate(
**model_inputs,
**get_kwargs(),
)
generation = generation[0][input_len:]
response = self.processor.decode(generation, skip_special_tokens=True)
return response
def _load_ovis(self, repo: str):
"""Load Ovis model (requires flash-attn)."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
self.model = None
self.model = transformers.AutoModelForCausalLM.from_pretrained(
repo,
torch_dtype=devices.dtype,
multimodal_max_length=32768,
trust_remote_code=True,
cache_dir=shared.opts.hfcache_dir,
)
self.loaded = repo
devices.torch_gc()
def _ovis(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
try:
import flash_attn # pylint: disable=unused-import
except Exception:
shared.log.error(f'Interrogate: vlm="{repo}" flash-attn is not available')
return ''
self._load_ovis(repo)
sd_models.move_model(self.model, devices.device)
text_tokenizer = self.model.get_text_tokenizer()
visual_tokenizer = self.model.get_visual_tokenizer()
max_partition = 9
question = question.replace('<', '').replace('>', '').replace('_', ' ')
question = f'<image>\n{question}'
_prompt, input_ids, pixel_values = self.model.preprocess_inputs(question, [image], max_partition=max_partition)
attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
input_ids = input_ids.unsqueeze(0).to(device=self.model.device)
attention_mask = attention_mask.unsqueeze(0).to(device=self.model.device)
if pixel_values is not None:
pixel_values = pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)
pixel_values = [pixel_values]
with devices.inference_context():
output_ids = self.model.generate(
input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
repetition_penalty=None,
eos_token_id=self.model.generation_config.eos_token_id,
pad_token_id=text_tokenizer.pad_token_id,
use_cache=True,
**get_kwargs())
response = text_tokenizer.decode(output_ids[0], skip_special_tokens=True)
return response
def _load_smol(self, repo: str):
"""Load SmolVLM model and processor."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
self.model = None
quant_args = model_quant.create_config(module='LLM')
self.model = transformers.AutoModelForVision2Seq.from_pretrained(
repo,
cache_dir=shared.opts.hfcache_dir,
torch_dtype=devices.dtype,
**quant_args,
)
self.processor = transformers.AutoProcessor.from_pretrained(repo, max_pixels=1024*1024, cache_dir=shared.opts.hfcache_dir)
if 'LLM' in shared.opts.cuda_compile:
self.model = sd_models_compile.compile_torch(self.model)
self.loaded = repo
devices.torch_gc()
def _smol(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False):
self._load_smol(repo)
sd_models.move_model(self.model, devices.device)
# Get model class name for logging
cls_name = self.model.__class__.__name__
debug(f'VQA interrogate: handler=smol model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
question = question.replace('<', '').replace('>', '').replace('_', ' ')
system_prompt = system_prompt or shared.opts.interrogate_vlm_system
conversation = [
{
"role": "system",
"content": [{"type": "text", "text": system_prompt}],
},
{
"role": "user",
"content": [
{"type": "image", "image": b64(image)},
{"type": "text", "text": question},
],
}
]
# Add prefill if provided
prefill_value = vlm_prefill if prefill is None else prefill
prefill_text = prefill_value.strip()
use_prefill = len(prefill_text) > 0
# Thinking models emit their own <think> tags via the chat template
# Use manual toggle OR auto-detection based on model name
use_thinking = thinking_mode or is_thinking_model(model_name)
if use_prefill:
conversation.append({
"role": "assistant",
"content": [{"type": "text", "text": prefill_text}],
})
debug(f'VQA interrogate: handler=smol prefill="{prefill_text}"')
else:
debug('VQA interrogate: handler=smol prefill disabled (empty), relying on add_generation_prompt')
if debug_enabled:
debug(f'VQA interrogate: handler=smol conversation_roles={[msg["role"] for msg in conversation]}')
debug(f'VQA interrogate: handler=smol full_conversation={truncate_b64_in_conversation(conversation)}')
debug_prefill_mode = 'add_generation_prompt=False continue_final_message=True' if use_prefill else 'add_generation_prompt=True'
debug(f'VQA interrogate: handler=smol template_mode={debug_prefill_mode}')
try:
if use_prefill:
text_prompt = self.processor.apply_chat_template(
conversation,
add_generation_prompt=False,
continue_final_message=True,
)
else:
text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
except (TypeError, ValueError) as e:
debug(f'VQA interrogate: handler=smol chat_template fallback add_generation_prompt=True: {e}')
text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
if use_prefill and use_thinking:
text_prompt = keep_think_block_open(text_prompt)
if debug_enabled:
debug(f'VQA interrogate: handler=smol text_prompt="{text_prompt}"')
inputs = self.processor(text=text_prompt, images=[image], padding=True, return_tensors="pt")
inputs = inputs.to(devices.device, devices.dtype)
gen_kwargs = get_kwargs()
debug(f'VQA interrogate: handler=smol generation_kwargs={gen_kwargs}')
output_ids = self.model.generate(
**inputs,
**gen_kwargs,
)
debug(f'VQA interrogate: handler=smol output_ids_shape={output_ids.shape}')
response = self.processor.batch_decode(output_ids, skip_special_tokens=True)
if debug_enabled:
debug(f'VQA interrogate: handler=smol response_before_clean="{response}"')
# Clean up thinking tags
if len(response) > 0:
text = response[0]
if shared.opts.interrogate_vlm_keep_thinking:
text = text.replace('<think>', 'Reasoning:\n').replace('</think>', '\n\nAnswer:')
else:
while '</think>' in text:
start = text.find('<think>')
end = text.find('</think>')
if start != -1 and start < end:
text = text[:start] + text[end+8:]
else:
text = text[end+8:]
response[0] = text
return response
def _load_git(self, repo: str):
"""Load Microsoft GIT model and processor."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
self.model = None
self.model = transformers.GitForCausalLM.from_pretrained(
repo,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
)
self.processor = transformers.GitProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
self.loaded = repo
devices.torch_gc()
def _git(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
self._load_git(repo)
sd_models.move_model(self.model, devices.device)
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values
git_dict = {}
git_dict['pixel_values'] = pixel_values.to(devices.device, devices.dtype)
if len(question) > 0:
input_ids = self.processor(text=question, add_special_tokens=False).input_ids
input_ids = [self.processor.tokenizer.cls_token_id] + input_ids
input_ids = torch.tensor(input_ids).unsqueeze(0)
git_dict['input_ids'] = input_ids.to(devices.device)
with devices.inference_context():
generated_ids = self.model.generate(**git_dict)
response = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
def _load_blip(self, repo: str):
"""Load Salesforce BLIP model and processor."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
self.model = None
self.model = transformers.BlipForQuestionAnswering.from_pretrained(
repo,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
)
self.processor = transformers.BlipProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
self.loaded = repo
devices.torch_gc()
def _blip(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
self._load_blip(repo)
sd_models.move_model(self.model, devices.device)
inputs = self.processor(image, question, return_tensors="pt")
inputs = inputs.to(devices.device, devices.dtype)
with devices.inference_context():
outputs = self.model.generate(**inputs)
response = self.processor.decode(outputs[0], skip_special_tokens=True)
return response
def _load_vilt(self, repo: str):
"""Load ViLT model and processor."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
self.model = None
self.model = transformers.ViltForQuestionAnswering.from_pretrained(
repo,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
)
self.processor = transformers.ViltProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
self.loaded = repo
devices.torch_gc()
def _vilt(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
self._load_vilt(repo)
sd_models.move_model(self.model, devices.device)
inputs = self.processor(image, question, return_tensors="pt")
inputs = inputs.to(devices.device)
with devices.inference_context():
outputs = self.model(**inputs)
logits = outputs.logits
idx = logits.argmax(-1).item()
response = self.model.config.id2label[idx]
return response
def _load_pix(self, repo: str):
"""Load Pix2Struct model and processor."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
self.model = None
self.model = transformers.Pix2StructForConditionalGeneration.from_pretrained(
repo,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
)
self.processor = transformers.Pix2StructProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
self.loaded = repo
devices.torch_gc()
def _pix(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
self._load_pix(repo)
sd_models.move_model(self.model, devices.device)
if len(question) > 0:
inputs = self.processor(images=image, text=question, return_tensors="pt").to(devices.device)
else:
inputs = self.processor(images=image, return_tensors="pt").to(devices.device)
with devices.inference_context():
outputs = self.model.generate(**inputs)
response = self.processor.decode(outputs[0], skip_special_tokens=True)
return response
def _load_moondream(self, repo: str):
"""Load Moondream 2 model and tokenizer."""
if self.model is None or self.loaded != repo:
shared.log.debug(f'Interrogate load: vlm="{repo}"')
self.model = None
self.model = transformers.AutoModelForCausalLM.from_pretrained(
repo,
revision="2025-06-21",
trust_remote_code=True,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
)
self.processor = transformers.AutoTokenizer.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
self.loaded = repo
self.model.eval()
devices.torch_gc()
def _moondream(self, question: str, image: Image.Image, repo: str, model_name: str = None, thinking_mode: bool = False):
debug(f'VQA interrogate: handler=moondream model_name="{model_name}" repo="{repo}" question="{question}" thinking_mode={thinking_mode}')
self._load_moondream(repo)
sd_models.move_model(self.model, devices.device)
question = question.replace('<', '').replace('>', '').replace('_', ' ')
with devices.inference_context():
if question == 'CAPTION':
response = self.model.caption(image, length="short")['caption']
elif question == 'DETAILED CAPTION':
response = self.model.caption(image, length="normal")['caption']
elif question == 'MORE DETAILED CAPTION':
response = self.model.caption(image, length="long")['caption']
elif question.lower().startswith('point at ') or question == 'POINT_MODE':
target = question[9:].strip() if question.lower().startswith('point at ') else ''
if not target:
return "Please specify an object to locate"
debug(f'VQA interrogate: handler=moondream method=point target="{target}"')
result = self.model.point(image, target)
debug(f'VQA interrogate: handler=moondream point_raw_result={result}')
points = vqa_detection.parse_points(result)
if points:
self.last_detection_data = {'points': points}
return vqa_detection.format_points_text(points)
return "Object not found"
elif question.lower().startswith('detect ') or question == 'DETECT_MODE':
target = question[7:].strip() if question.lower().startswith('detect ') else ''
if not target:
return "Please specify an object to detect"
debug(f'VQA interrogate: handler=moondream method=detect target="{target}"')
result = self.model.detect(image, target)
debug(f'VQA interrogate: handler=moondream detect_raw_result={result}')
detections = vqa_detection.parse_detections(result, target)
if detections:
self.last_detection_data = {'detections': detections}
return vqa_detection.format_detections_text(detections, include_confidence=False)
return "No objects detected"
elif question == 'DETECT_GAZE' or question.lower() == 'detect gaze':
debug('VQA interrogate: handler=moondream method=detect_gaze')
faces = self.model.detect(image, "face")
debug(f'VQA interrogate: handler=moondream detect_gaze faces={faces}')
if faces.get('objects'):
eye_x, eye_y = vqa_detection.calculate_eye_position(faces['objects'][0])
result = self.model.detect_gaze(image, eye=(eye_x, eye_y))
debug(f'VQA interrogate: handler=moondream detect_gaze result={result}')
if result.get('gaze'):
gaze = result['gaze']
self.last_detection_data = {'points': [(gaze['x'], gaze['y'])]}
return f"Gaze direction: ({gaze['x']:.3f}, {gaze['y']:.3f})"
return "No face/gaze detected"
else:
debug(f'VQA interrogate: handler=moondream method=query question="{question}" reasoning={thinking_mode}')
result = self.model.query(image, question, reasoning=thinking_mode)
response = result['answer']
debug(f'VQA interrogate: handler=moondream query_result keys={list(result.keys()) if isinstance(result, dict) else "not dict"}')
if thinking_mode and 'reasoning' in result:
reasoning_text = result['reasoning'].get('text', '') if isinstance(result['reasoning'], dict) else str(result['reasoning'])
debug(f'VQA interrogate: handler=moondream reasoning_text="{reasoning_text[:100]}..."')
if shared.opts.interrogate_vlm_keep_thinking:
response = f"Reasoning:\n{reasoning_text}\n\nAnswer:\n{response}"
# When keep_thinking is False, just use the answer (reasoning is discarded)
return response
def _load_florence(self, repo: str, revision: str = None):
"""Load Florence-2 model and processor."""
_get_imports = transformers.dynamic_module_utils.get_imports
def get_imports(f):
R = _get_imports(f)
if "flash_attn" in R:
R.remove("flash_attn") # flash_attn is optional
return R
# Handle revision splitting and caching
cache_key = repo
effective_revision = revision
repo_name = repo
if repo and '@' in repo:
repo_name, revision_from_repo = repo.split('@')
effective_revision = revision_from_repo
if self.model is None or self.loaded != cache_key:
shared.log.debug(f'Interrogate load: vlm="{repo_name}" revision="{effective_revision}" path="{shared.opts.hfcache_dir}"')
transformers.dynamic_module_utils.get_imports = get_imports
self.model = None
quant_args = model_quant.create_config(module='LLM')
self.model = transformers.Florence2ForConditionalGeneration.from_pretrained(
repo_name,
dtype=torch.bfloat16,
revision=effective_revision,
torch_dtype=devices.dtype,
cache_dir=shared.opts.hfcache_dir,
**quant_args,
)
self.processor = transformers.AutoProcessor.from_pretrained(repo_name, max_pixels=1024*1024, trust_remote_code=True, revision=effective_revision, cache_dir=shared.opts.hfcache_dir)
transformers.dynamic_module_utils.get_imports = _get_imports
self.loaded = cache_key
self.model.eval()
devices.torch_gc()
def _florence(self, question: str, image: Image.Image, repo: str, revision: str = None, model_name: str = None): # pylint: disable=unused-argument
self._load_florence(repo, revision)
sd_models.move_model(self.model, devices.device)
if question.startswith('<'):
task = question.split('>', 1)[0] + '>'
else:
task = '<MORE_DETAILED_CAPTION>'
inputs = self.processor(text=task, images=image, return_tensors="pt")
input_ids = inputs['input_ids'].to(devices.device)
pixel_values = inputs['pixel_values'].to(devices.device, devices.dtype)
with devices.inference_context():
generated_ids = self.model.generate(
input_ids=input_ids,
pixel_values=pixel_values,
**get_kwargs()
)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
response = self.processor.post_process_generation(generated_text, task="task", image_size=(image.width, image.height))
return response
def _load_sa2(self, repo: str):
"""Load SA2VA model and tokenizer."""
if self.model is None or self.loaded != repo:
self.model = None
self.model = transformers.AutoModel.from_pretrained(
repo,
torch_dtype=devices.dtype,
low_cpu_mem_usage=True,
use_flash_attn=False,
trust_remote_code=True)
self.model = self.model.eval()
self.processor = transformers.AutoTokenizer.from_pretrained(
repo,
trust_remote_code=True,
use_fast=False,
)
self.loaded = repo
devices.torch_gc()
def _sa2(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
self._load_sa2(repo)
sd_models.move_model(self.model, devices.device)
if question.startswith('<'):
task = question.split('>', 1)[0] + '>'
else:
task = '<MORE_DETAILED_CAPTION>'
input_dict = {
'image': image,
'text': f'<image>{task}',
'past_text': '',
'mask_prompts': None,
'tokenizer': self.processor,
}
return_dict = self.model.predict_forward(**input_dict)
response = return_dict["prediction"] # the text format answer
return response
def interrogate(self, question: str = '', system_prompt: str = None, prompt: str = None, image: Image.Image = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False, quiet: bool = False) -> str:
"""
Main entry point for VQA interrogation. Returns string answer.
Detection data stored in self.last_detection_data for annotated image creation.
"""
self.last_annotated_image = None
self.last_detection_data = None
jobid = shared.state.begin('Interrogate LLM')
t0 = time.time()
model_name = model_name or shared.opts.interrogate_vlm_model
prefill = vlm_prefill if prefill is None else prefill # Use provided prefill when specified
if isinstance(image, list):
image = image[0] if len(image) > 0 else None
if isinstance(image, dict) and 'name' in image:
image = Image.open(image['name'])
if isinstance(image, Image.Image):
if image.width > 768 or image.height > 768:
image.thumbnail((768, 768), Image.Resampling.LANCZOS)
if image.mode != 'RGB':
image = image.convert('RGB')
if image is None:
shared.log.error(f'VQA interrogate: model="{model_name}" error="No input image provided"')
shared.state.end(jobid)
return 'Error: No input image provided. Please upload or select an image.'
# Convert friendly prompt names to internal tokens/commands
if question == "Use Prompt":
# Use content from Prompt field directly - requires user input
if not prompt or len(prompt.strip()) < 2:
shared.log.error(f'VQA interrogate: model="{model_name}" error="Please enter a prompt"')
shared.state.end(jobid)
return 'Error: Please enter a question or instruction in the Prompt field.'
question = prompt
elif question in vlm_prompt_mapping:
# Check if this is a mode that requires user input (Point/Detect)
raw_mapping = vlm_prompt_mapping.get(question)
if raw_mapping in ("POINT_MODE", "DETECT_MODE"):
# These modes require user input in the prompt field
if not prompt or len(prompt.strip()) < 2:
shared.log.error(f'VQA interrogate: model="{model_name}" error="Please specify what to find in the prompt field"')
shared.state.end(jobid)
return 'Error: Please specify what to find in the prompt field (e.g., "the red car" or "faces").'
# Convert friendly name to internal token (handles Point/Detect prefix)
question = get_internal_prompt(question, prompt)
# else: question is already an internal token or custom text
from modules import modelloader
modelloader.hf_login()
try:
if model_name is None:
shared.log.error(f'Interrogate: type=vlm model="{model_name}" no model selected')
shared.state.end(jobid)
return ''
vqa_model = vlm_models.get(model_name, None)
if vqa_model is None:
shared.log.error(f'Interrogate: type=vlm model="{model_name}" unknown')
shared.state.end(jobid)
return ''
handler = 'unknown'
if 'git' in vqa_model.lower():
handler = 'git'
answer = self._git(question, image, vqa_model, model_name)
elif 'vilt' in vqa_model.lower():
handler = 'vilt'
answer = self._vilt(question, image, vqa_model, model_name)
elif 'blip' in vqa_model.lower():
handler = 'blip'
answer = self._blip(question, image, vqa_model, model_name)
elif 'pix' in vqa_model.lower():
handler = 'pix'
answer = self._pix(question, image, vqa_model, model_name)
elif 'moondream3' in vqa_model.lower():
handler = 'moondream3'
from modules.interrogate import moondream3
answer = moondream3.predict(question, image, vqa_model, model_name, thinking_mode=thinking_mode)
elif 'moondream2' in vqa_model.lower():
handler = 'moondream'
answer = self._moondream(question, image, vqa_model, model_name, thinking_mode)
elif 'florence' in vqa_model.lower():
handler = 'florence'
answer = self._florence(question, image, vqa_model, None, model_name)
elif 'qwen' in vqa_model.lower() or 'torii' in vqa_model.lower() or 'mimo' in vqa_model.lower():
handler = 'qwen'
answer = self._qwen(question, image, vqa_model, system_prompt, model_name, prefill, thinking_mode)
elif 'smol' in vqa_model.lower():
handler = 'smol'
answer = self._smol(question, image, vqa_model, system_prompt, model_name, prefill, thinking_mode)
elif 'joytag' in vqa_model.lower():
handler = 'joytag'
from modules.interrogate import joytag
answer = joytag.predict(image)
elif 'joycaption' in vqa_model.lower():
handler = 'joycaption'
from modules.interrogate import joycaption
answer = joycaption.predict(question, image, vqa_model)
elif 'deepseek' in vqa_model.lower():
handler = 'deepseek'
from modules.interrogate import deepseek
answer = deepseek.predict(question, image, vqa_model)
elif 'paligemma' in vqa_model.lower():
handler = 'paligemma'
answer = self._paligemma(question, image, vqa_model, model_name)
elif 'gemma' in vqa_model.lower():
handler = 'gemma'
answer = self._gemma(question, image, vqa_model, system_prompt, model_name, prefill, thinking_mode)
elif 'ovis' in vqa_model.lower():
handler = 'ovis'
answer = self._ovis(question, image, vqa_model, model_name)
elif 'sa2' in vqa_model.lower():
handler = 'sa2'
answer = self._sa2(question, image, vqa_model, model_name)
elif 'fastvlm' in vqa_model.lower():
handler = 'fastvlm'
answer = self._fastvlm(question, image, vqa_model, model_name)
else:
answer = 'unknown model'
except Exception as e:
errors.display(e, 'VQA')
answer = 'error'
if shared.opts.interrogate_offload and self.model is not None:
sd_models.move_model(self.model, devices.cpu, force=True)
devices.torch_gc(force=True, reason='vqa')
# Clean the answer
answer = clean(answer, question, prefill)
# Create annotated image if detection data is available
if self.last_detection_data and isinstance(self.last_detection_data, dict) and image:
detections = self.last_detection_data.get('detections', None)
points = self.last_detection_data.get('points', None)
if detections or points:
self.last_annotated_image = vqa_detection.draw_bounding_boxes(image, detections or [], points)
debug(f'VQA interrogate: handler={handler} created annotated image detections={len(detections) if detections else 0} points={len(points) if points else 0}')
debug(f'VQA interrogate: handler={handler} response_after_clean="{answer}" has_annotation={self.last_annotated_image is not None}')
t1 = time.time()
if not quiet:
shared.log.debug(f'Interrogate: type=vlm model="{model_name}" repo="{vqa_model}" args={get_kwargs()} time={t1-t0:.2f}')
shared.state.end(jobid)
return answer
def batch(self, model_name, system_prompt, batch_files, batch_folder, batch_str, question, prompt, write, append, recursive, prefill=None, thinking_mode=False):
class BatchWriter:
def __init__(self, folder, mode='w'):
self.folder = folder
self.csv = None
self.file = None
self.mode = mode
def add(self, file, prompt_text):
txt_file = os.path.splitext(file)[0] + ".txt"
if self.mode == 'a':
prompt_text = '\n' + prompt_text
with open(os.path.join(self.folder, txt_file), self.mode, encoding='utf-8') as f:
f.write(prompt_text)
def close(self):
if self.file is not None:
self.file.close()
files = []
if batch_files is not None:
files += [f.name for f in batch_files]
if batch_folder is not None:
files += [f.name for f in batch_folder]
if batch_str is not None and len(batch_str) > 0 and os.path.exists(batch_str) and os.path.isdir(batch_str):
from modules.files_cache import list_files
files += list(list_files(batch_str, ext_filter=['.png', '.jpg', '.jpeg', '.webp', '.jxl'], recursive=recursive))
if len(files) == 0:
shared.log.warning('Interrogate batch: type=vlm no images')
return ''
jobid = shared.state.begin('Interrogate batch')
prompts = []
if write:
mode = 'w' if not append else 'a'
writer = BatchWriter(os.path.dirname(files[0]), mode=mode)
orig_offload = shared.opts.interrogate_offload
shared.opts.interrogate_offload = False
import rich.progress as rp
pbar = rp.Progress(rp.TextColumn('[cyan]Caption:'), rp.BarColumn(), rp.MofNCompleteColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
with pbar:
task = pbar.add_task(total=len(files), description='starting...')
for file in files:
pbar.update(task, advance=1, description=file)
try:
if shared.state.interrupted:
break
img = Image.open(file)
caption = self.interrogate(question, system_prompt, prompt, img, model_name, prefill, thinking_mode, quiet=True)
# Save annotated image if available
if self.last_annotated_image and write:
annotated_path = os.path.splitext(file)[0] + "_annotated.png"
self.last_annotated_image.save(annotated_path)
prompts.append(caption)
if write:
writer.add(file, caption)
except Exception as e:
shared.log.error(f'Interrogate batch: {e}')
if write:
writer.close()
shared.opts.interrogate_offload = orig_offload
shared.state.end(jobid)
return '\n\n'.join(prompts)
# Module-level singleton instance
_instance = None
def get_instance() -> VQA:
"""Get or create the singleton VQA instance."""
global _instance # pylint: disable=global-statement
if _instance is None:
_instance = VQA()
return _instance
# Backwards-compatible module-level functions
def interrogate(*args, **kwargs):
return get_instance().interrogate(*args, **kwargs)
def unload_model():
return get_instance().unload()
def load_model(model_name: str = None):
return get_instance().load(model_name)
def get_last_annotated_image():
return get_instance().last_annotated_image
def batch(*args, **kwargs):
return get_instance().batch(*args, **kwargs)