1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/sd_models_utils.py
Vladimir Mandic 2e4e741d47 seedvt2
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-10-12 15:35:08 -04:00

227 lines
12 KiB
Python

import io
import copy
import json
import inspect
import os.path
from rich import progress # pylint: disable=redefined-builtin
import torch
import safetensors.torch
from modules import paths, shared, errors
from modules.sd_checkpoint import CheckpointInfo, select_checkpoint, list_models, checkpoints_list, checkpoint_titles, get_closest_checkpoint_match, model_hash, update_model_hashes, setup_model, write_metadata, read_metadata_from_safetensors # pylint: disable=unused-import
from modules.sd_offload import disable_offload, set_diffuser_offload, apply_balanced_offload, set_accelerate # pylint: disable=unused-import
class NoWatermark:
def apply_watermark(self, img):
return img
def get_signature(cls):
if cls is None or not hasattr(cls, '__init__'):
return {}
signature = inspect.signature(cls.__init__, follow_wrapped=True)
return signature.parameters
def get_call(cls):
if cls is None or not hasattr(cls, '__call__'): # noqa: B004
return {}
signature = inspect.signature(cls.__call__, follow_wrapped=True)
return signature.parameters
def path_to_repo(checkpoint_info):
if isinstance(checkpoint_info, CheckpointInfo):
if os.path.exists(checkpoint_info.path) and 'models--' not in checkpoint_info.path:
return checkpoint_info.path # local models
repo_id = checkpoint_info.name
else:
repo_id = checkpoint_info # fallback if fn is used with str param
repo_id = repo_id.replace('\\', '/')
if repo_id.startswith('Diffusers/'):
repo_id = repo_id.split('Diffusers/')[-1]
if repo_id.startswith('models--'):
repo_id = repo_id.split('models--')[-1]
repo_id = repo_id.replace('--', '/')
if repo_id.count('/') != 1:
shared.log.warning(f'Model: repo="{repo_id}" repository not recognized')
if '+' in repo_id:
repo_id = repo_id.split('+')[0]
return repo_id
def convert_to_faketensors(tensor):
try:
fake_module = torch._subclasses.fake_tensor.FakeTensorMode(allow_non_fake_inputs=True) # pylint: disable=protected-access
if hasattr(tensor, "weight"):
tensor.weight = torch.nn.Parameter(fake_module.from_tensor(tensor.weight))
return tensor
except Exception:
pass
return tensor
def read_state_dict(checkpoint_file, map_location=None, what:str='model'): # pylint: disable=unused-argument
if not os.path.isfile(checkpoint_file):
shared.log.error(f'Load dict: path="{checkpoint_file}" not a file')
return None
try:
pl_sd = None
with progress.open(checkpoint_file, 'rb', description=f'[cyan]Load {what}: [yellow]{checkpoint_file}', auto_refresh=True, console=shared.console) as f:
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".ckpt" and shared.opts.sd_disable_ckpt:
shared.log.warning(f"Checkpoint loading disabled: {checkpoint_file}")
return None
if shared.opts.stream_load:
if extension.lower() == ".safetensors":
buffer = f.read()
pl_sd = safetensors.torch.load(buffer)
else:
buffer = io.BytesIO(f.read())
pl_sd = torch.load(buffer, map_location='cpu')
else:
if extension.lower() == ".safetensors":
pl_sd = safetensors.torch.load_file(checkpoint_file, device='cpu')
else:
pl_sd = torch.load(f, map_location='cpu')
sd = get_state_dict_from_checkpoint(pl_sd)
del pl_sd
except Exception as e:
errors.display(e, f'Load model: {checkpoint_file}')
sd = None
return sd
def get_state_dict_from_checkpoint(pl_sd):
checkpoint_dict_replacements = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
}
def transform_checkpoint_dict_key(k):
for text, replacement in checkpoint_dict_replacements.items():
if k.startswith(text):
k = replacement + k[len(text):]
return k
pl_sd = pl_sd.pop("state_dict", pl_sd)
pl_sd.pop("state_dict", None)
sd = {}
for k, v in pl_sd.items():
new_key = transform_checkpoint_dict_key(k)
if new_key is not None:
sd[new_key] = v
pl_sd.clear()
pl_sd.update(sd)
return pl_sd
def patch_diffuser_config(sd_model, model_file):
def load_config(fn, k):
model_file = os.path.splitext(fn)[0]
cfg_file = f'{model_file}_{k}.json'
try:
if os.path.exists(cfg_file):
with open(cfg_file, 'r', encoding='utf-8') as f:
return json.load(f)
cfg_file = f'{os.path.join(paths.sd_configs_path, os.path.basename(model_file))}_{k}.json'
if os.path.exists(cfg_file):
with open(cfg_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception:
pass
return {}
if sd_model is None:
return sd_model
if hasattr(sd_model, 'unet') and hasattr(sd_model.unet, 'config') and 'inpaint' in model_file.lower():
sd_model.unet.config.in_channels = 9
if not hasattr(sd_model, '_internal_dict'):
return sd_model
for c in sd_model._internal_dict.keys(): # pylint: disable=protected-access
component = getattr(sd_model, c, None)
if hasattr(component, 'config'):
override = load_config(model_file, c)
updated = {}
for k, v in override.items():
if k.startswith('_'):
continue
if v != component.config.get(k, None):
if hasattr(component.config, '__frozen'):
component.config.__frozen = False # pylint: disable=protected-access
component.config[k] = v
updated[k] = v
return sd_model
def apply_function_to_model(sd_model, function, options, op=None):
if "Model" in options:
if hasattr(sd_model, 'model') and (hasattr(sd_model.model, 'config') or isinstance(sd_model.model, torch.nn.Module)):
sd_model.model = function(sd_model.model, op="model", sd_model=sd_model)
if hasattr(sd_model, 'unet') and hasattr(sd_model.unet, 'config'):
sd_model.unet = function(sd_model.unet, op="unet", sd_model=sd_model)
if hasattr(sd_model, 'transformer') and hasattr(sd_model.transformer, 'config'):
sd_model.transformer = function(sd_model.transformer, op="transformer", sd_model=sd_model)
if hasattr(sd_model, 'dit') and hasattr(sd_model.dit, 'config'):
sd_model.dit = function(sd_model.dit, op="dit", sd_model=sd_model)
if hasattr(sd_model, 'transformer_2') and hasattr(sd_model.transformer_2, 'config'):
sd_model.transformer_2 = function(sd_model.transformer_2, op="transformer_2", sd_model=sd_model)
if hasattr(sd_model, 'transformer_3') and hasattr(sd_model.transformer_3, 'config'):
sd_model.transformer_3 = function(sd_model.transformer_3, op="transformer_3", sd_model=sd_model)
if hasattr(sd_model, 'decoder_pipe') and hasattr(sd_model, 'decoder'):
sd_model.decoder = None
sd_model.decoder = sd_model.decoder_pipe.decoder = function(sd_model.decoder_pipe.decoder, op="decoder_pipe.decoder", sd_model=sd_model)
if hasattr(sd_model, 'prior_pipe') and hasattr(sd_model.prior_pipe, 'prior'):
if op == "sdnq" and "StableCascade" in sd_model.__class__.__name__: # fixes dtype errors
backup_clip_txt_pooled_mapper = copy.deepcopy(sd_model.prior_pipe.prior.clip_txt_pooled_mapper)
sd_model.prior_pipe.prior = function(sd_model.prior_pipe.prior, op="prior_pipe.prior", sd_model=sd_model)
if op == "sdnq" and "StableCascade" in sd_model.__class__.__name__:
sd_model.prior_pipe.prior.clip_txt_pooled_mapper = backup_clip_txt_pooled_mapper
if "TE" in options:
if hasattr(sd_model, 'text_encoder') and hasattr(sd_model.text_encoder, 'config'):
if hasattr(sd_model, 'decoder_pipe') and hasattr(sd_model.decoder_pipe, 'text_encoder') and hasattr(sd_model.decoder_pipe.text_encoder, 'config'):
sd_model.decoder_pipe.text_encoder = function(sd_model.decoder_pipe.text_encoder, op="decoder_pipe.text_encoder", sd_model=sd_model)
else:
sd_model.text_encoder = function(sd_model.text_encoder, op="text_encoder", sd_model=sd_model)
if hasattr(sd_model, 'text_encoder_2') and hasattr(sd_model.text_encoder_2, 'config'):
sd_model.text_encoder_2 = function(sd_model.text_encoder_2, op="text_encoder_2", sd_model=sd_model)
if hasattr(sd_model, 'text_encoder_3') and hasattr(sd_model.text_encoder_3, 'config'):
sd_model.text_encoder_3 = function(sd_model.text_encoder_3, op="text_encoder_3", sd_model=sd_model)
if hasattr(sd_model, 'text_encoder_4') and hasattr(sd_model.text_encoder_4, 'config'):
sd_model.text_encoder_4 = function(sd_model.text_encoder_4, op="text_encoder_4", sd_model=sd_model)
if hasattr(sd_model, 'mllm') and hasattr(sd_model.mllm, 'config'):
sd_model.mllm = function(sd_model.mllm, op="text_encoder_mllm", sd_model=sd_model)
if hasattr(sd_model, 'prior_pipe') and hasattr(sd_model.prior_pipe, 'text_encoder') and hasattr(sd_model.prior_pipe.text_encoder, 'config'):
sd_model.prior_pipe.text_encoder = function(sd_model.prior_pipe.text_encoder, op="prior_pipe.text_encoder", sd_model=sd_model)
if "VAE" in options:
if hasattr(sd_model, 'vae') and hasattr(sd_model.vae, 'decode'):
if op == "compile":
sd_model.vae.decode = function(sd_model.vae.decode, op="vae_decode", sd_model=sd_model)
sd_model.vae.encode = function(sd_model.vae.encode, op="vae_encode", sd_model=sd_model)
else:
sd_model.vae = function(sd_model.vae, op="vae", sd_model=sd_model)
if hasattr(sd_model, 'movq') and hasattr(sd_model.movq, 'decode'):
if op == "compile":
sd_model.movq.decode = function(sd_model.movq.decode, op="movq_decode", sd_model=sd_model)
sd_model.movq.encode = function(sd_model.movq.encode, op="movq_encode", sd_model=sd_model)
else:
sd_model.movq = function(sd_model.movq, op="movq", sd_model=sd_model)
if hasattr(sd_model, 'vqgan') and hasattr(sd_model.vqgan, 'decode'):
if op == "compile":
sd_model.vqgan.decode = function(sd_model.vqgan.decode, op="vqgan_decode", sd_model=sd_model)
sd_model.vqgan.encode = function(sd_model.vqgan.encode, op="vqgan_encode", sd_model=sd_model)
else:
sd_model.vqgan = function(sd_model.vqgan, op="vqgan", sd_model=sd_model)
if hasattr(sd_model, 'decoder_pipe') and hasattr(sd_model.decoder_pipe, 'vqgan'):
if op == "compile":
sd_model.decoder_pipe.vqgan.decode = function(sd_model.decoder_pipe.vqgan.decode, op="vqgan_decode", sd_model=sd_model)
sd_model.decoder_pipe.vqgan.encode = function(sd_model.decoder_pipe.vqgan.encode, op="vqgan_encode", sd_model=sd_model)
else:
sd_model.decoder_pipe.vqgan = sd_model.vqgan
if hasattr(sd_model, 'image_encoder') and hasattr(sd_model.image_encoder, 'config'):
sd_model.image_encoder = function(sd_model.image_encoder, op="image_encoder", sd_model=sd_model)
return sd_model