mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
Add before/after debug messages when unloading VQA model to match the pattern used in prompt enhance for better debugging visibility.
1459 lines
68 KiB
Python
1459 lines
68 KiB
Python
import io
|
|
import os
|
|
import time
|
|
import json
|
|
import base64
|
|
import copy
|
|
import torch
|
|
import transformers
|
|
import transformers.dynamic_module_utils
|
|
from PIL import Image
|
|
from modules import shared, devices, errors, model_quant, sd_models, sd_models_compile, ui_symbols
|
|
from modules.interrogate import vqa_detection
|
|
|
|
|
|
# Debug logging - function-based to avoid circular import
|
|
debug_enabled = os.environ.get('SD_VQA_DEBUG', None) is not None
|
|
|
|
def debug(*args, **kwargs):
|
|
if debug_enabled:
|
|
shared.log.trace(*args, **kwargs)
|
|
|
|
vlm_default = "Alibaba Qwen 2.5 VL 3B"
|
|
vlm_models = {
|
|
"Google Gemma 3 4B": "google/gemma-3-4b-it",
|
|
"Google Gemma 3n E2B": "google/gemma-3n-E2B-it", # 1.5GB
|
|
"Google Gemma 3n E4B": "google/gemma-3n-E4B-it", # 1.5GB
|
|
"Nidum Gemma 3 4B Uncensored": "nidum/Nidum-Gemma-3-4B-it-Uncensored",
|
|
"Allura Gemma 3 Glitter 4B": "allura-org/Gemma-3-Glitter-4B",
|
|
"Alibaba Qwen 2.0 VL 2B": "Qwen/Qwen2-VL-2B-Instruct",
|
|
"Alibaba Qwen 2.5 Omni 3B": "Qwen/Qwen2.5-Omni-3B",
|
|
"Alibaba Qwen 2.5 VL 3B": "Qwen/Qwen2.5-VL-3B-Instruct",
|
|
"Alibaba Qwen 3 VL 2B": "Qwen/Qwen3-VL-2B-Instruct",
|
|
f"Alibaba Qwen 3 VL 2B Thinking {ui_symbols.reasoning}": "Qwen/Qwen3-VL-2B-Thinking",
|
|
"Alibaba Qwen 3 VL 4B": "Qwen/Qwen3-VL-4B-Instruct",
|
|
f"Alibaba Qwen 3 VL 4B Thinking {ui_symbols.reasoning}": "Qwen/Qwen3-VL-4B-Thinking",
|
|
"Alibaba Qwen 3 VL 8B": "Qwen/Qwen3-VL-8B-Instruct",
|
|
f"Alibaba Qwen 3 VL 8B Thinking {ui_symbols.reasoning}": "Qwen/Qwen3-VL-8B-Thinking",
|
|
"XiaomiMiMo MiMo VL 7B RL": "XiaomiMiMo/MiMo-VL-7B-RL-2508", # 8.3GB
|
|
"Huggingface Smol VL2 0.5B": "HuggingFaceTB/SmolVLM-500M-Instruct",
|
|
"Huggingface Smol VL2 2B": "HuggingFaceTB/SmolVLM-Instruct",
|
|
"Apple FastVLM 0.5B": "apple/FastVLM-0.5B",
|
|
"Apple FastVLM 1.5B": "apple/FastVLM-1.5B",
|
|
"Apple FastVLM 7B": "apple/FastVLM-7B",
|
|
"Microsoft Florence 2 Base": "florence-community/Florence-2-base-ft", # 0.5GB
|
|
"Microsoft Florence 2 Large": "florence-community/Florence-2-large-ft", # 1.5GB
|
|
"MiaoshouAI PromptGen 1.5 Base": "Disty0/Florence-2-base-PromptGen-v1.5", # 0.5GB
|
|
"MiaoshouAI PromptGen 1.5 Large": "Disty0/Florence-2-large-PromptGen-v1.5", # 1.5GB
|
|
"MiaoshouAI PromptGen 2.0 Base": "Disty0/Florence-2-base-PromptGen-v2.0", # 0.5GB
|
|
"MiaoshouAI PromptGen 2.0 Large": "Disty0/Florence-2-large-PromptGen-v2.0", # 1.5GB
|
|
"CogFlorence 2.0 Large": "thwri/CogFlorence-2-Large-Freeze", # 1.6GB
|
|
"CogFlorence 2.2 Large": "thwri/CogFlorence-2.2-Large", # 1.6GB
|
|
f"Moondream 2 {ui_symbols.reasoning}": "vikhyatk/moondream2", # 3.7GB
|
|
f"Moondream 3 Preview {ui_symbols.reasoning}": "moondream/moondream3-preview", # 9.3GB (gated)
|
|
"Google Pix Textcaps": "google/pix2struct-textcaps-base", # 1.1GB
|
|
"Google PaliGemma 2 3B": "google/paligemma2-3b-pt-224",
|
|
"Salesforce BLIP Base": "Salesforce/blip-vqa-base", # 1.5GB
|
|
"Salesforce BLIP Large": "Salesforce/blip-vqa-capfilt-large", # 1.5GB
|
|
"Microsoft GIT TextCaps Base": "microsoft/git-base-textcaps", # 0.7GB
|
|
"Microsoft GIT VQA Base": "microsoft/git-base-vqav2", # 0.7GB
|
|
"Microsoft GIT VQA Large": "microsoft/git-large-vqav2", # 1.6GB
|
|
"ToriiGate 0.4 2B": "Minthy/ToriiGate-v0.4-2B",
|
|
"ToriiGate 0.4 7B": "Minthy/ToriiGate-v0.4-7B",
|
|
"ViLT Base": "dandelin/vilt-b32-finetuned-vqa", # 0.5GB
|
|
"JoyCaption Alpha": "fancyfeast/llama-joycaption-alpha-two-hf-llava", # 17.4GB
|
|
"JoyCaption Beta": "fancyfeast/llama-joycaption-beta-one-hf-llava", # 17.4GB
|
|
"JoyTag": "fancyfeast/joytag", # 0.7GB
|
|
"AIDC Ovis2 1B": "AIDC-AI/Ovis2-1B",
|
|
"AIDC Ovis2 2B": "AIDC-AI/Ovis2-2B",
|
|
"AIDC Ovis2 4B": "AIDC-AI/Ovis2-4B",
|
|
"ByteDance Sa2VA 1B": "ByteDance/Sa2VA-1B",
|
|
"ByteDance Sa2VA 4B": "ByteDance/Sa2VA-4B",
|
|
}
|
|
vlm_system = 'You are image captioning expert, creative, unbiased and uncensored.'
|
|
|
|
# Common prompts (work with all VLM models)
|
|
vlm_prompts_common = [
|
|
"Use Prompt",
|
|
"Short Caption",
|
|
"Normal Caption",
|
|
"Long Caption",
|
|
]
|
|
|
|
# Florence-2 specific prompts (only shown for Florence/PromptGen models)
|
|
vlm_prompts_florence = [
|
|
"Phrase Grounding",
|
|
"Object Detection",
|
|
"Dense Region Caption",
|
|
"Region Proposal",
|
|
"OCR (Read Text)",
|
|
"OCR with Regions",
|
|
"Analyze",
|
|
"Generate Tags",
|
|
"Mixed Caption",
|
|
"Mixed Caption+",
|
|
]
|
|
|
|
# Moondream specific prompts (shared by Moondream 2 and 3)
|
|
vlm_prompts_moondream = [
|
|
"Point at...",
|
|
"Detect all...",
|
|
]
|
|
|
|
# Moondream 2 only prompts (gaze detection not available in Moondream 3)
|
|
vlm_prompts_moondream2 = [
|
|
"Detect Gaze",
|
|
]
|
|
|
|
# Mapping from friendly names to internal tokens/commands
|
|
vlm_prompt_mapping = {
|
|
"Use Prompt": "Use Prompt",
|
|
"Short Caption": "<CAPTION>",
|
|
"Normal Caption": "<DETAILED_CAPTION>",
|
|
"Long Caption": "<MORE_DETAILED_CAPTION>",
|
|
"Phrase Grounding": "<CAPTION_TO_PHRASE_GROUNDING>",
|
|
"Object Detection": "<OD>",
|
|
"Dense Region Caption": "<DENSE_REGION_CAPTION>",
|
|
"Region Proposal": "<REGION_PROPOSAL>",
|
|
"OCR (Read Text)": "<OCR>",
|
|
"OCR with Regions": "<OCR_WITH_REGION>",
|
|
"Analyze": "<ANALYZE>",
|
|
"Generate Tags": "<GENERATE_TAGS>",
|
|
"Mixed Caption": "<MIXED_CAPTION>",
|
|
"Mixed Caption+": "<MIXED_CAPTION_PLUS>",
|
|
"Point at...": "POINT_MODE",
|
|
"Detect all...": "DETECT_MODE",
|
|
"Detect Gaze": "DETECT_GAZE",
|
|
}
|
|
|
|
# Placeholder hints for prompt field based on selected question
|
|
vlm_prompt_placeholders = {
|
|
"Use Prompt": "Enter your question or instruction for the model",
|
|
"Short Caption": "Optional: add specific focus or style instructions",
|
|
"Normal Caption": "Optional: add specific focus or style instructions",
|
|
"Long Caption": "Optional: add specific focus or style instructions",
|
|
"Phrase Grounding": "Optional: specify phrases to ground in the image",
|
|
"Object Detection": "Optional: specify object types to detect",
|
|
"Dense Region Caption": "Optional: add specific instructions",
|
|
"Region Proposal": "Optional: add specific instructions",
|
|
"OCR (Read Text)": "Optional: add specific instructions",
|
|
"OCR with Regions": "Optional: add specific instructions",
|
|
"Analyze": "Optional: add specific analysis instructions",
|
|
"Generate Tags": "Optional: add specific tagging instructions",
|
|
"Mixed Caption": "Optional: add specific instructions",
|
|
"Mixed Caption+": "Optional: add specific instructions",
|
|
"Point at...": "Enter objects to locate, e.g., 'the red car' or 'all the eyes'",
|
|
"Detect all...": "Enter object type to detect, e.g., 'cars' or 'faces'",
|
|
"Detect Gaze": "No input needed - auto-detects face and gaze direction",
|
|
}
|
|
|
|
# Legacy list for backwards compatibility
|
|
vlm_prompts = vlm_prompts_common + vlm_prompts_florence + vlm_prompts_moondream + vlm_prompts_moondream2
|
|
|
|
vlm_prefill = 'Answer: the image shows'
|
|
|
|
|
|
def get_prompts_for_model(model_name: str) -> list:
|
|
"""Get available prompts based on selected model."""
|
|
if model_name is None:
|
|
return vlm_prompts_common
|
|
|
|
model_lower = model_name.lower()
|
|
|
|
# Check for Florence-2 / PromptGen models
|
|
if 'florence' in model_lower or 'promptgen' in model_lower:
|
|
return vlm_prompts_common + vlm_prompts_florence
|
|
|
|
# Check for Moondream models (Moondream 2 has gaze detection, Moondream 3 does not)
|
|
if 'moondream' in model_lower:
|
|
if 'moondream3' in model_lower or 'moondream 3' in model_lower:
|
|
return vlm_prompts_common + vlm_prompts_moondream
|
|
else: # Moondream 2 includes gaze detection
|
|
return vlm_prompts_common + vlm_prompts_moondream + vlm_prompts_moondream2
|
|
|
|
# Default: common prompts only
|
|
return vlm_prompts_common
|
|
|
|
|
|
def get_internal_prompt(friendly_name: str, user_prompt: str = None) -> str:
|
|
"""Convert friendly prompt name to internal token/command."""
|
|
internal = vlm_prompt_mapping.get(friendly_name, friendly_name)
|
|
|
|
# Handle Moondream point/detect modes - prepend trigger phrase
|
|
if internal == "POINT_MODE" and user_prompt:
|
|
return f"Point at {user_prompt}"
|
|
elif internal == "DETECT_MODE" and user_prompt:
|
|
return f"Detect {user_prompt}"
|
|
|
|
return internal
|
|
|
|
|
|
def get_prompt_placeholder(friendly_name: str) -> str:
|
|
"""Get placeholder text for the prompt field based on selected question."""
|
|
return vlm_prompt_placeholders.get(friendly_name, "Enter your question or instruction")
|
|
|
|
|
|
def is_florence_task(question: str) -> bool:
|
|
"""Check if the question is a Florence-2 task token (either friendly name or internal token)."""
|
|
if not question:
|
|
return False
|
|
# Check if it's a Florence-specific friendly name
|
|
if question in vlm_prompts_florence:
|
|
return True
|
|
# Check if it's an internal Florence-2 task token (for backwards compatibility)
|
|
florence_tokens = ['<CAPTION>', '<DETAILED_CAPTION>', '<MORE_DETAILED_CAPTION>', '<CAPTION_TO_PHRASE_GROUNDING>',
|
|
'<OD>', '<DENSE_REGION_CAPTION>', '<REGION_PROPOSAL>', '<OCR>', '<OCR_WITH_REGION>',
|
|
'<ANALYZE>', '<GENERATE_TAGS>', '<MIXED_CAPTION>', '<MIXED_CAPTION_PLUS>']
|
|
return question in florence_tokens
|
|
|
|
|
|
def is_thinking_model(model_name: str) -> bool:
|
|
"""Check if the model supports thinking mode based on its name."""
|
|
if not model_name:
|
|
return False
|
|
model_lower = model_name.lower()
|
|
# Check for known thinking models
|
|
thinking_indicators = [
|
|
'thinking', # Qwen3-VL-*-Thinking models
|
|
'moondream3', # Moondream 3 supports thinking
|
|
'moondream 3',
|
|
'moondream2', # Moondream 2 supports reasoning mode
|
|
'moondream 2',
|
|
'mimo',
|
|
]
|
|
return any(indicator in model_lower for indicator in thinking_indicators)
|
|
|
|
|
|
def truncate_b64_in_conversation(conversation, front_chars=50, tail_chars=50, threshold=200):
|
|
"""
|
|
Deep copy a conversation structure and truncate long base64 image strings for logging.
|
|
Preserves front and tail of base64 strings with truncation indicator.
|
|
"""
|
|
conv_copy = copy.deepcopy(conversation)
|
|
|
|
def truncate_recursive(obj):
|
|
if isinstance(obj, dict):
|
|
for key, value in obj.items():
|
|
if key == "image" and isinstance(value, str) and len(value) > threshold:
|
|
# Truncate the base64 image string
|
|
truncated_count = len(value) - front_chars - tail_chars
|
|
obj[key] = f"{value[:front_chars]}...[{truncated_count} chars truncated]...{value[-tail_chars:]}"
|
|
elif isinstance(value, (dict, list)):
|
|
truncate_recursive(value)
|
|
elif isinstance(obj, list):
|
|
for item in obj:
|
|
truncate_recursive(item)
|
|
|
|
truncate_recursive(conv_copy)
|
|
return conv_copy
|
|
|
|
|
|
def keep_think_block_open(text_prompt: str) -> str:
|
|
"""Remove the closing </think> of the final assistant message so the model can continue reasoning."""
|
|
think_open = "<think>"
|
|
think_close = "</think>"
|
|
last_open = text_prompt.rfind(think_open)
|
|
if last_open == -1:
|
|
return text_prompt
|
|
close_index = text_prompt.find(think_close, last_open)
|
|
if close_index == -1:
|
|
return text_prompt
|
|
# Skip any whitespace immediately following the closing tag
|
|
end_close = close_index + len(think_close)
|
|
while end_close < len(text_prompt) and text_prompt[end_close] in (' ', '\t'):
|
|
end_close += 1
|
|
while end_close < len(text_prompt) and text_prompt[end_close] in ('\r', '\n'):
|
|
end_close += 1
|
|
trimmed_prompt = text_prompt[:close_index] + text_prompt[end_close:]
|
|
debug('VQA interrogate: keep_think_block_open applied to prompt segment near assistant reply')
|
|
return trimmed_prompt
|
|
|
|
|
|
def b64(image):
|
|
if image is None:
|
|
return ''
|
|
with io.BytesIO() as stream:
|
|
image.save(stream, 'JPEG')
|
|
values = stream.getvalue()
|
|
encoded = base64.b64encode(values).decode()
|
|
return encoded
|
|
|
|
|
|
def clean(response, question, prefill=None):
|
|
strip = ['---', '\r', '\t', '**', '"', '"', '"', 'Assistant:', 'Caption:', '<|im_end|>', '<pad>']
|
|
if isinstance(response, str):
|
|
response = response.strip()
|
|
elif isinstance(response, dict):
|
|
text_response = ""
|
|
if 'reasoning' in response and shared.opts.interrogate_vlm_keep_thinking:
|
|
r_text = response['reasoning']
|
|
if isinstance(r_text, dict) and 'text' in r_text:
|
|
r_text = r_text['text']
|
|
text_response += f"Reasoning:\n{r_text}\n\nAnswer:\n"
|
|
|
|
if 'answer' in response:
|
|
text_response += response['answer']
|
|
elif 'caption' in response:
|
|
text_response += response['caption']
|
|
elif 'task' in response:
|
|
text_response += response['task']
|
|
else:
|
|
if not text_response:
|
|
text_response = json.dumps(response)
|
|
response = text_response
|
|
elif isinstance(response, list):
|
|
response = response[0]
|
|
else:
|
|
response = str(response)
|
|
|
|
# Determine prefill text
|
|
prefill_text = vlm_prefill if prefill is None else prefill
|
|
if prefill_text is None:
|
|
prefill_text = ""
|
|
prefill_text = prefill_text.strip()
|
|
|
|
question = question.replace('<', '').replace('>', '').replace('_', ' ')
|
|
if question in response:
|
|
response = response.split(question, 1)[1]
|
|
while any(s in response for s in strip):
|
|
for s in strip:
|
|
response = response.replace(s, '')
|
|
response = response.replace(' ', ' ').replace('* ', '- ').strip()
|
|
|
|
# Handle prefill retention/removal
|
|
if shared.opts.interrogate_vlm_keep_prefill:
|
|
# Add prefill if it's missing from the cleaned response
|
|
if len(prefill_text) > 0 and not response.startswith(prefill_text):
|
|
sep = " "
|
|
if not response or response[0] in ".,!?;:":
|
|
sep = ""
|
|
response = f"{prefill_text}{sep}{response}"
|
|
else:
|
|
# Remove prefill if it's present in the cleaned response
|
|
if len(prefill_text) > 0 and response.startswith(prefill_text):
|
|
response = response[len(prefill_text):].strip()
|
|
|
|
return response
|
|
|
|
|
|
def get_kwargs():
|
|
kwargs = {
|
|
'max_new_tokens': shared.opts.interrogate_vlm_max_length,
|
|
'do_sample': shared.opts.interrogate_vlm_do_sample,
|
|
}
|
|
if shared.opts.interrogate_vlm_num_beams > 0:
|
|
kwargs['num_beams'] = shared.opts.interrogate_vlm_num_beams
|
|
if shared.opts.interrogate_vlm_temperature > 0:
|
|
kwargs['temperature'] = shared.opts.interrogate_vlm_temperature
|
|
if shared.opts.interrogate_vlm_top_k > 0:
|
|
kwargs['top_k'] = shared.opts.interrogate_vlm_top_k
|
|
if shared.opts.interrogate_vlm_top_p > 0:
|
|
kwargs['top_p'] = shared.opts.interrogate_vlm_top_p
|
|
return kwargs
|
|
|
|
|
|
class VQA:
|
|
"""Vision-Language Model interrogation class with per-model self-contained loading."""
|
|
|
|
def __init__(self):
|
|
self.processor = None
|
|
self.model = None
|
|
self.loaded: str = None
|
|
self.last_annotated_image = None
|
|
self.last_detection_data = None
|
|
|
|
def unload(self):
|
|
"""Release VLM model from GPU/memory."""
|
|
if self.model is not None:
|
|
model_name = self.loaded
|
|
shared.log.debug(f'VQA unload: unloading model="{model_name}"')
|
|
sd_models.move_model(self.model, devices.cpu, force=True)
|
|
self.model = None
|
|
self.processor = None
|
|
self.loaded = None
|
|
devices.torch_gc(force=True, reason='vqa unload')
|
|
shared.log.debug(f'VQA unload: model="{model_name}" unloaded')
|
|
else:
|
|
shared.log.debug('VQA unload: no model loaded')
|
|
|
|
def load(self, model_name: str = None):
|
|
"""Load VLM model into memory for the specified model name."""
|
|
model_name = model_name or shared.opts.interrogate_vlm_model
|
|
if not model_name:
|
|
shared.log.warning('VQA load: no model specified')
|
|
return
|
|
repo = vlm_models.get(model_name)
|
|
if repo is None:
|
|
shared.log.error(f'VQA load: unknown model="{model_name}"')
|
|
return
|
|
|
|
shared.log.debug(f'VQA load: pre-loading model="{model_name}" repo="{repo}"')
|
|
|
|
# Dispatch to appropriate loader (same logic as interrogate)
|
|
repo_lower = repo.lower()
|
|
if 'qwen' in repo_lower or 'torii' in repo_lower or 'mimo' in repo_lower:
|
|
self._load_qwen(repo)
|
|
elif 'gemma' in repo_lower and 'pali' not in repo_lower:
|
|
self._load_gemma(repo)
|
|
elif 'smol' in repo_lower:
|
|
self._load_smol(repo)
|
|
elif 'florence' in repo_lower:
|
|
self._load_florence(repo)
|
|
elif 'moondream2' in repo_lower:
|
|
self._load_moondream(repo)
|
|
elif 'git' in repo_lower:
|
|
self._load_git(repo)
|
|
elif 'blip' in repo_lower:
|
|
self._load_blip(repo)
|
|
elif 'vilt' in repo_lower:
|
|
self._load_vilt(repo)
|
|
elif 'pix' in repo_lower:
|
|
self._load_pix(repo)
|
|
elif 'paligemma' in repo_lower:
|
|
self._load_paligemma(repo)
|
|
elif 'ovis' in repo_lower:
|
|
self._load_ovis(repo)
|
|
elif 'sa2' in repo_lower:
|
|
self._load_sa2(repo)
|
|
elif 'fastvlm' in repo_lower:
|
|
self._load_fastvlm(repo)
|
|
elif 'moondream3' in repo_lower:
|
|
from modules.interrogate import moondream3
|
|
moondream3.load_model(repo)
|
|
shared.log.info(f'VQA load: model="{model_name}" loaded (external handler)')
|
|
return
|
|
elif 'joytag' in repo_lower:
|
|
from modules.interrogate import joytag
|
|
joytag.load()
|
|
shared.log.info(f'VQA load: model="{model_name}" loaded (external handler)')
|
|
return
|
|
elif 'joycaption' in repo_lower:
|
|
from modules.interrogate import joycaption
|
|
joycaption.load(repo)
|
|
shared.log.info(f'VQA load: model="{model_name}" loaded (external handler)')
|
|
return
|
|
elif 'deepseek' in repo_lower:
|
|
from modules.interrogate import deepseek
|
|
deepseek.load(repo)
|
|
shared.log.info(f'VQA load: model="{model_name}" loaded (external handler)')
|
|
return
|
|
else:
|
|
shared.log.warning(f'VQA load: no pre-loader for model="{model_name}"')
|
|
return
|
|
|
|
sd_models.move_model(self.model, devices.device)
|
|
shared.log.info(f'VQA load: model="{model_name}" loaded')
|
|
|
|
def _load_fastvlm(self, repo: str):
|
|
"""Load FastVLM model and tokenizer."""
|
|
if self.model is None or self.loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
quant_args = model_quant.create_config(module='LLM')
|
|
self.model = None
|
|
self.processor = transformers.AutoTokenizer.from_pretrained(repo, trust_remote_code=True, cache_dir=shared.opts.hfcache_dir)
|
|
self.model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
repo,
|
|
torch_dtype=devices.dtype,
|
|
trust_remote_code=True,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
**quant_args,
|
|
)
|
|
self.loaded = repo
|
|
devices.torch_gc()
|
|
|
|
def _fastvlm(self, question: str, image: Image.Image, repo: str, model_name: str = None):
|
|
debug(f'VQA interrogate: handler=fastvlm model_name="{model_name}" repo="{repo}" question="{question}" image_size={image.size if image else None}')
|
|
self._load_fastvlm(repo)
|
|
sd_models.move_model(self.model, devices.device)
|
|
if len(question) < 2:
|
|
question = "Describe the image."
|
|
question = question.replace('<', '').replace('>', '')
|
|
IMAGE_TOKEN_INDEX = -200 # what the model code looks for
|
|
messages = [{"role": "user", "content": f"<image>\n{question}"}]
|
|
rendered = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
|
pre, post = rendered.split("<image>", 1)
|
|
pre_ids = self.processor(pre, return_tensors="pt", add_special_tokens=False).input_ids
|
|
post_ids = self.processor(post, return_tensors="pt", add_special_tokens=False).input_ids
|
|
img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype)
|
|
input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1)
|
|
input_ids = input_ids.to(devices.device)
|
|
attention_mask = torch.ones_like(input_ids, device=devices.device)
|
|
px = self.model.get_vision_tower().image_processor(images=image, return_tensors="pt")
|
|
px = px["pixel_values"].to(self.model.device, dtype=self.model.dtype)
|
|
with devices.inference_context():
|
|
outputs = self.model.generate(
|
|
inputs=input_ids,
|
|
attention_mask=attention_mask,
|
|
images=px,
|
|
max_new_tokens=128,
|
|
)
|
|
answer = self.processor.decode(outputs[0], skip_special_tokens=True)
|
|
return answer
|
|
|
|
def _load_qwen(self, repo: str):
|
|
"""Load Qwen VL model and processor."""
|
|
if self.model is None or self.loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
self.model = None
|
|
if 'Qwen3-VL' in repo or 'Qwen3VL' in repo:
|
|
cls_name = transformers.Qwen3VLForConditionalGeneration
|
|
elif 'Qwen2.5-VL' in repo or 'Qwen2_5_VL' in repo or 'MiMo-VL' in repo:
|
|
cls_name = transformers.Qwen2_5_VLForConditionalGeneration
|
|
elif 'Qwen2-VL' in repo or 'Qwen2VL' in repo:
|
|
cls_name = transformers.Qwen2VLForConditionalGeneration
|
|
else:
|
|
cls_name = transformers.AutoModelForCausalLM
|
|
quant_args = model_quant.create_config(module='LLM')
|
|
self.model = cls_name.from_pretrained(
|
|
repo,
|
|
torch_dtype=devices.dtype,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
**quant_args,
|
|
)
|
|
self.processor = transformers.AutoProcessor.from_pretrained(repo, max_pixels=1024*1024, cache_dir=shared.opts.hfcache_dir)
|
|
if 'LLM' in shared.opts.cuda_compile:
|
|
self.model = sd_models_compile.compile_torch(self.model)
|
|
self.loaded = repo
|
|
devices.torch_gc()
|
|
|
|
def _qwen(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False):
|
|
self._load_qwen(repo)
|
|
sd_models.move_model(self.model, devices.device)
|
|
# Get model class name for logging
|
|
cls_name = self.model.__class__.__name__
|
|
debug(f'VQA interrogate: handler=qwen model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
|
|
|
|
question = question.replace('<', '').replace('>', '').replace('_', ' ')
|
|
system_prompt = system_prompt or shared.opts.interrogate_vlm_system
|
|
conversation = [
|
|
{
|
|
"role": "system",
|
|
"content": [{"type": "text", "text": system_prompt}],
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image", "image": b64(image)},
|
|
{"type": "text", "text": question},
|
|
],
|
|
}
|
|
]
|
|
# Add prefill if provided
|
|
prefill_value = vlm_prefill if prefill is None else prefill
|
|
prefill_text = prefill_value.strip()
|
|
|
|
# Thinking models emit their own <think> tags via the chat template
|
|
# Only models with thinking capability can use thinking mode
|
|
is_thinking = is_thinking_model(model_name)
|
|
|
|
# Standardize prefill
|
|
prefill_value = vlm_prefill if prefill is None else prefill
|
|
prefill_text = prefill_value.strip()
|
|
use_prefill = len(prefill_text) > 0
|
|
|
|
if debug_enabled:
|
|
debug(f'VQA interrogate: handler=qwen conversation_roles={[msg["role"] for msg in conversation]}')
|
|
debug(f'VQA interrogate: handler=qwen full_conversation={truncate_b64_in_conversation(conversation)}')
|
|
debug(f'VQA interrogate: handler=qwen is_thinking={is_thinking} thinking_mode={thinking_mode} prefill="{prefill_text}"')
|
|
|
|
# Generate base prompt using template
|
|
# Qwen-Thinking template automatically adds "<|im_start|>assistant\n<think>\n" when add_generation_prompt=True
|
|
try:
|
|
text_prompt = self.processor.apply_chat_template(
|
|
conversation,
|
|
add_generation_prompt=True,
|
|
)
|
|
except (TypeError, ValueError) as e:
|
|
debug(f'VQA interrogate: handler=qwen chat_template fallback add_generation_prompt=True: {e}')
|
|
text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
|
|
|
|
# Manually handle thinking tags and prefill
|
|
if is_thinking:
|
|
if not thinking_mode:
|
|
# User wants to SKIP thinking.
|
|
# Since template opened the block with <think>, we close it immediately.
|
|
text_prompt += "</think>\n"
|
|
if use_prefill:
|
|
text_prompt += prefill_text
|
|
else:
|
|
# User wants thinking. Prompt already ends in <think>.
|
|
# If prefill is provided, it becomes part of the thought process.
|
|
if use_prefill:
|
|
text_prompt += prefill_text
|
|
else:
|
|
# Standard model (not forcing <think>)
|
|
if use_prefill:
|
|
text_prompt += prefill_text
|
|
|
|
if debug_enabled:
|
|
debug(f'VQA interrogate: handler=qwen text_prompt="{text_prompt}"')
|
|
inputs = self.processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt")
|
|
inputs = inputs.to(devices.device, devices.dtype)
|
|
gen_kwargs = get_kwargs()
|
|
debug(f'VQA interrogate: handler=qwen generation_kwargs={gen_kwargs} input_ids_shape={inputs.input_ids.shape}')
|
|
output_ids = self.model.generate(
|
|
**inputs,
|
|
**gen_kwargs,
|
|
)
|
|
debug(f'VQA interrogate: handler=qwen output_ids_shape={output_ids.shape}')
|
|
generated_ids = [
|
|
output_ids[len(input_ids):]
|
|
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
|
|
]
|
|
response = self.processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
|
if debug_enabled:
|
|
debug(f'VQA interrogate: handler=qwen response_before_clean="{response}"')
|
|
# Clean up thinking tags
|
|
# Note: <think> is in the prompt, not the response - only </think> appears in generated output
|
|
if len(response) > 0:
|
|
text = response[0]
|
|
if shared.opts.interrogate_vlm_keep_thinking:
|
|
# Handle case where <think> is in prompt (not response) but </think> is in response
|
|
if '</think>' in text and '<think>' not in text:
|
|
text = 'Reasoning:\n' + text.replace('</think>', '\n\nAnswer:')
|
|
else:
|
|
text = text.replace('<think>', 'Reasoning:\n').replace('</think>', '\n\nAnswer:')
|
|
else:
|
|
while '</think>' in text:
|
|
start = text.find('<think>')
|
|
end = text.find('</think>')
|
|
|
|
if start != -1 and start < end:
|
|
# Standard <think>...content...</think> block
|
|
text = text[:start] + text[end+8:]
|
|
else:
|
|
# Missing <think> (implied at start) or malformed
|
|
# Remove from start up to </think>
|
|
text = text[end+8:]
|
|
response[0] = text
|
|
return response
|
|
|
|
def _load_gemma(self, repo: str):
|
|
"""Load Gemma 3 model and processor."""
|
|
if self.model is None or self.loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
self.model = None
|
|
if '3n' in repo:
|
|
cls = transformers.Gemma3nForConditionalGeneration # pylint: disable=no-member
|
|
else:
|
|
cls = transformers.Gemma3ForConditionalGeneration
|
|
quant_args = model_quant.create_config(module='LLM')
|
|
self.model = cls.from_pretrained(
|
|
repo,
|
|
torch_dtype=devices.dtype,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
**quant_args,
|
|
)
|
|
if 'LLM' in shared.opts.cuda_compile:
|
|
self.model = sd_models_compile.compile_torch(self.model)
|
|
self.processor = transformers.AutoProcessor.from_pretrained(repo, max_pixels=1024*1024, cache_dir=shared.opts.hfcache_dir)
|
|
self.loaded = repo
|
|
devices.torch_gc()
|
|
|
|
def _gemma(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False):
|
|
self._load_gemma(repo)
|
|
sd_models.move_model(self.model, devices.device)
|
|
# Get model class name for logging
|
|
cls_name = self.model.__class__.__name__
|
|
debug(f'VQA interrogate: handler=gemma model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
|
|
|
|
question = question.replace('<', '').replace('>', '').replace('_', ' ')
|
|
system_prompt = system_prompt or shared.opts.interrogate_vlm_system
|
|
|
|
system_content = []
|
|
if system_prompt is not None and len(system_prompt) > 4:
|
|
system_content.append({"type": "text", "text": system_prompt})
|
|
|
|
user_content = []
|
|
if question is not None and len(question) > 4:
|
|
user_content.append({"type": "text", "text": question})
|
|
if image is not None:
|
|
user_content.append({"type": "image", "image": b64(image)})
|
|
conversation = [
|
|
{"role": "system", "content": system_content},
|
|
{"role": "user", "content": user_content},
|
|
]
|
|
# Add prefill if provided
|
|
prefill_value = vlm_prefill if prefill is None else prefill
|
|
prefill_text = prefill_value.strip()
|
|
use_prefill = len(prefill_text) > 0
|
|
# Thinking models emit their own <think> tags via the chat template
|
|
# Use manual toggle OR auto-detection based on model name
|
|
use_thinking = thinking_mode or is_thinking_model(model_name)
|
|
if use_prefill:
|
|
conversation.append({
|
|
"role": "assistant",
|
|
"content": [{"type": "text", "text": prefill_text}],
|
|
})
|
|
debug(f'VQA interrogate: handler=gemma prefill="{prefill_text}"')
|
|
else:
|
|
debug('VQA interrogate: handler=gemma prefill disabled (empty), relying on add_generation_prompt')
|
|
if debug_enabled:
|
|
debug(f'VQA interrogate: handler=gemma conversation_roles={[msg["role"] for msg in conversation]}')
|
|
debug(f'VQA interrogate: handler=gemma full_conversation={truncate_b64_in_conversation(conversation)}')
|
|
debug_prefill_mode = 'add_generation_prompt=False continue_final_message=True' if use_prefill else 'add_generation_prompt=True'
|
|
debug(f'VQA interrogate: handler=gemma template_mode={debug_prefill_mode}')
|
|
try:
|
|
if use_prefill:
|
|
text_prompt = self.processor.apply_chat_template(
|
|
conversation,
|
|
add_generation_prompt=False,
|
|
continue_final_message=True,
|
|
tokenize=False,
|
|
)
|
|
else:
|
|
text_prompt = self.processor.apply_chat_template(
|
|
conversation,
|
|
add_generation_prompt=True,
|
|
tokenize=False,
|
|
)
|
|
except (TypeError, ValueError) as e:
|
|
debug(f'VQA interrogate: handler=gemma chat_template fallback add_generation_prompt=True: {e}')
|
|
text_prompt = self.processor.apply_chat_template(
|
|
conversation,
|
|
add_generation_prompt=True,
|
|
tokenize=False,
|
|
)
|
|
if use_prefill and use_thinking:
|
|
text_prompt = keep_think_block_open(text_prompt)
|
|
if debug_enabled:
|
|
debug(f'VQA interrogate: handler=gemma text_prompt="{text_prompt}"')
|
|
inputs = self.processor(
|
|
text=[text_prompt],
|
|
images=[image],
|
|
padding=True,
|
|
return_tensors="pt",
|
|
).to(device=devices.device, dtype=devices.dtype)
|
|
input_len = inputs["input_ids"].shape[-1]
|
|
gen_kwargs = get_kwargs()
|
|
debug(f'VQA interrogate: handler=gemma generation_kwargs={gen_kwargs} input_len={input_len}')
|
|
with devices.inference_context():
|
|
generation = self.model.generate(
|
|
**inputs,
|
|
**gen_kwargs,
|
|
)
|
|
debug(f'VQA interrogate: handler=gemma output_ids_shape={generation.shape}')
|
|
generation = generation[0][input_len:]
|
|
response = self.processor.decode(generation, skip_special_tokens=True)
|
|
if debug_enabled:
|
|
debug(f'VQA interrogate: handler=gemma response_before_clean="{response}"')
|
|
|
|
# Clean up thinking tags (if any remain)
|
|
if shared.opts.interrogate_vlm_keep_thinking:
|
|
response = response.replace('<think>', 'Reasoning:\n').replace('</think>', '\n\nAnswer:')
|
|
else:
|
|
text = response
|
|
while '</think>' in text:
|
|
start = text.find('<think>')
|
|
end = text.find('</think>')
|
|
if start != -1 and start < end:
|
|
text = text[:start] + text[end+8:]
|
|
else:
|
|
text = text[end+8:]
|
|
response = text
|
|
|
|
return response
|
|
|
|
def _load_paligemma(self, repo: str):
|
|
"""Load PaliGemma model and processor."""
|
|
if self.model is None or self.loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
self.processor = transformers.PaliGemmaProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
|
self.model = None
|
|
self.model = transformers.PaliGemmaForConditionalGeneration.from_pretrained(
|
|
repo,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
torch_dtype=devices.dtype,
|
|
)
|
|
self.loaded = repo
|
|
devices.torch_gc()
|
|
|
|
def _paligemma(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
|
|
self._load_paligemma(repo)
|
|
sd_models.move_model(self.model, devices.device)
|
|
question = question.replace('<', '').replace('>', '').replace('_', ' ')
|
|
model_inputs = self.processor(text=question, images=image, return_tensors="pt").to(devices.device, devices.dtype)
|
|
input_len = model_inputs["input_ids"].shape[-1]
|
|
with devices.inference_context():
|
|
generation = self.model.generate(
|
|
**model_inputs,
|
|
**get_kwargs(),
|
|
)
|
|
generation = generation[0][input_len:]
|
|
response = self.processor.decode(generation, skip_special_tokens=True)
|
|
return response
|
|
|
|
def _load_ovis(self, repo: str):
|
|
"""Load Ovis model (requires flash-attn)."""
|
|
if self.model is None or self.loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
self.model = None
|
|
self.model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
repo,
|
|
torch_dtype=devices.dtype,
|
|
multimodal_max_length=32768,
|
|
trust_remote_code=True,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
)
|
|
self.loaded = repo
|
|
devices.torch_gc()
|
|
|
|
def _ovis(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
|
|
try:
|
|
import flash_attn # pylint: disable=unused-import
|
|
except Exception:
|
|
shared.log.error(f'Interrogate: vlm="{repo}" flash-attn is not available')
|
|
return ''
|
|
self._load_ovis(repo)
|
|
sd_models.move_model(self.model, devices.device)
|
|
text_tokenizer = self.model.get_text_tokenizer()
|
|
visual_tokenizer = self.model.get_visual_tokenizer()
|
|
max_partition = 9
|
|
question = question.replace('<', '').replace('>', '').replace('_', ' ')
|
|
question = f'<image>\n{question}'
|
|
_prompt, input_ids, pixel_values = self.model.preprocess_inputs(question, [image], max_partition=max_partition)
|
|
attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
|
|
input_ids = input_ids.unsqueeze(0).to(device=self.model.device)
|
|
attention_mask = attention_mask.unsqueeze(0).to(device=self.model.device)
|
|
if pixel_values is not None:
|
|
pixel_values = pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)
|
|
pixel_values = [pixel_values]
|
|
with devices.inference_context():
|
|
output_ids = self.model.generate(
|
|
input_ids,
|
|
pixel_values=pixel_values,
|
|
attention_mask=attention_mask,
|
|
repetition_penalty=None,
|
|
eos_token_id=self.model.generation_config.eos_token_id,
|
|
pad_token_id=text_tokenizer.pad_token_id,
|
|
use_cache=True,
|
|
**get_kwargs())
|
|
response = text_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
|
return response
|
|
|
|
def _load_smol(self, repo: str):
|
|
"""Load SmolVLM model and processor."""
|
|
if self.model is None or self.loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
self.model = None
|
|
quant_args = model_quant.create_config(module='LLM')
|
|
self.model = transformers.AutoModelForVision2Seq.from_pretrained(
|
|
repo,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
torch_dtype=devices.dtype,
|
|
**quant_args,
|
|
)
|
|
self.processor = transformers.AutoProcessor.from_pretrained(repo, max_pixels=1024*1024, cache_dir=shared.opts.hfcache_dir)
|
|
if 'LLM' in shared.opts.cuda_compile:
|
|
self.model = sd_models_compile.compile_torch(self.model)
|
|
self.loaded = repo
|
|
devices.torch_gc()
|
|
|
|
def _smol(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False):
|
|
self._load_smol(repo)
|
|
sd_models.move_model(self.model, devices.device)
|
|
# Get model class name for logging
|
|
cls_name = self.model.__class__.__name__
|
|
debug(f'VQA interrogate: handler=smol model_name="{model_name}" model_class="{cls_name}" repo="{repo}" question="{question}" system_prompt="{system_prompt}" image_size={image.size if image else None}')
|
|
|
|
question = question.replace('<', '').replace('>', '').replace('_', ' ')
|
|
system_prompt = system_prompt or shared.opts.interrogate_vlm_system
|
|
conversation = [
|
|
{
|
|
"role": "system",
|
|
"content": [{"type": "text", "text": system_prompt}],
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image", "image": b64(image)},
|
|
{"type": "text", "text": question},
|
|
],
|
|
}
|
|
]
|
|
# Add prefill if provided
|
|
prefill_value = vlm_prefill if prefill is None else prefill
|
|
prefill_text = prefill_value.strip()
|
|
use_prefill = len(prefill_text) > 0
|
|
# Thinking models emit their own <think> tags via the chat template
|
|
# Use manual toggle OR auto-detection based on model name
|
|
use_thinking = thinking_mode or is_thinking_model(model_name)
|
|
if use_prefill:
|
|
conversation.append({
|
|
"role": "assistant",
|
|
"content": [{"type": "text", "text": prefill_text}],
|
|
})
|
|
debug(f'VQA interrogate: handler=smol prefill="{prefill_text}"')
|
|
else:
|
|
debug('VQA interrogate: handler=smol prefill disabled (empty), relying on add_generation_prompt')
|
|
if debug_enabled:
|
|
debug(f'VQA interrogate: handler=smol conversation_roles={[msg["role"] for msg in conversation]}')
|
|
debug(f'VQA interrogate: handler=smol full_conversation={truncate_b64_in_conversation(conversation)}')
|
|
debug_prefill_mode = 'add_generation_prompt=False continue_final_message=True' if use_prefill else 'add_generation_prompt=True'
|
|
debug(f'VQA interrogate: handler=smol template_mode={debug_prefill_mode}')
|
|
try:
|
|
if use_prefill:
|
|
text_prompt = self.processor.apply_chat_template(
|
|
conversation,
|
|
add_generation_prompt=False,
|
|
continue_final_message=True,
|
|
)
|
|
else:
|
|
text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
|
|
except (TypeError, ValueError) as e:
|
|
debug(f'VQA interrogate: handler=smol chat_template fallback add_generation_prompt=True: {e}')
|
|
text_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
|
|
if use_prefill and use_thinking:
|
|
text_prompt = keep_think_block_open(text_prompt)
|
|
if debug_enabled:
|
|
debug(f'VQA interrogate: handler=smol text_prompt="{text_prompt}"')
|
|
inputs = self.processor(text=text_prompt, images=[image], padding=True, return_tensors="pt")
|
|
inputs = inputs.to(devices.device, devices.dtype)
|
|
gen_kwargs = get_kwargs()
|
|
debug(f'VQA interrogate: handler=smol generation_kwargs={gen_kwargs}')
|
|
output_ids = self.model.generate(
|
|
**inputs,
|
|
**gen_kwargs,
|
|
)
|
|
debug(f'VQA interrogate: handler=smol output_ids_shape={output_ids.shape}')
|
|
response = self.processor.batch_decode(output_ids, skip_special_tokens=True)
|
|
if debug_enabled:
|
|
debug(f'VQA interrogate: handler=smol response_before_clean="{response}"')
|
|
|
|
# Clean up thinking tags
|
|
if len(response) > 0:
|
|
text = response[0]
|
|
if shared.opts.interrogate_vlm_keep_thinking:
|
|
text = text.replace('<think>', 'Reasoning:\n').replace('</think>', '\n\nAnswer:')
|
|
else:
|
|
while '</think>' in text:
|
|
start = text.find('<think>')
|
|
end = text.find('</think>')
|
|
if start != -1 and start < end:
|
|
text = text[:start] + text[end+8:]
|
|
else:
|
|
text = text[end+8:]
|
|
response[0] = text
|
|
|
|
return response
|
|
|
|
def _load_git(self, repo: str):
|
|
"""Load Microsoft GIT model and processor."""
|
|
if self.model is None or self.loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
self.model = None
|
|
self.model = transformers.GitForCausalLM.from_pretrained(
|
|
repo,
|
|
torch_dtype=devices.dtype,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
)
|
|
self.processor = transformers.GitProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
|
self.loaded = repo
|
|
devices.torch_gc()
|
|
|
|
def _git(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
|
|
self._load_git(repo)
|
|
sd_models.move_model(self.model, devices.device)
|
|
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values
|
|
git_dict = {}
|
|
git_dict['pixel_values'] = pixel_values.to(devices.device, devices.dtype)
|
|
if len(question) > 0:
|
|
input_ids = self.processor(text=question, add_special_tokens=False).input_ids
|
|
input_ids = [self.processor.tokenizer.cls_token_id] + input_ids
|
|
input_ids = torch.tensor(input_ids).unsqueeze(0)
|
|
git_dict['input_ids'] = input_ids.to(devices.device)
|
|
with devices.inference_context():
|
|
generated_ids = self.model.generate(**git_dict)
|
|
response = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
return response
|
|
|
|
def _load_blip(self, repo: str):
|
|
"""Load Salesforce BLIP model and processor."""
|
|
if self.model is None or self.loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
self.model = None
|
|
self.model = transformers.BlipForQuestionAnswering.from_pretrained(
|
|
repo,
|
|
torch_dtype=devices.dtype,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
)
|
|
self.processor = transformers.BlipProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
|
self.loaded = repo
|
|
devices.torch_gc()
|
|
|
|
def _blip(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
|
|
self._load_blip(repo)
|
|
sd_models.move_model(self.model, devices.device)
|
|
inputs = self.processor(image, question, return_tensors="pt")
|
|
inputs = inputs.to(devices.device, devices.dtype)
|
|
with devices.inference_context():
|
|
outputs = self.model.generate(**inputs)
|
|
response = self.processor.decode(outputs[0], skip_special_tokens=True)
|
|
return response
|
|
|
|
def _load_vilt(self, repo: str):
|
|
"""Load ViLT model and processor."""
|
|
if self.model is None or self.loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
self.model = None
|
|
self.model = transformers.ViltForQuestionAnswering.from_pretrained(
|
|
repo,
|
|
torch_dtype=devices.dtype,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
)
|
|
self.processor = transformers.ViltProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
|
self.loaded = repo
|
|
devices.torch_gc()
|
|
|
|
def _vilt(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
|
|
self._load_vilt(repo)
|
|
sd_models.move_model(self.model, devices.device)
|
|
inputs = self.processor(image, question, return_tensors="pt")
|
|
inputs = inputs.to(devices.device)
|
|
with devices.inference_context():
|
|
outputs = self.model(**inputs)
|
|
logits = outputs.logits
|
|
idx = logits.argmax(-1).item()
|
|
response = self.model.config.id2label[idx]
|
|
return response
|
|
|
|
def _load_pix(self, repo: str):
|
|
"""Load Pix2Struct model and processor."""
|
|
if self.model is None or self.loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
self.model = None
|
|
self.model = transformers.Pix2StructForConditionalGeneration.from_pretrained(
|
|
repo,
|
|
torch_dtype=devices.dtype,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
)
|
|
self.processor = transformers.Pix2StructProcessor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
|
self.loaded = repo
|
|
devices.torch_gc()
|
|
|
|
def _pix(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
|
|
self._load_pix(repo)
|
|
sd_models.move_model(self.model, devices.device)
|
|
if len(question) > 0:
|
|
inputs = self.processor(images=image, text=question, return_tensors="pt").to(devices.device)
|
|
else:
|
|
inputs = self.processor(images=image, return_tensors="pt").to(devices.device)
|
|
with devices.inference_context():
|
|
outputs = self.model.generate(**inputs)
|
|
response = self.processor.decode(outputs[0], skip_special_tokens=True)
|
|
return response
|
|
|
|
def _load_moondream(self, repo: str):
|
|
"""Load Moondream 2 model and tokenizer."""
|
|
if self.model is None or self.loaded != repo:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo}"')
|
|
self.model = None
|
|
self.model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
repo,
|
|
revision="2025-06-21",
|
|
trust_remote_code=True,
|
|
torch_dtype=devices.dtype,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
)
|
|
self.processor = transformers.AutoTokenizer.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
|
|
self.loaded = repo
|
|
self.model.eval()
|
|
devices.torch_gc()
|
|
|
|
def _moondream(self, question: str, image: Image.Image, repo: str, model_name: str = None, thinking_mode: bool = False):
|
|
debug(f'VQA interrogate: handler=moondream model_name="{model_name}" repo="{repo}" question="{question}" thinking_mode={thinking_mode}')
|
|
self._load_moondream(repo)
|
|
sd_models.move_model(self.model, devices.device)
|
|
question = question.replace('<', '').replace('>', '').replace('_', ' ')
|
|
with devices.inference_context():
|
|
if question == 'CAPTION':
|
|
response = self.model.caption(image, length="short")['caption']
|
|
elif question == 'DETAILED CAPTION':
|
|
response = self.model.caption(image, length="normal")['caption']
|
|
elif question == 'MORE DETAILED CAPTION':
|
|
response = self.model.caption(image, length="long")['caption']
|
|
elif question.lower().startswith('point at ') or question == 'POINT_MODE':
|
|
target = question[9:].strip() if question.lower().startswith('point at ') else ''
|
|
if not target:
|
|
return "Please specify an object to locate"
|
|
debug(f'VQA interrogate: handler=moondream method=point target="{target}"')
|
|
result = self.model.point(image, target)
|
|
debug(f'VQA interrogate: handler=moondream point_raw_result={result}')
|
|
points = vqa_detection.parse_points(result)
|
|
if points:
|
|
self.last_detection_data = {'points': points}
|
|
return vqa_detection.format_points_text(points)
|
|
return "Object not found"
|
|
elif question.lower().startswith('detect ') or question == 'DETECT_MODE':
|
|
target = question[7:].strip() if question.lower().startswith('detect ') else ''
|
|
if not target:
|
|
return "Please specify an object to detect"
|
|
debug(f'VQA interrogate: handler=moondream method=detect target="{target}"')
|
|
result = self.model.detect(image, target)
|
|
debug(f'VQA interrogate: handler=moondream detect_raw_result={result}')
|
|
detections = vqa_detection.parse_detections(result, target)
|
|
if detections:
|
|
self.last_detection_data = {'detections': detections}
|
|
return vqa_detection.format_detections_text(detections, include_confidence=False)
|
|
return "No objects detected"
|
|
elif question == 'DETECT_GAZE' or question.lower() == 'detect gaze':
|
|
debug('VQA interrogate: handler=moondream method=detect_gaze')
|
|
faces = self.model.detect(image, "face")
|
|
debug(f'VQA interrogate: handler=moondream detect_gaze faces={faces}')
|
|
if faces.get('objects'):
|
|
eye_x, eye_y = vqa_detection.calculate_eye_position(faces['objects'][0])
|
|
result = self.model.detect_gaze(image, eye=(eye_x, eye_y))
|
|
debug(f'VQA interrogate: handler=moondream detect_gaze result={result}')
|
|
if result.get('gaze'):
|
|
gaze = result['gaze']
|
|
self.last_detection_data = {'points': [(gaze['x'], gaze['y'])]}
|
|
return f"Gaze direction: ({gaze['x']:.3f}, {gaze['y']:.3f})"
|
|
return "No face/gaze detected"
|
|
else:
|
|
debug(f'VQA interrogate: handler=moondream method=query question="{question}" reasoning={thinking_mode}')
|
|
result = self.model.query(image, question, reasoning=thinking_mode)
|
|
response = result['answer']
|
|
debug(f'VQA interrogate: handler=moondream query_result keys={list(result.keys()) if isinstance(result, dict) else "not dict"}')
|
|
if thinking_mode and 'reasoning' in result:
|
|
reasoning_text = result['reasoning'].get('text', '') if isinstance(result['reasoning'], dict) else str(result['reasoning'])
|
|
debug(f'VQA interrogate: handler=moondream reasoning_text="{reasoning_text[:100]}..."')
|
|
if shared.opts.interrogate_vlm_keep_thinking:
|
|
response = f"Reasoning:\n{reasoning_text}\n\nAnswer:\n{response}"
|
|
# When keep_thinking is False, just use the answer (reasoning is discarded)
|
|
return response
|
|
|
|
def _load_florence(self, repo: str, revision: str = None):
|
|
"""Load Florence-2 model and processor."""
|
|
_get_imports = transformers.dynamic_module_utils.get_imports
|
|
|
|
def get_imports(f):
|
|
R = _get_imports(f)
|
|
if "flash_attn" in R:
|
|
R.remove("flash_attn") # flash_attn is optional
|
|
return R
|
|
|
|
# Handle revision splitting and caching
|
|
cache_key = repo
|
|
effective_revision = revision
|
|
repo_name = repo
|
|
|
|
if repo and '@' in repo:
|
|
repo_name, revision_from_repo = repo.split('@')
|
|
effective_revision = revision_from_repo
|
|
|
|
if self.model is None or self.loaded != cache_key:
|
|
shared.log.debug(f'Interrogate load: vlm="{repo_name}" revision="{effective_revision}" path="{shared.opts.hfcache_dir}"')
|
|
transformers.dynamic_module_utils.get_imports = get_imports
|
|
self.model = None
|
|
quant_args = model_quant.create_config(module='LLM')
|
|
self.model = transformers.Florence2ForConditionalGeneration.from_pretrained(
|
|
repo_name,
|
|
dtype=torch.bfloat16,
|
|
revision=effective_revision,
|
|
torch_dtype=devices.dtype,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
**quant_args,
|
|
)
|
|
self.processor = transformers.AutoProcessor.from_pretrained(repo_name, max_pixels=1024*1024, trust_remote_code=True, revision=effective_revision, cache_dir=shared.opts.hfcache_dir)
|
|
transformers.dynamic_module_utils.get_imports = _get_imports
|
|
self.loaded = cache_key
|
|
self.model.eval()
|
|
devices.torch_gc()
|
|
|
|
def _florence(self, question: str, image: Image.Image, repo: str, revision: str = None, model_name: str = None): # pylint: disable=unused-argument
|
|
self._load_florence(repo, revision)
|
|
sd_models.move_model(self.model, devices.device)
|
|
if question.startswith('<'):
|
|
task = question.split('>', 1)[0] + '>'
|
|
else:
|
|
task = '<MORE_DETAILED_CAPTION>'
|
|
inputs = self.processor(text=task, images=image, return_tensors="pt")
|
|
input_ids = inputs['input_ids'].to(devices.device)
|
|
pixel_values = inputs['pixel_values'].to(devices.device, devices.dtype)
|
|
with devices.inference_context():
|
|
generated_ids = self.model.generate(
|
|
input_ids=input_ids,
|
|
pixel_values=pixel_values,
|
|
**get_kwargs()
|
|
)
|
|
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
|
response = self.processor.post_process_generation(generated_text, task="task", image_size=(image.width, image.height))
|
|
return response
|
|
|
|
def _load_sa2(self, repo: str):
|
|
"""Load SA2VA model and tokenizer."""
|
|
if self.model is None or self.loaded != repo:
|
|
self.model = None
|
|
self.model = transformers.AutoModel.from_pretrained(
|
|
repo,
|
|
torch_dtype=devices.dtype,
|
|
low_cpu_mem_usage=True,
|
|
use_flash_attn=False,
|
|
trust_remote_code=True)
|
|
self.model = self.model.eval()
|
|
self.processor = transformers.AutoTokenizer.from_pretrained(
|
|
repo,
|
|
trust_remote_code=True,
|
|
use_fast=False,
|
|
)
|
|
self.loaded = repo
|
|
devices.torch_gc()
|
|
|
|
def _sa2(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
|
|
self._load_sa2(repo)
|
|
sd_models.move_model(self.model, devices.device)
|
|
if question.startswith('<'):
|
|
task = question.split('>', 1)[0] + '>'
|
|
else:
|
|
task = '<MORE_DETAILED_CAPTION>'
|
|
input_dict = {
|
|
'image': image,
|
|
'text': f'<image>{task}',
|
|
'past_text': '',
|
|
'mask_prompts': None,
|
|
'tokenizer': self.processor,
|
|
}
|
|
return_dict = self.model.predict_forward(**input_dict)
|
|
response = return_dict["prediction"] # the text format answer
|
|
return response
|
|
|
|
def interrogate(self, question: str = '', system_prompt: str = None, prompt: str = None, image: Image.Image = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False, quiet: bool = False) -> str:
|
|
"""
|
|
Main entry point for VQA interrogation. Returns string answer.
|
|
Detection data stored in self.last_detection_data for annotated image creation.
|
|
"""
|
|
self.last_annotated_image = None
|
|
self.last_detection_data = None
|
|
jobid = shared.state.begin('Interrogate LLM')
|
|
t0 = time.time()
|
|
model_name = model_name or shared.opts.interrogate_vlm_model
|
|
prefill = vlm_prefill if prefill is None else prefill # Use provided prefill when specified
|
|
if isinstance(image, list):
|
|
image = image[0] if len(image) > 0 else None
|
|
if isinstance(image, dict) and 'name' in image:
|
|
image = Image.open(image['name'])
|
|
if isinstance(image, Image.Image):
|
|
if image.width > 768 or image.height > 768:
|
|
image.thumbnail((768, 768), Image.Resampling.LANCZOS)
|
|
if image.mode != 'RGB':
|
|
image = image.convert('RGB')
|
|
if image is None:
|
|
shared.log.error(f'VQA interrogate: model="{model_name}" error="No input image provided"')
|
|
shared.state.end(jobid)
|
|
return 'Error: No input image provided. Please upload or select an image.'
|
|
|
|
# Convert friendly prompt names to internal tokens/commands
|
|
if question == "Use Prompt":
|
|
# Use content from Prompt field directly - requires user input
|
|
if not prompt or len(prompt.strip()) < 2:
|
|
shared.log.error(f'VQA interrogate: model="{model_name}" error="Please enter a prompt"')
|
|
shared.state.end(jobid)
|
|
return 'Error: Please enter a question or instruction in the Prompt field.'
|
|
question = prompt
|
|
elif question in vlm_prompt_mapping:
|
|
# Check if this is a mode that requires user input (Point/Detect)
|
|
raw_mapping = vlm_prompt_mapping.get(question)
|
|
if raw_mapping in ("POINT_MODE", "DETECT_MODE"):
|
|
# These modes require user input in the prompt field
|
|
if not prompt or len(prompt.strip()) < 2:
|
|
shared.log.error(f'VQA interrogate: model="{model_name}" error="Please specify what to find in the prompt field"')
|
|
shared.state.end(jobid)
|
|
return 'Error: Please specify what to find in the prompt field (e.g., "the red car" or "faces").'
|
|
# Convert friendly name to internal token (handles Point/Detect prefix)
|
|
question = get_internal_prompt(question, prompt)
|
|
# else: question is already an internal token or custom text
|
|
|
|
from modules import modelloader
|
|
modelloader.hf_login()
|
|
|
|
try:
|
|
if model_name is None:
|
|
shared.log.error(f'Interrogate: type=vlm model="{model_name}" no model selected')
|
|
shared.state.end(jobid)
|
|
return ''
|
|
vqa_model = vlm_models.get(model_name, None)
|
|
if vqa_model is None:
|
|
shared.log.error(f'Interrogate: type=vlm model="{model_name}" unknown')
|
|
shared.state.end(jobid)
|
|
return ''
|
|
|
|
handler = 'unknown'
|
|
if 'git' in vqa_model.lower():
|
|
handler = 'git'
|
|
answer = self._git(question, image, vqa_model, model_name)
|
|
elif 'vilt' in vqa_model.lower():
|
|
handler = 'vilt'
|
|
answer = self._vilt(question, image, vqa_model, model_name)
|
|
elif 'blip' in vqa_model.lower():
|
|
handler = 'blip'
|
|
answer = self._blip(question, image, vqa_model, model_name)
|
|
elif 'pix' in vqa_model.lower():
|
|
handler = 'pix'
|
|
answer = self._pix(question, image, vqa_model, model_name)
|
|
elif 'moondream3' in vqa_model.lower():
|
|
handler = 'moondream3'
|
|
from modules.interrogate import moondream3
|
|
answer = moondream3.predict(question, image, vqa_model, model_name, thinking_mode=thinking_mode)
|
|
elif 'moondream2' in vqa_model.lower():
|
|
handler = 'moondream'
|
|
answer = self._moondream(question, image, vqa_model, model_name, thinking_mode)
|
|
elif 'florence' in vqa_model.lower():
|
|
handler = 'florence'
|
|
answer = self._florence(question, image, vqa_model, None, model_name)
|
|
elif 'qwen' in vqa_model.lower() or 'torii' in vqa_model.lower() or 'mimo' in vqa_model.lower():
|
|
handler = 'qwen'
|
|
answer = self._qwen(question, image, vqa_model, system_prompt, model_name, prefill, thinking_mode)
|
|
elif 'smol' in vqa_model.lower():
|
|
handler = 'smol'
|
|
answer = self._smol(question, image, vqa_model, system_prompt, model_name, prefill, thinking_mode)
|
|
elif 'joytag' in vqa_model.lower():
|
|
handler = 'joytag'
|
|
from modules.interrogate import joytag
|
|
answer = joytag.predict(image)
|
|
elif 'joycaption' in vqa_model.lower():
|
|
handler = 'joycaption'
|
|
from modules.interrogate import joycaption
|
|
answer = joycaption.predict(question, image, vqa_model)
|
|
elif 'deepseek' in vqa_model.lower():
|
|
handler = 'deepseek'
|
|
from modules.interrogate import deepseek
|
|
answer = deepseek.predict(question, image, vqa_model)
|
|
elif 'paligemma' in vqa_model.lower():
|
|
handler = 'paligemma'
|
|
answer = self._paligemma(question, image, vqa_model, model_name)
|
|
elif 'gemma' in vqa_model.lower():
|
|
handler = 'gemma'
|
|
answer = self._gemma(question, image, vqa_model, system_prompt, model_name, prefill, thinking_mode)
|
|
elif 'ovis' in vqa_model.lower():
|
|
handler = 'ovis'
|
|
answer = self._ovis(question, image, vqa_model, model_name)
|
|
elif 'sa2' in vqa_model.lower():
|
|
handler = 'sa2'
|
|
answer = self._sa2(question, image, vqa_model, model_name)
|
|
elif 'fastvlm' in vqa_model.lower():
|
|
handler = 'fastvlm'
|
|
answer = self._fastvlm(question, image, vqa_model, model_name)
|
|
else:
|
|
answer = 'unknown model'
|
|
except Exception as e:
|
|
errors.display(e, 'VQA')
|
|
answer = 'error'
|
|
|
|
if shared.opts.interrogate_offload and self.model is not None:
|
|
sd_models.move_model(self.model, devices.cpu, force=True)
|
|
devices.torch_gc(force=True, reason='vqa')
|
|
|
|
# Clean the answer
|
|
answer = clean(answer, question, prefill)
|
|
|
|
# Create annotated image if detection data is available
|
|
if self.last_detection_data and isinstance(self.last_detection_data, dict) and image:
|
|
detections = self.last_detection_data.get('detections', None)
|
|
points = self.last_detection_data.get('points', None)
|
|
if detections or points:
|
|
self.last_annotated_image = vqa_detection.draw_bounding_boxes(image, detections or [], points)
|
|
debug(f'VQA interrogate: handler={handler} created annotated image detections={len(detections) if detections else 0} points={len(points) if points else 0}')
|
|
|
|
debug(f'VQA interrogate: handler={handler} response_after_clean="{answer}" has_annotation={self.last_annotated_image is not None}')
|
|
t1 = time.time()
|
|
if not quiet:
|
|
shared.log.debug(f'Interrogate: type=vlm model="{model_name}" repo="{vqa_model}" args={get_kwargs()} time={t1-t0:.2f}')
|
|
shared.state.end(jobid)
|
|
return answer
|
|
|
|
def batch(self, model_name, system_prompt, batch_files, batch_folder, batch_str, question, prompt, write, append, recursive, prefill=None, thinking_mode=False):
|
|
class BatchWriter:
|
|
def __init__(self, folder, mode='w'):
|
|
self.folder = folder
|
|
self.csv = None
|
|
self.file = None
|
|
self.mode = mode
|
|
|
|
def add(self, file, prompt_text):
|
|
txt_file = os.path.splitext(file)[0] + ".txt"
|
|
if self.mode == 'a':
|
|
prompt_text = '\n' + prompt_text
|
|
with open(os.path.join(self.folder, txt_file), self.mode, encoding='utf-8') as f:
|
|
f.write(prompt_text)
|
|
|
|
def close(self):
|
|
if self.file is not None:
|
|
self.file.close()
|
|
|
|
files = []
|
|
if batch_files is not None:
|
|
files += [f.name for f in batch_files]
|
|
if batch_folder is not None:
|
|
files += [f.name for f in batch_folder]
|
|
if batch_str is not None and len(batch_str) > 0 and os.path.exists(batch_str) and os.path.isdir(batch_str):
|
|
from modules.files_cache import list_files
|
|
files += list(list_files(batch_str, ext_filter=['.png', '.jpg', '.jpeg', '.webp', '.jxl'], recursive=recursive))
|
|
if len(files) == 0:
|
|
shared.log.warning('Interrogate batch: type=vlm no images')
|
|
return ''
|
|
jobid = shared.state.begin('Interrogate batch')
|
|
prompts = []
|
|
if write:
|
|
mode = 'w' if not append else 'a'
|
|
writer = BatchWriter(os.path.dirname(files[0]), mode=mode)
|
|
orig_offload = shared.opts.interrogate_offload
|
|
shared.opts.interrogate_offload = False
|
|
import rich.progress as rp
|
|
pbar = rp.Progress(rp.TextColumn('[cyan]Caption:'), rp.BarColumn(), rp.MofNCompleteColumn(), rp.TaskProgressColumn(), rp.TimeRemainingColumn(), rp.TimeElapsedColumn(), rp.TextColumn('[cyan]{task.description}'), console=shared.console)
|
|
with pbar:
|
|
task = pbar.add_task(total=len(files), description='starting...')
|
|
for file in files:
|
|
pbar.update(task, advance=1, description=file)
|
|
try:
|
|
if shared.state.interrupted:
|
|
break
|
|
img = Image.open(file)
|
|
caption = self.interrogate(question, system_prompt, prompt, img, model_name, prefill, thinking_mode, quiet=True)
|
|
# Save annotated image if available
|
|
if self.last_annotated_image and write:
|
|
annotated_path = os.path.splitext(file)[0] + "_annotated.png"
|
|
self.last_annotated_image.save(annotated_path)
|
|
prompts.append(caption)
|
|
if write:
|
|
writer.add(file, caption)
|
|
except Exception as e:
|
|
shared.log.error(f'Interrogate batch: {e}')
|
|
if write:
|
|
writer.close()
|
|
shared.opts.interrogate_offload = orig_offload
|
|
shared.state.end(jobid)
|
|
return '\n\n'.join(prompts)
|
|
|
|
|
|
# Module-level singleton instance
|
|
_instance = None
|
|
|
|
|
|
def get_instance() -> VQA:
|
|
"""Get or create the singleton VQA instance."""
|
|
global _instance # pylint: disable=global-statement
|
|
if _instance is None:
|
|
_instance = VQA()
|
|
return _instance
|
|
|
|
|
|
# Backwards-compatible module-level functions
|
|
def interrogate(*args, **kwargs):
|
|
return get_instance().interrogate(*args, **kwargs)
|
|
|
|
|
|
def unload_model():
|
|
return get_instance().unload()
|
|
|
|
|
|
def load_model(model_name: str = None):
|
|
return get_instance().load(model_name)
|
|
|
|
|
|
def get_last_annotated_image():
|
|
return get_instance().last_annotated_image
|
|
|
|
|
|
def batch(*args, **kwargs):
|
|
return get_instance().batch(*args, **kwargs)
|