1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-29 05:02:09 +03:00
Files
sdnext/modules/api/api.py
2024-02-07 12:12:54 -05:00

295 lines
17 KiB
Python

from typing import List, Optional
from threading import Lock
from secrets import compare_digest
from fastapi import FastAPI, APIRouter, Depends, Request
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.exceptions import HTTPException
from modules import errors, shared, scripts, ui, postprocessing
from modules.api import models, endpoints, script, train, helpers, server, nvml
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
errors.install()
class Api:
def __init__(self, app: FastAPI, queue_lock: Lock):
self.credentials = {}
if shared.cmd_opts.auth:
for auth in shared.cmd_opts.auth.split(","):
user, password = auth.split(":")
self.credentials[user.replace('"', '').strip()] = password.replace('"', '').strip()
if shared.cmd_opts.auth_file:
with open(shared.cmd_opts.auth_file, 'r', encoding="utf8") as file:
for line in file.readlines():
user, password = line.split(":")
self.credentials[user.replace('"', '').strip()] = password.replace('"', '').strip()
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
# server api
self.add_api_route("/sdapi/v1/motd", server.get_motd, methods=["GET"], response_model=str)
self.add_api_route("/sdapi/v1/log", server.get_log_buffer, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/start", self.get_session_start, methods=["GET"])
self.add_api_route("/sdapi/v1/version", server.get_version, methods=["GET"])
self.add_api_route("/sdapi/v1/platform", server.get_platform, methods=["GET"])
self.add_api_route("/sdapi/v1/progress", server.get_progress, methods=["GET"], response_model=models.ResProgress)
self.add_api_route("/sdapi/v1/interrupt", server.post_interrupt, methods=["POST"])
self.add_api_route("/sdapi/v1/skip", server.post_skip, methods=["POST"])
self.add_api_route("/sdapi/v1/shutdown", server.post_shutdown, methods=["POST"])
self.add_api_route("/sdapi/v1/memory", server.get_memory, methods=["GET"], response_model=models.ResMemory)
self.add_api_route("/sdapi/v1/options", server.get_config, methods=["GET"], response_model=models.OptionsModel)
self.add_api_route("/sdapi/v1/options", server.set_config, methods=["POST"])
self.add_api_route("/sdapi/v1/cmd-flags", server.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
app.add_api_route("/sdapi/v1/nvml", nvml.get_nvml, methods=["GET"], response_model=List[models.ResNVML])
# core api using locking
self.add_api_route("/sdapi/v1/txt2img", self.post_text2img, methods=["POST"], response_model=models.ResTxt2Img)
self.add_api_route("/sdapi/v1/img2img", self.post_img2img, methods=["POST"], response_model=models.ResImg2Img)
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ResProcessImage)
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ResProcessBatch)
# api dealing with optional scripts
self.add_api_route("/sdapi/v1/scripts", script.get_scripts_list, methods=["GET"], response_model=models.ResScripts)
self.add_api_route("/sdapi/v1/script-info", script.get_script_info, methods=["GET"], response_model=List[models.ItemScript])
# enumerator api
self.add_api_route("/sdapi/v1/interrogate", endpoints.get_interrogate, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/samplers", endpoints.get_samplers, methods=["GET"], response_model=List[models.ItemSampler])
self.add_api_route("/sdapi/v1/upscalers", endpoints.get_upscalers, methods=["GET"], response_model=List[models.ItemUpscaler])
self.add_api_route("/sdapi/v1/sd-models", endpoints.get_sd_models, methods=["GET"], response_model=List[models.ItemModel])
self.add_api_route("/sdapi/v1/hypernetworks", endpoints.get_hypernetworks, methods=["GET"], response_model=List[models.ItemHypernetwork])
self.add_api_route("/sdapi/v1/face-restorers", endpoints.get_face_restorers, methods=["GET"], response_model=List[models.ItemFaceRestorer])
self.add_api_route("/sdapi/v1/prompt-styles", endpoints.get_prompt_styles, methods=["GET"], response_model=List[models.ItemStyle])
self.add_api_route("/sdapi/v1/embeddings", endpoints.get_embeddings, methods=["GET"], response_model=models.ResEmbeddings)
self.add_api_route("/sdapi/v1/sd-vae", endpoints.get_sd_vaes, methods=["GET"], response_model=List[models.ItemVae])
self.add_api_route("/sdapi/v1/extensions", endpoints.get_extensions_list, methods=["GET"], response_model=List[models.ItemExtension])
self.add_api_route("/sdapi/v1/extra-networks", endpoints.get_extra_networks, methods=["GET"], response_model=List[models.ItemExtraNetwork])
# functional api
self.add_api_route("/sdapi/v1/png-info", endpoints.post_pnginfo, methods=["POST"], response_model=models.ResImageInfo)
self.add_api_route("/sdapi/v1/interrogate", endpoints.post_interrogate, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-checkpoints", endpoints.post_refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/unload-checkpoint", endpoints.post_unload_checkpoint, methods=["POST"])
self.add_api_route("/sdapi/v1/reload-checkpoint", endpoints.post_reload_checkpoint, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-vae", endpoints.post_refresh_vae, methods=["POST"])
# train api
self.add_api_route("/sdapi/v1/create/embedding", train.post_create_embedding, methods=["POST"], response_model=models.ResCreate)
self.add_api_route("/sdapi/v1/create/hypernetwork", train.post_create_hypernetwork, methods=["POST"], response_model=models.ResCreate)
self.add_api_route("/sdapi/v1/preprocess", train.post_preprocess, methods=["POST"], response_model=models.ResPreprocess)
self.add_api_route("/sdapi/v1/train/embedding", train.post_train_embedding, methods=["POST"], response_model=models.ResTrain)
self.add_api_route("/sdapi/v1/train/hypernetwork", train.post_train_hypernetwork, methods=["POST"], response_model=models.ResTrain)
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
def add_api_route(self, path: str, endpoint, **kwargs):
if (shared.cmd_opts.auth or shared.cmd_opts.auth_file) and shared.cmd_opts.api_only:
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
return self.app.add_api_route(path, endpoint, **kwargs)
def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
# this is only needed for api-only since otherwise auth is handled in gradio/routes.py
if credentials.username in self.credentials:
if compare_digest(credentials.password, self.credentials[credentials.username]):
return True
raise HTTPException(status_code=401, detail="Unauthorized", headers={"WWW-Authenticate": "Basic"})
def get_session_start(self, req: Request, agent: Optional[str] = None):
token = req.cookies.get("access-token") or req.cookies.get("access-token-unsecure")
user = self.app.tokens.get(token) if hasattr(self.app, 'tokens') else None
shared.log.info(f'Browser session: user={user} client={req.client.host} agent={agent}')
return {}
def prepare_img_gen_request(self, request):
if hasattr(request, "face") and request.face and not request.script_name and (not request.alwayson_scripts or "face" not in request.alwayson_scripts.keys()):
request.script_name = "face"
request.script_args = [
request.face.mode,
request.face.source_images,
request.face.ip_model,
request.face.ip_override_sampler,
request.face.ip_cache_model,
request.face.ip_strength,
request.face.ip_structure,
request.face.id_strength,
request.face.id_conditioning,
request.face.id_cache,
request.face.pm_trigger,
request.face.pm_strength,
request.face.pm_start,
request.face.fs_cache
]
del request.face
if hasattr(request, "ip_adapter") and request.ip_adapter and request.script_name != "IP Adapter" and (not request.alwayson_scripts or "IP Adapter" not in request.alwayson_scripts.keys()):
request.alwayson_scripts = {} if request.alwayson_scripts is None else request.alwayson_scripts
request.alwayson_scripts["IP Adapter"] = {
"args": [request.ip_adapter.adapter, request.ip_adapter.scale, request.ip_adapter.image]
}
del request.ip_adapter
def sanitize_args(self, args: list):
for idx in range(0, len(args)):
if isinstance(args[idx], str) and len(args[idx]) >= 1000:
args[idx] = f"<str {len(args[idx])}>"
def sanitize_img_gen_request(self, request):
if hasattr(request, "alwayson_scripts") and request.alwayson_scripts:
for script_name in request.alwayson_scripts.keys():
script_obj = request.alwayson_scripts[script_name]
if script_obj and "args" in script_obj and script_obj["args"]:
self.sanitize_args(script_obj["args"])
if hasattr(request, "script_args") and request.script_args:
self.sanitize_args(request.script_args)
def post_text2img(self, txt2imgreq: models.ReqTxt2Img):
self.prepare_img_gen_request(txt2imgreq)
script_runner = scripts.scripts_txt2img
if not script_runner.scripts:
script_runner.initialize_scripts(False)
ui.create_ui(None)
if not self.default_script_arg_txt2img:
self.default_script_arg_txt2img = script.init_default_script_args(script_runner)
selectable_scripts, selectable_script_idx = script.get_selectable_script(txt2imgreq.script_name, script_runner)
populate = txt2imgreq.copy(update={ # Override __init__ params
"sampler_name": helpers.validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
"do_not_save_samples": not txt2imgreq.save_images,
"do_not_save_grid": not txt2imgreq.save_images,
})
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
args = vars(populate)
args.pop('script_name', None)
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('face', None)
args.pop('ip_adapter', None)
args.pop('alwayson_scripts', None)
send_images = args.pop('send_images', True)
args.pop('save_images', None)
with self.queue_lock:
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
p.scripts = script_runner
p.outpath_grids = shared.opts.outdir_grids or shared.opts.outdir_txt2img_grids
p.outpath_samples = shared.opts.outdir_samples or shared.opts.outdir_txt2img_samples
shared.state.begin('api-txt2img', api=True)
script_args = script.init_script_args(p, txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
if selectable_scripts is not None:
processed = scripts.scripts_txt2img.run(p, *script_args) # Need to pass args as list here
else:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
shared.state.end(api=False)
b64images = list(map(helpers.encode_pil_to_base64, processed.images)) if send_images else []
self.sanitize_img_gen_request(txt2imgreq)
return models.ResTxt2Img(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def post_img2img(self, img2imgreq: models.ReqImg2Img):
self.prepare_img_gen_request(img2imgreq)
init_images = img2imgreq.init_images
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
mask = img2imgreq.mask
if mask:
mask = helpers.decode_base64_to_image(mask)
script_runner = scripts.scripts_img2img
if not script_runner.scripts:
script_runner.initialize_scripts(True)
ui.create_ui(None)
if not self.default_script_arg_img2img:
self.default_script_arg_img2img = script.init_default_script_args(script_runner)
selectable_scripts, selectable_script_idx = script.get_selectable_script(img2imgreq.script_name, script_runner)
populate = img2imgreq.copy(update={ # Override __init__ params
"sampler_name": helpers.validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
"do_not_save_samples": not img2imgreq.save_images,
"do_not_save_grid": not img2imgreq.save_images,
"mask": mask,
})
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
args = vars(populate)
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
args.pop('script_name', None)
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None)
args.pop('face_id', None)
args.pop('ip_adapter', None)
send_images = args.pop('send_images', True)
args.pop('save_images', None)
with self.queue_lock:
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
p.init_images = [helpers.decode_base64_to_image(x) for x in init_images]
p.scripts = script_runner
p.outpath_grids = shared.opts.outdir_img2img_grids
p.outpath_samples = shared.opts.outdir_img2img_samples
shared.state.begin('api-img2img', api=True)
script_args = script.init_script_args(p, img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
if selectable_scripts is not None:
processed = scripts.scripts_img2img.run(p, *script_args) # Need to pass args as list here
else:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
shared.state.end(api=False)
b64images = list(map(helpers.encode_pil_to_base64, processed.images)) if send_images else []
if not img2imgreq.include_init_images:
img2imgreq.init_images = None
img2imgreq.mask = None
self.sanitize_img_gen_request(img2imgreq)
return models.ResImg2Img(images=b64images, parameters=vars(img2imgreq), info=processed.js())
def set_upscalers(self, req: dict):
reqDict = vars(req)
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
return reqDict
def extras_single_image_api(self, req: models.ReqProcessImage):
reqDict = self.set_upscalers(req)
reqDict['image'] = helpers.decode_base64_to_image(reqDict['image'])
with self.queue_lock:
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
return models.ResProcessImage(image=helpers.encode_pil_to_base64(result[0][0]), html_info=result[1])
def extras_batch_images_api(self, req: models.ReqProcessBatch):
reqDict = self.set_upscalers(req)
image_list = reqDict.pop('imageList', [])
image_folder = [helpers.decode_base64_to_image(x.data) for x in image_list]
with self.queue_lock:
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
return models.ResProcessBatch(images=list(map(helpers.encode_pil_to_base64, result[0])), html_info=result[1])
def launch(self):
config = {
"listen": shared.cmd_opts.listen,
"port": shared.cmd_opts.port,
"keyfile": shared.cmd_opts.tls_keyfile,
"certfile": shared.cmd_opts.tls_certfile,
"loop": "auto", # auto, asyncio, uvloop
"http": "auto", # auto, h11, httptools
}
from modules.server import UvicornServer
http_server = UvicornServer(self.app, **config)
# from modules.server import HypercornServer
# server = HypercornServer(self.app, **config)
http_server.start()
shared.log.info(f'API server: Uvicorn options={config}')
return http_server
# compatibility items
decode_base64_to_image = helpers.decode_base64_to_image
encode_pil_to_base64 = helpers.encode_pil_to_base64
validate_sampler_name = helpers.validate_sampler_name