mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
Make external VQA handlers (moondream3, joytag, joycaption, deepseek) compatible with VQA load/unload mechanism for consistent model lifecycle. - Added vqa_detection.py, add shared detection helpers - Add load and unload functions to all external handlers - Replace device_map="auto" with sd_models.move_model in joycaption - Update dispatcher and moondream handlers to use shared helpers
128 lines
5.8 KiB
Python
128 lines
5.8 KiB
Python
# based on <https://huggingface.co/fancyfeast/llama-joycaption-alpha-two-hf-llava>
|
|
|
|
from dataclasses import dataclass
|
|
import torch
|
|
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
|
from modules import shared, devices, sd_models, model_quant
|
|
|
|
|
|
"""
|
|
Example prompts
|
|
Short description: Write a short description of the image.
|
|
Detailed descriptive: Please provide a detailed description of the image.
|
|
Descriptive: Write a descriptive caption for this image in a formal tone.
|
|
Descriptive (Informal): Write a descriptive caption for this image in a casual tone.
|
|
Training Prompt: Write a stable diffusion prompt for this image.
|
|
MidJourney: Write a MidJourney prompt for this image.
|
|
Booru tag list: Write a list of Booru tags for this image.
|
|
Booru-like tag list: Write a list of Booru-like tags for this image.
|
|
Art Critic: Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc.
|
|
Product Listing: Write a caption for this image as though it were a product listing.
|
|
Social Media Post: Write a caption for this image as if it were being used for a social media post.
|
|
Extra Options:
|
|
- If there is a person/character in the image you must refer to them as {name}.
|
|
- Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).
|
|
- Include information about lighting.
|
|
- Include information about camera angle.
|
|
- Include information about whether there is a watermark or not.
|
|
- Include information about whether there are JPEG artifacts or not.
|
|
- If it is a photo you MUST include information about what camera was likely used and details such as aperture, shutter speed, ISO, etc.
|
|
- Do NOT include anything sexual; keep it PG.
|
|
- Do NOT mention the image's resolution.
|
|
- You MUST include information about the subjective aesthetic quality of the image from low to very high.
|
|
- Include information on the image's composition style, such as leading lines, rule of thirds, or symmetry.
|
|
- Do NOT mention any text that is in the image.
|
|
- Specify the depth of field and whether the background is in focus or blurred.
|
|
- If applicable, mention the likely use of artificial or natural lighting sources.
|
|
- Do NOT use any ambiguous language.
|
|
- Include whether the image is sfw, suggestive, or nsfw.
|
|
- ONLY describe the most important elements of the image.
|
|
"""
|
|
|
|
@dataclass
|
|
class JoyOptions():
|
|
repo: str = "fancyfeast/llama-joycaption-alpha-two-hf-llava"
|
|
temp: float = 0.5
|
|
top_k: float = 10
|
|
top_p: float = 0.9
|
|
max_new_tokens: int = 512
|
|
sample: bool = True
|
|
|
|
def __str__(self):
|
|
return f'repo="{self.repo}" temp={self.temp} top_k={self.top_k} top_p={self.top_p} sample={self.sample} tokens={self.max_new_tokens}'
|
|
|
|
|
|
processor: AutoProcessor = None
|
|
llava_model: LlavaForConditionalGeneration = None
|
|
opts = JoyOptions()
|
|
|
|
|
|
def load(repo: str = None):
|
|
"""Load JoyCaption model."""
|
|
global llava_model, processor # pylint: disable=global-statement
|
|
repo = repo or opts.repo
|
|
if llava_model is None or opts.repo != repo:
|
|
opts.repo = repo
|
|
llava_model = None
|
|
shared.log.info(f'Interrogate: type=vlm model="JoyCaption" {str(opts)}')
|
|
processor = AutoProcessor.from_pretrained(repo, max_pixels=1024*1024, cache_dir=shared.opts.hfcache_dir)
|
|
quant_args = model_quant.create_config(module='LLM')
|
|
llava_model = LlavaForConditionalGeneration.from_pretrained(
|
|
repo,
|
|
torch_dtype=devices.dtype,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
**quant_args,
|
|
)
|
|
llava_model.eval()
|
|
sd_models.move_model(llava_model, devices.device)
|
|
|
|
|
|
def unload():
|
|
"""Release JoyCaption model from GPU/memory."""
|
|
global llava_model, processor # pylint: disable=global-statement
|
|
if llava_model is not None:
|
|
shared.log.debug(f'JoyCaption unload: model="{opts.repo}"')
|
|
sd_models.move_model(llava_model, devices.cpu, force=True)
|
|
llava_model = None
|
|
processor = None
|
|
devices.torch_gc(force=True)
|
|
else:
|
|
shared.log.debug('JoyCaption unload: no model loaded')
|
|
|
|
|
|
@torch.no_grad()
|
|
def predict(question: str, image, vqa_model: str = None) -> str:
|
|
opts.max_new_tokens = shared.opts.interrogate_vlm_max_length
|
|
load(vqa_model)
|
|
|
|
if len(question) < 2:
|
|
question = "Describe the image."
|
|
question = question.replace('<', '').replace('>', '')
|
|
convo = [
|
|
{ "role": "system", "content": "You are a helpful image captioner." },
|
|
{ "role": "user", "content": question },
|
|
]
|
|
convo_string = processor.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
|
|
inputs = processor(text=[convo_string], images=[image], return_tensors="pt").to(devices.device)
|
|
inputs['pixel_values'] = inputs['pixel_values'].to(devices.dtype)
|
|
with devices.inference_context():
|
|
generate_ids = llava_model.generate( # Generate the captions
|
|
**inputs,
|
|
# input_ids=inputs['input_ids'],
|
|
# pixel_values=inputs['pixel_values'],
|
|
# attention_mask=inputs['attention_mask'],
|
|
max_new_tokens=opts.max_new_tokens,
|
|
suppress_tokens=None,
|
|
use_cache=True,
|
|
do_sample=opts.sample,
|
|
temperature=opts.temp,
|
|
top_k=opts.top_k,
|
|
top_p=opts.top_p,
|
|
)[0]
|
|
generate_ids = generate_ids[inputs['input_ids'].shape[1]:] # Trim off the prompt
|
|
caption = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) # Decode the caption
|
|
if shared.opts.interrogate_offload:
|
|
sd_models.move_model(llava_model, devices.cpu, force=True)
|
|
caption = caption.replace('\n\n', '\n').strip()
|
|
return caption
|