1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/webui.py
2026-01-10 13:27:38 -08:00

426 lines
16 KiB
Python

import io
import os
import sys
import time
import glob
import signal
import asyncio
import logging
import importlib
import contextlib
from threading import Thread
from installer import log, git_commit, custom_excepthook, version
from modules import timer
import modules.loader
import modules.hashes
import modules.paths
import modules.devices
from modules import shared
from modules.call_queue import queue_lock, wrap_queued_call, wrap_gradio_gpu_call # pylint: disable=unused-import
import modules.gr_tempdir
import modules.modeldata
import modules.extensions
import modules.modelloader
import modules.sd_checkpoint
import modules.sd_samplers
import modules.scripts_manager
import modules.scripts
import modules.sd_models
import modules.sd_vae
import modules.sd_unet
import modules.sd_hijack
import modules.model_te
import modules.progress
import modules.ui
import modules.txt2img
import modules.img2img
import modules.upscaler
import modules.upscaler_simple
import modules.extra_networks
import modules.ui_extra_networks
import modules.textual_inversion
import modules.script_callbacks
import modules.api.middleware
if not modules.loader.initialized:
timer.startup.record("libraries")
modules.loader.initialized = True
sys.excepthook = custom_excepthook
local_url = None
state = shared.state
backend = shared.backend
if shared.cmd_opts.server_name:
server_name = shared.cmd_opts.server_name
else:
server_name = "0.0.0.0" if shared.cmd_opts.listen else None
fastapi_args = {
"version": f'0.0.{git_commit}',
"title": "SD.Next",
"description": "SD.Next",
"docs_url": None,
"redoc_url": None,
# "docs_url": "/docs" if cmd_opts.docs else None, # custom handler in api.py
# "redoc_url": "/redocs" if cmd_opts.docs else None,
}
def initialize():
log.debug('Initializing: modules')
modules.sd_checkpoint.init_metadata()
modules.hashes.init_cache()
modules.sd_samplers.list_samplers()
timer.startup.record("samplers")
modules.sd_vae.refresh_vae_list()
timer.startup.record("vae")
modules.sd_unet.refresh_unet_list()
timer.startup.record("unet")
modules.model_te.refresh_te_list()
timer.startup.record("te")
modules.modelloader.cleanup_models()
modules.sd_models.setup_model()
timer.startup.record("models")
from modules.lora import lora_load
lora_load.list_available_networks()
timer.startup.record("lora")
shared.prompt_styles.reload()
timer.startup.record("styles")
import modules.postprocess.codeformer_model as codeformer
codeformer.setup_model(shared.opts.codeformer_models_path)
sys.modules["modules.codeformer_model"] = codeformer
import modules.postprocess.gfpgan_model as gfpgan
gfpgan.setup_model(shared.opts.gfpgan_models_path)
import modules.postprocess.yolo as yolo
yolo.initialize()
timer.startup.record("detailer")
modules.extensions.list_extensions()
timer.startup.record("extensions")
log.info('Load extensions')
t_timer, t_total = modules.scripts_manager.load_scripts()
modules.scripts.register_runners()
timer.startup.record("extensions")
timer.startup.records["extensions"] = t_total # scripts can reset the time
log.debug(f'Extensions init time: {t_timer.summary()}')
modules.modelloader.load_upscalers()
timer.startup.record("upscalers")
modules.ui_extra_networks.initialize()
modules.ui_extra_networks.register_pages()
modules.extra_networks.initialize()
modules.extra_networks.register_default_extra_networks()
timer.startup.record("networks")
from modules.models_hf import hf_init, hf_check_cache
hf_init()
hf_check_cache()
if shared.cmd_opts.tls_keyfile is not None and shared.cmd_opts.tls_certfile is not None:
try:
if not os.path.exists(shared.cmd_opts.tls_keyfile):
log.error("Invalid path to TLS keyfile given")
if not os.path.exists(shared.cmd_opts.tls_certfile):
log.error(f"Invalid path to TLS certfile: '{shared.cmd_opts.tls_certfile}'")
except TypeError:
shared.cmd_opts.tls_keyfile = shared.cmd_opts.tls_certfile = None
log.error("TLS setup invalid, running webui without TLS")
else:
log.info("Running with TLS")
timer.startup.record("tls")
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(_sig, _frame):
log.trace(f'State history: uptime={round(time.time() - shared.state.server_start)} jobs={shared.state.job_history} tasks={shared.state.task_history} latents={shared.state.latent_history} images={shared.state.image_history}')
log.info('Exiting')
try:
for f in glob.glob("*.lock"):
os.remove(f)
except Exception:
pass
sys.exit(0)
signal.signal(signal.SIGINT, sigint_handler)
def load_model():
modules.modeldata.model_data.locked = False
autoload = shared.opts.sd_checkpoint_autoload or shared.cmd_opts.ckpt is not None
log.info(f'Model: autoload={autoload} selected="{shared.opts.sd_model_checkpoint}"')
if autoload:
jobid = shared.state.begin('Load model')
thread_model = Thread(target=lambda: shared.sd_model)
thread_model.start()
thread_refiner = Thread(target=lambda: shared.sd_refiner)
thread_refiner.start()
thread_model.join()
thread_refiner.join()
shared.state.end(jobid)
timer.startup.record("checkpoint")
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(op='model')), call=False)
shared.opts.onchange("sd_model_refiner", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(op='refiner')), call=False)
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_unet", wrap_queued_call(lambda: modules.sd_unet.load_unet(shared.sd_model)), call=False)
shared.opts.onchange("sd_text_encoder", wrap_queued_call(lambda: modules.sd_models.reload_text_encoder()), call=False)
shared.opts.onchange("temp_dir", modules.gr_tempdir.on_tmpdir_changed)
timer.startup.record("onchange")
def create_api(app):
log.debug('API initialize')
from modules.api.api import Api
api = Api(app, queue_lock)
return api
def async_policy():
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy") else asyncio.DefaultEventLoopPolicy
class AnyThreadEventLoopPolicy(_BasePolicy):
def handle_exception(self, context):
msg = context.get("exception", context["message"])
log.error(f"AsyncIO loop: {msg}")
def get_event_loop(self) -> asyncio.AbstractEventLoop:
try:
self.loop = super().get_event_loop()
except (RuntimeError, AssertionError):
self.loop = self.new_event_loop()
self.set_event_loop(self.loop)
return self.loop
def __init__(self):
super().__init__()
self.loop = self.get_event_loop()
self.loop.set_exception_handler(self.handle_exception)
# log.debug(f"Event loop: {self.loop}")
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
def get_external_ip():
import socket
try:
ip_address = socket.gethostbyname(socket.gethostname())
if ip_address.startswith('127.'):
return None
return ip_address
except Exception:
return None
def get_remote_ip():
import requests
try:
response = requests.get('https://api.ipify.org?format=json', timeout=2)
ip_address = response.json()['ip']
return ip_address
except Exception:
return None
def start_common():
log.debug('Entering start sequence')
if shared.cmd_opts.data_dir is not None and len(shared.cmd_opts.data_dir) > 0:
log.info(f'Base path: data="{shared.cmd_opts.data_dir}"')
if shared.cmd_opts.models_dir is not None and len(shared.cmd_opts.models_dir) > 0 and shared.cmd_opts.models_dir != 'models':
log.info(f'Base path: models="{shared.cmd_opts.models_dir}"')
modules.paths.create_paths(shared.opts)
async_policy()
initialize()
if shared.cmd_opts.backend == 'original':
shared.log.error('Legacy option: backend=original is no longer supported')
shared.cmd_opts.backend = 'diffusers'
try:
from installer import diffusers_commit
if diffusers_commit != 'unknown':
shared.opts.diffusers_version = diffusers_commit # update installed diffusers version
except Exception:
pass
if shared.opts.clean_temp_dir_at_start:
modules.gr_tempdir.cleanup_tmpdr()
timer.startup.record("cleanup")
def mount_subpath(app):
if shared.cmd_opts.subpath:
shared.opts.subpath = shared.cmd_opts.subpath
if shared.opts.subpath is None or len(shared.opts.subpath) == 0:
return
import gradio
if not shared.opts.subpath.startswith('/'):
shared.opts.subpath = f'/{shared.opts.subpath}'
gradio.mount_gradio_app(app, shared.demo, path=shared.opts.subpath)
shared.log.info(f'Mounted: subpath="{shared.opts.subpath}"')
def start_ui():
log.debug('UI start sequence')
log.debug(f'UI image support: kanvas={version.get("kanvas", "unknown")}')
modules.script_callbacks.before_ui_callback()
timer.startup.record("before-ui")
shared.demo = modules.ui.create_ui(timer.startup)
timer.startup.record("ui")
if shared.cmd_opts.disable_queue:
log.info('Server queues disabled')
shared.demo.progress_tracking = False
else:
shared.demo.queue(concurrency_count=64)
gradio_auth_creds = []
if shared.cmd_opts.auth:
gradio_auth_creds += [x.strip() for x in shared.cmd_opts.auth.strip('"').replace('\n', '').split(',') if x.strip()]
if shared.cmd_opts.auth_file:
if not os.path.exists(shared.cmd_opts.auth_file):
log.error(f"Invalid path to auth file: '{shared.cmd_opts.auth_file}'")
else:
with open(shared.cmd_opts.auth_file, 'r', encoding="utf8") as file:
for line in file.readlines():
gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
if len(gradio_auth_creds) > 0:
log.info(f'Authentication enabled: users={len(list(gradio_auth_creds))}')
global local_url # pylint: disable=global-statement
stdout = io.StringIO()
allowed_paths = [os.path.dirname(__file__)]
if shared.cmd_opts.data_dir is not None and os.path.isdir(shared.cmd_opts.data_dir):
allowed_paths.append(shared.cmd_opts.data_dir)
if shared.cmd_opts.models_dir is not None and os.path.isdir(shared.cmd_opts.models_dir):
allowed_paths.append(shared.cmd_opts.models_dir)
if shared.cmd_opts.allowed_paths is not None:
allowed_paths += [p for p in shared.cmd_opts.allowed_paths if os.path.isdir(p)]
shared.log.debug(f'Root paths: {allowed_paths}')
with contextlib.redirect_stdout(stdout):
app, local_url, share_url = shared.demo.launch( # app is FastAPI(Starlette) instance
share=shared.cmd_opts.share,
server_name=server_name,
server_port=shared.cmd_opts.port if shared.cmd_opts.port != 7860 else None,
ssl_keyfile=shared.cmd_opts.tls_keyfile,
ssl_certfile=shared.cmd_opts.tls_certfile,
ssl_verify=not shared.cmd_opts.tls_selfsign,
debug=False,
auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None,
prevent_thread_lock=True,
max_threads=64,
show_api=False,
quiet=True,
favicon_path='html/favicon.svg',
allowed_paths=allowed_paths,
app_kwargs=fastapi_args,
_frontend=True and shared.cmd_opts.share,
)
if shared.cmd_opts.data_dir is not None:
modules.gr_tempdir.register_tmp_file(shared.demo, os.path.join(shared.cmd_opts.data_dir, 'x'))
shared.log.info(f'Local URL: {local_url}')
if shared.cmd_opts.listen:
if not gradio_auth_creds:
shared.log.warning('Public URL: enabled without authentication')
if shared.cmd_opts.insecure:
shared.log.warning('Public URL: enabled with insecure flag')
proto = 'https' if shared.cmd_opts.tls_keyfile is not None else 'http'
external_ip = get_external_ip()
if external_ip is not None:
shared.log.info(f'External URL: {proto}://{external_ip}:{shared.cmd_opts.port}')
public_ip = get_remote_ip()
if public_ip is not None:
shared.log.info(f'Public URL: {proto}://{public_ip}:{shared.cmd_opts.port}')
if shared.cmd_opts.docs:
shared.log.info(f'API docs: {local_url[:-1]}/docs') # pylint: disable=unsubscriptable-object
shared.log.info(f'API redocs: {local_url[:-1]}/redocs') # pylint: disable=unsubscriptable-object
if share_url is not None:
shared.log.info(f'Share URL: {share_url}')
# shared.log.debug(f'Gradio functions: registered={len(shared.demo.fns)}')
shared.demo.server.wants_restart = False
modules.api.middleware.setup_middleware(app, shared.cmd_opts)
timer.startup.record("launch")
shared.api = create_api(app)
shared.api.register()
modules.progress.setup_progress_api()
modules.ui_extra_networks.init_api()
timer.startup.record("api")
modules.script_callbacks.app_started_callback(shared.demo, app)
timer.startup.record("app-started")
time_sorted = sorted(modules.scripts_manager.time_setup.items(), key=lambda x: x[1], reverse=True)
time_script = [f'{k}:{round(v,3)}' for (k,v) in time_sorted if v > 0.03]
time_total = sum(modules.scripts_manager.time_setup.values())
shared.log.debug(f'Scripts setup: time={time_total:.3f} {time_script}')
time_component = [f'{k}:{round(v,3)}' for (k,v) in modules.scripts_manager.time_component.items() if v > 0.005]
if len(time_component) > 0:
shared.log.debug(f'Scripts components: {time_component}')
return app
def webui(restart=False):
if restart:
modules.script_callbacks.app_reload_callback()
modules.script_callbacks.script_unloaded_callback()
start_common()
app = start_ui()
modules.script_callbacks.after_ui_callback()
modules.sd_models.write_metadata()
load_model()
mount_subpath(app)
shared.opts.save()
if shared.cmd_opts.profile:
for k, v in modules.script_callbacks.callback_map.items():
shared.log.debug(f'Registered callbacks: {k}={len(v)} {[c.script for c in v]}')
debug = log.trace if os.environ.get('SD_SCRIPT_DEBUG', None) is not None else lambda *args, **kwargs: None
debug('Trace: SCRIPTS')
for m in modules.scripts_manager.scripts_data:
debug(f' {m}')
debug('Loaded postprocessing scripts:')
for m in modules.scripts_manager.postprocessing_scripts_data:
debug(f' {m}')
modules.script_callbacks.print_timers()
if shared.cmd_opts.profile:
log.info(f"Launch time: {timer.launch.summary(min_time=0)}")
log.info(f"Installer time: {timer.init.summary(min_time=0)}")
log.info(f"Startup time: {timer.startup.summary(min_time=0)}")
else:
timer.startup.add('launch', timer.launch.get_total())
timer.startup.add('installer', timer.launch.get_total())
log.info(f"Startup time: {timer.startup.summary()}")
timer.startup.reset()
if not restart:
# override all loggers to use the same handlers as the main logger
for logger in [logging.getLogger(name) for name in logging.root.manager.loggerDict]: # pylint: disable=no-member
if logger.name.startswith('uvicorn') or logger.name.startswith('sd'):
continue
logger.handlers = log.handlers
# autolaunch only on initial start
if (shared.opts.autolaunch or shared.cmd_opts.autolaunch) and local_url is not None:
shared.cmd_opts.autolaunch = False
shared.log.info('Launching browser')
import webbrowser
webbrowser.open(local_url, new=2, autoraise=True)
else:
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
importlib.reload(module)
return shared.demo.server
if __name__ == "__main__":
webui()