1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/lora/lora_diffusers.py
Vladimir Mandic ba270db6ad separate settings for lora fuse
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-11-08 11:08:06 -05:00

94 lines
4.7 KiB
Python

from typing import Union
import os
import time
import diffusers
from modules import shared, errors
from modules.lora import network
from modules.lora import lora_common as l
diffuser_loaded = []
diffuser_scales = []
def load_per_module(sd_model: diffusers.DiffusionPipeline, filename: str, adapter_name: str, lora_modules: list[str]):
shared.log.debug(f'LoRA load: modules={lora_modules}')
try:
state_dict = sd_model.lora_state_dict(filename)
if isinstance(state_dict, tuple) and len(state_dict) == 2:
state_dict, network_alphas = state_dict
else:
network_alphas = {}
except Exception as e:
shared.log.error(f'LoRA load: {e}')
if l.debug:
errors.display(e, "LoRA")
return None
for lora_module in lora_modules:
if lora_module == 'transformer':
if hasattr(sd_model, 'transformer') and sd_model.transformer is not None:
sd_model.load_lora_into_transformer(state_dict, transformer=sd_model.transformer, adapter_name=adapter_name)
else:
shared.log.warning(f'LoRA load: requested={lora_module} missing')
elif lora_module == 'transformer_2':
if hasattr(sd_model, 'transformer_2') and sd_model.transformer_2 is not None:
sd_model.load_lora_into_transformer(state_dict, transformer=sd_model.transformer_2, adapter_name=adapter_name)
else:
shared.log.warning(f'LoRA load: requested={lora_module} missing')
elif lora_module == 'unet':
if hasattr(sd_model, 'unet') and sd_model.unet is not None:
sd_model.load_lora_into_unet(state_dict, network_alphas, unet=sd_model.unet, adapter_name=adapter_name)
else:
shared.log.warning(f'LoRA load: requested={lora_module} missing')
elif lora_module == 'text_encoder' or lora_module == 'te':
if hasattr(sd_model, 'text_encoder') and sd_model.text_encoder is not None:
sd_model.load_lora_into_text_encoder(state_dict, network_alphas, text_encoder=sd_model.text_encoder, adapter_name=adapter_name)
else:
shared.log.warning(f'LoRA load: requested={lora_module} missing')
else:
shared.log.warning(f'LoRA load: requested={lora_module} unknown')
return adapter_name
def load_diffusers(name: str, network_on_disk: network.NetworkOnDisk, lora_scale:float=shared.opts.extra_networks_default_multiplier, lora_module=None) -> Union[network.Network, None]:
t0 = time.time()
name = name.replace(".", "_")
sd_model: diffusers.DiffusionPipeline = getattr(shared.sd_model, "pipe", shared.sd_model)
shared.log.debug(f'Network load: type=LoRA name="{name}" file="{network_on_disk.filename}" detected={network_on_disk.sd_version} method=diffusers scale={lora_scale} fuse={shared.opts.lora_fuse_native}:{shared.opts.lora_fuse_diffusers}')
if not hasattr(sd_model, 'load_lora_weights'):
shared.log.error(f'Network load: type=LoRA class={sd_model.__class__} does not implement load lora')
return None
try:
if lora_module is not None and isinstance(lora_module, list) and len(lora_module) > 0:
name = load_per_module(sd_model, network_on_disk.filename, adapter_name=name, lora_modules=lora_module)
sd_model._lora_partial = True # pylint: disable=protected-access
else:
sd_model.load_lora_weights(network_on_disk.filename, adapter_name=name)
except Exception as e:
if 'already in use' in str(e):
pass
else:
if 'following keys have not been correctly renamed' in str(e):
shared.log.error(f'Network load: type=LoRA name="{name}" diffusers unsupported format')
elif 'object has no attribute' in str(e):
shared.log.error(f'Network load: type=LoRA name="{name}" diffusers empty module')
else:
shared.log.error(f'Network load: type=LoRA name="{name}" {e}')
if l.debug:
errors.display(e, "LoRA")
return None
if name is None:
return None
if name not in diffuser_loaded:
list_adapters = sd_model.get_list_adapters()
list_adapters = [adapter for adapters in list_adapters.values() for adapter in adapters]
if name not in list_adapters:
shared.log.error(f'Network load: type=LoRA name="{name}" adapters={list_adapters} not loaded')
else:
diffuser_loaded.append(name)
diffuser_scales.append(lora_scale)
net = network.Network(name, network_on_disk)
net.mtime = os.path.getmtime(network_on_disk.filename)
l.timer.activate += time.time() - t0
return net