1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/lora/lora_load.py
vladmandic bfbe4af598 fix lora load
Signed-off-by: vladmandic <mandic00@live.com>
2026-01-21 08:45:45 +01:00

324 lines
16 KiB
Python

from typing import Union
import os
import time
import concurrent
from modules import shared, errors, sd_models, sd_models_compile, files_cache
from modules.lora import network, lora_overrides, lora_convert, lora_diffusers
from modules.lora import lora_common as l
lora_cache = {}
available_networks = {}
available_network_aliases = {}
forbidden_network_aliases = {}
available_network_hash_lookup = {}
dump_lora_keys = os.environ.get('SD_LORA_DUMP', None) is not None
exclude_errors = [
"'ChronoEditTransformer3DModel'",
]
def lora_dump(lora, dct):
import tempfile
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
ty = shared.sd_model_type
cn = sd_model.__class__.__name__
shared.log.trace(f'LoRA dump: type={ty} model={cn} fn="{lora}"')
bn = os.path.splitext(os.path.basename(lora))[0]
fn = os.path.join(tempfile.gettempdir(), f'LoRA-{ty}-{cn}-{bn}.txt')
with open(fn, 'w', encoding='utf8') as f:
keys = sorted(dct.keys())
shared.log.trace(f'LoRA dump: type=LoRA fn="{fn}" keys={len(keys)}')
for line in keys:
f.write(line + "\n")
fn = os.path.join(tempfile.gettempdir(), f'Model-{ty}-{cn}.txt')
with open(fn, 'w', encoding='utf8') as f:
keys = sd_model.network_layer_mapping.keys()
shared.log.trace(f'LoRA dump: type=Mapping fn="{fn}" keys={len(keys)}')
for line in keys:
f.write(line + "\n")
def load_safetensors(name, network_on_disk: network.NetworkOnDisk) -> Union[network.Network, None]:
if not shared.sd_loaded:
return None
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
cached = lora_cache.get(name, None)
if l.debug:
shared.log.debug(f'Network load: type=LoRA name="{name}" file="{network_on_disk.filename}" type=lora {"cached" if cached else ""}')
if cached is not None:
return cached
net = network.Network(name, network_on_disk)
net.mtime = os.path.getmtime(network_on_disk.filename)
state_dict = sd_models.read_state_dict(network_on_disk.filename, what='network')
if shared.sd_model_type in ['f1', 'chroma']: # if kohya flux lora, convert state_dict
state_dict = lora_convert._convert_kohya_flux_lora_to_diffusers(state_dict) or state_dict # pylint: disable=protected-access
if shared.sd_model_type == 'sd3': # if kohya flux lora, convert state_dict
try:
state_dict = lora_convert._convert_kohya_sd3_lora_to_diffusers(state_dict) or state_dict # pylint: disable=protected-access
except ValueError: # EAFP for diffusers PEFT keys
pass
lora_convert.assign_network_names_to_compvis_modules(sd_model)
keys_failed_to_match = {}
matched_networks = {}
bundle_embeddings = {}
dtypes = []
convert = lora_convert.KeyConvert()
if dump_lora_keys:
lora_dump(network_on_disk.filename, state_dict)
for key_network, weight in state_dict.items():
parts = key_network.split('.')
if parts[0] == "bundle_emb":
emb_name, vec_name = parts[1], key_network.split(".", 2)[-1]
emb_dict = bundle_embeddings.get(emb_name, {})
emb_dict[vec_name] = weight
bundle_embeddings[emb_name] = emb_dict
continue
if parts[0] in ["clip_l","clip_g","t5","unet","transformer"]:
network_part = []
while parts[-1] in ["alpha","weight","lora_up","lora_down"]:
network_part.insert(0,parts[-1])
parts = parts[0:-1]
network_part = ".".join(network_part)
key_network_without_network_parts = "_".join(parts)
if key_network_without_network_parts.startswith("unet") or key_network_without_network_parts.startswith("transformer"):
key_network_without_network_parts = "lora_" + key_network_without_network_parts
key_network_without_network_parts = key_network_without_network_parts.replace("clip_g","lora_te2").replace("clip_l","lora_te")
# TODO lora: add t5 key support for sd35/f1
elif len(parts) > 5: # messy handler for diffusers peft lora
key_network_without_network_parts = '_'.join(parts[:-2])
if not key_network_without_network_parts.startswith('lora_'):
key_network_without_network_parts = 'lora_' + key_network_without_network_parts
network_part = '.'.join(parts[-2:]).replace('lora_A', 'lora_down').replace('lora_B', 'lora_up')
else:
key_network_without_network_parts, network_part = key_network.split(".", 1)
key, sd_module = convert(key_network_without_network_parts)
if sd_module is None:
keys_failed_to_match[key_network] = key
continue
if key not in matched_networks:
matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
matched_networks[key].w[network_part] = weight
if weight.dtype not in dtypes:
dtypes.append(weight.dtype)
network_types = []
state_dict = None
del state_dict
module_errors = 0
for key, weights in matched_networks.items():
net_module = None
for nettype in l.module_types:
net_module = nettype.create_module(net, weights)
if net_module is not None:
network_types.append(nettype.__class__.__name__)
break
if net_module is None:
module_errors += 1
if l.debug:
shared.log.error(f'LoRA unhandled: name={name} key={key} weights={weights.w.keys()}')
else:
net.modules[key] = net_module
if module_errors > 0:
shared.log.error(f'Network load: type=LoRA name="{name}" file="{network_on_disk.filename}" errors={module_errors} empty modules')
if len(keys_failed_to_match) > 0:
shared.log.warning(f'Network load: type=LoRA name="{name}" type={set(network_types)} unmatched={len(keys_failed_to_match)} matched={len(matched_networks)}')
if l.debug:
shared.log.debug(f'Network load: type=LoRA name="{name}" unmatched={keys_failed_to_match}')
else:
shared.log.debug(f'Network load: type=LoRA name="{name}" type={set(network_types)} keys={len(matched_networks)} dtypes={dtypes} fuse={shared.opts.lora_fuse_native}:{shared.opts.lora_fuse_diffusers}')
if len(matched_networks) == 0:
return None
lora_cache[name] = net
net.bundle_embeddings = bundle_embeddings
return net
def maybe_recompile_model(names, te_multipliers):
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
recompile_model = False
skip_lora_load = False
if shared.compiled_model_state is not None and shared.compiled_model_state.is_compiled:
if len(names) == len(shared.compiled_model_state.lora_model):
for i, name in enumerate(names):
if shared.compiled_model_state.lora_model[
i] != f"{name}:{te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier}":
recompile_model = True
shared.compiled_model_state.lora_model = []
break
if not recompile_model:
skip_lora_load = True
if len(l.loaded_networks) > 0 and l.debug:
shared.log.debug('Model Compile: Skipping LoRa loading')
return recompile_model, skip_lora_load
else:
recompile_model = True
shared.compiled_model_state.lora_model = []
if recompile_model:
current_task = sd_models.get_diffusers_task(shared.sd_model)
shared.log.debug(f'Compile: task={current_task} force model reload')
backup_cuda_compile = shared.opts.cuda_compile
backup_scheduler = getattr(sd_model, "scheduler", None)
sd_models.unload_model_weights(op='model')
shared.opts.cuda_compile = []
sd_models.reload_model_weights(op='model')
shared.sd_model = sd_models.set_diffuser_pipe(shared.sd_model, current_task)
shared.opts.cuda_compile = backup_cuda_compile
if backup_scheduler is not None:
sd_model.scheduler = backup_scheduler
return recompile_model, skip_lora_load
def list_available_networks():
t0 = time.time()
available_networks.clear()
available_network_aliases.clear()
forbidden_network_aliases.clear()
available_network_hash_lookup.clear()
forbidden_network_aliases.update({"none": 1, "Addams": 1})
if not os.path.exists(shared.cmd_opts.lora_dir):
shared.log.warning(f'LoRA directory not found: path="{shared.cmd_opts.lora_dir}"')
def add_network(filename):
if not os.path.isfile(filename):
return
name = os.path.splitext(os.path.basename(filename))[0]
name = name.replace('.', '_')
try:
entry = network.NetworkOnDisk(name, filename)
available_networks[entry.name] = entry
if entry.alias in available_network_aliases:
forbidden_network_aliases[entry.alias.lower()] = 1
available_network_aliases[entry.name] = entry
if entry.shorthash:
available_network_hash_lookup[entry.shorthash] = entry
except OSError as e: # should catch FileNotFoundError and PermissionError etc.
shared.log.error(f'LoRA: filename="{filename}" {e}')
candidates = sorted(files_cache.list_files(shared.cmd_opts.lora_dir, ext_filter=[".pt", ".ckpt", ".safetensors"]))
with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor:
for fn in candidates:
executor.submit(add_network, fn)
t1 = time.time()
l.timer.list = t1 - t0
shared.log.info(f'Available LoRAs: path="{shared.cmd_opts.lora_dir}" items={len(available_networks)} folders={len(forbidden_network_aliases)} time={t1 - t0:.2f}')
def network_download(name):
from huggingface_hub import hf_hub_download
if os.path.exists(name):
return network.NetworkOnDisk(name, name)
parts = name.split('/')
if len(parts) >= 5 and parts[1] == 'huggingface.co':
repo_id = f'{parts[2]}/{parts[3]}'
filename = '/'.join(parts[4:])
fn = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=shared.opts.hfcache_dir)
return network.NetworkOnDisk(name, fn)
return None
def gather_networks(names):
networks_on_disk: list[network.NetworkOnDisk] = [available_network_aliases.get(name, None) for name in names]
if any(x is None for x in networks_on_disk):
list_available_networks()
networks_on_disk: list[network.NetworkOnDisk] = [available_network_aliases.get(name, None) for name in names]
for i in range(len(names)):
if names[i].startswith('/'):
networks_on_disk[i] = network_download(names[i])
return networks_on_disk
def network_load(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None, lora_modules=None):
networks_on_disk = gather_networks(names)
failed_to_load_networks = []
recompile_model, skip_lora_load = maybe_recompile_model(names, te_multipliers)
sd_model = getattr(shared.sd_model, "pipe", shared.sd_model)
l.loaded_networks.clear()
lora_diffusers.diffuser_loaded.clear()
lora_diffusers.diffuser_scales.clear()
t0 = time.time()
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
net = None
if network_on_disk is not None:
shorthash = getattr(network_on_disk, 'shorthash', '').lower()
if l.debug:
shared.log.debug(f'Network load: type=LoRA name="{name}" file="{network_on_disk.filename}" hash="{shorthash}"')
try:
lora_scale = te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier
lora_module = lora_modules[i] if lora_modules and len(lora_modules) > i else None
if recompile_model and shared.compiled_model_state is not None:
shared.compiled_model_state.lora_model.append(f"{name}:{lora_scale}")
lora_method = lora_overrides.get_method(shorthash)
if lora_method == 'diffusers':
net = lora_diffusers.load_diffusers(name, network_on_disk, lora_scale, lora_module)
elif lora_method == 'nunchaku':
pass # handled directly from extra_networks_lora.load_nunchaku
else:
net = load_safetensors(name, network_on_disk)
if net is not None:
net.mentioned_name = name
network_on_disk.read_hash()
except Exception as e:
shared.log.error(f'Network load: type=LoRA file="{network_on_disk.filename}" {e}')
if l.debug:
errors.display(e, 'LoRA')
continue
if net is None:
failed_to_load_networks.append(name)
shared.log.error(f'Network load: type=LoRA name="{name}" detected={network_on_disk.sd_version if network_on_disk is not None else None} failed')
continue
if hasattr(sd_model, 'embedding_db'):
sd_model.embedding_db.load_diffusers_embedding(None, net.bundle_embeddings)
net.te_multiplier = te_multipliers[i] if te_multipliers else shared.opts.extra_networks_default_multiplier
net.unet_multiplier = unet_multipliers[i] if unet_multipliers else shared.opts.extra_networks_default_multiplier
net.dyn_dim = dyn_dims[i] if dyn_dims else shared.opts.extra_networks_default_multiplier
l.loaded_networks.append(net)
while len(lora_cache) > shared.opts.lora_in_memory_limit:
name = next(iter(lora_cache))
lora_cache.pop(name, None)
if not skip_lora_load and len(lora_diffusers.diffuser_loaded) > 0:
shared.log.debug(f'Network load: type=LoRA loaded={lora_diffusers.diffuser_loaded} available={sd_model.get_list_adapters()} active={sd_model.get_active_adapters()} scales={lora_diffusers.diffuser_scales}')
try:
t1 = time.time()
if l.debug:
shared.log.trace(f'Network load: type=LoRA list={sd_model.get_list_adapters()}')
shared.log.trace(f'Network load: type=LoRA active={sd_model.get_active_adapters()}')
sd_model.set_adapters(adapter_names=lora_diffusers.diffuser_loaded, adapter_weights=lora_diffusers.diffuser_scales)
except Exception as e:
if str(e) not in exclude_errors:
shared.log.error(f'Network load: type=LoRA action=strength {str(e)}')
if l.debug:
errors.display(e, 'LoRA')
try:
if shared.opts.lora_fuse_diffusers and not lora_overrides.disable_fuse():
sd_model.fuse_lora(adapter_names=lora_diffusers.diffuser_loaded, lora_scale=1.0, fuse_unet=True, fuse_text_encoder=True) # diffusers with fuse uses fixed scale since later apply does the scaling
sd_model.unload_lora_weights()
l.timer.activate += time.time() - t1
except Exception as e:
shared.log.error(f'Network load: type=LoRA action=fuse {str(e)}')
if l.debug:
errors.display(e, 'LoRA')
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model, force=True, silent=True) # some layers may end up on cpu without hook
if len(l.loaded_networks) > 0 and l.debug:
shared.log.debug(f'Network load: type=LoRA loaded={[n.name for n in l.loaded_networks]} cache={list(lora_cache)} fuse={shared.opts.lora_fuse_native}:{shared.opts.lora_fuse_diffusers}')
if recompile_model:
shared.log.info("Network load: type=LoRA recompiling model")
if shared.compiled_model_state is not None:
backup_lora_model = shared.compiled_model_state.lora_model
else:
backup_lora_model = []
if 'Model' in shared.opts.cuda_compile:
sd_model = sd_models_compile.compile_diffusers(sd_model)
if shared.compiled_model_state is not None:
shared.compiled_model_state.lora_model = backup_lora_model
l.timer.load = time.time() - t0