mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
156 lines
6.7 KiB
Python
156 lines
6.7 KiB
Python
import os
|
|
import time
|
|
from typing import Union
|
|
import threading
|
|
import numpy as np
|
|
from PIL import Image
|
|
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
|
from modules.shared import log, opts, listdir
|
|
from modules import errors
|
|
from modules.control.units.lite_model import ControlNetLLLite
|
|
|
|
|
|
what = 'ControlLLLite'
|
|
debug = log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None
|
|
debug('Trace: CONTROL')
|
|
predefined_sd15 = {
|
|
}
|
|
predefined_sdxl = {
|
|
'Canny XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_canny',
|
|
'Canny anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_canny_anime',
|
|
'Depth anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01008016e_sdxl_depth_anime',
|
|
'Blur anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01016032e_sdxl_blur_anime_beta',
|
|
'Pose anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_pose_anime',
|
|
'Replicate anime XL': 'kohya-ss/controlnet-lllite/controllllite_v01032064e_sdxl_replicate_anime_v2',
|
|
}
|
|
models = {}
|
|
all_models = {}
|
|
all_models.update(predefined_sd15)
|
|
all_models.update(predefined_sdxl)
|
|
cache_dir = 'models/control/lite'
|
|
load_lock = threading.Lock()
|
|
|
|
|
|
def find_models():
|
|
path = os.path.join(opts.control_dir, 'lite')
|
|
files = listdir(path)
|
|
files = [f for f in files if f.endswith('.safetensors')]
|
|
downloaded_models = {}
|
|
for f in files:
|
|
basename = os.path.splitext(os.path.relpath(f, path))[0]
|
|
downloaded_models[basename] = os.path.join(path, f)
|
|
all_models.update(downloaded_models)
|
|
return downloaded_models
|
|
|
|
|
|
def list_models(refresh=False):
|
|
import modules.shared
|
|
global models # pylint: disable=global-statement
|
|
if not refresh and len(models) > 0:
|
|
return models
|
|
models = {}
|
|
if modules.shared.sd_model_type == 'none':
|
|
models = ['None']
|
|
elif modules.shared.sd_model_type == 'sdxl':
|
|
models = ['None'] + sorted(predefined_sdxl) + sorted(find_models())
|
|
elif modules.shared.sd_model_type == 'sd':
|
|
models = ['None'] + sorted(predefined_sd15) + sorted(find_models())
|
|
else:
|
|
log.warning(f'Control {what} model list failed: unknown model type')
|
|
models = ['None'] + sorted(predefined_sd15) + sorted(predefined_sdxl) + sorted(find_models())
|
|
debug(f'Control list {what}: path={cache_dir} models={models}')
|
|
return models
|
|
|
|
|
|
class ControlLLLite():
|
|
def __init__(self, model_id: str = None, device = None, dtype = None, load_config = None):
|
|
self.model: ControlNetLLLite = None
|
|
self.model_id: str = model_id
|
|
self.device = device
|
|
self.dtype = dtype
|
|
self.load_config = { 'cache_dir': cache_dir }
|
|
if load_config is not None:
|
|
self.load_config.update(load_config)
|
|
if model_id is not None:
|
|
self.load()
|
|
|
|
def __str__(self):
|
|
return f' ControlLLLite(id={self.model_id} model={self.model.__class__.__name__})' if self.model_id and self.model else ''
|
|
|
|
def reset(self):
|
|
if self.model is not None:
|
|
debug(f'Control {what} model unloaded')
|
|
self.model = None
|
|
self.model_id = None
|
|
|
|
def load(self, model_id: str = None, force: bool = True) -> str:
|
|
with load_lock:
|
|
try:
|
|
t0 = time.time()
|
|
model_id = model_id or self.model_id
|
|
if model_id is None or model_id == 'None':
|
|
self.reset()
|
|
return
|
|
if model_id not in all_models:
|
|
log.error(f'Control {what} unknown model: id="{model_id}" available={list(all_models)}')
|
|
return
|
|
model_path = all_models[model_id]
|
|
if model_path == '':
|
|
return
|
|
if model_path is None:
|
|
log.error(f'Control {what} model load failed: id="{model_id}" error=unknown model id')
|
|
return
|
|
if model_id == self.model_id and not force:
|
|
# log.debug(f'Control {what} model: id="{model_id}" path="{model_path}" already loaded')
|
|
return
|
|
log.debug(f'Control {what} model loading: id="{model_id}" path="{model_path}" {self.load_config}')
|
|
if model_path.endswith('.safetensors'):
|
|
self.model = ControlNetLLLite(model_path)
|
|
else:
|
|
import huggingface_hub as hf
|
|
offline_config = {}
|
|
if opts.offline_mode:
|
|
offline_config["local_files_only"] = True
|
|
os.environ['HF_HUB_OFFLINE'] = '1'
|
|
else:
|
|
os.environ.pop('HF_HUB_OFFLINE', None)
|
|
os.unsetenv('HF_HUB_OFFLINE')
|
|
folder, filename = os.path.split(model_path)
|
|
model_path = hf.hf_hub_download(repo_id=folder, filename=f'{filename}.safetensors', cache_dir=cache_dir, **offline_config)
|
|
self.model = ControlNetLLLite(model_path)
|
|
if self.device is not None:
|
|
self.model.to(self.device)
|
|
if self.dtype is not None:
|
|
self.model.to(self.dtype)
|
|
t1 = time.time()
|
|
self.model_id = model_id
|
|
log.debug(f'Control {what} model loaded: id="{model_id}" path="{model_path}" time={t1-t0:.2f}')
|
|
return f'{what} loaded model: {model_id}'
|
|
except Exception as e:
|
|
log.error(f'Control {what} model load failed: id="{model_id}" error={e}')
|
|
errors.display(e, f'Control {what} load')
|
|
return f'{what} failed to load model: {model_id}'
|
|
|
|
|
|
class ControlLLitePipeline():
|
|
def __init__(self, pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline]):
|
|
self.pipeline = pipeline
|
|
# self.pipeline.__class__.__name__ = 'ControlLLLitePipeline'
|
|
self.nets = []
|
|
|
|
def apply(self, controlnet: Union[ControlNetLLLite, list[ControlNetLLLite]], image, conditioning):
|
|
if image is None:
|
|
return
|
|
self.nets = [controlnet] if isinstance(controlnet, ControlNetLLLite) else controlnet
|
|
debug(f'Control {what} apply: models={len(self.nets)} image={image} conditioning={conditioning}')
|
|
weight = [conditioning] if isinstance(conditioning, float) else conditioning
|
|
images = [image] if isinstance(image, Image.Image) else image
|
|
images = [i.convert('RGB') for i in images]
|
|
for i, cn in enumerate(self.nets):
|
|
cn.apply(pipe=self.pipeline, cond=np.asarray(images[i % len(images)]), weight=weight[i % len(weight)])
|
|
|
|
def restore(self):
|
|
from modules.control.units.lite_model import clear_all_lllite
|
|
clear_all_lllite()
|
|
self.nets = []
|