mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
252 lines
11 KiB
Python
Executable File
252 lines
11 KiB
Python
Executable File
#!/usr/bin/env python
|
|
|
|
from typing import Union
|
|
import os
|
|
import re
|
|
import logging
|
|
from tqdm.rich import tqdm
|
|
import torch
|
|
import PIL
|
|
import faiss
|
|
import numpy as np
|
|
import pandas as pd
|
|
import transformers
|
|
|
|
|
|
class ImageDB:
|
|
# TODO index: quantize and train faiss index
|
|
# TODO index: clip batch processing
|
|
def __init__(self,
|
|
name:str='db',
|
|
fmt:str='json',
|
|
cache_dir:str=None,
|
|
dtype:torch.dtype=torch.float16,
|
|
device:torch.device=torch.device('cpu'),
|
|
model:str='openai/clip-vit-large-patch14', # 'facebook/dinov2-small'
|
|
debug:bool=False,
|
|
pbar:bool=True,
|
|
):
|
|
self.format = fmt
|
|
self.name = name
|
|
self.cache_dir = cache_dir
|
|
self.processor: transformers.AutoImageProcessor = None
|
|
self.model: transformers.AutoModel = None
|
|
self.tokenizer = transformers.AutoTokenizer = None
|
|
self.device: torch.device = device
|
|
self.dtype: torch.dtype = dtype
|
|
self.dimension = 768 if 'clip' in model else 384
|
|
self.debug = debug
|
|
self.pbar = pbar
|
|
self.repo = model
|
|
self.df = pd.DataFrame([], columns=['filename', 'timestamp', 'metadata']) # image/metadata database
|
|
self.index = faiss.IndexFlatL2(self.dimension) # embed database
|
|
self.log = logging.getLogger(__name__)
|
|
self.err = logging.getLogger(__name__).error
|
|
self.log = logging.getLogger(__name__).info if self.debug else logging.getLogger(__name__).debug
|
|
# self.init()
|
|
# self.load()
|
|
|
|
def __str__(self):
|
|
return f'db: name="{self.name}" format={self.format} device={self.device} dtype={self.dtype} dimension={self.dimension} model="{self.repo}" records={len(self.df)} index={self.index.ntotal}'
|
|
|
|
def init(self): # initialize models
|
|
if self.processor is None or self.model is None:
|
|
if 'clip' in self.repo:
|
|
self.processor = transformers.CLIPImageProcessor.from_pretrained(self.repo, cache_dir=self.cache_dir)
|
|
self.tokenizer = transformers.CLIPTokenizer.from_pretrained(self.repo, cache_dir=self.cache_dir)
|
|
self.model = transformers.CLIPModel.from_pretrained(self.repo, cache_dir=self.cache_dir).to(device=self.device, dtype=self.dtype)
|
|
elif 'dino' in self.repo:
|
|
self.processor = transformers.AutoImageProcessor.from_pretrained(self.repo, cache_dir=self.cache_dir)
|
|
self.model = transformers.AutoModel.from_pretrained(self.repo, cache_dir=self.cache_dir).to(device=self.device, dtype=self.dtype)
|
|
else:
|
|
self.err(f'db: model="{self.repo}" unknown')
|
|
self.log(f'db: load model="{self.repo}" cache="{self.cache_dir}" device={self.device} dtype={self.dtype}')
|
|
|
|
def load(self): # load db to disk
|
|
if self.format == 'json' and os.path.exists(f'{self.name}.json'):
|
|
self.df = pd.read_json(f'{self.name}.json')
|
|
elif self.format == 'csv' and os.path.exists(f'{self.name}.csv'):
|
|
self.df = pd.read_csv(f'{self.name}.csv')
|
|
elif self.format == 'pickle' and os.path.exists(f'{self.name}.pkl'):
|
|
self.df = pd.read_pickle(f'{self.name}.parquet')
|
|
if os.path.exists(f'{self.name}.index'):
|
|
self.index = faiss.read_index(f'{self.name}.index')
|
|
if self.index.ntotal != len(self.df):
|
|
self.err(f'db: index={self.index.ntotal} data={len(self.df)} mismatch')
|
|
self.index = faiss.IndexFlatL2(self.dimension)
|
|
self.df = pd.DataFrame([], columns=['filename', 'timestamp', 'metadata'])
|
|
self.log(f'db: load data={len(self.df)} name={self.name} format={self.format} name={self.name}')
|
|
|
|
def save(self): # save db to disk
|
|
if self.format == 'json':
|
|
self.df.to_json(f'{self.name}.json')
|
|
elif self.format == 'csv':
|
|
self.df.to_csv(f'{self.name}.csv')
|
|
elif self.format == 'pickle':
|
|
self.df.to_pickle(f'{self.name}.pkl')
|
|
faiss.write_index(self.index, f'{self.name}.index')
|
|
self.log(f'db: save data={len(self.df)} name={self.name} format={self.format} name={self.name}')
|
|
|
|
def normalize(self, embed) -> np.ndarray: # normalize embed before using it
|
|
embed = embed.detach().float().cpu().numpy()
|
|
faiss.normalize_L2(embed)
|
|
return embed
|
|
|
|
def embedding(self, query: Union[PIL.Image.Image | str]) -> np.ndarray: # calculate embed for prompt or image
|
|
if self.processor is None or self.model is None:
|
|
self.err('db: model not loaded')
|
|
if isinstance(query, str) and os.path.exists(query):
|
|
query = PIL.Image.open(query).convert('RGB')
|
|
self.model = self.model.to(self.device)
|
|
with torch.no_grad():
|
|
if 'clip' in self.repo:
|
|
if isinstance(query, str):
|
|
processed = self.tokenizer(text=query, padding=True, return_tensors="pt").to(device=self.device)
|
|
results = self.model.get_text_features(**processed)
|
|
else:
|
|
processed = self.processor(images=query, return_tensors="pt").to(device=self.device, dtype=self.dtype)
|
|
results = self.model.get_image_features(**processed)
|
|
elif 'dino' in self.repo:
|
|
processed = self.processor(images=query, return_tensors="pt").to(device=self.device, dtype=self.dtype)
|
|
results = self.model(**processed)
|
|
results = results.last_hidden_state.mean(dim=1)
|
|
else:
|
|
self.err(f'db: model="{self.repo}" unknown')
|
|
return None
|
|
return self.normalize(results)
|
|
|
|
def add(self, embed, filename=None, metadata=None): # add embed to db
|
|
rec = pd.DataFrame([{'filename': filename, 'timestamp': pd.Timestamp.now(), 'metadata': metadata}])
|
|
if len(self.df) > 0:
|
|
self.df = pd.concat([self.df, rec], ignore_index=True)
|
|
else:
|
|
self.df = rec
|
|
self.index.add(embed)
|
|
|
|
def search(self, filename: str = None, metadata: str = None, embed: np.ndarray = None, k=10, d=1.0): # search by filename/metadata/prompt-embed/image-embed
|
|
def dct(record: pd.DataFrame, mode: str, distance: float = None):
|
|
if distance is not None:
|
|
return {'type': mode, 'filename': record[1]['filename'], 'metadata': record[1]['metadata'], 'distance': round(distance, 2)}
|
|
else:
|
|
return {'type': mode, 'filename': record[1]['filename'], 'metadata': record[1]['metadata']}
|
|
|
|
if self.index.ntotal == 0:
|
|
return
|
|
self.log(f'db: search k={k} d={d}')
|
|
if embed is not None:
|
|
distances, indexes = self.index.search(embed, k)
|
|
records = self.df.iloc[indexes[0]]
|
|
for record, distance in zip(records.iterrows(), distances[0]):
|
|
if d <= 0 or distance <= d:
|
|
yield dct(record, distance=distance, mode='embed')
|
|
if filename is not None:
|
|
records = self.df[self.df['filename'].str.contains(filename, na=False, case=False)]
|
|
for record in records.iterrows():
|
|
yield dct(record, mode='filename')
|
|
if metadata is not None:
|
|
records = self.df[self.df['metadata'].str.contains(filename, na=False, case=False)]
|
|
for record in records.iterrows():
|
|
yield dct(record, mode='metadata')
|
|
|
|
def decode(self, s: bytes): # decode byte-encoded exif metadata
|
|
remove_prefix = lambda text, prefix: text[len(prefix):] if text.startswith(prefix) else text # pylint: disable=unnecessary-lambda-assignment
|
|
for encoding in ['utf-8', 'utf-16', 'ascii', 'latin_1', 'cp1252', 'cp437']: # try different encodings
|
|
try:
|
|
s = remove_prefix(s, b'UNICODE')
|
|
s = remove_prefix(s, b'ASCII')
|
|
s = remove_prefix(s, b'\x00')
|
|
val = s.decode(encoding, errors="strict")
|
|
val = re.sub(r'[\x00-\x09\n\s\s+]', '', val).strip() # remove remaining special characters, new line breaks, and double empty spaces
|
|
if len(val) == 0: # remove empty strings
|
|
val = None
|
|
return val
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
def metadata(self, image: PIL.Image.Image): # get exif metadata from image
|
|
exif = image._getexif() # pylint: disable=protected-access
|
|
if exif is None:
|
|
return ''
|
|
for k, v in exif.items():
|
|
if k == 37510: # comment
|
|
return self.decode(v)
|
|
return ''
|
|
|
|
def image(self, filename: str, image=None): # add file/image to db
|
|
try:
|
|
if image is None:
|
|
image = PIL.Image.open(filename)
|
|
image.load()
|
|
embed = self.embedding(image.convert('RGB'))
|
|
metadata = self.metadata(image)
|
|
image.close()
|
|
self.add(embed, filename=filename, metadata=metadata)
|
|
except Exception as _e:
|
|
# self.err(f'db: {str(_e)}')
|
|
pass
|
|
|
|
def folder(self, folder: str): # add all files from folder to db
|
|
files = []
|
|
for root, _subdir, _files in os.walk(folder):
|
|
for f in _files:
|
|
files.append(os.path.join(root, f))
|
|
if self.pbar:
|
|
for f in tqdm(files):
|
|
self.image(filename=f)
|
|
else:
|
|
for f in files:
|
|
self.image(filename=f)
|
|
|
|
def offload(self): # offload model to cpu
|
|
if self.model is not None:
|
|
self.model = self.model.to('cpu')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import time
|
|
import argparse
|
|
logging.basicConfig(level=logging.INFO)
|
|
parser = argparse.ArgumentParser(description = 'image-search')
|
|
group = parser.add_mutually_exclusive_group(required=True)
|
|
group.add_argument('--search', action='store_true', help='run search')
|
|
group.add_argument('--index', action='store_true', help='run indexing')
|
|
parser.add_argument('--db', default='db', help='database name')
|
|
parser.add_argument('--model', default='openai/clip-vit-large-patch14', help='huggingface model')
|
|
parser.add_argument('--cache', default='/mnt/models/huggingface', help='cache folder')
|
|
parser.add_argument('input', nargs='*', default=os.getcwd())
|
|
args = parser.parse_args()
|
|
|
|
db = ImageDB(
|
|
name=args.db,
|
|
model=args.model, # 'facebook/dinov2-small'
|
|
cache_dir=args.cache,
|
|
dtype=torch.bfloat16,
|
|
device=torch.device('cuda'),
|
|
debug=True,
|
|
pbar=True,
|
|
)
|
|
db.init()
|
|
db.load()
|
|
print(db)
|
|
|
|
if args.index:
|
|
t0 = time.time()
|
|
if len(args.input) > 0:
|
|
for fn in args.input:
|
|
if os.path.isfile(fn):
|
|
db.image(filename=fn)
|
|
elif os.path.isdir(fn):
|
|
db.folder(folder=fn)
|
|
t1 = time.time()
|
|
print('index', t1-t0)
|
|
db.save()
|
|
db.offload()
|
|
|
|
if args.search:
|
|
for ref in args.input:
|
|
emb = db.embedding(ref)
|
|
res = db.search(filename=ref, metadata=ref, embed=emb, k=10, d=0)
|
|
for r in res:
|
|
print(ref, r)
|