mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
175 lines
11 KiB
Python
175 lines
11 KiB
Python
from typing import List, Optional
|
|
from threading import Lock
|
|
from secrets import compare_digest
|
|
from fastapi import FastAPI, APIRouter, Depends, Request
|
|
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
|
from fastapi.exceptions import HTTPException
|
|
from modules import errors, shared, postprocessing
|
|
from modules.api import models, endpoints, script, helpers, server, nvml, generate, process, control, gallery, loras, docs
|
|
|
|
|
|
errors.install()
|
|
|
|
|
|
class Api:
|
|
def __init__(self, app: FastAPI, queue_lock: Lock):
|
|
self.credentials = {}
|
|
if shared.cmd_opts.auth:
|
|
for auth in shared.cmd_opts.auth.split(","):
|
|
user, password = auth.split(":")
|
|
self.credentials[user.replace('"', '').strip()] = password.replace('"', '').strip()
|
|
if shared.cmd_opts.auth_file:
|
|
with open(shared.cmd_opts.auth_file, 'r', encoding="utf8") as file:
|
|
for line in file.readlines():
|
|
user, password = line.split(":")
|
|
self.credentials[user.replace('"', '').strip()] = password.replace('"', '').strip()
|
|
self.router = APIRouter()
|
|
if shared.cmd_opts.docs:
|
|
docs.create_docs(app)
|
|
docs.create_redocs(app)
|
|
self.app = app
|
|
self.queue_lock = queue_lock
|
|
self.generate = generate.APIGenerate(queue_lock)
|
|
self.process = process.APIProcess(queue_lock)
|
|
self.control = control.APIControl(queue_lock)
|
|
# compatibility api
|
|
self.text2imgapi = self.generate.post_text2img
|
|
self.img2imgapi = self.generate.post_img2img
|
|
|
|
def register(self):
|
|
# server api
|
|
self.add_api_route("/sdapi/v1/motd", server.get_motd, methods=["GET"], response_model=str)
|
|
self.add_api_route("/sdapi/v1/log", server.get_log, methods=["GET"], response_model=List[str])
|
|
self.add_api_route("/sdapi/v1/log", server.post_log, methods=["POST"])
|
|
self.add_api_route("/sdapi/v1/start", self.get_session_start, methods=["GET"])
|
|
self.add_api_route("/sdapi/v1/version", server.get_version, methods=["GET"])
|
|
self.add_api_route("/sdapi/v1/status", server.get_status, methods=["GET"], response_model=models.ResStatus)
|
|
self.add_api_route("/sdapi/v1/platform", server.get_platform, methods=["GET"])
|
|
self.add_api_route("/sdapi/v1/progress", server.get_progress, methods=["GET"], response_model=models.ResProgress)
|
|
self.add_api_route("/sdapi/v1/history", server.get_history, methods=["GET"], response_model=list[models.ResHistory])
|
|
self.add_api_route("/sdapi/v1/interrupt", server.post_interrupt, methods=["POST"])
|
|
self.add_api_route("/sdapi/v1/skip", server.post_skip, methods=["POST"])
|
|
self.add_api_route("/sdapi/v1/shutdown", server.post_shutdown, methods=["POST"])
|
|
self.add_api_route("/sdapi/v1/memory", server.get_memory, methods=["GET"], response_model=models.ResMemory)
|
|
self.add_api_route("/sdapi/v1/options", server.get_config, methods=["GET"], response_model=models.OptionsModel)
|
|
self.add_api_route("/sdapi/v1/options", server.set_config, methods=["POST"])
|
|
self.add_api_route("/sdapi/v1/cmd-flags", server.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
|
self.add_api_route("/sdapi/v1/nvml", nvml.get_nvml, methods=["GET"], response_model=List[models.ResNVML])
|
|
|
|
# core api using locking
|
|
self.add_api_route("/sdapi/v1/txt2img", self.generate.post_text2img, methods=["POST"], response_model=models.ResTxt2Img)
|
|
self.add_api_route("/sdapi/v1/img2img", self.generate.post_img2img, methods=["POST"], response_model=models.ResImg2Img)
|
|
self.add_api_route("/sdapi/v1/control", self.control.post_control, methods=["POST"], response_model=control.ResControl)
|
|
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ResProcessImage)
|
|
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ResProcessBatch)
|
|
self.add_api_route("/sdapi/v1/preprocess", self.process.post_preprocess, methods=["POST"])
|
|
self.add_api_route("/sdapi/v1/mask", self.process.post_mask, methods=["POST"])
|
|
self.add_api_route("/sdapi/v1/detect", self.process.post_detect, methods=["POST"])
|
|
self.add_api_route("/sdapi/v1/prompt-enhance", self.process.post_prompt_enhance, methods=["POST"], response_model=models.ResPromptEnhance)
|
|
|
|
# api dealing with optional scripts
|
|
self.add_api_route("/sdapi/v1/scripts", script.get_scripts_list, methods=["GET"], response_model=models.ResScripts)
|
|
self.add_api_route("/sdapi/v1/script-info", script.get_script_info, methods=["GET"], response_model=List[models.ItemScript])
|
|
|
|
# enumerator api
|
|
self.add_api_route("/sdapi/v1/preprocessors", self.process.get_preprocess, methods=["GET"], response_model=List[process.ItemPreprocess])
|
|
self.add_api_route("/sdapi/v1/masking", self.process.get_mask, methods=["GET"], response_model=process.ItemMask)
|
|
self.add_api_route("/sdapi/v1/interrogate", endpoints.get_interrogate, methods=["GET"], response_model=List[str])
|
|
self.add_api_route("/sdapi/v1/samplers", endpoints.get_samplers, methods=["GET"], response_model=List[models.ItemSampler])
|
|
self.add_api_route("/sdapi/v1/upscalers", endpoints.get_upscalers, methods=["GET"], response_model=List[models.ItemUpscaler])
|
|
self.add_api_route("/sdapi/v1/sd-models", endpoints.get_sd_models, methods=["GET"], response_model=List[models.ItemModel])
|
|
self.add_api_route("/sdapi/v1/controlnets", endpoints.get_controlnets, methods=["GET"], response_model=List[str])
|
|
self.add_api_route("/sdapi/v1/hypernetworks", endpoints.get_hypernetworks, methods=["GET"], response_model=List[models.ItemHypernetwork])
|
|
self.add_api_route("/sdapi/v1/face-restorers", endpoints.get_detailers, methods=["GET"], response_model=List[models.ItemDetailer])
|
|
self.add_api_route("/sdapi/v1/prompt-styles", endpoints.get_prompt_styles, methods=["GET"], response_model=List[models.ItemStyle])
|
|
self.add_api_route("/sdapi/v1/embeddings", endpoints.get_embeddings, methods=["GET"], response_model=models.ResEmbeddings)
|
|
self.add_api_route("/sdapi/v1/sd-vae", endpoints.get_sd_vaes, methods=["GET"], response_model=List[models.ItemVae])
|
|
self.add_api_route("/sdapi/v1/extensions", endpoints.get_extensions_list, methods=["GET"], response_model=List[models.ItemExtension])
|
|
self.add_api_route("/sdapi/v1/extra-networks", endpoints.get_extra_networks, methods=["GET"], response_model=List[models.ItemExtraNetwork])
|
|
|
|
# functional api
|
|
self.add_api_route("/sdapi/v1/png-info", endpoints.post_pnginfo, methods=["POST"], response_model=models.ResImageInfo)
|
|
self.add_api_route("/sdapi/v1/interrogate", endpoints.post_interrogate, methods=["POST"])
|
|
self.add_api_route("/sdapi/v1/vqa", endpoints.post_vqa, methods=["POST"])
|
|
self.add_api_route("/sdapi/v1/checkpoint", endpoints.get_checkpoint, methods=["GET"])
|
|
self.add_api_route("/sdapi/v1/refresh-checkpoints", endpoints.post_refresh_checkpoints, methods=["POST"])
|
|
self.add_api_route("/sdapi/v1/unload-checkpoint", endpoints.post_unload_checkpoint, methods=["POST"])
|
|
self.add_api_route("/sdapi/v1/reload-checkpoint", endpoints.post_reload_checkpoint, methods=["POST"])
|
|
self.add_api_route("/sdapi/v1/refresh-vae", endpoints.post_refresh_vae, methods=["POST"])
|
|
self.add_api_route("/sdapi/v1/latents", endpoints.get_latent_history, methods=["GET"], response_model=List[str])
|
|
self.add_api_route("/sdapi/v1/latents", endpoints.post_latent_history, methods=["POST"], response_model=int)
|
|
|
|
# lora api
|
|
if shared.native:
|
|
self.add_api_route("/sdapi/v1/lora", loras.get_lora, methods=["GET"], response_model=dict)
|
|
self.add_api_route("/sdapi/v1/loras", loras.get_loras, methods=["GET"], response_model=List[dict])
|
|
self.add_api_route("/sdapi/v1/refresh-loras", loras.post_refresh_loras, methods=["POST"])
|
|
|
|
# gallery api
|
|
gallery.register_api(self.app)
|
|
|
|
|
|
def add_api_route(self, path: str, endpoint, **kwargs):
|
|
if (shared.cmd_opts.auth or shared.cmd_opts.auth_file) and shared.cmd_opts.api_only:
|
|
kwargs['dependencies'] = [Depends(self.auth)]
|
|
if shared.opts.subpath is not None and len(shared.opts.subpath) > 0:
|
|
self.app.add_api_route(f'{shared.opts.subpath}{path}', endpoint, **kwargs)
|
|
self.app.add_api_route(path, endpoint, **kwargs)
|
|
|
|
def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
|
|
# this is only needed for api-only since otherwise auth is handled in gradio/routes.py
|
|
if credentials.username in self.credentials:
|
|
if compare_digest(credentials.password, self.credentials[credentials.username]):
|
|
return True
|
|
raise HTTPException(status_code=401, detail="Unauthorized", headers={"WWW-Authenticate": "Basic"})
|
|
|
|
def get_session_start(self, req: Request, agent: Optional[str] = None):
|
|
token = req.cookies.get("access-token") or req.cookies.get("access-token-unsecure")
|
|
user = self.app.tokens.get(token) if hasattr(self.app, 'tokens') else None
|
|
shared.log.info(f'Browser session: user={user} client={req.client.host} agent={agent}')
|
|
return {}
|
|
|
|
def set_upscalers(self, req: dict):
|
|
reqDict = vars(req)
|
|
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
|
|
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
|
|
return reqDict
|
|
|
|
def extras_single_image_api(self, req: models.ReqProcessImage):
|
|
reqDict = self.set_upscalers(req)
|
|
reqDict['image'] = helpers.decode_base64_to_image(reqDict['image'])
|
|
with self.queue_lock:
|
|
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
|
return models.ResProcessImage(image=helpers.encode_pil_to_base64(result[0][0]), html_info=result[1])
|
|
|
|
def extras_batch_images_api(self, req: models.ReqProcessBatch):
|
|
reqDict = self.set_upscalers(req)
|
|
image_list = reqDict.pop('imageList', [])
|
|
image_folder = [helpers.decode_base64_to_image(x.data) for x in image_list]
|
|
with self.queue_lock:
|
|
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
|
return models.ResProcessBatch(images=list(map(helpers.encode_pil_to_base64, result[0])), html_info=result[1])
|
|
|
|
def launch(self):
|
|
config = {
|
|
"listen": shared.cmd_opts.listen,
|
|
"port": shared.cmd_opts.port,
|
|
"keyfile": shared.cmd_opts.tls_keyfile,
|
|
"certfile": shared.cmd_opts.tls_certfile,
|
|
"loop": "auto", # auto, asyncio, uvloop
|
|
"http": "auto", # auto, h11, httptools
|
|
}
|
|
from modules.server import UvicornServer
|
|
http_server = UvicornServer(self.app, **config)
|
|
# from modules.server import HypercornServer
|
|
# server = HypercornServer(self.app, **config)
|
|
http_server.start()
|
|
shared.log.info(f'API server: Uvicorn options={config}')
|
|
return http_server
|
|
|
|
|
|
# compatibility items
|
|
decode_base64_to_image = helpers.decode_base64_to_image
|
|
encode_pil_to_base64 = helpers.encode_pil_to_base64
|
|
validate_sampler_name = helpers.validate_sampler_name
|