mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
328 lines
12 KiB
Python
328 lines
12 KiB
Python
# pylint: disable=global-statement
|
|
import os
|
|
import io
|
|
import math
|
|
import base64
|
|
import numpy as np
|
|
import mediapipe as mp
|
|
from PIL import Image, ImageOps
|
|
from pi_heif import register_heif_opener
|
|
from skimage.metrics import structural_similarity as ssim
|
|
from scipy.stats import beta
|
|
|
|
import util
|
|
import sdapi
|
|
import options
|
|
|
|
face_model = None
|
|
body_model = None
|
|
segmentation_model = None
|
|
all_images = []
|
|
all_images_by_type = {}
|
|
|
|
|
|
class Result():
|
|
def __init__(self, typ: str, fn: str, tag: str = None, requested: list = []):
|
|
self.type = typ
|
|
self.input = fn
|
|
self.output = ''
|
|
self.basename = ''
|
|
self.message = ''
|
|
self.image = None
|
|
self.caption = ''
|
|
self.tag = tag
|
|
self.tags = []
|
|
self.ops = []
|
|
self.steps = requested
|
|
|
|
|
|
def detect_blur(image: Image):
|
|
# based on <https://github.com/karthik9319/Blur-Detection/>
|
|
bw = ImageOps.grayscale(image)
|
|
cx, cy = image.size[0] // 2, image.size[1] // 2
|
|
fft = np.fft.fft2(bw)
|
|
fftShift = np.fft.fftshift(fft)
|
|
fftShift[cy - options.process.blur_samplesize: cy + options.process.blur_samplesize, cx - options.process.blur_samplesize: cx + options.process.blur_samplesize] = 0
|
|
fftShift = np.fft.ifftshift(fftShift)
|
|
recon = np.fft.ifft2(fftShift)
|
|
magnitude = np.log(np.abs(recon))
|
|
mean = round(np.mean(magnitude), 2)
|
|
return mean
|
|
|
|
|
|
def detect_dynamicrange(image: Image):
|
|
# based on <https://towardsdatascience.com/measuring-enhancing-image-quality-attributes-234b0f250e10>
|
|
data = np.asarray(image)
|
|
image = np.float32(data)
|
|
RGB = [0.299, 0.587, 0.114]
|
|
height, width = image.shape[:2] # pylint: disable=unsubscriptable-object
|
|
brightness_image = np.sqrt(image[..., 0] ** 2 * RGB[0] + image[..., 1] ** 2 * RGB[1] + image[..., 2] ** 2 * RGB[2]) # pylint: disable=unsubscriptable-object
|
|
hist, _ = np.histogram(brightness_image, bins=256, range=(0, 255))
|
|
img_brightness_pmf = hist / (height * width)
|
|
dist = beta(2, 2)
|
|
ys = dist.pdf(np.linspace(0, 1, 256))
|
|
ref_pmf = ys / np.sum(ys)
|
|
dot_product = np.dot(ref_pmf, img_brightness_pmf)
|
|
squared_dist_a = np.sum(ref_pmf ** 2)
|
|
squared_dist_b = np.sum(img_brightness_pmf ** 2)
|
|
res = dot_product / math.sqrt(squared_dist_a * squared_dist_b)
|
|
return round(res, 2)
|
|
|
|
|
|
def detect_simmilar(image: Image):
|
|
img = image.resize((options.process.similarity_size, options.process.similarity_size))
|
|
img = ImageOps.grayscale(img)
|
|
data = np.array(img)
|
|
similarity = 0
|
|
for i in all_images:
|
|
val = ssim(data, i, data_range=255, channel_axis=None, gradient=False, full=False)
|
|
if val > similarity:
|
|
similarity = val
|
|
all_images.append(data)
|
|
return similarity
|
|
|
|
|
|
def segmentation(res: Result):
|
|
global segmentation_model
|
|
if segmentation_model is None:
|
|
segmentation_model = mp.solutions.selfie_segmentation.SelfieSegmentation(model_selection=options.process.segmentation_model)
|
|
data = np.array(res.image)
|
|
results = segmentation_model.process(data)
|
|
condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > 0.1
|
|
background = np.zeros(data.shape, dtype=np.uint8)
|
|
background[:] = options.process.segmentation_background
|
|
data = np.where(condition, data, background) # consider using a joint bilateral filter instead of pure combine
|
|
segmented = Image.fromarray(data)
|
|
res.image = segmented
|
|
res.ops.append('segmentation')
|
|
return res
|
|
|
|
|
|
def unload():
|
|
global face_model
|
|
if face_model is not None:
|
|
face_model = None
|
|
global body_model
|
|
if body_model is not None:
|
|
body_model = None
|
|
global segmentation_model
|
|
if segmentation_model is not None:
|
|
segmentation_model = None
|
|
|
|
|
|
def encode(img):
|
|
with io.BytesIO() as stream:
|
|
img.save(stream, 'JPEG')
|
|
values = stream.getvalue()
|
|
encoded = base64.b64encode(values).decode()
|
|
return encoded
|
|
|
|
|
|
def reset():
|
|
unload()
|
|
global all_images_by_type
|
|
all_images_by_type = {}
|
|
global all_images
|
|
all_images = []
|
|
|
|
|
|
def upscale_restore_image(res: Result, upscale: bool = False, restore: bool = False):
|
|
kwargs = util.Map({
|
|
'image': encode(res.image),
|
|
'codeformer_visibility': 0.0,
|
|
'codeformer_weight': 0.0,
|
|
})
|
|
if res.image.width >= options.process.target_size and res.image.height >= options.process.target_size:
|
|
upscale = False
|
|
if upscale:
|
|
kwargs.upscaler_1 = 'SwinIR_4x'
|
|
kwargs.upscaling_resize = 2
|
|
res.ops.append('upscale')
|
|
if restore:
|
|
kwargs.codeformer_visibility = 1.0
|
|
kwargs.codeformer_weight = 0.2
|
|
res.ops.append('restore')
|
|
if upscale or restore:
|
|
result = sdapi.postsync('/sdapi/v1/extra-single-image', kwargs)
|
|
if 'image' not in result:
|
|
res.message = 'failed to upscale/restore image'
|
|
else:
|
|
res.image = Image.open(io.BytesIO(base64.b64decode(result['image'])))
|
|
return res
|
|
|
|
|
|
def interrogate_image(res: Result, tag: str = None):
|
|
caption = ''
|
|
tags = []
|
|
for model in options.process.interrogate_model:
|
|
json = util.Map({ 'image': encode(res.image), 'model': model })
|
|
result = sdapi.postsync('/sdapi/v1/interrogate', json)
|
|
if model == 'clip':
|
|
caption = result.caption if 'caption' in result else ''
|
|
caption = caption.split(',')[0].replace(' a ', ' ').strip()
|
|
if tag is not None:
|
|
caption = res.tag + ', ' + caption
|
|
if model == 'deepdanbooru':
|
|
tag = result.caption if 'caption' in result else ''
|
|
tags = tag.split(',')
|
|
tags = [t.replace('(', '').replace(')', '').replace('\\', '').split(':')[0].strip() for t in tags]
|
|
if tag is not None:
|
|
for t in res.tag.split(',')[::-1]:
|
|
tags.insert(0, t.strip())
|
|
pos = 0 if len(tags) == 0 else 1
|
|
tags.insert(pos, caption.split(' ')[1])
|
|
tags = [t for t in tags if len(t) > 2]
|
|
if len(tags) > options.process.tag_limit:
|
|
tags = tags[:options.process.tag_limit]
|
|
res.caption = caption
|
|
res.tags = tags
|
|
res.ops.append('interrogate')
|
|
return res
|
|
|
|
|
|
def resize_image(res: Result):
|
|
resized = res.image
|
|
resized.thumbnail((options.process.target_size, options.process.target_size), Image.Resampling.HAMMING)
|
|
res.image = resized
|
|
res.ops.append('resize')
|
|
return res
|
|
|
|
|
|
def square_image(res: Result):
|
|
size = max(res.image.width, res.image.height)
|
|
squared = Image.new('RGB', (size, size))
|
|
squared.paste(res.image, ((size - res.image.width) // 2, (size - res.image.height) // 2))
|
|
res.image = squared
|
|
res.ops.append('square')
|
|
return res
|
|
|
|
|
|
def process_face(res: Result):
|
|
res.ops.append('face')
|
|
global face_model
|
|
if face_model is None:
|
|
face_model = mp.solutions.face_detection.FaceDetection(min_detection_confidence=options.process.face_score, model_selection=options.process.face_model)
|
|
results = face_model.process(np.array(res.image))
|
|
if results.detections is None:
|
|
res.message = 'no face detected'
|
|
res.image = None
|
|
return res
|
|
box = results.detections[0].location_data.relative_bounding_box
|
|
if box.xmin < 0 or box.ymin < 0 or (box.width - box.xmin) > 1 or (box.height - box.ymin) > 1:
|
|
res.message = 'face out of frame'
|
|
res.image = None
|
|
return res
|
|
x = max(0, (box.xmin - options.process.face_pad / 2) * res.image.width)
|
|
y = max(0, (box.ymin - options.process.face_pad / 2)* res.image.height)
|
|
w = min(res.image.width, (box.width + options.process.face_pad) * res.image.width)
|
|
h = min(res.image.height, (box.height + options.process.face_pad) * res.image.height)
|
|
x = max(0, x)
|
|
res.image = res.image.crop((x, y, x + w, y + h))
|
|
return res
|
|
|
|
|
|
def process_body(res: Result):
|
|
res.ops.append('body')
|
|
global body_model
|
|
if body_model is None:
|
|
body_model = mp.solutions.pose.Pose(static_image_mode=True, min_detection_confidence=options.process.body_score, model_complexity=options.process.body_model)
|
|
results = body_model.process(np.array(res.image))
|
|
if results.pose_landmarks is None:
|
|
res.message = 'no body detected'
|
|
res.image = None
|
|
return res
|
|
x0 = [res.image.width * (i.x - options.process.body_pad / 2) for i in results.pose_landmarks.landmark if i.visibility > options.process.body_visibility]
|
|
y0 = [res.image.height * (i.y - options.process.body_pad / 2) for i in results.pose_landmarks.landmark if i.visibility > options.process.body_visibility]
|
|
x1 = [res.image.width * (i.x + options.process.body_pad / 2) for i in results.pose_landmarks.landmark if i.visibility > options.process.body_visibility]
|
|
y1 = [res.image.height * (i.y + options.process.body_pad / 2) for i in results.pose_landmarks.landmark if i.visibility > options.process.body_visibility]
|
|
if len(x0) < options.process.body_parts:
|
|
res.message = f'insufficient body parts detected: {len(x0)}'
|
|
res.image = None
|
|
return res
|
|
res.image = res.image.crop((max(0, min(x0)), max(0, min(y0)), min(res.image.width, max(x1)), min(res.image.height, max(y1))))
|
|
return res
|
|
|
|
|
|
def process_original(res: Result):
|
|
res.ops.append('original')
|
|
return res
|
|
|
|
|
|
def save_image(res: Result, folder: str):
|
|
if res.image is None or folder is None:
|
|
return res
|
|
all_images_by_type[res.type] = all_images_by_type.get(res.type, 0) + 1
|
|
res.basename = os.path.basename(res.input).split('.')[0]
|
|
res.basename = str(all_images_by_type[res.type]).rjust(3, '0') + '-' + res.type + '-' + res.basename
|
|
res.basename = os.path.join(folder, res.basename)
|
|
res.output = res.basename + options.process.format
|
|
res.image.save(res.output)
|
|
res.image.close()
|
|
res.ops.append('save')
|
|
return res
|
|
|
|
|
|
def file(filename: str, folder: str, tag = None, requested = []):
|
|
# initialize result dict
|
|
res = Result(fn = filename, typ='unknown', tag=tag, requested = requested)
|
|
# open image
|
|
try:
|
|
register_heif_opener()
|
|
res.image = Image.open(filename)
|
|
if res.image.mode == 'RGBA':
|
|
res.image = res.image.convert('RGB')
|
|
res.image = ImageOps.exif_transpose(res.image) # rotate image according to EXIF orientation
|
|
except Exception as e:
|
|
res.message = f'error opening: {e}'
|
|
return res
|
|
# primary steps
|
|
if 'face' in requested:
|
|
res.type = 'face'
|
|
res = process_face(res)
|
|
elif 'body' in requested:
|
|
res.type = 'body'
|
|
res = process_body(res)
|
|
elif 'original' in requested:
|
|
res.type = 'original'
|
|
res = process_original(res)
|
|
# validation steps
|
|
if res.image is None:
|
|
return res
|
|
if 'blur' in requested:
|
|
res.ops.append('blur')
|
|
val = detect_blur(res.image)
|
|
if val > options.process.blur_score:
|
|
res.message = f'blur check failed: {val}'
|
|
res.image = None
|
|
if 'range' in requested:
|
|
res.ops.append('range')
|
|
val = detect_dynamicrange(res.image)
|
|
if val < options.process.range_score:
|
|
res.message = f'dynamic range check failed: {val}'
|
|
res.image = None
|
|
if 'similarity' in requested:
|
|
res.ops.append('similarity')
|
|
val = detect_simmilar(res.image)
|
|
if val > options.process.similarity_score:
|
|
res.message = f'dynamic range check failed: {val}'
|
|
res.image = None
|
|
if res.image is None:
|
|
return res
|
|
# post processing steps
|
|
res = upscale_restore_image(res, 'upscale' in requested, 'restore' in requested)
|
|
if res.image.width < options.process.target_size or res.image.height < options.process.target_size:
|
|
res.message = f'low resolution: [{res.image.width}, {res.image.height}]'
|
|
res.image = None
|
|
return res
|
|
if 'interrogate' in requested:
|
|
res = interrogate_image(res, tag)
|
|
if 'resize' in requested:
|
|
res = resize_image(res)
|
|
if 'square' in requested:
|
|
res = square_image(res)
|
|
if 'segment' in requested:
|
|
res = segmentation(res)
|
|
# finally save image
|
|
res = save_image(res, folder)
|
|
return res
|