1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/model_te.py
Vladimir Mandic 41ae06bd90 fix loading custom t5
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-08-13 16:06:25 -04:00

207 lines
9.9 KiB
Python

import os
import json
import torch
import transformers
from safetensors.torch import load_file
from modules import shared, devices, files_cache, errors, model_quant
te_dict = {}
debug = os.environ.get('SD_LOAD_DEBUG', None) is not None
loaded_te = None
def load_t5(name=None, cache_dir=None):
global loaded_te # pylint: disable=global-statement
if name is None:
return None
cache_dir = cache_dir or shared.opts.hfcache_dir
from modules import modelloader
modelloader.hf_login()
repo_id = 'stabilityai/stable-diffusion-3-medium-diffusers'
if os.path.exists(name):
fn = name
else:
fn = te_dict.get(name) if name in te_dict else None
if fn is not None and name.lower().endswith('gguf'):
from modules import ggml
ggml.install_gguf()
with open(os.path.join('configs', 'flux', 'text_encoder_2', 'config.json'), encoding='utf8') as f:
t5_config = transformers.T5Config(**json.load(f))
t5 = transformers.T5EncoderModel.from_pretrained(None, gguf_file=fn, config=t5_config, device_map="auto", cache_dir=cache_dir, torch_dtype=devices.dtype)
elif fn is not None and 'fp8' in name.lower():
from accelerate.utils import set_module_tensor_to_device
with open(os.path.join('configs', 'flux', 'text_encoder_2', 'config.json'), encoding='utf8') as f:
t5_config = transformers.T5Config(**json.load(f))
state_dict = load_file(fn)
dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype
with torch.device("meta"):
t5 = transformers.T5EncoderModel(t5_config).to(dtype=dtype)
for param_name, param in state_dict.items():
is_param_float8_e4m3fn = hasattr(torch, "float8_e4m3fn") and param.dtype == torch.float8_e4m3fn
if torch.is_floating_point(param) and not is_param_float8_e4m3fn:
param = param.to(devices.dtype)
set_module_tensor_to_device(t5, param_name, device=0, value=param)
if t5.dtype != devices.dtype:
try:
t5 = t5.to(dtype=devices.dtype)
except Exception:
shared.log.error(f"T5: Failed to cast text encoder to {devices.dtype}, set dtype to {t5.dtype}")
raise
del state_dict
elif fn is not None:
with open(os.path.join('configs', 'flux', 'text_encoder_2', 'config.json'), encoding='utf8') as f:
t5_config = transformers.T5Config(**json.load(f))
state_dict = load_file(fn)
t5 = transformers.T5EncoderModel.from_pretrained(None, state_dict=state_dict, config=t5_config, torch_dtype=devices.dtype)
elif 'fp16' in name.lower():
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', cache_dir=cache_dir, torch_dtype=devices.dtype)
elif 'fp4' in name.lower():
model_quant.load_bnb('Load model: type=T5')
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True)
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
elif 'fp8' in name.lower():
model_quant.load_bnb('Load model: type=T5')
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
elif 'int8' in name.lower():
from modules.model_quant import create_sdnq_config
quantization_config = create_sdnq_config(kwargs=None, allow=True, module='any', weights_dtype='int8')
if quantization_config is not None:
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
elif 'uint4' in name.lower():
from modules.model_quant import create_sdnq_config
quantization_config = create_sdnq_config(kwargs=None, allow=True, module='any', weights_dtype='uint4')
if quantization_config is not None:
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
elif 'qint4' in name.lower():
model_quant.load_quanto('Load model: type=T5')
quantization_config = transformers.QuantoConfig(weights='int4')
if quantization_config is not None:
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
elif 'qint8' in name.lower():
model_quant.load_quanto('Load model: type=T5')
quantization_config = transformers.QuantoConfig(weights='int8')
if quantization_config is not None:
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
elif '/' in name:
shared.log.debug(f'Load model: type=T5 repo={name}')
quant_config = model_quant.create_config(module='TE')
if quantization_config is not None:
t5 = transformers.T5EncoderModel.from_pretrained(name, cache_dir=cache_dir, torch_dtype=devices.dtype, **quant_config)
else:
t5 = None
if t5 is not None:
loaded_te = name
return t5
def set_t5(pipe, module, t5=None, cache_dir=None):
global loaded_te # pylint: disable=global-statement
if loaded_te == shared.opts.sd_text_encoder:
return pipe
if pipe is None or not hasattr(pipe, module):
return pipe
try:
t5 = load_t5(name=t5, cache_dir=cache_dir)
except Exception as e:
shared.log.error(f'Load module: type={module} class="T5" file="{shared.opts.sd_text_encoder}" {e}')
if debug:
errors.display(e, 'TE:')
t5 = None
if t5 is None:
return pipe
loaded_te = shared.opts.sd_text_encoder
setattr(pipe, module, t5)
if shared.opts.diffusers_offload_mode == "sequential":
from accelerate import cpu_offload
getattr(pipe, module).to("cpu")
cpu_offload(getattr(pipe, module), devices.device, offload_buffers=len(getattr(pipe, module)._parameters) > 0) # pylint: disable=protected-access
elif shared.opts.diffusers_offload_mode == "model":
if not hasattr(pipe, "_all_hooks") or len(pipe._all_hooks) == 0: # pylint: disable=protected-access
pipe.enable_model_cpu_offload(device=devices.device)
if hasattr(pipe, "maybe_free_model_hooks"):
pipe.maybe_free_model_hooks()
devices.torch_gc()
return pipe
def load_vit_l():
global loaded_te # pylint: disable=global-statement
config = transformers.PretrainedConfig.from_json_file('configs/sdxl/text_encoder/config.json')
state_dict = load_file(os.path.join(shared.opts.te_dir, f'{shared.opts.sd_text_encoder}.safetensors'))
te = transformers.CLIPTextModel.from_pretrained(pretrained_model_name_or_path=None, state_dict=state_dict, config=config)
te = te.to(dtype=devices.dtype)
loaded_te = shared.opts.sd_text_encoder
del state_dict
return te
def load_vit_g():
global loaded_te # pylint: disable=global-statement
config = transformers.PretrainedConfig.from_json_file('configs/sdxl/text_encoder_2/config.json')
state_dict = load_file(os.path.join(shared.opts.te_dir, f'{shared.opts.sd_text_encoder}.safetensors'))
te = transformers.CLIPTextModelWithProjection.from_pretrained(pretrained_model_name_or_path=None, state_dict=state_dict, config=config)
te = te.to(dtype=devices.dtype)
loaded_te = shared.opts.sd_text_encoder
del state_dict
return te
def set_clip(pipe):
if loaded_te == shared.opts.sd_text_encoder:
return
from modules.sd_models import move_model
if 'vit-l' in shared.opts.sd_text_encoder.lower() and hasattr(shared.sd_model, 'text_encoder') and shared.sd_model.text_encoder.__class__.__name__ == 'CLIPTextModel':
try:
te = load_vit_l()
except Exception as e:
shared.log.error(f'Load module: type="text_encoder" class="ViT-L" file="{shared.opts.sd_text_encoder}" {e}')
if debug:
errors.display(e, 'TE:')
te = None
if te is not None:
pipe.text_encoder = te
shared.log.info(f'Load module: type="text_encoder" class="ViT-L" file="{shared.opts.sd_text_encoder}"')
import modules.prompt_parser_diffusers
modules.prompt_parser_diffusers.cache.clear()
move_model(pipe.text_encoder, devices.device)
devices.torch_gc()
if 'vit-g' in shared.opts.sd_text_encoder.lower() and hasattr(shared.sd_model, 'text_encoder_2') and shared.sd_model.text_encoder_2.__class__.__name__ == 'CLIPTextModelWithProjection':
try:
te = load_vit_g()
except Exception as e:
shared.log.error(f'Load module: type module="text_encoder_2" class="ViT-G" file="{shared.opts.sd_text_encoder}" {e}')
if debug:
errors.display(e, 'TE:')
te = None
if te is not None:
pipe.text_encoder_2 = te
shared.log.info(f'Load module: type="text_encoder_2" class="ViT-G" file="{shared.opts.sd_text_encoder}"')
import modules.prompt_parser_diffusers
modules.prompt_parser_diffusers.cache.clear()
move_model(pipe.text_encoder_2, devices.device)
devices.torch_gc()
def refresh_te_list():
te_dict.clear()
for file in files_cache.list_files(shared.opts.te_dir, ext_filter=['.safetensors', '.gguf']):
basename = os.path.basename(file)
name = os.path.splitext(basename)[0] if '.safetensors' in basename else basename
te_dict[name] = file
shared.log.info(f'Available TEs: path="{shared.opts.te_dir}" items={len(te_dict)}')