mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
94 lines
4.7 KiB
Python
94 lines
4.7 KiB
Python
from typing import Union
|
|
import os
|
|
import time
|
|
import diffusers
|
|
from modules import shared, errors
|
|
from modules.lora import network
|
|
from modules.lora import lora_common as l
|
|
|
|
|
|
diffuser_loaded = []
|
|
diffuser_scales = []
|
|
|
|
|
|
def load_per_module(sd_model: diffusers.DiffusionPipeline, filename: str, adapter_name: str, lora_modules: list[str]):
|
|
shared.log.debug(f'LoRA load: modules={lora_modules}')
|
|
try:
|
|
state_dict = sd_model.lora_state_dict(filename)
|
|
if isinstance(state_dict, tuple) and len(state_dict) == 2:
|
|
state_dict, network_alphas = state_dict
|
|
else:
|
|
network_alphas = {}
|
|
except Exception as e:
|
|
shared.log.error(f'LoRA load: {e}')
|
|
if l.debug:
|
|
errors.display(e, "LoRA")
|
|
return None
|
|
for lora_module in lora_modules:
|
|
if lora_module == 'transformer':
|
|
if hasattr(sd_model, 'transformer') and sd_model.transformer is not None:
|
|
sd_model.load_lora_into_transformer(state_dict, transformer=sd_model.transformer, adapter_name=adapter_name)
|
|
else:
|
|
shared.log.warning(f'LoRA load: requested={lora_module} missing')
|
|
elif lora_module == 'transformer_2':
|
|
if hasattr(sd_model, 'transformer_2') and sd_model.transformer_2 is not None:
|
|
sd_model.load_lora_into_transformer(state_dict, transformer=sd_model.transformer_2, adapter_name=adapter_name)
|
|
else:
|
|
shared.log.warning(f'LoRA load: requested={lora_module} missing')
|
|
elif lora_module == 'unet':
|
|
if hasattr(sd_model, 'unet') and sd_model.unet is not None:
|
|
sd_model.load_lora_into_unet(state_dict, network_alphas, unet=sd_model.unet, adapter_name=adapter_name)
|
|
else:
|
|
shared.log.warning(f'LoRA load: requested={lora_module} missing')
|
|
elif lora_module == 'text_encoder' or lora_module == 'te':
|
|
if hasattr(sd_model, 'text_encoder') and sd_model.text_encoder is not None:
|
|
sd_model.load_lora_into_text_encoder(state_dict, network_alphas, text_encoder=sd_model.text_encoder, adapter_name=adapter_name)
|
|
else:
|
|
shared.log.warning(f'LoRA load: requested={lora_module} missing')
|
|
else:
|
|
shared.log.warning(f'LoRA load: requested={lora_module} unknown')
|
|
return adapter_name
|
|
|
|
|
|
def load_diffusers(name: str, network_on_disk: network.NetworkOnDisk, lora_scale:float=shared.opts.extra_networks_default_multiplier, lora_module=None) -> Union[network.Network, None]:
|
|
t0 = time.time()
|
|
name = name.replace(".", "_")
|
|
sd_model: diffusers.DiffusionPipeline = getattr(shared.sd_model, "pipe", shared.sd_model)
|
|
shared.log.debug(f'Network load: type=LoRA name="{name}" file="{network_on_disk.filename}" detected={network_on_disk.sd_version} method=diffusers scale={lora_scale} fuse={shared.opts.lora_fuse_native}:{shared.opts.lora_fuse_diffusers}')
|
|
if not hasattr(sd_model, 'load_lora_weights'):
|
|
shared.log.error(f'Network load: type=LoRA class={sd_model.__class__} does not implement load lora')
|
|
return None
|
|
try:
|
|
if lora_module is not None and isinstance(lora_module, list) and len(lora_module) > 0:
|
|
name = load_per_module(sd_model, network_on_disk.filename, adapter_name=name, lora_modules=lora_module)
|
|
sd_model._lora_partial = True # pylint: disable=protected-access
|
|
else:
|
|
sd_model.load_lora_weights(network_on_disk.filename, adapter_name=name)
|
|
except Exception as e:
|
|
if 'already in use' in str(e):
|
|
pass
|
|
else:
|
|
if 'following keys have not been correctly renamed' in str(e):
|
|
shared.log.error(f'Network load: type=LoRA name="{name}" diffusers unsupported format')
|
|
elif 'object has no attribute' in str(e):
|
|
shared.log.error(f'Network load: type=LoRA name="{name}" diffusers empty module')
|
|
else:
|
|
shared.log.error(f'Network load: type=LoRA name="{name}" {e}')
|
|
if l.debug:
|
|
errors.display(e, "LoRA")
|
|
return None
|
|
if name is None:
|
|
return None
|
|
if name not in diffuser_loaded:
|
|
list_adapters = sd_model.get_list_adapters()
|
|
list_adapters = [adapter for adapters in list_adapters.values() for adapter in adapters]
|
|
if name not in list_adapters:
|
|
shared.log.error(f'Network load: type=LoRA name="{name}" adapters={list_adapters} not loaded')
|
|
else:
|
|
diffuser_loaded.append(name)
|
|
diffuser_scales.append(lora_scale)
|
|
net = network.Network(name, network_on_disk)
|
|
net.mtime = os.path.getmtime(network_on_disk.filename)
|
|
l.timer.activate += time.time() - t0
|
|
return net
|