1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-29 05:02:09 +03:00
Files
sdnext/modules/merging/modules_sdxl.py
Vladimir Mandic 74bb2b9ea6 autodetect v-pred
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-07-28 19:41:39 -04:00

325 lines
12 KiB
Python

import io
import os
import json
import base64
from datetime import datetime
from PIL import Image
import torch
from safetensors.torch import load_file
import diffusers
import transformers
from modules import shared, devices, errors
class Recipe:
author = ''
name = ''
version = ''
desc = ''
hint = ''
license = ''
prediction = ''
thumbnail = None
base = None
unet = None
vae = None
te1 = None
te2 = None
scheduler = 'UniPCMultistepScheduler'
dtype = torch.float16
diffusers = True
safetensors = True
debug = False
lora = {
}
fuse = 1.0
def __repr__(self):
return f'Recipe(name="{self.name}" version="{self.version}" author="{self.author}" desc="{self.desc}" hint="{self.hint}" license="{self.license}" dtype="{self.dtype}" fuse={self.fuse} diffusers={self.diffusers} safetensors={self.safetensors})'
class Test:
generate = True
prompt = 'astronaut in a diner drinking coffee with burger and french fries on the table'
negative = 'ugly, blurry'
width = 1024
height = 1024
guidance = 4
steps = 20
recipe = Recipe()
test = Test()
pipeline: diffusers.StableDiffusionXLPipeline = None
status = ''
def msg(text, err:bool=False):
global status # pylint: disable=global-statement
if err:
shared.log.error(f'Modules merge: {text}')
else:
shared.log.info(f'Modules merge: {text}')
status += text + '<br>'
return status
def load_base(override:str=None):
global pipeline # pylint: disable=global-statement
fn = override or recipe.base
yield msg(f'base={fn}')
if os.path.isfile(fn):
pipeline = diffusers.StableDiffusionXLPipeline.from_single_file(fn, cache_dir=shared.opts.hfcache_dir, torch_dtype=recipe.dtype, add_watermarker=False)
elif os.path.isdir(fn):
pipeline = diffusers.StableDiffusionXLPipeline.from_pretrained(fn, cache_dir=shared.opts.hfcache_dir, torch_dtype=recipe.dtype, add_watermarker=False)
else:
yield msg('base: not found')
return
pipeline.vae.register_to_config(force_upcast = False)
def load_unet(pipe: diffusers.StableDiffusionXLPipeline, override:str=None):
if (recipe.unet is None or len(recipe.unet) == 0) and override is None:
return
fn = override or recipe.unet
if not os.path.isabs(fn):
fn = os.path.join(shared.opts.unet_dir, fn)
if not fn.endswith('.safetensors'):
fn += '.safetensors'
yield msg(f'unet={fn}')
if recipe.debug:
yield msg(f'config={pipe.unet.config}')
try:
unet = diffusers.UNet2DConditionModel.from_config(pipe.unet.config).to(recipe.dtype)
state_dict = load_file(fn)
unet.load_state_dict(state_dict)
pipe.unet = unet.to(device=devices.device, dtype=recipe.dtype)
except Exception as e:
yield msg(f'unet: {e}')
def load_scheduler(pipe: diffusers.StableDiffusionXLPipeline, override:str=None):
if recipe.scheduler is None and override is None:
return
config = pipe.scheduler.config.__dict__
scheduler = override or recipe.scheduler
yield msg(f'scheduler={scheduler}')
if recipe.debug:
yield msg(f'config={config}')
try:
pipe.scheduler = getattr(diffusers, scheduler).from_config(config)
except Exception as e:
yield msg(f'scheduler: {e}')
def load_vae(pipe: diffusers.StableDiffusionXLPipeline, override:str=None):
if (recipe.vae is None or len(recipe.vae) == 0)and override is None:
return
fn = override or recipe.vae
if not os.path.isabs(fn):
fn = os.path.join(shared.opts.vae_dir, fn)
if not fn.endswith('.safetensors'):
fn += '.safetensors'
try:
vae = diffusers.AutoencoderKL.from_single_file(fn, cache_dir=shared.opts.hfcache_dir, torch_dtype=recipe.dtype)
vae.config.force_upcast = False
vae.config.scaling_factor = 0.13025
vae.config.sample_size = 1024
yield msg(f'vae={fn}')
if recipe.debug:
yield msg(f'config={pipe.vae.config}')
pipe.vae = vae.to(device=devices.device, dtype=recipe.dtype)
except Exception as e:
yield msg(f'vae: {e}')
def load_te1(pipe: diffusers.StableDiffusionXLPipeline, override:str=None):
if (recipe.te1 is None or len(recipe.te1) == 0) and override is None:
return
config = pipe.text_encoder.config.__dict__
pretrained_config = transformers.PretrainedConfig.from_dict(config)
fn = override or recipe.te1
if not os.path.isabs(fn):
fn = os.path.join(shared.opts.te_dir, fn)
if not fn.endswith('.safetensors'):
fn += '.safetensors'
yield msg(f'te1={fn}')
if recipe.debug:
yield msg(f'config={config}')
try:
state_dict = load_file(fn)
te1 = transformers.CLIPTextModel.from_pretrained(pretrained_model_name_or_path=None, state_dict=state_dict, config=pretrained_config, cache_dir=shared.opts.hfcache_dir)
pipe.text_encoder = te1.to(device=devices.device, dtype=recipe.dtype)
except Exception as e:
yield msg(f'te1: {e}')
def load_te2(pipe: diffusers.StableDiffusionXLPipeline, override:str=None):
if (recipe.te2 is None or len(recipe.te2) == 0) and override is None:
return
config = pipe.text_encoder_2.config.__dict__
pretrained_config = transformers.PretrainedConfig.from_dict(config)
fn = override or recipe.te2
if not os.path.isabs(fn):
fn = os.path.join(shared.opts.te_dir, fn)
if not fn.endswith('.safetensors'):
fn += '.safetensors'
yield msg(f'te2={recipe.te2}')
if recipe.debug:
yield msg(f'config={config}')
try:
state_dict = load_file(fn)
te2 = transformers.CLIPTextModelWithProjection.from_pretrained(pretrained_model_name_or_path=None, state_dict=state_dict, config=pretrained_config, cache_dir=shared.opts.hfcache_dir)
pipe.text_encoder_2 = te2.to(device=devices.device, dtype=recipe.dtype)
except Exception as e:
yield msg(f'te2: {e}')
def load_lora(pipe: diffusers.StableDiffusionXLPipeline, override: dict=None, fuse: float=None):
if recipe.lora is None and override is None:
return
names = []
pipe.unfuse_lora()
pipe.unload_lora_weights()
loras = override or recipe.lora
for lora, weight in loras.items():
try:
fn = lora
if not os.path.isabs(fn):
fn = os.path.join(shared.opts.lora_dir, fn)
if not fn.endswith('.safetensors'):
fn += '.safetensors'
yield msg(f'lora={fn} weight={weight} fuse={fuse or recipe.fuse}')
name = os.path.splitext(os.path.basename(lora))[0].replace('.', '').replace(' ', '').replace('-', '').replace('_', '')
names.append(name)
pipe.load_lora_weights(fn, name)
except Exception as e:
yield msg(f'lora: {e}')
if len(names) > 0:
pipe.set_adapters(adapter_names=names, adapter_weights=list(loras.values()))
pipe.fuse_lora(adapter_names=names, lora_scale=fuse or recipe.fuse, components=["unet", "text_encoder", "text_encoder_2"])
pipe.unload_lora_weights()
def test_model(pipe: diffusers.StableDiffusionXLPipeline, fn: str, **kwargs):
if not test.generate:
return
try:
generator = torch.Generator(devices.device).manual_seed(int(4242))
args = {
'prompt': test.prompt,
'negative_prompt': test.negative,
'num_inference_steps': test.steps,
'width': test.width,
'height': test.height,
'guidance_scale': test.guidance,
'generator': generator,
}
args.update(kwargs)
yield msg(f'test={args}')
image = pipe(**args).images[0]
yield msg(f'image={fn} {image}')
image.save(fn)
except Exception as e:
yield msg(f'test: {e}')
def get_thumbnail():
if recipe.thumbnail is None:
return ''
image = Image.open(recipe.thumbnail)
image = image.convert('RGB')
image.thumbnail((512, 512), resample=Image.Resampling.LANCZOS)
buffer = io.BytesIO()
image.save(buffer, format="JPEG", quality=50)
b64encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
return f'data:image/jpeg;base64,{b64encoded}'
def get_metadata():
return {
"modelspec.sai_model_spec": "1.0.0",
"modelspec.architecture": "stable-diffusion-xl-v1-base",
"modelspec.implementation": "diffusers",
"modelspec.title": recipe.name,
"modelspec.version": recipe.version,
"modelspec.description": recipe.desc,
"modelspec.author": recipe.author,
"modelspec.date": datetime.now().isoformat(timespec='minutes'),
"modelspec.license": recipe.license,
"modelspec.usage_hint": recipe.hint,
"modelspec.prediction_type": recipe.prediction,
"modelspec.dtype": str(recipe.dtype).split('.')[1],
"modelspec.hash_sha256": "",
"modelspec.thumbnail": get_thumbnail(),
"recipe": json.dumps({
"base": os.path.basename(recipe.base) if recipe.base else "default",
"unet": os.path.basename(recipe.unet) if recipe.unet else "default",
"vae": os.path.basename(recipe.vae) if recipe.vae else "default",
"te1": os.path.basename(recipe.te1) if recipe.te1 else "default",
"te2": os.path.basename(recipe.te2) if recipe.te2 else "default",
"scheduler": recipe.scheduler or "default",
"lora": [f'{os.path.basename(k)}:{v}' for k, v in recipe.lora.items()],
}),
}
def save_model(pipe: diffusers.StableDiffusionXLPipeline):
author = recipe.author if len(recipe.author) > 0 else 'anonymous'
folder = os.path.join(shared.opts.diffusers_dir, f'models--{author}--{recipe.name}')
if len(recipe.version) > 0:
folder += f'-{recipe.version}'
if not (recipe.diffusers or recipe.safetensors):
shared.log.debug(f'Modules merge: type=sdxl {recipe} skipping save')
return
try:
yield msg('save')
yield msg(f'pretrained={folder}')
shared.log.info(f'Modules merge save: type=sdxl diffusers="{folder}"')
pipe.save_pretrained(folder, safe_serialization=True, push_to_hub=False)
with open(os.path.join(folder, 'vae', 'config.json'), 'r', encoding='utf8') as f:
vae_config = json.load(f)
vae_config['force_upcast'] = False
vae_config['scaling_factor'] = 0.13025
vae_config['sample_size'] = 1024
with open(os.path.join(folder, 'vae', 'config.json'), 'w', encoding='utf8') as f:
json.dump(vae_config, f, indent=2)
if recipe.safetensors:
fn = recipe.name
if len(recipe.version) > 0:
fn += f'-{recipe.version}'
if not os.path.isabs(fn):
fn = os.path.join(shared.opts.ckpt_dir, fn)
if not fn.endswith('.safetensors'):
fn += '.safetensors'
shared.log.info(f'Modules merge save: type=sdxl safetensors="{fn}"')
yield msg(f'safetensors={fn}')
from modules.merging import convert_sdxl
metadata = convert_sdxl.convert(model_path=folder, checkpoint_path=fn, metadata=get_metadata())
if 'modelspec.thumbnail' in metadata:
metadata['modelspec.thumbnail'] = f"{metadata['modelspec.thumbnail'].split(',')[0]}:{len(metadata['modelspec.thumbnail'])}" # pylint: disable=use-maxsplit-arg
yield msg(f'metadata={metadata}')
except Exception as e:
shared.log.error(f'Modules merge save: {e}')
errors.display(e, 'merge')
yield msg(f'save: {e}')
def merge():
global pipeline # pylint: disable=global-statement
yield from load_base()
if pipeline is None:
return
shared.log.info(f'Modules merge: type=sdxl {recipe}')
pipeline = pipeline.to(device=devices.device, dtype=recipe.dtype)
yield from load_scheduler(pipeline)
yield from load_unet(pipeline)
yield from load_vae(pipeline)
yield from load_te1(pipeline)
yield from load_te2(pipeline)
yield from load_lora(pipeline)
yield from save_model(pipeline)
# pipeline = pipeline.to(device=devices.device, dtype=recipe.dtype)
# test_model(pipeline, '/tmp/merge.png')