1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/sd_offload.py
vladmandic bfbe4af598 fix lora load
Signed-off-by: vladmandic <mandic00@live.com>
2026-01-21 08:45:45 +01:00

513 lines
25 KiB
Python

import os
import re
import sys
import time
import inspect
import torch
import accelerate.hooks
import accelerate.utils.modeling
from installer import log
from modules import shared, devices, errors, model_quant, sd_models
from modules.timer import process as process_timer
debug = os.environ.get('SD_MOVE_DEBUG', None) is not None
verbose = os.environ.get('SD_MOVE_VERBOSE', None) is not None
debug_move = log.trace if debug else lambda *args, **kwargs: None
offload_allow_none = ['sd', 'sdxl']
offload_post = ['h1']
offload_hook_instance = None
balanced_offload_exclude = ['CogView4Pipeline', 'MeissonicPipeline']
no_split_module_classes = [
"Linear", "Conv1d", "Conv2d", "Conv3d", "ConvTranspose1d", "ConvTranspose2d", "ConvTranspose3d",
"SDNQLinear", "SDNQConv1d", "SDNQConv2d", "SDNQConv3d", "SDNQConvTranspose1d", "SDNQConvTranspose2d", "SDNQConvTranspose3d",
"WanTransformerBlock",
]
accelerate_dtype_byte_size = None
move_stream = None
def dtype_byte_size(dtype: torch.dtype):
try:
if dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]:
dtype = accelerate.utils.modeling.CustomDtype.FP8
except Exception: # catch since older torch many not have defined dtypes
pass
return accelerate_dtype_byte_size(dtype)
def get_signature(cls):
signature = inspect.signature(cls.__init__, follow_wrapped=True)
return signature.parameters
def disable_offload(sd_model):
if not getattr(sd_model, 'has_accelerate', False):
return
for module_name in get_module_names(sd_model):
module = getattr(sd_model, module_name, None)
if isinstance(module, torch.nn.Module):
network_layer_name = getattr(module, "network_layer_name", None)
try:
module = accelerate.hooks.remove_hook_from_module(module, recurse=True)
except Exception as e:
shared.log.warning(f'Offload remove hook: module={module_name} {e}')
if network_layer_name:
module.network_layer_name = network_layer_name
sd_model.has_accelerate = False
def set_accelerate(sd_model):
def set_accelerate_to_module(model):
if hasattr(model, "pipe"):
set_accelerate_to_module(model.pipe)
for module_name in get_module_names(model):
component = getattr(model, module_name, None)
if isinstance(component, torch.nn.Module):
component.has_accelerate = True
sd_model.has_accelerate = True
set_accelerate_to_module(sd_model)
if hasattr(sd_model, "prior_pipe"):
set_accelerate_to_module(sd_model.prior_pipe)
if hasattr(sd_model, "decoder_pipe"):
set_accelerate_to_module(sd_model.decoder_pipe)
def apply_group_offload(sd_model, op:str='model'):
offload_dct = {
'onload_device': devices.device,
'offload_device': devices.cpu,
'offload_type': shared.opts.group_offload_type,
'num_blocks_per_group': shared.opts.group_offload_blocks,
'non_blocking': shared.opts.diffusers_offload_nonblocking,
'use_stream': shared.opts.group_offload_stream,
'record_stream': shared.opts.group_offload_record,
'low_cpu_mem_usage': False,
}
if shared.opts.group_offload_type == 'block_level':
offload_dct['exclude_modules'] = ['vae']
shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} options={offload_dct}')
if hasattr(sd_model, "enable_group_offload"):
sd_model.enable_group_offload(**offload_dct)
else:
shared.log.warning(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} not supported')
set_accelerate(sd_model)
return sd_model
def apply_model_offload(sd_model, op:str='model', quiet:bool=False):
try:
shared.log.quiet(quiet, f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}')
if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner:
shared.opts.diffusers_move_base = False
shared.opts.diffusers_move_unet = False
shared.opts.diffusers_move_refiner = False
shared.log.warning(f'Disabling {op} "Move model to CPU" since "Model CPU offload" is enabled')
if not hasattr(sd_model, "_all_hooks") or len(sd_model._all_hooks) == 0: # pylint: disable=protected-access
sd_model.enable_model_cpu_offload(device=devices.device)
else:
sd_model.maybe_free_model_hooks()
set_accelerate(sd_model)
except Exception as e:
shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}')
def apply_sequential_offload(sd_model, op:str='model', quiet:bool=False):
try:
shared.log.quiet(quiet, f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}')
if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner:
shared.opts.diffusers_move_base = False
shared.opts.diffusers_move_unet = False
shared.opts.diffusers_move_refiner = False
shared.log.warning(f'Disabling {op} "Move model to CPU" since "Sequential CPU offload" is enabled')
if sd_model.has_accelerate:
if op == "vae": # reapply sequential offload to vae
from accelerate import cpu_offload
sd_model.vae.to(devices.cpu)
cpu_offload(sd_model.vae, devices.device, offload_buffers=len(sd_model.vae._parameters) > 0) # pylint: disable=protected-access
else:
pass # do nothing if offload is already applied
else:
sd_model.enable_sequential_cpu_offload(device=devices.device)
set_accelerate(sd_model)
except Exception as e:
shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}')
def apply_none_offload(sd_model, op:str='model', quiet:bool=False):
if shared.sd_model_type not in offload_allow_none:
shared.log.warning(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} type={shared.sd_model.__class__.__name__} large model')
else:
shared.log.quiet(quiet, f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}')
try:
sd_model.has_accelerate = False
if hasattr(sd_model, 'maybe_free_model_hooks'):
sd_model.maybe_free_model_hooks()
sd_model = accelerate.hooks.remove_hook_from_module(sd_model, recurse=True)
except Exception:
pass
sd_models.move_model(sd_model, devices.device)
def set_diffuser_offload(sd_model, op:str='model', quiet:bool=False, force:bool=False):
global accelerate_dtype_byte_size # pylint: disable=global-statement
t0 = time.time()
if sd_model is None:
shared.log.warning(f'{op} is not loaded')
return
if not (hasattr(sd_model, "has_accelerate") and sd_model.has_accelerate):
sd_model.has_accelerate = False
if accelerate_dtype_byte_size is None:
accelerate_dtype_byte_size = accelerate.utils.modeling.dtype_byte_size
accelerate.utils.modeling.dtype_byte_size = dtype_byte_size
if shared.opts.diffusers_offload_mode == "none":
apply_none_offload(sd_model, op=op, quiet=quiet)
if shared.opts.diffusers_offload_mode == "model" and hasattr(sd_model, "enable_model_cpu_offload"):
apply_model_offload(sd_model, op=op, quiet=quiet)
if shared.opts.diffusers_offload_mode == "sequential" and hasattr(sd_model, "enable_sequential_cpu_offload"):
apply_sequential_offload(sd_model, op=op, quiet=quiet)
if shared.opts.diffusers_offload_mode == "group":
sd_model = apply_group_offload(sd_model, op=op)
if shared.opts.diffusers_offload_mode == "balanced":
sd_model = apply_balanced_offload(sd_model, force=force)
process_timer.add('offload', time.time() - t0)
class OffloadHook(accelerate.hooks.ModelHook):
def __init__(self, checkpoint_name):
if shared.opts.diffusers_offload_max_gpu_memory > 1:
shared.opts.diffusers_offload_max_gpu_memory = 0.75
if shared.opts.diffusers_offload_max_cpu_memory > 1:
shared.opts.diffusers_offload_max_cpu_memory = 0.75
self.checkpoint_name = checkpoint_name
self.min_watermark = shared.opts.diffusers_offload_min_gpu_memory
self.max_watermark = shared.opts.diffusers_offload_max_gpu_memory
self.cpu_watermark = shared.opts.diffusers_offload_max_cpu_memory
self.offload_always = [m.strip() for m in re.split(';|,| ', shared.opts.diffusers_offload_always) if len(m.strip()) > 2]
self.offload_never = [m.strip() for m in re.split(';|,| ', shared.opts.diffusers_offload_never) if len(m.strip()) > 2]
self.gpu = int(shared.gpu_memory * shared.opts.diffusers_offload_max_gpu_memory * 1024*1024*1024)
self.cpu = int(shared.cpu_memory * shared.opts.diffusers_offload_max_cpu_memory * 1024*1024*1024)
self.offload_map = {}
self.param_map = {}
self.last_pre = None
self.last_post = None
self.last_cls = None
gpu = f'{(shared.gpu_memory * shared.opts.diffusers_offload_min_gpu_memory):.2f}-{(shared.gpu_memory * shared.opts.diffusers_offload_max_gpu_memory):.2f}:{shared.gpu_memory:.2f}'
shared.log.info(f'Offload: type=balanced op=init watermark={self.min_watermark}-{self.max_watermark} gpu={gpu} cpu={shared.cpu_memory:.3f} limit={shared.opts.cuda_mem_fraction:.2f} always={self.offload_always} never={self.offload_never} pre={shared.opts.diffusers_offload_pre} streams={shared.opts.diffusers_offload_streams}')
self.validate()
super().__init__()
def validate(self):
if shared.opts.diffusers_offload_mode != 'balanced':
return
if shared.opts.diffusers_offload_min_gpu_memory < 0 or shared.opts.diffusers_offload_min_gpu_memory > 1:
shared.opts.diffusers_offload_min_gpu_memory = 0.2
shared.log.warning(f'Offload: type=balanced op=validate: watermark low={shared.opts.diffusers_offload_min_gpu_memory} invalid value')
if shared.opts.diffusers_offload_max_gpu_memory < 0.1 or shared.opts.diffusers_offload_max_gpu_memory > 1:
shared.opts.diffusers_offload_max_gpu_memory = 0.7
shared.log.warning(f'Offload: type=balanced op=validate: watermark high={shared.opts.diffusers_offload_max_gpu_memory} invalid value')
if shared.opts.diffusers_offload_min_gpu_memory > shared.opts.diffusers_offload_max_gpu_memory:
shared.opts.diffusers_offload_min_gpu_memory = shared.opts.diffusers_offload_max_gpu_memory
shared.log.warning(f'Offload: type=balanced op=validate: watermark low={shared.opts.diffusers_offload_min_gpu_memory} reset')
if shared.opts.diffusers_offload_max_gpu_memory * shared.gpu_memory < 4:
shared.log.warning(f'Offload: type=balanced op=validate: watermark high={shared.opts.diffusers_offload_max_gpu_memory} low memory')
def model_size(self):
return sum(self.offload_map.values())
def init_hook(self, module):
return module
def offload_allowed(self, module):
if hasattr(module, "offload_never"):
return False
if hasattr(module, 'nets') and any(hasattr(n, "offload_never") for n in module.nets):
return False
if shared.sd_model_type.lower() in [m.lower().strip() for m in re.split(r'[ ,]+', shared.opts.models_not_to_offload)]:
return False
return True
@torch.compiler.disable
def pre_forward(self, module, *args, **kwargs):
_id = id(module)
do_offload = (self.last_pre != _id) or (module.__class__.__name__ != self.last_cls)
if do_offload and self.offload_allowed(module): # offload every other module first time when new module starts pre-forward
if shared.opts.diffusers_offload_pre:
t0 = time.time()
debug_move(f'Offload: type=balanced op=pre module={module.__class__.__name__}')
for pipe in get_pipe_variants():
for module_name in get_module_names(pipe):
module_instance = getattr(pipe, module_name, None)
module_cls = module_instance.__class__.__name__
if (module_instance is not None) and (_id != id(module_instance)) and (module_cls not in self.offload_never) and (not devices.same_device(module_instance.device, devices.cpu)):
apply_balanced_offload_to_module(module_instance, op='pre')
self.last_cls = module.__class__.__name__
process_timer.add('offload', time.time() - t0)
if not devices.same_device(module.device, devices.device): # move-to-device
t0 = time.time()
device_index = torch.device(devices.device).index
if device_index is None:
device_index = 0
max_memory = { device_index: self.gpu, "cpu": self.cpu }
device_map = getattr(module, "balanced_offload_device_map", None)
if (device_map is None) or (max_memory != getattr(module, "balanced_offload_max_memory", None)):
device_map = accelerate.infer_auto_device_map(module,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
verbose=verbose,
clean_result=False,
)
offload_dir = getattr(module, "offload_dir", os.path.join(shared.opts.accelerate_offload_path, module.__class__.__name__))
if devices.backend == "directml":
for k, v in device_map.items():
if isinstance(v, int):
device_map[k] = f"{devices.device.type}:{v}" # int implies CUDA or XPU device, but it will break DirectML backend so we add type
if debug:
shared.log.trace(f'Offload: type=balanced op=dispatch map={device_map}')
if device_map is not None:
skip_keys = getattr(module, "_skip_keys", None)
module = accelerate.dispatch_model(module,
main_device=torch.device(devices.device),
device_map=device_map,
offload_dir=offload_dir,
skip_keys=skip_keys,
force_hooks=True,
)
module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access
module.balanced_offload_device_map = device_map
module.balanced_offload_max_memory = max_memory
process_timer.add('onload', time.time() - t0)
if debug:
for _i, pipe in enumerate(get_pipe_variants()):
for module_name in get_module_names(pipe):
module_instance = getattr(pipe, module_name, None)
shared.log.trace(f'Offload: type=balanced op=pre:status forward={module.__class__.__name__} module={module_name} class={module_instance.__class__.__name__} pipe={_i} device={module_instance.device} dtype={module_instance.dtype}')
self.last_pre = _id
return args, kwargs
@torch.compiler.disable
def post_forward(self, module, output):
if self.last_post != id(module):
self.last_post = id(module)
if getattr(module, "offload_post", False) and (module.device != devices.cpu):
apply_balanced_offload_to_module(module, op='post')
return output
def detach_hook(self, module):
return module
def get_pipe_variants(pipe=None):
if pipe is None:
if shared.sd_loaded:
pipe = shared.sd_model
else:
return [pipe]
variants = [pipe]
if hasattr(pipe, "pipe"):
variants.append(pipe.pipe)
if hasattr(pipe, "prior_pipe"):
variants.append(pipe.prior_pipe)
if hasattr(pipe, "decoder_pipe"):
variants.append(pipe.decoder_pipe)
return variants
def get_module_names(pipe=None, exclude=None):
def is_valid(module):
if isinstance(getattr(pipe, module, None), torch.nn.ModuleDict):
return True
if isinstance(getattr(pipe, module, None), torch.nn.ModuleList):
return True
if isinstance(getattr(pipe, module, None), torch.nn.Module):
return True
return False
if exclude is None:
exclude = []
if pipe is None:
if shared.sd_loaded:
pipe = shared.sd_model
else:
return []
modules_names = []
try:
dict_keys = pipe._internal_dict.keys() # pylint: disable=protected-access
modules_names.extend(dict_keys)
except Exception:
pass
try:
dict_keys = get_signature(pipe).keys()
modules_names.extend(dict_keys)
except Exception:
pass
modules_names = [m for m in modules_names if m not in exclude and not m.startswith('_')]
modules_names = [m for m in modules_names if is_valid(m)]
modules_names = sorted(set(modules_names))
return modules_names
def get_module_sizes(pipe=None, exclude=None):
if exclude is None:
exclude = []
modules = {}
for module_name in get_module_names(pipe, exclude):
module_size = offload_hook_instance.offload_map.get(module_name, None)
if module_size is None:
module = getattr(pipe, module_name, None)
if not isinstance(module, torch.nn.Module):
continue
try:
module_size = sum(p.numel() * p.element_size() for p in module.parameters(recurse=True)) / 1024 / 1024 / 1024
param_num = sum(p.numel() for p in module.parameters(recurse=True)) / 1024 / 1024 / 1024
except Exception as e:
shared.log.error(f'Offload: type=balanced op=calc module={module_name} {e}')
module_size = 0
offload_hook_instance.offload_map[module_name] = module_size
offload_hook_instance.param_map[module_name] = param_num
modules[module_name] = module_size
modules = sorted(modules.items(), key=lambda x: x[1], reverse=True)
return modules
def move_module_to_cpu(module, op='unk', force:bool=False):
def do_move(module):
if shared.opts.diffusers_offload_streams:
global move_stream # pylint: disable=global-statement
if move_stream is None:
move_stream = torch.cuda.Stream(device=devices.device)
with torch.cuda.stream(move_stream):
module = module.to(devices.cpu)
else:
module = module.to(devices.cpu)
return module
try:
module_name = getattr(module, "module_name", module.__class__.__name__)
module_size = offload_hook_instance.offload_map.get(module_name, offload_hook_instance.model_size())
used_gpu, used_ram = devices.torch_gc(fast=True)
perc_gpu = used_gpu / shared.gpu_memory
prev_gpu = used_gpu
module_cls = module.__class__.__name__
op = f'{op}:skip'
if force:
op = f'{op}:force'
module = do_move(module)
used_gpu -= module_size
elif module_cls in offload_hook_instance.offload_never:
op = f'{op}:never'
elif module_cls in offload_hook_instance.offload_always:
op = f'{op}:always'
module = do_move(module)
used_gpu -= module_size
elif perc_gpu > shared.opts.diffusers_offload_min_gpu_memory:
op = f'{op}:mem'
module = do_move(module)
used_gpu -= module_size
if debug:
quant = getattr(module, "quantization_method", None)
debug_move(f'Offload: type=balanced op={op} gpu={prev_gpu:.3f}:{used_gpu:.3f} perc={perc_gpu:.2f}:{shared.opts.diffusers_offload_min_gpu_memory} ram={used_ram:.3f} current={module.device} dtype={module.dtype} quant={quant} module={module_cls} size={module_size:.3f}')
except Exception as e:
if 'out of memory' in str(e):
devices.torch_gc(fast=True, force=True, reason='oom')
elif 'bitsandbytes' in str(e):
pass
else:
shared.log.error(f'Offload: type=balanced op=apply module={getattr(module, "__name__", None)} cls={module.__class__ if inspect.isclass(module) else None} {e}')
if os.environ.get('SD_MOVE_DEBUG', None):
errors.display(e, f'Offload: type=balanced op=apply module={getattr(module, "__name__", None)}')
def apply_balanced_offload_to_module(module, op="apply", force:bool=False):
module_name = getattr(module, "module_name", module.__class__.__name__)
network_layer_name = getattr(module, "network_layer_name", None)
device_map = getattr(module, "balanced_offload_device_map", None)
max_memory = getattr(module, "balanced_offload_max_memory", None)
try:
module = accelerate.hooks.remove_hook_from_module(module, recurse=True)
except Exception as e:
shared.log.warning(f'Offload remove hook: module={module_name} {e}')
move_module_to_cpu(module, op=op, force=force)
try:
module = accelerate.hooks.add_hook_to_module(module, offload_hook_instance, append=True)
except Exception as e:
shared.log.warning(f'Offload add hook: module={module_name} {e}')
module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access
if network_layer_name:
module.network_layer_name = network_layer_name
if device_map and max_memory:
module.balanced_offload_device_map = device_map
module.balanced_offload_max_memory = max_memory
module.offload_post = shared.sd_model_type in offload_post and module_name.startswith("text_encoder")
if shared.opts.layerwise_quantization or getattr(module, 'quantization_method', None) == 'LayerWise':
model_quant.apply_layerwise(module, quiet=True) # need to reapply since hooks were removed/readded
devices.torch_gc(fast=True, force=True, reason='offload')
def report_model_stats(module_name, module):
try:
size = offload_hook_instance.offload_map.get(module_name, 0)
quant = getattr(module, "quantization_method", None)
params = sum(p.numel() for p in module.parameters(recurse=True))
shared.log.debug(f'Module: name={module_name} cls={module.__class__.__name__} size={size:.3f} params={params} quant={quant}')
except Exception as e:
shared.log.error(f'Module stats: name={module_name} {e}')
def apply_balanced_offload(sd_model=None, exclude:list[str]=None, force:bool=False, silent:bool=False):
global offload_hook_instance # pylint: disable=global-statement
if shared.opts.diffusers_offload_mode != "balanced":
return sd_model
if sd_model is None:
if not shared.sd_loaded:
return sd_model
sd_model = shared.sd_model
if sd_model is None:
return sd_model
if exclude is None:
exclude = []
t0 = time.time()
if sd_model.__class__.__name__ in balanced_offload_exclude:
return sd_model
cached = True
checkpoint_name = sd_model.sd_checkpoint_info.name if getattr(sd_model, "sd_checkpoint_info", None) is not None else sd_model.__class__.__name__
if force or (offload_hook_instance is None) or (offload_hook_instance.min_watermark != shared.opts.diffusers_offload_min_gpu_memory) or (offload_hook_instance.max_watermark != shared.opts.diffusers_offload_max_gpu_memory) or (checkpoint_name != offload_hook_instance.checkpoint_name):
cached = False
offload_hook_instance = OffloadHook(checkpoint_name)
if cached and shared.opts.diffusers_offload_pre:
debug_move('Offload: type=balanced op=apply skip')
return sd_model
for pipe in get_pipe_variants(sd_model):
for module_name, _module_size in get_module_sizes(pipe, exclude):
module = getattr(pipe, module_name, None)
if module is None:
continue
module.module_name = module_name
module.offload_dir = os.path.join(shared.opts.accelerate_offload_path, checkpoint_name, module_name)
apply_balanced_offload_to_module(module, op='apply')
if not silent:
report_model_stats(module_name, module)
set_accelerate(sd_model)
t = time.time() - t0
process_timer.add('offload', t)
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
debug_move(f'Apply offload: time={t:.2f} type=balanced fn={fn}')
if not cached:
shared.log.info(f'Model class={sd_model.__class__.__name__} modules={len(offload_hook_instance.offload_map)} size={offload_hook_instance.model_size():.3f}')
return sd_model