mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
263 lines
7.9 KiB
Python
Executable File
263 lines
7.9 KiB
Python
Executable File
#!/usr/bin/env python
|
|
#pylint: disable=redefined-outer-name
|
|
"""
|
|
helper methods that creates HTTP session with managed connection pool
|
|
provides async HTTP get/post methods and several helper methods
|
|
"""
|
|
|
|
import io
|
|
import os
|
|
import sys
|
|
import ssl
|
|
import base64
|
|
import asyncio
|
|
import logging
|
|
import aiohttp
|
|
import requests
|
|
import urllib3
|
|
from PIL import Image
|
|
from util import Map, log
|
|
from rich import print # pylint: disable=redefined-builtin
|
|
|
|
|
|
sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860") # api url root
|
|
sd_username = os.environ.get('SDAPI_USR', None)
|
|
sd_password = os.environ.get('SDAPI_PWD', None)
|
|
|
|
use_session = True
|
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|
ssl.create_default_context = ssl._create_unverified_context # pylint: disable=protected-access
|
|
timeout = aiohttp.ClientTimeout(total = None, sock_connect = 10, sock_read = None) # default value is 5 minutes, we need longer for training
|
|
sess = None
|
|
quiet = False
|
|
BaseThreadPolicy = asyncio.WindowsSelectorEventLoopPolicy if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy") else asyncio.DefaultEventLoopPolicy
|
|
|
|
|
|
class AnyThreadEventLoopPolicy(BaseThreadPolicy):
|
|
def get_event_loop(self) -> asyncio.AbstractEventLoop:
|
|
try:
|
|
return super().get_event_loop()
|
|
except (RuntimeError, AssertionError):
|
|
loop = self.new_event_loop()
|
|
self.set_event_loop(loop)
|
|
return loop
|
|
|
|
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
|
|
|
|
|
def authsync():
|
|
if sd_username is not None and sd_password is not None:
|
|
return requests.auth.HTTPBasicAuth(sd_username, sd_password)
|
|
return None
|
|
|
|
|
|
def auth():
|
|
if sd_username is not None and sd_password is not None:
|
|
return aiohttp.BasicAuth(sd_username, sd_password)
|
|
return None
|
|
|
|
|
|
async def result(req):
|
|
if req.status != 200:
|
|
if not quiet:
|
|
log.error({ 'request error': req.status, 'reason': req.reason, 'url': req.url })
|
|
if not use_session and sess is not None:
|
|
await sess.close()
|
|
return Map({ 'error': req.status, 'reason': req.reason, 'url': req.url })
|
|
else:
|
|
json = await req.json()
|
|
if isinstance(json, list):
|
|
res = json
|
|
elif json is None:
|
|
res = {}
|
|
else:
|
|
res = Map(json)
|
|
log.debug({ 'request': req.status, 'url': req.url, 'reason': req.reason })
|
|
return res
|
|
|
|
|
|
def resultsync(req: requests.Response):
|
|
if req.status_code != 200:
|
|
if not quiet:
|
|
log.error({ 'request error': req.status_code, 'reason': req.reason, 'url': req.url })
|
|
return Map({ 'error': req.status_code, 'reason': req.reason, 'url': req.url })
|
|
else:
|
|
json = req.json()
|
|
if isinstance(json, list):
|
|
res = json
|
|
elif json is None:
|
|
res = {}
|
|
else:
|
|
res = Map(json)
|
|
log.debug({ 'request': req.status_code, 'url': req.url, 'reason': req.reason })
|
|
return res
|
|
|
|
|
|
async def get(endpoint: str, json: dict = None):
|
|
global sess # pylint: disable=global-statement
|
|
sess = sess if sess is not None else await session()
|
|
try:
|
|
async with sess.get(url=endpoint, json=json, verify_ssl=False) as req:
|
|
res = await result(req)
|
|
return res
|
|
except Exception as err:
|
|
log.error({ 'session': err })
|
|
return {}
|
|
|
|
|
|
def getsync(endpoint: str, json: dict = None):
|
|
try:
|
|
req = requests.get(f'{sd_url}{endpoint}', json=json, verify=False, auth=authsync()) # pylint: disable=missing-timeout
|
|
res = resultsync(req)
|
|
return res
|
|
except Exception as err:
|
|
log.error({ 'session': err })
|
|
return {}
|
|
|
|
|
|
async def post(endpoint: str, json: dict = None):
|
|
global sess # pylint: disable=global-statement
|
|
# sess = sess if sess is not None else await session()
|
|
if sess and not sess.closed:
|
|
await sess.close()
|
|
sess = await session()
|
|
try:
|
|
async with sess.post(url=endpoint, json=json, verify_ssl=False) as req:
|
|
res = await result(req)
|
|
return res
|
|
except Exception as err:
|
|
log.error({ 'session': err })
|
|
return {}
|
|
|
|
|
|
def postsync(endpoint: str, json: dict = None):
|
|
req = requests.post(f'{sd_url}{endpoint}', json=json, verify=False, auth=authsync()) # pylint: disable=missing-timeout
|
|
res = resultsync(req)
|
|
return res
|
|
|
|
|
|
async def interrupt():
|
|
res = await get('/sdapi/v1/progress?skip_current_image=true')
|
|
if 'state' in res and res.state.job_count > 0:
|
|
log.debug({ 'interrupt': res.state })
|
|
res = await post('/sdapi/v1/interrupt')
|
|
await asyncio.sleep(1)
|
|
return res
|
|
else:
|
|
log.debug({ 'interrupt': 'idle' })
|
|
return { 'interrupt': 'idle' }
|
|
|
|
|
|
def interruptsync():
|
|
res = getsync('/sdapi/v1/progress?skip_current_image=true')
|
|
if 'state' in res and res.state.job_count > 0:
|
|
log.debug({ 'interrupt': res.state })
|
|
res = postsync('/sdapi/v1/interrupt')
|
|
return res
|
|
else:
|
|
log.debug({ 'interrupt': 'idle' })
|
|
return { 'interrupt': 'idle' }
|
|
|
|
|
|
async def progress():
|
|
res = await get('/sdapi/v1/progress?skip_current_image=false')
|
|
try:
|
|
if res is not None and res.get('current_image', None) is not None:
|
|
res.current_image = Image.open(io.BytesIO(base64.b64decode(res['current_image'])))
|
|
except Exception:
|
|
pass
|
|
log.debug({ 'progress': res })
|
|
return res
|
|
|
|
|
|
def progresssync():
|
|
res = getsync('/sdapi/v1/progress?skip_current_image=true')
|
|
log.debug({ 'progress': res })
|
|
return res
|
|
|
|
|
|
def get_log():
|
|
res = getsync('/sdapi/v1/log')
|
|
for line in res:
|
|
log.debug(line)
|
|
return res
|
|
|
|
|
|
def get_info():
|
|
import time
|
|
t0 = time.time()
|
|
res = getsync('/sdapi/v1/system-info/status?full=true&refresh=true')
|
|
t1 = time.time()
|
|
print({ 'duration': 1000 * round(t1-t0, 3), **res })
|
|
return res
|
|
|
|
|
|
def options():
|
|
opts = getsync('/sdapi/v1/options')
|
|
flags = getsync('/sdapi/v1/cmd-flags')
|
|
return { 'options': opts, 'flags': flags }
|
|
|
|
|
|
def shutdown():
|
|
try:
|
|
postsync('/sdapi/v1/shutdown')
|
|
except Exception as e:
|
|
log.info({ 'shutdown': e })
|
|
|
|
|
|
async def session():
|
|
global sess # pylint: disable=global-statement
|
|
time = aiohttp.ClientTimeout(total = None, sock_connect = 10, sock_read = None) # default value is 5 minutes, we need longer for training
|
|
sess = aiohttp.ClientSession(timeout = time, base_url = sd_url, auth=auth())
|
|
log.debug({ 'sdapi': 'session created', 'endpoint': sd_url })
|
|
"""
|
|
sess = await aiohttp.ClientSession(timeout = timeout).__aenter__()
|
|
try:
|
|
async with sess.get(url = f'{sd_url}/') as req:
|
|
log.debug({ 'sdapi': 'session created', 'endpoint': sd_url })
|
|
except Exception as e:
|
|
log.error({ 'sdapi': e })
|
|
await asyncio.sleep(0)
|
|
await sess.__aexit__(None, None, None)
|
|
sess = None
|
|
return sess
|
|
"""
|
|
return sess
|
|
|
|
|
|
async def close():
|
|
if sess is not None:
|
|
await asyncio.sleep(0)
|
|
await sess.close()
|
|
await sess.__aexit__(None, None, None)
|
|
log.debug({ 'sdapi': 'session closed', 'endpoint': sd_url })
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.argv.pop(0)
|
|
log.setLevel(logging.DEBUG)
|
|
if 'interrupt' in sys.argv:
|
|
asyncio.run(interrupt())
|
|
elif 'progress' in sys.argv:
|
|
asyncio.run(progress())
|
|
elif 'progresssync' in sys.argv:
|
|
progresssync()
|
|
elif 'options' in sys.argv:
|
|
opt = options()
|
|
log.debug({ 'options' })
|
|
import json
|
|
print(json.dumps(opt['options'], indent = 2))
|
|
log.debug({ 'cmd-flags' })
|
|
print(json.dumps(opt['flags'], indent = 2))
|
|
elif 'log' in sys.argv:
|
|
get_log()
|
|
elif 'info' in sys.argv:
|
|
get_info()
|
|
elif 'shutdown' in sys.argv:
|
|
shutdown()
|
|
else:
|
|
res = getsync(sys.argv[0])
|
|
print(res)
|
|
asyncio.run(close(), debug=True)
|
|
asyncio.run(asyncio.sleep(0.5))
|