1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/api/api.py
Vladimir Mandic 2aa917b58e add /sdapi/v1/modules endpoint
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-08-12 15:09:08 -04:00

182 lines
11 KiB
Python

from typing import List, Optional
from threading import Lock
from secrets import compare_digest
from fastapi import FastAPI, APIRouter, Depends, Request
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.exceptions import HTTPException
from modules import errors, shared, postprocessing
from modules.api import models, endpoints, script, helpers, server, generate, process, control, docs, gpu
errors.install()
class Api:
def __init__(self, app: FastAPI, queue_lock: Lock):
self.credentials = {}
if shared.cmd_opts.auth:
for auth in shared.cmd_opts.auth.split(","):
user, password = auth.split(":")
self.credentials[user.replace('"', '').strip()] = password.replace('"', '').strip()
if shared.cmd_opts.auth_file:
with open(shared.cmd_opts.auth_file, 'r', encoding="utf8") as file:
for line in file.readlines():
user, password = line.split(":")
self.credentials[user.replace('"', '').strip()] = password.replace('"', '').strip()
self.router = APIRouter()
if shared.cmd_opts.docs:
docs.create_docs(app)
docs.create_redocs(app)
self.app = app
self.queue_lock = queue_lock
self.generate = generate.APIGenerate(queue_lock)
self.process = process.APIProcess(queue_lock)
self.control = control.APIControl(queue_lock)
# compatibility api
self.text2imgapi = self.generate.post_text2img
self.img2imgapi = self.generate.post_img2img
def register(self):
# server api
self.add_api_route("/sdapi/v1/motd", server.get_motd, methods=["GET"], response_model=str)
self.add_api_route("/sdapi/v1/log", server.get_log, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/log", server.post_log, methods=["POST"])
self.add_api_route("/sdapi/v1/start", self.get_session_start, methods=["GET"])
self.add_api_route("/sdapi/v1/version", server.get_version, methods=["GET"])
self.add_api_route("/sdapi/v1/status", server.get_status, methods=["GET"], response_model=models.ResStatus)
self.add_api_route("/sdapi/v1/platform", server.get_platform, methods=["GET"])
self.add_api_route("/sdapi/v1/progress", server.get_progress, methods=["GET"], response_model=models.ResProgress)
self.add_api_route("/sdapi/v1/history", server.get_history, methods=["GET"], response_model=list[models.ResHistory])
self.add_api_route("/sdapi/v1/interrupt", server.post_interrupt, methods=["POST"])
self.add_api_route("/sdapi/v1/skip", server.post_skip, methods=["POST"])
self.add_api_route("/sdapi/v1/shutdown", server.post_shutdown, methods=["POST"])
self.add_api_route("/sdapi/v1/memory", server.get_memory, methods=["GET"], response_model=models.ResMemory)
self.add_api_route("/sdapi/v1/options", server.get_config, methods=["GET"], response_model=models.OptionsModel)
self.add_api_route("/sdapi/v1/options", server.set_config, methods=["POST"])
self.add_api_route("/sdapi/v1/cmd-flags", server.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
self.add_api_route("/sdapi/v1/gpu", gpu.get_gpu_status, methods=["GET"], response_model=List[models.ResGPU])
# core api using locking
self.add_api_route("/sdapi/v1/txt2img", self.generate.post_text2img, methods=["POST"], response_model=models.ResTxt2Img)
self.add_api_route("/sdapi/v1/img2img", self.generate.post_img2img, methods=["POST"], response_model=models.ResImg2Img)
self.add_api_route("/sdapi/v1/control", self.control.post_control, methods=["POST"], response_model=control.ResControl)
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ResProcessImage)
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ResProcessBatch)
self.add_api_route("/sdapi/v1/preprocess", self.process.post_preprocess, methods=["POST"])
self.add_api_route("/sdapi/v1/mask", self.process.post_mask, methods=["POST"])
self.add_api_route("/sdapi/v1/detect", self.process.post_detect, methods=["POST"])
self.add_api_route("/sdapi/v1/prompt-enhance", self.process.post_prompt_enhance, methods=["POST"], response_model=models.ResPromptEnhance)
# api dealing with optional scripts
self.add_api_route("/sdapi/v1/scripts", script.get_scripts_list, methods=["GET"], response_model=models.ResScripts)
self.add_api_route("/sdapi/v1/script-info", script.get_script_info, methods=["GET"], response_model=List[models.ItemScript])
# enumerator api
self.add_api_route("/sdapi/v1/preprocessors", self.process.get_preprocess, methods=["GET"], response_model=List[process.ItemPreprocess])
self.add_api_route("/sdapi/v1/masking", self.process.get_mask, methods=["GET"], response_model=process.ItemMask)
self.add_api_route("/sdapi/v1/interrogate", endpoints.get_interrogate, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/samplers", endpoints.get_samplers, methods=["GET"], response_model=List[models.ItemSampler])
self.add_api_route("/sdapi/v1/upscalers", endpoints.get_upscalers, methods=["GET"], response_model=List[models.ItemUpscaler])
self.add_api_route("/sdapi/v1/sd-models", endpoints.get_sd_models, methods=["GET"], response_model=List[models.ItemModel])
self.add_api_route("/sdapi/v1/controlnets", endpoints.get_controlnets, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/face-restorers", endpoints.get_detailers, methods=["GET"], response_model=List[models.ItemDetailer])
self.add_api_route("/sdapi/v1/prompt-styles", endpoints.get_prompt_styles, methods=["GET"], response_model=List[models.ItemStyle])
self.add_api_route("/sdapi/v1/embeddings", endpoints.get_embeddings, methods=["GET"], response_model=models.ResEmbeddings)
self.add_api_route("/sdapi/v1/sd-vae", endpoints.get_sd_vaes, methods=["GET"], response_model=List[models.ItemVae])
self.add_api_route("/sdapi/v1/extensions", endpoints.get_extensions_list, methods=["GET"], response_model=List[models.ItemExtension])
self.add_api_route("/sdapi/v1/extra-networks", endpoints.get_extra_networks, methods=["GET"], response_model=List[models.ItemExtraNetwork])
# functional api
self.add_api_route("/sdapi/v1/png-info", endpoints.post_pnginfo, methods=["POST"], response_model=models.ResImageInfo)
self.add_api_route("/sdapi/v1/interrogate", endpoints.post_interrogate, methods=["POST"])
self.add_api_route("/sdapi/v1/vqa", endpoints.post_vqa, methods=["POST"])
self.add_api_route("/sdapi/v1/checkpoint", endpoints.get_checkpoint, methods=["GET"])
self.add_api_route("/sdapi/v1/checkpoint", endpoints.set_checkpoint, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-checkpoints", endpoints.post_refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/unload-checkpoint", endpoints.post_unload_checkpoint, methods=["POST"])
self.add_api_route("/sdapi/v1/reload-checkpoint", endpoints.post_reload_checkpoint, methods=["POST"])
self.add_api_route("/sdapi/v1/lock-checkpoint", endpoints.post_lock_checkpoint, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-vae", endpoints.post_refresh_vae, methods=["POST"])
self.add_api_route("/sdapi/v1/latents", endpoints.get_latent_history, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/latents", endpoints.post_latent_history, methods=["POST"], response_model=int)
self.add_api_route("/sdapi/v1/modules", endpoints.get_modules, methods=["GET"])
# lora api
from modules.api import loras
loras.register_api()
# gallery api
from modules.api import gallery
gallery.register_api(self.app)
# nudenet api
from modules.api import nudenet
nudenet.register_api()
# civitai api
from modules.civitai import api_civitai
api_civitai.register_api()
def add_api_route(self, path: str, endpoint, **kwargs):
if shared.opts.subpath is not None and len(shared.opts.subpath) > 0:
self.app.add_api_route(f'{shared.opts.subpath}{path}', endpoint, **kwargs)
self.app.add_api_route(path, endpoint, **kwargs)
def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
# this is only needed for api-only since otherwise auth is handled in gradio/routes.py
if credentials.username in self.credentials:
if compare_digest(credentials.password, self.credentials[credentials.username]):
return True
raise HTTPException(status_code=401, detail="Unauthorized", headers={"WWW-Authenticate": "Basic"})
def get_session_start(self, req: Request, agent: Optional[str] = None):
token = req.cookies.get("access-token") or req.cookies.get("access-token-unsecure")
user = self.app.tokens.get(token) if hasattr(self.app, 'tokens') else None
shared.log.info(f'Browser session: user={user} client={req.client.host} agent={agent}')
return {}
def set_upscalers(self, req: dict):
reqDict = vars(req)
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
return reqDict
def extras_single_image_api(self, req: models.ReqProcessImage):
reqDict = self.set_upscalers(req)
reqDict['image'] = helpers.decode_base64_to_image(reqDict['image'])
with self.queue_lock:
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
return models.ResProcessImage(image=helpers.encode_pil_to_base64(result[0][0]), html_info=result[1])
def extras_batch_images_api(self, req: models.ReqProcessBatch):
reqDict = self.set_upscalers(req)
image_list = reqDict.pop('imageList', [])
image_folder = [helpers.decode_base64_to_image(x.data) for x in image_list]
with self.queue_lock:
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
return models.ResProcessBatch(images=list(map(helpers.encode_pil_to_base64, result[0])), html_info=result[1])
def launch(self):
config = {
"listen": shared.cmd_opts.listen,
"port": shared.cmd_opts.port,
"keyfile": shared.cmd_opts.tls_keyfile,
"certfile": shared.cmd_opts.tls_certfile,
"loop": "auto", # auto, asyncio, uvloop
"http": "auto", # auto, h11, httptools
}
from modules.server import UvicornServer
http_server = UvicornServer(self.app, **config)
# from modules.server import HypercornServer
# server = HypercornServer(self.app, **config)
http_server.start()
shared.log.info(f'API server: Uvicorn options={config}')
return http_server
# compatibility items
decode_base64_to_image = helpers.decode_base64_to_image
encode_pil_to_base64 = helpers.encode_pil_to_base64
validate_sampler_name = helpers.validate_sampler_name