1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-29 05:02:09 +03:00
Files
sdnext/cli/sdapi.py
Vladimir Mandic d36b16d03f refactor api auth
2023-05-23 14:31:22 -04:00

233 lines
7.2 KiB
Python
Executable File

#!/usr/bin/env python
#pylint: disable=redefined-outer-name
"""
helper methods that creates HTTP session with managed connection pool
provides async HTTP get/post methods and several helper methods
"""
import os
import sys
import ssl
import asyncio
import logging
import aiohttp
import requests
import urllib3
from util import Map, log
sd_url = os.environ.get('SDAPI_URL', "http://127.0.0.1:7860") # automatic1111 api url root
use_session = True
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
ssl.create_default_context = ssl._create_unverified_context # pylint: disable=protected-access
timeout = aiohttp.ClientTimeout(total = None, sock_connect = 10, sock_read = None) # default value is 5 minutes, we need longer for training
sess = None
quiet = False
BaseThreadPolicy = asyncio.WindowsSelectorEventLoopPolicy if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy") else asyncio.DefaultEventLoopPolicy
class AnyThreadEventLoopPolicy(BaseThreadPolicy):
def get_event_loop(self) -> asyncio.AbstractEventLoop:
try:
return super().get_event_loop()
except (RuntimeError, AssertionError):
loop = self.new_event_loop()
self.set_event_loop(loop)
return loop
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
def authsync():
sd_username = os.environ.get('SDAPI_USR', None)
sd_password = os.environ.get('SDAPI_PWD', None)
if sd_username is not None and sd_password is not None:
return requests.auth.HTTPBasicAuth(sd_username, sd_password)
return None
def auth():
sd_username = os.environ.get('SDAPI_USR', None)
sd_password = os.environ.get('SDAPI_PWD', None)
if sd_username is not None and sd_password is not None:
return aiohttp.BasicAuth(sd_username, sd_password)
return None
async def result(req):
if req.status != 200:
if not quiet:
log.error({ 'request error': req.status, 'reason': req.reason, 'url': req.url })
if not use_session and sess is not None:
await sess.close()
return Map({ 'error': req.status, 'reason': req.reason, 'url': req.url })
else:
json = await req.json()
if type(json) == list:
res = json
elif json is None:
res = {}
else:
res = Map(json)
log.debug({ 'request': req.status, 'url': req.url, 'reason': req.reason })
return res
def resultsync(req: requests.Response):
if req.status_code != 200:
if not quiet:
log.error({ 'request error': req.status_code, 'reason': req.reason, 'url': req.url })
return Map({ 'error': req.status_code, 'reason': req.reason, 'url': req.url })
else:
json = req.json()
if type(json) == list:
res = json
elif json is None:
res = {}
else:
res = Map(json)
log.debug({ 'request': req.status_code, 'url': req.url, 'reason': req.reason })
return res
async def get(endpoint: str, json: dict = None):
global sess # pylint: disable=global-statement
sess = sess if sess is not None else await session()
try:
async with sess.get(url=endpoint, json=json, verify_ssl=False) as req:
res = await result(req)
return res
except Exception as err:
log.error({ 'session': err })
return {}
def getsync(endpoint: str, json: dict = None):
try:
req = requests.get(f'{sd_url}{endpoint}', json=json, verify=False, auth=authsync()) # pylint: disable=missing-timeout
res = resultsync(req)
return res
except Exception as err:
log.error({ 'session': err })
return {}
async def post(endpoint: str, json: dict = None):
global sess # pylint: disable=global-statement
# sess = sess if sess is not None else await session()
if sess and not sess.closed:
await sess.close()
sess = await session()
try:
async with sess.post(url=endpoint, json=json, verify_ssl=False) as req:
res = await result(req)
return res
except Exception as err:
log.error({ 'session': err })
return {}
def postsync(endpoint: str, json: dict = None):
req = requests.post(f'{sd_url}{endpoint}', json=json, verify=False, auth=authsync()) # pylint: disable=missing-timeout
res = resultsync(req)
return res
async def interrupt():
res = await get('/sdapi/v1/progress?skip_current_image=true')
if 'state' in res and res.state.job_count > 0:
log.debug({ 'interrupt': res.state })
res = await post('/sdapi/v1/interrupt')
await asyncio.sleep(1)
return res
else:
log.debug({ 'interrupt': 'idle' })
return { 'interrupt': 'idle' }
def interruptsync():
res = getsync('/sdapi/v1/progress?skip_current_image=true')
if 'state' in res and res.state.job_count > 0:
log.debug({ 'interrupt': res.state })
res = postsync('/sdapi/v1/interrupt')
return res
else:
log.debug({ 'interrupt': 'idle' })
return { 'interrupt': 'idle' }
async def progress():
res = await get('/sdapi/v1/progress?skip_current_image=true')
log.debug({ 'progress': res })
return res
def progresssync():
res = getsync('/sdapi/v1/progress?skip_current_image=true')
log.debug({ 'progress': res })
return res
def options():
opts = getsync('/sdapi/v1/options')
flags = getsync('/sdapi/v1/cmd-flags')
return { 'options': opts, 'flags': flags }
def shutdown():
try:
postsync('/sdapi/v1/shutdown')
except Exception as e:
log.info({ 'shutdown': e })
async def session():
global sess # pylint: disable=global-statement
time = aiohttp.ClientTimeout(total = None, sock_connect = 10, sock_read = None) # default value is 5 minutes, we need longer for training
sess = aiohttp.ClientSession(timeout = time, base_url = sd_url, auth=auth())
log.debug({ 'sdapi': 'session created', 'endpoint': sd_url })
"""
sess = await aiohttp.ClientSession(timeout = timeout).__aenter__()
try:
async with sess.get(url = f'{sd_url}/') as req:
log.debug({ 'sdapi': 'session created', 'endpoint': sd_url })
except Exception as e:
log.error({ 'sdapi': e })
await asyncio.sleep(0)
await sess.__aexit__(None, None, None)
sess = None
return sess
"""
return sess
async def close():
if sess is not None:
await asyncio.sleep(0)
await sess.close()
await sess.__aexit__(None, None, None)
log.debug({ 'sdapi': 'session closed', 'endpoint': sd_url })
if __name__ == "__main__":
log.setLevel(logging.DEBUG)
if 'interrupt' in sys.argv:
asyncio.run(interrupt())
if 'progress' in sys.argv:
asyncio.run(progress())
if 'progresssync' in sys.argv:
progresssync()
if 'options' in sys.argv:
opt = options()
log.debug({ 'options' })
import json
print(json.dumps(opt['options'], indent = 2))
log.debug({ 'cmd-flags' })
print(json.dumps(opt['flags'], indent = 2))
if 'shutdown' in sys.argv:
shutdown()
asyncio.run(close(), debug=True)
asyncio.run(asyncio.sleep(0.5))