1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/interrogate/deepseek.py
vladmandic 69f0d6bf5d lint
Signed-off-by: vladmandic <mandic00@live.com>
2025-12-08 18:12:47 +01:00

117 lines
4.5 KiB
Python

# source: <https://huggingface.co/deepseek-ai/deepseek-vl2-tiny>
# implementation: <https://github.com/deepseek-ai/DeepSeek-VL2/tree/main/deepseek_vl2/serve>
"""
- run `git clone https://github.com/deepseek-ai/DeepSeek-VL2 repositories/deepseek-vl2 --depth 1`
- remove hardcoded `python==3.9` requirement due to obsolete attrdict package dependency
- patch transformers due to internal changes as deepseek requires obsolete `transformers==4.38.2`
- deepseek requires `xformers`
- broken flash_attention
"""
import os
import sys
import importlib
from transformers import AutoModelForCausalLM
from modules import shared, devices, paths, sd_models
# model_path = "deepseek-ai/deepseek-vl2-small"
vl_gpt = None
vl_chat_processor = None
loaded_repo = None
class fake_attrdict():
class AttrDict(dict): # dot notation access to dictionary attributes
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def load(repo: str):
"""Load DeepSeek VL2 model (experimental)."""
global vl_gpt, vl_chat_processor, loaded_repo # pylint: disable=global-statement
if not shared.cmd_opts.experimental:
shared.log.error(f'Interrogate: type=vlm model="DeepSeek VL2" repo="{repo}" is experimental-only')
return False
folder = os.path.join(paths.script_path, 'repositories', 'deepseek-vl2')
if not os.path.exists(folder):
shared.log.error(f'Interrogate: type=vlm model="DeepSeek VL2" repo="{repo}" deepseek-vl2 repo not found')
return False
if vl_gpt is None or loaded_repo != repo:
sys.modules['attrdict'] = fake_attrdict
from transformers.models.llama import modeling_llama
modeling_llama.LlamaFlashAttention2 = modeling_llama.LlamaAttention
importlib.import_module('repositories.deepseek-vl2.deepseek_vl2')
deekseek_vl_models = importlib.import_module('repositories.deepseek-vl2.deepseek_vl2.models')
vl_chat_processor = deekseek_vl_models.DeepseekVLV2Processor.from_pretrained(repo, cache_dir=shared.opts.hfcache_dir)
vl_gpt = AutoModelForCausalLM.from_pretrained(
repo,
trust_remote_code=True,
cache_dir=shared.opts.hfcache_dir,
)
vl_gpt.to(dtype=devices.dtype)
vl_gpt.eval()
loaded_repo = repo
shared.log.info(f'Interrogate: type=vlm model="DeepSeek VL2" repo="{repo}"')
sd_models.move_model(vl_gpt, devices.device)
return True
def unload():
"""Release DeepSeek VL2 model from GPU/memory."""
global vl_gpt, vl_chat_processor, loaded_repo # pylint: disable=global-statement
if vl_gpt is not None:
shared.log.debug(f'DeepSeek unload: model="{loaded_repo}"')
sd_models.move_model(vl_gpt, devices.cpu, force=True)
vl_gpt = None
vl_chat_processor = None
loaded_repo = None
devices.torch_gc(force=True)
else:
shared.log.debug('DeepSeek unload: no model loaded')
def predict(question, image, repo):
global vl_gpt # pylint: disable=global-statement
if not load(repo):
return ''
if len(question) < 2:
question = "Describe the image."
question = question.replace('<', '').replace('>', '')
conversation = [
{
"role": "<|User|>",
"content": f"<image>\n<|ref|>{question}<|/ref|>.",
# "images": [image],
},
{"role": "<|Assistant|>", "content": ""},
]
prepare_inputs = vl_chat_processor(
conversations=conversation,
images=[image],
force_batchify=True,
system_prompt=""
).to(device=devices.device, dtype=devices.dtype)
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
inputs_embeds = inputs_embeds.to(device=devices.device, dtype=devices.dtype)
sd_models.move_model(vl_gpt, devices.device)
with devices.inference_context():
outputs = vl_gpt.language.generate(
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=vl_chat_processor.tokenizer.eos_token_id,
bos_token_id=vl_chat_processor.tokenizer.bos_token_id,
eos_token_id=vl_chat_processor.tokenizer.eos_token_id,
max_new_tokens=shared.opts.interrogate_vlm_max_length,
do_sample=False,
use_cache=True
)
vl_gpt = vl_gpt.to(devices.cpu)
answer = vl_chat_processor.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
print('inputs', prepare_inputs['sft_format'][0])
print('answer', answer)
return answer