mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
374 lines
19 KiB
Python
Executable File
374 lines
19 KiB
Python
Executable File
#!/usr/bin/env python
|
|
# pylint: disable=no-member
|
|
"""generate batches of images from prompts and upscale them
|
|
|
|
params: run with `--help`
|
|
|
|
default workflow runs infinite loop and prints stats when interrupted:
|
|
1. choose random scheduler lookup all available and pick one
|
|
2. generate dynamic prompt based on styles, embeddings, places, artists, suffixes
|
|
3. beautify prompt
|
|
4. generate 3x3 images
|
|
5. create image grid
|
|
6. upscale images with face restoration
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import base64
|
|
import io
|
|
import json
|
|
import logging
|
|
import math
|
|
import os
|
|
import pathlib
|
|
import secrets
|
|
import time
|
|
import sys
|
|
import importlib
|
|
|
|
from random import randrange
|
|
from PIL import Image
|
|
from PIL.ExifTags import TAGS
|
|
from PIL.TiffImagePlugin import ImageFileDirectory_v2
|
|
|
|
from sdapi import close, get, interrupt, post, session
|
|
from util import Map, log, safestring
|
|
|
|
|
|
sd = {}
|
|
random = {}
|
|
stats = Map({ 'images': 0, 'wall': 0, 'generate': 0, 'upscale': 0 })
|
|
avg = {}
|
|
|
|
|
|
def grid(data):
|
|
if len(data.image) > 1:
|
|
w, h = data.image[0].size
|
|
rows = round(math.sqrt(len(data.image)))
|
|
cols = math.ceil(len(data.image) / rows)
|
|
image = Image.new('RGB', size = (cols * w, rows * h), color = 'black')
|
|
for i, img in enumerate(data.image):
|
|
image.paste(img, box=(i % cols * w, i // cols * h))
|
|
short = data.info.prompt[:min(len(data.info.prompt), 96)] # limit prompt part of filename to 96 chars
|
|
name = '{seed:0>9} {short}'.format(short = short, seed = data.info.all_seeds[0]) # pylint: disable=consider-using-f-string
|
|
name = safestring(name) + '.jpg'
|
|
f = os.path.join(sd.paths.root, sd.paths.grid, name)
|
|
log.info({ 'grid': { 'name': f, 'size': image.size, 'images': len(data.image) } })
|
|
image.save(f, 'JPEG', exif = exif(data.info, None, 'grid'), optimize = True, quality = 70)
|
|
return image
|
|
return data.image
|
|
|
|
|
|
def exif(info, i = None, op = 'generate'):
|
|
seed = [info.all_seeds[i]] if len(info.all_seeds) > 0 and i is not None else info.all_seeds # always returns list
|
|
seed = ', '.join([str(x) for x in seed]) # int list to str list to single str
|
|
template = '{prompt} | negative {negative_prompt} | seed {s} | steps {steps} | cfgscale {cfg_scale} | sampler {sampler_name} | batch {batch_size} | timestamp {job_timestamp} | model {model} | vae {vae}'.format(s = seed, model = sd.options['sd_model_checkpoint'], vae = sd.options['sd_vae'], **info) # pylint: disable=consider-using-f-string
|
|
if op == 'upscale':
|
|
template += ' | faces gfpgan' if sd.upscale.gfpgan_visibility > 0 else ''
|
|
template += ' | faces codeformer' if sd.upscale.codeformer_visibility > 0 else ''
|
|
template += ' | upscale {resize}x {upscaler}'.format(resize = sd.upscale.upscaling_resize, upscaler = sd.upscale.upscaler_1) if sd.upscale.upscaler_1 != 'None' else '' # pylint: disable=consider-using-f-string
|
|
template += ' | upscale {resize}x {upscaler}'.format(resize = sd.upscale.upscaling_resize, upscaler = sd.upscale.upscaler_2) if sd.upscale.upscaler_2 != 'None' else '' # pylint: disable=consider-using-f-string
|
|
if op == 'grid':
|
|
template += ' | grid {num}'.format(num = sd.generate.batch_size * sd.generate.n_iter) # pylint: disable=consider-using-f-string
|
|
ifd = ImageFileDirectory_v2()
|
|
exif_stream = io.BytesIO()
|
|
_TAGS = {v: k for k, v in TAGS.items()} # enumerate possible exif tags
|
|
ifd[_TAGS['ImageDescription']] = template
|
|
ifd.save(exif_stream)
|
|
val = b'Exif\x00\x00' + exif_stream.getvalue()
|
|
return val
|
|
|
|
|
|
def randomize(lst):
|
|
if len(lst) > 0:
|
|
return secrets.choice(lst)
|
|
else:
|
|
return ''
|
|
|
|
|
|
def prompt(params): # generate dynamic prompt or use one if provided
|
|
sd.generate.prompt = params.prompt if params.prompt != 'dynamic' else randomize(random.prompts)
|
|
sd.generate.negative_prompt = params.negative if params.negative != 'dynamic' else randomize(random.negative)
|
|
embedding = params.embedding if params.embedding != 'random' else randomize(random.embeddings)
|
|
sd.generate.prompt = sd.generate.prompt.replace('<embedding>', embedding)
|
|
artist = params.artist if params.artist != 'random' else randomize(random.artists)
|
|
sd.generate.prompt = sd.generate.prompt.replace('<artist>', artist)
|
|
style = params.style if params.style != 'random' else randomize(random.styles)
|
|
sd.generate.prompt = sd.generate.prompt.replace('<style>', style)
|
|
suffix = params.suffix if params.suffix != 'random' else randomize(random.suffixes)
|
|
sd.generate.prompt = sd.generate.prompt.replace('<suffix>', suffix)
|
|
place = params.suffix if params.suffix != 'random' else randomize(random.places)
|
|
sd.generate.prompt = sd.generate.prompt.replace('<place>', place)
|
|
if params.prompts or params.debug:
|
|
log.info({ 'random initializers': random })
|
|
if params.prompt == 'dynamic':
|
|
log.info({ 'dynamic prompt': sd.generate.prompt })
|
|
return sd.generate.prompt
|
|
|
|
|
|
def sampler(params, options): # find sampler
|
|
if params.sampler == 'random':
|
|
sd.generate.sampler_name = randomize(options.samplers)
|
|
log.info({ 'random sampler': sd.generate.sampler_name })
|
|
else:
|
|
found = [i for i in options.samplers if i.startswith(params.sampler)]
|
|
if len(found) == 0:
|
|
log.error({ 'sampler error': sd.generate.sampler_name, 'available': options.samplers})
|
|
exit()
|
|
sd.generate.sampler_name = found[0]
|
|
return sd.generate.sampler_name
|
|
|
|
|
|
async def generate(prompt = None, options = None, quiet = False): # pylint: disable=redefined-outer-name
|
|
global sd # pylint: disable=global-statement
|
|
if options:
|
|
sd = Map(options)
|
|
if prompt is not None:
|
|
sd.generate.prompt = prompt
|
|
if not quiet:
|
|
log.info({ 'generate': sd.generate })
|
|
if sd.get('options', None) is None:
|
|
sd['options'] = await get('/sdapi/v1/options')
|
|
names = []
|
|
b64s = []
|
|
images = []
|
|
info = Map({})
|
|
data = await post('/sdapi/v1/txt2img', sd.generate)
|
|
if 'error' in data:
|
|
log.error({ 'generate': data['error'], 'reason': data['reason'] })
|
|
return Map({})
|
|
info = Map(json.loads(data['info']))
|
|
log.debug({ 'info': info })
|
|
images = data['images']
|
|
short = info.prompt[:min(len(info.prompt), 96)] # limit prompt part of filename to 64 chars
|
|
for i in range(len(images)):
|
|
b64s.append(images[i])
|
|
images[i] = Image.open(io.BytesIO(base64.b64decode(images[i].split(',',1)[0])))
|
|
name = '{seed:0>9} {short}'.format(short = short, seed = info.all_seeds[i]) # pylint: disable=consider-using-f-string
|
|
name = safestring(name) + '.jpg'
|
|
f = os.path.join(sd.paths.root, sd.paths.generate, name)
|
|
names.append(f)
|
|
if not quiet:
|
|
log.info({ 'image': { 'name': f, 'size': images[i].size } })
|
|
images[i].save(f, 'JPEG', exif = exif(info, i), optimize = True, quality = 70)
|
|
return Map({ 'name': names, 'image': images, 'b64': b64s, 'info': info })
|
|
|
|
|
|
async def upscale(data):
|
|
data.upscaled = []
|
|
if sd.upscale.upscaling_resize <=1:
|
|
return data
|
|
sd.upscale.image = ''
|
|
log.info({ 'upscale': sd.upscale })
|
|
for i in range(len(data.image)):
|
|
f = data.name[i].replace(sd.paths.generate, sd.paths.upscale)
|
|
sd.upscale.image = data.b64[i]
|
|
res = await post('/sdapi/v1/extra-single-image', sd.upscale)
|
|
image = Image.open(io.BytesIO(base64.b64decode(res['image'].split(',',1)[0])))
|
|
data.upscaled.append(image)
|
|
log.info({ 'image': { 'name': f, 'size': image.size } })
|
|
image.save(f, 'JPEG', exif = exif(data.info, i, 'upscale'), optimize = True, quality = 70)
|
|
return data
|
|
|
|
|
|
async def init():
|
|
'''
|
|
import torch
|
|
log.info({ 'torch': torch.__version__, 'available': torch.cuda.is_available() })
|
|
current_device = torch.cuda.current_device()
|
|
mem_free, mem_total = torch.cuda.mem_get_info()
|
|
log.info({ 'cuda': torch.version.cuda, 'available': torch.cuda.is_available(), 'arch': torch.cuda.get_arch_list(), 'device': torch.cuda.get_device_name(current_device), 'memory': { 'free': round(mem_free / 1024 / 1024), 'total': (mem_total / 1024 / 1024) } })
|
|
'''
|
|
options = Map({})
|
|
options.flags = await get('/sdapi/v1/cmd-flags')
|
|
log.debug({ 'flags': options.flags })
|
|
data = await get('/sdapi/v1/sd-models')
|
|
options.models = [obj['title'] for obj in data]
|
|
log.debug({ 'registered models': options.models })
|
|
found = sd.options.sd_model_checkpoint if sd.options.sd_model_checkpoint in options.models else None
|
|
if found is None:
|
|
found = [i for i in options.models if i.startswith(sd.options.sd_model_checkpoint)]
|
|
if len(found) == 0:
|
|
log.error({ 'model error': sd.generate.sd_model_checkpoint, 'available': options.models})
|
|
exit()
|
|
sd.options.sd_model_checkpoint = found[0]
|
|
data = await get('/sdapi/v1/samplers')
|
|
options.samplers = [obj['name'] for obj in data]
|
|
log.debug({ 'registered samplers': options.samplers })
|
|
data = await get('/sdapi/v1/upscalers')
|
|
options.upscalers = [obj['name'] for obj in data]
|
|
log.debug({ 'registered upscalers': options.upscalers })
|
|
data = await get('/sdapi/v1/face-restorers')
|
|
options.restorers = [obj['name'] for obj in data]
|
|
log.debug({ 'registered face restorers': options.restorers })
|
|
await interrupt()
|
|
await post('/sdapi/v1/options', sd.options)
|
|
options.options = await get('/sdapi/v1/options')
|
|
log.info({ 'target models': { 'diffuser': options.options['sd_model_checkpoint'], 'vae': options.options['sd_vae'] } })
|
|
log.info({ 'paths': sd.paths })
|
|
options.queue = await get('/queue/status')
|
|
log.info({ 'queue': options.queue })
|
|
pathlib.Path(sd.paths.root).mkdir(parents = True, exist_ok = True)
|
|
pathlib.Path(os.path.join(sd.paths.root, sd.paths.generate)).mkdir(parents = True, exist_ok = True)
|
|
pathlib.Path(os.path.join(sd.paths.root, sd.paths.upscale)).mkdir(parents = True, exist_ok = True)
|
|
pathlib.Path(os.path.join(sd.paths.root, sd.paths.grid)).mkdir(parents = True, exist_ok = True)
|
|
return options
|
|
|
|
|
|
def args(): # parse cmd arguments
|
|
global sd # pylint: disable=global-statement
|
|
global random # pylint: disable=global-statement
|
|
parser = argparse.ArgumentParser(description = 'sd pipeline')
|
|
parser.add_argument('--config', type = str, default = 'generate.json', required = False, help = 'configuration file')
|
|
parser.add_argument('--random', type = str, default = 'random.json', required = False, help = 'prompt file with randomized sections')
|
|
parser.add_argument('--max', type = int, default = 1, required = False, help = 'maximum number of generated images')
|
|
parser.add_argument('--prompt', type = str, default = 'dynamic', required = False, help = 'prompt')
|
|
parser.add_argument('--negative', type = str, default = 'dynamic', required = False, help = 'negative prompt')
|
|
parser.add_argument('--artist', type = str, default = 'random', required = False, help = 'artist style, used to guide dynamic prompt when prompt is not provided')
|
|
parser.add_argument('--embedding', type = str, default = 'random', required = False, help = 'use embedding, used to guide dynamic prompt when prompt is not provided')
|
|
parser.add_argument('--style', type = str, default = 'random', required = False, help = 'image style, used to guide dynamic prompt when prompt is not provided')
|
|
parser.add_argument('--suffix', type = str, default = 'random', required = False, help = 'style suffix, used to guide dynamic prompt when prompt is not provided')
|
|
parser.add_argument('--place', type = str, default = 'random', required = False, help = 'place locator, used to guide dynamic prompt when prompt is not provided')
|
|
parser.add_argument('--detailer', default = False, action='store_true', help = 'run detailer')
|
|
parser.add_argument('--steps', type = int, default = 0, required = False, help = 'number of steps')
|
|
parser.add_argument('--batch', type = int, default = 0, required = False, help = 'batch size, limited by gpu vram')
|
|
parser.add_argument('--n', type = int, default = 0, required = False, help = 'number of iterations')
|
|
parser.add_argument('--cfg', type = int, default = 0, required = False, help = 'classifier free guidance scale')
|
|
parser.add_argument('--sampler', type = str, default = 'random', required = False, help = 'sampler')
|
|
parser.add_argument('--seed', type = int, default = 0, required = False, help = 'seed, default is random')
|
|
parser.add_argument('--upscale', type = int, default = 0, required = False, help = 'upscale factor, disabled if 0')
|
|
parser.add_argument('--model', type = str, default = '', required = False, help = 'diffusion model')
|
|
parser.add_argument('--vae', type = str, default = '', required = False, help = 'vae model')
|
|
parser.add_argument('--path', type = str, default = '', required = False, help = 'output path')
|
|
parser.add_argument('--width', type = int, default = 0, required = False, help = 'width')
|
|
parser.add_argument('--height', type = int, default = 0, required = False, help = 'height')
|
|
parser.add_argument('--beautify', default = False, action='store_true', help = 'beautify prompt')
|
|
parser.add_argument('--prompts', default = False, action='store_true', help = 'print dynamic prompt templates')
|
|
parser.add_argument('--debug', default = False, action='store_true', help = 'print extra debug information')
|
|
params = parser.parse_args()
|
|
if params.debug:
|
|
log.setLevel(logging.DEBUG)
|
|
log.debug({ 'debug': True })
|
|
log.debug({ 'args': params.__dict__ })
|
|
home = pathlib.Path(sys.argv[0]).parent
|
|
if os.path.isfile(params.config):
|
|
try:
|
|
with open(params.config, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
sd = Map(data)
|
|
log.debug({ 'config': sd })
|
|
except Exception as e:
|
|
log.error({ 'config error': params.config, 'exception': e })
|
|
exit()
|
|
elif os.path.isfile(os.path.join(home, params.config)):
|
|
try:
|
|
with open(os.path.join(home, params.config), 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
sd = Map(data)
|
|
log.debug({ 'config': sd })
|
|
except Exception as e:
|
|
log.error({ 'config error': params.config, 'exception': e })
|
|
exit()
|
|
else:
|
|
log.error({ 'config file not found': params.config})
|
|
exit()
|
|
if params.prompt == 'dynamic':
|
|
log.info({ 'prompt template': params.random })
|
|
if os.path.isfile(params.random):
|
|
try:
|
|
with open(params.random, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
random = Map(data)
|
|
log.debug({ 'random template': sd })
|
|
except Exception:
|
|
log.error({ 'random template error': params.random})
|
|
exit()
|
|
elif os.path.isfile(os.path.join(home, params.random)):
|
|
try:
|
|
with open(os.path.join(home, params.random), 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
random = Map(data)
|
|
log.debug({ 'random template': sd })
|
|
except Exception:
|
|
log.error({ 'random template error': params.random})
|
|
exit()
|
|
else:
|
|
log.error({ 'random template file not found': params.random})
|
|
exit()
|
|
_dynamic = prompt(params)
|
|
|
|
sd.paths.root = params.path if params.path != '' else sd.paths.root
|
|
sd.generate.detailer = params.detailer if params.detailer is not None else sd.generate.detailer
|
|
sd.generate.seed = params.seed if params.seed > 0 else sd.generate.seed
|
|
sd.generate.sampler_name = params.sampler if params.sampler != 'random' else sd.generate.sampler_name
|
|
sd.generate.batch_size = params.batch if params.batch > 0 else sd.generate.batch_size
|
|
sd.generate.cfg_scale = params.cfg if params.cfg > 0 else sd.generate.cfg_scale
|
|
sd.generate.n_iter = params.n if params.n > 0 else sd.generate.n_iter
|
|
sd.generate.width = params.width if params.width > 0 else sd.generate.width
|
|
sd.generate.height = params.height if params.height > 0 else sd.generate.height
|
|
sd.generate.steps = params.steps if params.steps > 0 else sd.generate.steps
|
|
sd.upscale.upscaling_resize = params.upscale if params.upscale > 0 else sd.upscale.upscaling_resize
|
|
sd.upscale.codeformer_visibility = 1 if params.detailer else sd.upscale.codeformer_visibility
|
|
sd.options.sd_vae = params.vae if params.vae != '' else sd.options.sd_vae
|
|
sd.options.sd_model_checkpoint = params.model if params.model != '' else sd.options.sd_model_checkpoint
|
|
sd.upscale.upscaler_1 = 'SwinIR_4x' if params.upscale > 1 else sd.upscale.upscaler_1
|
|
if sd.generate.cfg_scale == 0:
|
|
sd.generate.cfg_scale = randrange(5, 10)
|
|
return params
|
|
|
|
|
|
async def main():
|
|
params = args()
|
|
sess = await session()
|
|
if sess is None:
|
|
await close()
|
|
exit()
|
|
options = await init()
|
|
iteration = 0
|
|
while True:
|
|
iteration += 1
|
|
log.info('')
|
|
log.info({ 'iteration': iteration, 'batch': sd.generate.batch_size, 'n': sd.generate.n_iter, 'total': sd.generate.n_iter * sd.generate.batch_size })
|
|
dynamic = prompt(params)
|
|
if params.beautify:
|
|
try:
|
|
promptist = importlib.import_module('modules.promptist')
|
|
sd.generate.prompt = promptist.beautify(dynamic)
|
|
except Exception as e:
|
|
log.error({ 'beautify': e })
|
|
scheduler = sampler(params, options)
|
|
t0 = time.perf_counter()
|
|
data = await generate() # generate returns list of images
|
|
if 'image' not in data:
|
|
break
|
|
stats.images += len(data.image)
|
|
t1 = time.perf_counter()
|
|
if len(data.image) > 0:
|
|
avg[scheduler] = (t1 - t0) / len(data.image)
|
|
stats.generate += t1 - t0
|
|
_image = grid(data)
|
|
data = await upscale(data)
|
|
t2 = time.perf_counter()
|
|
stats.upscale += t2 - t1
|
|
stats.wall += t2 - t0
|
|
its = sd.generate.steps / ((t1 - t0) / len(data.image)) if len(data.image) > 0 else 0
|
|
avg_time = round((t1 - t0) / len(data.image)) if len(data.image) > 0 else 0
|
|
log.info({ 'time' : { 'wall': round(t1 - t0), 'average': avg_time, 'upscale': round(t2 - t1), 'its': round(its, 2) } })
|
|
log.info({ 'generated': stats.images, 'max': params.max, 'progress': round(100 * stats.images / params.max, 1) })
|
|
if params.max != 0 and stats.images >= params.max:
|
|
break
|
|
|
|
|
|
if __name__ == '__main__':
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
asyncio.run(interrupt())
|
|
asyncio.run(close())
|
|
log.info({ 'interrupt': True })
|
|
finally:
|
|
log.info({ 'sampler performance': avg })
|
|
log.info({ 'stats' : stats })
|
|
asyncio.run(close())
|