1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/cli/image-search.py
Vladimir Mandic 10fb362bdc server state history
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-04-25 12:25:30 -04:00

252 lines
11 KiB
Python
Executable File

#!/usr/bin/env python
from typing import Union
import os
import re
import logging
from tqdm.rich import tqdm
import torch
import PIL
import faiss
import numpy as np
import pandas as pd
import transformers
class ImageDB:
# TODO index: quantize and train faiss index
# TODO index: clip batch processing
def __init__(self,
name:str='db',
fmt:str='json',
cache_dir:str=None,
dtype:torch.dtype=torch.float16,
device:torch.device=torch.device('cpu'),
model:str='openai/clip-vit-large-patch14', # 'facebook/dinov2-small'
debug:bool=False,
pbar:bool=True,
):
self.format = fmt
self.name = name
self.cache_dir = cache_dir
self.processor: transformers.AutoImageProcessor = None
self.model: transformers.AutoModel = None
self.tokenizer = transformers.AutoTokenizer = None
self.device: torch.device = device
self.dtype: torch.dtype = dtype
self.dimension = 768 if 'clip' in model else 384
self.debug = debug
self.pbar = pbar
self.repo = model
self.df = pd.DataFrame([], columns=['filename', 'timestamp', 'metadata']) # image/metadata database
self.index = faiss.IndexFlatL2(self.dimension) # embed database
self.log = logging.getLogger(__name__)
self.err = logging.getLogger(__name__).error
self.log = logging.getLogger(__name__).info if self.debug else logging.getLogger(__name__).debug
# self.init()
# self.load()
def __str__(self):
return f'db: name="{self.name}" format={self.format} device={self.device} dtype={self.dtype} dimension={self.dimension} model="{self.repo}" records={len(self.df)} index={self.index.ntotal}'
def init(self): # initialize models
if self.processor is None or self.model is None:
if 'clip' in self.repo:
self.processor = transformers.CLIPImageProcessor.from_pretrained(self.repo, cache_dir=self.cache_dir)
self.tokenizer = transformers.CLIPTokenizer.from_pretrained(self.repo, cache_dir=self.cache_dir)
self.model = transformers.CLIPModel.from_pretrained(self.repo, cache_dir=self.cache_dir).to(device=self.device, dtype=self.dtype)
elif 'dino' in self.repo:
self.processor = transformers.AutoImageProcessor.from_pretrained(self.repo, cache_dir=self.cache_dir)
self.model = transformers.AutoModel.from_pretrained(self.repo, cache_dir=self.cache_dir).to(device=self.device, dtype=self.dtype)
else:
self.err(f'db: model="{self.repo}" unknown')
self.log(f'db: load model="{self.repo}" cache="{self.cache_dir}" device={self.device} dtype={self.dtype}')
def load(self): # load db to disk
if self.format == 'json' and os.path.exists(f'{self.name}.json'):
self.df = pd.read_json(f'{self.name}.json')
elif self.format == 'csv' and os.path.exists(f'{self.name}.csv'):
self.df = pd.read_csv(f'{self.name}.csv')
elif self.format == 'pickle' and os.path.exists(f'{self.name}.pkl'):
self.df = pd.read_pickle(f'{self.name}.parquet')
if os.path.exists(f'{self.name}.index'):
self.index = faiss.read_index(f'{self.name}.index')
if self.index.ntotal != len(self.df):
self.err(f'db: index={self.index.ntotal} data={len(self.df)} mismatch')
self.index = faiss.IndexFlatL2(self.dimension)
self.df = pd.DataFrame([], columns=['filename', 'timestamp', 'metadata'])
self.log(f'db: load data={len(self.df)} name={self.name} format={self.format} name={self.name}')
def save(self): # save db to disk
if self.format == 'json':
self.df.to_json(f'{self.name}.json')
elif self.format == 'csv':
self.df.to_csv(f'{self.name}.csv')
elif self.format == 'pickle':
self.df.to_pickle(f'{self.name}.pkl')
faiss.write_index(self.index, f'{self.name}.index')
self.log(f'db: save data={len(self.df)} name={self.name} format={self.format} name={self.name}')
def normalize(self, embed) -> np.ndarray: # normalize embed before using it
embed = embed.detach().float().cpu().numpy()
faiss.normalize_L2(embed)
return embed
def embedding(self, query: Union[PIL.Image.Image | str]) -> np.ndarray: # calculate embed for prompt or image
if self.processor is None or self.model is None:
self.err('db: model not loaded')
if isinstance(query, str) and os.path.exists(query):
query = PIL.Image.open(query).convert('RGB')
self.model = self.model.to(self.device)
with torch.no_grad():
if 'clip' in self.repo:
if isinstance(query, str):
processed = self.tokenizer(text=query, padding=True, return_tensors="pt").to(device=self.device)
results = self.model.get_text_features(**processed)
else:
processed = self.processor(images=query, return_tensors="pt").to(device=self.device, dtype=self.dtype)
results = self.model.get_image_features(**processed)
elif 'dino' in self.repo:
processed = self.processor(images=query, return_tensors="pt").to(device=self.device, dtype=self.dtype)
results = self.model(**processed)
results = results.last_hidden_state.mean(dim=1)
else:
self.err(f'db: model="{self.repo}" unknown')
return None
return self.normalize(results)
def add(self, embed, filename=None, metadata=None): # add embed to db
rec = pd.DataFrame([{'filename': filename, 'timestamp': pd.Timestamp.now(), 'metadata': metadata}])
if len(self.df) > 0:
self.df = pd.concat([self.df, rec], ignore_index=True)
else:
self.df = rec
self.index.add(embed)
def search(self, filename: str = None, metadata: str = None, embed: np.ndarray = None, k=10, d=1.0): # search by filename/metadata/prompt-embed/image-embed
def dct(record: pd.DataFrame, mode: str, distance: float = None):
if distance is not None:
return {'type': mode, 'filename': record[1]['filename'], 'metadata': record[1]['metadata'], 'distance': round(distance, 2)}
else:
return {'type': mode, 'filename': record[1]['filename'], 'metadata': record[1]['metadata']}
if self.index.ntotal == 0:
return
self.log(f'db: search k={k} d={d}')
if embed is not None:
distances, indexes = self.index.search(embed, k)
records = self.df.iloc[indexes[0]]
for record, distance in zip(records.iterrows(), distances[0]):
if d <= 0 or distance <= d:
yield dct(record, distance=distance, mode='embed')
if filename is not None:
records = self.df[self.df['filename'].str.contains(filename, na=False, case=False)]
for record in records.iterrows():
yield dct(record, mode='filename')
if metadata is not None:
records = self.df[self.df['metadata'].str.contains(filename, na=False, case=False)]
for record in records.iterrows():
yield dct(record, mode='metadata')
def decode(self, s: bytes): # decode byte-encoded exif metadata
remove_prefix = lambda text, prefix: text[len(prefix):] if text.startswith(prefix) else text # pylint: disable=unnecessary-lambda-assignment
for encoding in ['utf-8', 'utf-16', 'ascii', 'latin_1', 'cp1252', 'cp437']: # try different encodings
try:
s = remove_prefix(s, b'UNICODE')
s = remove_prefix(s, b'ASCII')
s = remove_prefix(s, b'\x00')
val = s.decode(encoding, errors="strict")
val = re.sub(r'[\x00-\x09\n\s\s+]', '', val).strip() # remove remaining special characters, new line breaks, and double empty spaces
if len(val) == 0: # remove empty strings
val = None
return val
except Exception:
pass
return None
def metadata(self, image: PIL.Image.Image): # get exif metadata from image
exif = image._getexif() # pylint: disable=protected-access
if exif is None:
return ''
for k, v in exif.items():
if k == 37510: # comment
return self.decode(v)
return ''
def image(self, filename: str, image=None): # add file/image to db
try:
if image is None:
image = PIL.Image.open(filename)
image.load()
embed = self.embedding(image.convert('RGB'))
metadata = self.metadata(image)
image.close()
self.add(embed, filename=filename, metadata=metadata)
except Exception as _e:
# self.err(f'db: {str(_e)}')
pass
def folder(self, folder: str): # add all files from folder to db
files = []
for root, _subdir, _files in os.walk(folder):
for f in _files:
files.append(os.path.join(root, f))
if self.pbar:
for f in tqdm(files):
self.image(filename=f)
else:
for f in files:
self.image(filename=f)
def offload(self): # offload model to cpu
if self.model is not None:
self.model = self.model.to('cpu')
if __name__ == '__main__':
import time
import argparse
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description = 'image-search')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--search', action='store_true', help='run search')
group.add_argument('--index', action='store_true', help='run indexing')
parser.add_argument('--db', default='db', help='database name')
parser.add_argument('--model', default='openai/clip-vit-large-patch14', help='huggingface model')
parser.add_argument('--cache', default='/mnt/models/huggingface', help='cache folder')
parser.add_argument('input', nargs='*', default=os.getcwd())
args = parser.parse_args()
db = ImageDB(
name=args.db,
model=args.model, # 'facebook/dinov2-small'
cache_dir=args.cache,
dtype=torch.bfloat16,
device=torch.device('cuda'),
debug=True,
pbar=True,
)
db.init()
db.load()
print(db)
if args.index:
t0 = time.time()
if len(args.input) > 0:
for fn in args.input:
if os.path.isfile(fn):
db.image(filename=fn)
elif os.path.isdir(fn):
db.folder(folder=fn)
t1 = time.time()
print('index', t1-t0)
db.save()
db.offload()
if args.search:
for ref in args.input:
emb = db.embedding(ref)
res = db.search(filename=ref, metadata=ref, embed=emb, k=10, d=0)
for r in res:
print(ref, r)