mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
324 lines
16 KiB
Python
324 lines
16 KiB
Python
from typing import Union
|
|
import os
|
|
import time
|
|
import concurrent
|
|
from modules import shared, errors, sd_models, sd_models_compile, files_cache
|
|
from modules.lora import network, lora_overrides, lora_convert, lora_diffusers
|
|
from modules.lora import lora_common as l
|
|
|
|
|
|
lora_cache = {}
|
|
available_networks = {}
|
|
available_network_aliases = {}
|
|
forbidden_network_aliases = {}
|
|
available_network_hash_lookup = {}
|
|
dump_lora_keys = os.environ.get('SD_LORA_DUMP', None) is not None
|
|
exclude_errors = [
|
|
"'ChronoEditTransformer3DModel'",
|
|
]
|
|
|
|
|
|
def lora_dump(lora, dct):
|
|
import tempfile
|
|
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
|
|
ty = shared.sd_model_type
|
|
cn = sd_model.__class__.__name__
|
|
shared.log.trace(f'LoRA dump: type={ty} model={cn} fn="{lora}"')
|
|
bn = os.path.splitext(os.path.basename(lora))[0]
|
|
fn = os.path.join(tempfile.gettempdir(), f'LoRA-{ty}-{cn}-{bn}.txt')
|
|
with open(fn, 'w', encoding='utf8') as f:
|
|
keys = sorted(dct.keys())
|
|
shared.log.trace(f'LoRA dump: type=LoRA fn="{fn}" keys={len(keys)}')
|
|
for line in keys:
|
|
f.write(line + "\n")
|
|
fn = os.path.join(tempfile.gettempdir(), f'Model-{ty}-{cn}.txt')
|
|
with open(fn, 'w', encoding='utf8') as f:
|
|
keys = sd_model.network_layer_mapping.keys()
|
|
shared.log.trace(f'LoRA dump: type=Mapping fn="{fn}" keys={len(keys)}')
|
|
for line in keys:
|
|
f.write(line + "\n")
|
|
|
|
|
|
def load_safetensors(name, network_on_disk: network.NetworkOnDisk) -> Union[network.Network, None]:
|
|
if not shared.sd_loaded:
|
|
return None
|
|
|
|
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
|
|
cached = lora_cache.get(name, None)
|
|
if l.debug:
|
|
shared.log.debug(f'Network load: type=LoRA name="{name}" file="{network_on_disk.filename}" type=lora {"cached" if cached else ""}')
|
|
if cached is not None:
|
|
return cached
|
|
net = network.Network(name, network_on_disk)
|
|
net.mtime = os.path.getmtime(network_on_disk.filename)
|
|
state_dict = sd_models.read_state_dict(network_on_disk.filename, what='network')
|
|
if shared.sd_model_type in ['f1', 'chroma']: # if kohya flux lora, convert state_dict
|
|
state_dict = lora_convert._convert_kohya_flux_lora_to_diffusers(state_dict) or state_dict # pylint: disable=protected-access
|
|
if shared.sd_model_type == 'sd3': # if kohya flux lora, convert state_dict
|
|
try:
|
|
state_dict = lora_convert._convert_kohya_sd3_lora_to_diffusers(state_dict) or state_dict # pylint: disable=protected-access
|
|
except ValueError: # EAFP for diffusers PEFT keys
|
|
pass
|
|
lora_convert.assign_network_names_to_compvis_modules(sd_model)
|
|
keys_failed_to_match = {}
|
|
matched_networks = {}
|
|
bundle_embeddings = {}
|
|
dtypes = []
|
|
convert = lora_convert.KeyConvert()
|
|
if dump_lora_keys:
|
|
lora_dump(network_on_disk.filename, state_dict)
|
|
|
|
for key_network, weight in state_dict.items():
|
|
parts = key_network.split('.')
|
|
if parts[0] == "bundle_emb":
|
|
emb_name, vec_name = parts[1], key_network.split(".", 2)[-1]
|
|
emb_dict = bundle_embeddings.get(emb_name, {})
|
|
emb_dict[vec_name] = weight
|
|
bundle_embeddings[emb_name] = emb_dict
|
|
continue
|
|
if parts[0] in ["clip_l","clip_g","t5","unet","transformer"]:
|
|
network_part = []
|
|
while parts[-1] in ["alpha","weight","lora_up","lora_down"]:
|
|
network_part.insert(0,parts[-1])
|
|
parts = parts[0:-1]
|
|
network_part = ".".join(network_part)
|
|
key_network_without_network_parts = "_".join(parts)
|
|
if key_network_without_network_parts.startswith("unet") or key_network_without_network_parts.startswith("transformer"):
|
|
key_network_without_network_parts = "lora_" + key_network_without_network_parts
|
|
key_network_without_network_parts = key_network_without_network_parts.replace("clip_g","lora_te2").replace("clip_l","lora_te")
|
|
# TODO lora: add t5 key support for sd35/f1
|
|
|
|
elif len(parts) > 5: # messy handler for diffusers peft lora
|
|
key_network_without_network_parts = '_'.join(parts[:-2])
|
|
if not key_network_without_network_parts.startswith('lora_'):
|
|
key_network_without_network_parts = 'lora_' + key_network_without_network_parts
|
|
network_part = '.'.join(parts[-2:]).replace('lora_A', 'lora_down').replace('lora_B', 'lora_up')
|
|
else:
|
|
key_network_without_network_parts, network_part = key_network.split(".", 1)
|
|
key, sd_module = convert(key_network_without_network_parts)
|
|
if sd_module is None:
|
|
keys_failed_to_match[key_network] = key
|
|
continue
|
|
if key not in matched_networks:
|
|
matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
|
|
matched_networks[key].w[network_part] = weight
|
|
if weight.dtype not in dtypes:
|
|
dtypes.append(weight.dtype)
|
|
network_types = []
|
|
state_dict = None
|
|
del state_dict
|
|
module_errors = 0
|
|
for key, weights in matched_networks.items():
|
|
net_module = None
|
|
for nettype in l.module_types:
|
|
net_module = nettype.create_module(net, weights)
|
|
if net_module is not None:
|
|
network_types.append(nettype.__class__.__name__)
|
|
break
|
|
if net_module is None:
|
|
module_errors += 1
|
|
if l.debug:
|
|
shared.log.error(f'LoRA unhandled: name={name} key={key} weights={weights.w.keys()}')
|
|
else:
|
|
net.modules[key] = net_module
|
|
if module_errors > 0:
|
|
shared.log.error(f'Network load: type=LoRA name="{name}" file="{network_on_disk.filename}" errors={module_errors} empty modules')
|
|
if len(keys_failed_to_match) > 0:
|
|
shared.log.warning(f'Network load: type=LoRA name="{name}" type={set(network_types)} unmatched={len(keys_failed_to_match)} matched={len(matched_networks)}')
|
|
if l.debug:
|
|
shared.log.debug(f'Network load: type=LoRA name="{name}" unmatched={keys_failed_to_match}')
|
|
else:
|
|
shared.log.debug(f'Network load: type=LoRA name="{name}" type={set(network_types)} keys={len(matched_networks)} dtypes={dtypes} fuse={shared.opts.lora_fuse_native}:{shared.opts.lora_fuse_diffusers}')
|
|
if len(matched_networks) == 0:
|
|
return None
|
|
lora_cache[name] = net
|
|
net.bundle_embeddings = bundle_embeddings
|
|
return net
|
|
|
|
|
|
def maybe_recompile_model(names, te_multipliers):
|
|
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
|
|
recompile_model = False
|
|
skip_lora_load = False
|
|
if shared.compiled_model_state is not None and shared.compiled_model_state.is_compiled:
|
|
if len(names) == len(shared.compiled_model_state.lora_model):
|
|
for i, name in enumerate(names):
|
|
if shared.compiled_model_state.lora_model[
|
|
i] != f"{name}:{te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier}":
|
|
recompile_model = True
|
|
shared.compiled_model_state.lora_model = []
|
|
break
|
|
if not recompile_model:
|
|
skip_lora_load = True
|
|
if len(l.loaded_networks) > 0 and l.debug:
|
|
shared.log.debug('Model Compile: Skipping LoRa loading')
|
|
return recompile_model, skip_lora_load
|
|
else:
|
|
recompile_model = True
|
|
shared.compiled_model_state.lora_model = []
|
|
if recompile_model:
|
|
current_task = sd_models.get_diffusers_task(shared.sd_model)
|
|
shared.log.debug(f'Compile: task={current_task} force model reload')
|
|
backup_cuda_compile = shared.opts.cuda_compile
|
|
backup_scheduler = getattr(sd_model, "scheduler", None)
|
|
sd_models.unload_model_weights(op='model')
|
|
shared.opts.cuda_compile = []
|
|
sd_models.reload_model_weights(op='model')
|
|
shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, current_task)
|
|
shared.opts.cuda_compile = backup_cuda_compile
|
|
if backup_scheduler is not None:
|
|
sd_model.scheduler = backup_scheduler
|
|
return recompile_model, skip_lora_load
|
|
|
|
|
|
def list_available_networks():
|
|
t0 = time.time()
|
|
available_networks.clear()
|
|
available_network_aliases.clear()
|
|
forbidden_network_aliases.clear()
|
|
available_network_hash_lookup.clear()
|
|
forbidden_network_aliases.update({"none": 1, "Addams": 1})
|
|
if not os.path.exists(shared.cmd_opts.lora_dir):
|
|
shared.log.warning(f'LoRA directory not found: path="{shared.cmd_opts.lora_dir}"')
|
|
|
|
def add_network(filename):
|
|
if not os.path.isfile(filename):
|
|
return
|
|
name = os.path.splitext(os.path.basename(filename))[0]
|
|
name = name.replace('.', '_')
|
|
try:
|
|
entry = network.NetworkOnDisk(name, filename)
|
|
available_networks[entry.name] = entry
|
|
if entry.alias in available_network_aliases:
|
|
forbidden_network_aliases[entry.alias.lower()] = 1
|
|
available_network_aliases[entry.name] = entry
|
|
if entry.shorthash:
|
|
available_network_hash_lookup[entry.shorthash] = entry
|
|
except OSError as e: # should catch FileNotFoundError and PermissionError etc.
|
|
shared.log.error(f'LoRA: filename="{filename}" {e}')
|
|
|
|
candidates = sorted(files_cache.list_files(shared.cmd_opts.lora_dir, ext_filter=[".pt", ".ckpt", ".safetensors"]))
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor:
|
|
for fn in candidates:
|
|
executor.submit(add_network, fn)
|
|
t1 = time.time()
|
|
l.timer.list = t1 - t0
|
|
shared.log.info(f'Available LoRAs: path="{shared.cmd_opts.lora_dir}" items={len(available_networks)} folders={len(forbidden_network_aliases)} time={t1 - t0:.2f}')
|
|
|
|
|
|
def network_download(name):
|
|
from huggingface_hub import hf_hub_download
|
|
if os.path.exists(name):
|
|
return network.NetworkOnDisk(name, name)
|
|
parts = name.split('/')
|
|
if len(parts) >= 5 and parts[1] == 'huggingface.co':
|
|
repo_id = f'{parts[2]}/{parts[3]}'
|
|
filename = '/'.join(parts[4:])
|
|
fn = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=shared.opts.hfcache_dir)
|
|
return network.NetworkOnDisk(name, fn)
|
|
return None
|
|
|
|
|
|
def gather_networks(names):
|
|
networks_on_disk: list[network.NetworkOnDisk] = [available_network_aliases.get(name, None) for name in names]
|
|
if any(x is None for x in networks_on_disk):
|
|
list_available_networks()
|
|
networks_on_disk: list[network.NetworkOnDisk] = [available_network_aliases.get(name, None) for name in names]
|
|
for i in range(len(names)):
|
|
if names[i].startswith('/'):
|
|
networks_on_disk[i] = network_download(names[i])
|
|
return networks_on_disk
|
|
|
|
|
|
def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None, lora_modules=None):
|
|
networks_on_disk = gather_networks(names)
|
|
failed_to_load_networks = []
|
|
recompile_model, skip_lora_load = maybe_recompile_model(names, te_multipliers)
|
|
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
|
|
|
|
l.loaded_networks.clear()
|
|
lora_diffusers.diffuser_loaded.clear()
|
|
lora_diffusers.diffuser_scales.clear()
|
|
t0 = time.time()
|
|
|
|
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
|
|
net = None
|
|
if network_on_disk is not None:
|
|
shorthash = getattr(network_on_disk, 'shorthash', '').lower()
|
|
if l.debug:
|
|
shared.log.debug(f'Network load: type=LoRA name="{name}" file="{network_on_disk.filename}" hash="{shorthash}"')
|
|
try:
|
|
lora_scale = te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier
|
|
lora_module = lora_modules[i] if lora_modules and len(lora_modules) > i else None
|
|
if recompile_model and shared.compiled_model_state is not None:
|
|
shared.compiled_model_state.lora_model.append(f"{name}:{lora_scale}")
|
|
lora_method = lora_overrides.get_method(shorthash)
|
|
if lora_method == 'diffusers':
|
|
net = lora_diffusers.load_diffusers(name, network_on_disk, lora_scale, lora_module)
|
|
elif lora_method == 'nunchaku':
|
|
pass # handled directly from extra_networks_lora.load_nunchaku
|
|
else:
|
|
net = load_safetensors(name, network_on_disk)
|
|
if net is not None:
|
|
net.mentioned_name = name
|
|
network_on_disk.read_hash()
|
|
except Exception as e:
|
|
shared.log.error(f'Network load: type=LoRA file="{network_on_disk.filename}" {e}')
|
|
if l.debug:
|
|
errors.display(e, 'LoRA')
|
|
continue
|
|
if net is None:
|
|
failed_to_load_networks.append(name)
|
|
shared.log.error(f'Network load: type=LoRA name="{name}" detected={network_on_disk.sd_version if network_on_disk is not None else None} failed')
|
|
continue
|
|
if hasattr(sd_model, 'embedding_db'):
|
|
sd_model.embedding_db.load_diffusers_embedding(None, net.bundle_embeddings)
|
|
net.te_multiplier = te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier
|
|
net.unet_multiplier = unet_multipliers[i] if unet_multipliers else shared.opts.extra_networks_default_multiplier
|
|
net.dyn_dim = dyn_dims[i] if dyn_dims else shared.opts.extra_networks_default_multiplier
|
|
l.loaded_networks.append(net)
|
|
|
|
while len(lora_cache) > shared.opts.lora_in_memory_limit:
|
|
name = next(iter(lora_cache))
|
|
lora_cache.pop(name, None)
|
|
|
|
if not skip_lora_load and len(lora_diffusers.diffuser_loaded) > 0:
|
|
shared.log.debug(f'Network load: type=LoRA loaded={lora_diffusers.diffuser_loaded} available={sd_model.get_list_adapters()} active={sd_model.get_active_adapters()} scales={lora_diffusers.diffuser_scales}')
|
|
try:
|
|
t1 = time.time()
|
|
if l.debug:
|
|
shared.log.trace(f'Network load: type=LoRA list={sd_model.get_list_adapters()}')
|
|
shared.log.trace(f'Network load: type=LoRA active={sd_model.get_active_adapters()}')
|
|
sd_model.set_adapters(adapter_names=lora_diffusers.diffuser_loaded, adapter_weights=lora_diffusers.diffuser_scales)
|
|
except Exception as e:
|
|
if str(e) not in exclude_errors:
|
|
shared.log.error(f'Network load: type=LoRA action=strength {str(e)}')
|
|
if l.debug:
|
|
errors.display(e, 'LoRA')
|
|
try:
|
|
if shared.opts.lora_fuse_diffusers and not lora_overrides.disable_fuse():
|
|
sd_model.fuse_lora(adapter_names=lora_diffusers.diffuser_loaded, lora_scale=1.0, fuse_unet=True, fuse_text_encoder=True) # diffusers with fuse uses fixed scale since later apply does the scaling
|
|
sd_model.unload_lora_weights()
|
|
l.timer.activate += time.time() - t1
|
|
except Exception as e:
|
|
shared.log.error(f'Network load: type=LoRA action=fuse {str(e)}')
|
|
if l.debug:
|
|
errors.display(e, 'LoRA')
|
|
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model, force=True, silent=True) # some layers may end up on cpu without hook
|
|
|
|
if len(l.loaded_networks) > 0 and l.debug:
|
|
shared.log.debug(f'Network load: type=LoRA loaded={[n.name for n in l.loaded_networks]} cache={list(lora_cache)} fuse={shared.opts.lora_fuse_native}:{shared.opts.lora_fuse_diffusers}')
|
|
|
|
if recompile_model:
|
|
shared.log.info("Network load: type=LoRA recompiling model")
|
|
if shared.compiled_model_state is not None:
|
|
backup_lora_model = shared.compiled_model_state.lora_model
|
|
else:
|
|
backup_lora_model = []
|
|
if 'Model' in shared.opts.cuda_compile:
|
|
sd_model = sd_models_compile.compile_diffusers(sd_model)
|
|
if shared.compiled_model_state is not None:
|
|
shared.compiled_model_state.lora_model = backup_lora_model
|
|
|
|
l.timer.load = time.time() - t0
|