mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
207 lines
9.9 KiB
Python
207 lines
9.9 KiB
Python
import os
|
|
import json
|
|
import torch
|
|
import transformers
|
|
from safetensors.torch import load_file
|
|
from modules import shared, devices, files_cache, errors, model_quant
|
|
|
|
|
|
te_dict = {}
|
|
debug = os.environ.get('SD_LOAD_DEBUG', None) is not None
|
|
loaded_te = None
|
|
|
|
|
|
def load_t5(name=None, cache_dir=None):
|
|
global loaded_te # pylint: disable=global-statement
|
|
if name is None:
|
|
return None
|
|
cache_dir = cache_dir or shared.opts.hfcache_dir
|
|
from modules import modelloader
|
|
modelloader.hf_login()
|
|
repo_id = 'stabilityai/stable-diffusion-3-medium-diffusers'
|
|
if os.path.exists(name):
|
|
fn = name
|
|
else:
|
|
fn = te_dict.get(name) if name in te_dict else None
|
|
|
|
if fn is not None and name.lower().endswith('gguf'):
|
|
from modules import ggml
|
|
ggml.install_gguf()
|
|
with open(os.path.join('configs', 'flux', 'text_encoder_2', 'config.json'), encoding='utf8') as f:
|
|
t5_config = transformers.T5Config(**json.load(f))
|
|
t5 = transformers.T5EncoderModel.from_pretrained(None, gguf_file=fn, config=t5_config, device_map="auto", cache_dir=cache_dir, torch_dtype=devices.dtype)
|
|
|
|
elif fn is not None and 'fp8' in name.lower():
|
|
from accelerate.utils import set_module_tensor_to_device
|
|
with open(os.path.join('configs', 'flux', 'text_encoder_2', 'config.json'), encoding='utf8') as f:
|
|
t5_config = transformers.T5Config(**json.load(f))
|
|
state_dict = load_file(fn)
|
|
dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype
|
|
with torch.device("meta"):
|
|
t5 = transformers.T5EncoderModel(t5_config).to(dtype=dtype)
|
|
for param_name, param in state_dict.items():
|
|
is_param_float8_e4m3fn = hasattr(torch, "float8_e4m3fn") and param.dtype == torch.float8_e4m3fn
|
|
if torch.is_floating_point(param) and not is_param_float8_e4m3fn:
|
|
param = param.to(devices.dtype)
|
|
set_module_tensor_to_device(t5, param_name, device=0, value=param)
|
|
if t5.dtype != devices.dtype:
|
|
try:
|
|
t5 = t5.to(dtype=devices.dtype)
|
|
except Exception:
|
|
shared.log.error(f"T5: Failed to cast text encoder to {devices.dtype}, set dtype to {t5.dtype}")
|
|
raise
|
|
del state_dict
|
|
|
|
elif fn is not None:
|
|
with open(os.path.join('configs', 'flux', 'text_encoder_2', 'config.json'), encoding='utf8') as f:
|
|
t5_config = transformers.T5Config(**json.load(f))
|
|
state_dict = load_file(fn)
|
|
t5 = transformers.T5EncoderModel.from_pretrained(None, state_dict=state_dict, config=t5_config, torch_dtype=devices.dtype)
|
|
|
|
elif 'fp16' in name.lower():
|
|
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', cache_dir=cache_dir, torch_dtype=devices.dtype)
|
|
|
|
elif 'fp4' in name.lower():
|
|
model_quant.load_bnb('Load model: type=T5')
|
|
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True)
|
|
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
|
|
|
|
elif 'fp8' in name.lower():
|
|
model_quant.load_bnb('Load model: type=T5')
|
|
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
|
|
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
|
|
|
|
elif 'int8' in name.lower():
|
|
from modules.model_quant import create_sdnq_config
|
|
quantization_config = create_sdnq_config(kwargs=None, allow=True, module='any', weights_dtype='int8')
|
|
if quantization_config is not None:
|
|
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
|
|
|
|
elif 'uint4' in name.lower():
|
|
from modules.model_quant import create_sdnq_config
|
|
quantization_config = create_sdnq_config(kwargs=None, allow=True, module='any', weights_dtype='uint4')
|
|
if quantization_config is not None:
|
|
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
|
|
|
|
elif 'qint4' in name.lower():
|
|
model_quant.load_quanto('Load model: type=T5')
|
|
quantization_config = transformers.QuantoConfig(weights='int4')
|
|
if quantization_config is not None:
|
|
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
|
|
|
|
elif 'qint8' in name.lower():
|
|
model_quant.load_quanto('Load model: type=T5')
|
|
quantization_config = transformers.QuantoConfig(weights='int8')
|
|
if quantization_config is not None:
|
|
t5 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder='text_encoder_3', quantization_config=quantization_config, cache_dir=cache_dir, torch_dtype=devices.dtype)
|
|
|
|
elif '/' in name:
|
|
shared.log.debug(f'Load model: type=T5 repo={name}')
|
|
quant_config = model_quant.create_config(module='TE')
|
|
if quantization_config is not None:
|
|
t5 = transformers.T5EncoderModel.from_pretrained(name, cache_dir=cache_dir, torch_dtype=devices.dtype, **quant_config)
|
|
|
|
else:
|
|
t5 = None
|
|
|
|
if t5 is not None:
|
|
loaded_te = name
|
|
return t5
|
|
|
|
|
|
def set_t5(pipe, module, t5=None, cache_dir=None):
|
|
global loaded_te # pylint: disable=global-statement
|
|
if loaded_te == shared.opts.sd_text_encoder:
|
|
return pipe
|
|
if pipe is None or not hasattr(pipe, module):
|
|
return pipe
|
|
try:
|
|
t5 = load_t5(name=t5, cache_dir=cache_dir)
|
|
except Exception as e:
|
|
shared.log.error(f'Load module: type={module} class="T5" file="{shared.opts.sd_text_encoder}" {e}')
|
|
if debug:
|
|
errors.display(e, 'TE:')
|
|
t5 = None
|
|
if t5 is None:
|
|
return pipe
|
|
loaded_te = shared.opts.sd_text_encoder
|
|
setattr(pipe, module, t5)
|
|
if shared.opts.diffusers_offload_mode == "sequential":
|
|
from accelerate import cpu_offload
|
|
getattr(pipe, module).to("cpu")
|
|
cpu_offload(getattr(pipe, module), devices.device, offload_buffers=len(getattr(pipe, module)._parameters) > 0) # pylint: disable=protected-access
|
|
elif shared.opts.diffusers_offload_mode == "model":
|
|
if not hasattr(pipe, "_all_hooks") or len(pipe._all_hooks) == 0: # pylint: disable=protected-access
|
|
pipe.enable_model_cpu_offload(device=devices.device)
|
|
if hasattr(pipe, "maybe_free_model_hooks"):
|
|
pipe.maybe_free_model_hooks()
|
|
devices.torch_gc()
|
|
return pipe
|
|
|
|
|
|
def load_vit_l():
|
|
global loaded_te # pylint: disable=global-statement
|
|
config = transformers.PretrainedConfig.from_json_file('configs/sdxl/text_encoder/config.json')
|
|
state_dict = load_file(os.path.join(shared.opts.te_dir, f'{shared.opts.sd_text_encoder}.safetensors'))
|
|
te = transformers.CLIPTextModel.from_pretrained(pretrained_model_name_or_path=None, state_dict=state_dict, config=config)
|
|
te = te.to(dtype=devices.dtype)
|
|
loaded_te = shared.opts.sd_text_encoder
|
|
del state_dict
|
|
return te
|
|
|
|
|
|
def load_vit_g():
|
|
global loaded_te # pylint: disable=global-statement
|
|
config = transformers.PretrainedConfig.from_json_file('configs/sdxl/text_encoder_2/config.json')
|
|
state_dict = load_file(os.path.join(shared.opts.te_dir, f'{shared.opts.sd_text_encoder}.safetensors'))
|
|
te = transformers.CLIPTextModelWithProjection.from_pretrained(pretrained_model_name_or_path=None, state_dict=state_dict, config=config)
|
|
te = te.to(dtype=devices.dtype)
|
|
loaded_te = shared.opts.sd_text_encoder
|
|
del state_dict
|
|
return te
|
|
|
|
|
|
def set_clip(pipe):
|
|
if loaded_te == shared.opts.sd_text_encoder:
|
|
return
|
|
from modules.sd_models import move_model
|
|
if 'vit-l' in shared.opts.sd_text_encoder.lower() and hasattr(shared.sd_model, 'text_encoder') and shared.sd_model.text_encoder.__class__.__name__ == 'CLIPTextModel':
|
|
try:
|
|
te = load_vit_l()
|
|
except Exception as e:
|
|
shared.log.error(f'Load module: type="text_encoder" class="ViT-L" file="{shared.opts.sd_text_encoder}" {e}')
|
|
if debug:
|
|
errors.display(e, 'TE:')
|
|
te = None
|
|
if te is not None:
|
|
pipe.text_encoder = te
|
|
shared.log.info(f'Load module: type="text_encoder" class="ViT-L" file="{shared.opts.sd_text_encoder}"')
|
|
import modules.prompt_parser_diffusers
|
|
modules.prompt_parser_diffusers.cache.clear()
|
|
move_model(pipe.text_encoder, devices.device)
|
|
devices.torch_gc()
|
|
if 'vit-g' in shared.opts.sd_text_encoder.lower() and hasattr(shared.sd_model, 'text_encoder_2') and shared.sd_model.text_encoder_2.__class__.__name__ == 'CLIPTextModelWithProjection':
|
|
try:
|
|
te = load_vit_g()
|
|
except Exception as e:
|
|
shared.log.error(f'Load module: type module="text_encoder_2" class="ViT-G" file="{shared.opts.sd_text_encoder}" {e}')
|
|
if debug:
|
|
errors.display(e, 'TE:')
|
|
te = None
|
|
if te is not None:
|
|
pipe.text_encoder_2 = te
|
|
shared.log.info(f'Load module: type="text_encoder_2" class="ViT-G" file="{shared.opts.sd_text_encoder}"')
|
|
import modules.prompt_parser_diffusers
|
|
modules.prompt_parser_diffusers.cache.clear()
|
|
move_model(pipe.text_encoder_2, devices.device)
|
|
devices.torch_gc()
|
|
|
|
|
|
def refresh_te_list():
|
|
te_dict.clear()
|
|
for file in files_cache.list_files(shared.opts.te_dir, ext_filter=['.safetensors', '.gguf']):
|
|
basename = os.path.basename(file)
|
|
name = os.path.splitext(basename)[0] if '.safetensors' in basename else basename
|
|
te_dict[name] = file
|
|
shared.log.info(f'Available TEs: path="{shared.opts.te_dir}" items={len(te_dict)}')
|