mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
513 lines
25 KiB
Python
513 lines
25 KiB
Python
import os
|
|
import re
|
|
import sys
|
|
import time
|
|
import inspect
|
|
import torch
|
|
import accelerate.hooks
|
|
import accelerate.utils.modeling
|
|
from installer import log
|
|
from modules import shared, devices, errors, model_quant, sd_models
|
|
from modules.timer import process as process_timer
|
|
|
|
|
|
debug = os.environ.get('SD_MOVE_DEBUG', None) is not None
|
|
verbose = os.environ.get('SD_MOVE_VERBOSE', None) is not None
|
|
debug_move = log.trace if debug else lambda *args, **kwargs: None
|
|
offload_allow_none = ['sd', 'sdxl']
|
|
offload_post = ['h1']
|
|
offload_hook_instance = None
|
|
balanced_offload_exclude = ['CogView4Pipeline', 'MeissonicPipeline']
|
|
no_split_module_classes = [
|
|
"Linear", "Conv1d", "Conv2d", "Conv3d", "ConvTranspose1d", "ConvTranspose2d", "ConvTranspose3d",
|
|
"SDNQLinear", "SDNQConv1d", "SDNQConv2d", "SDNQConv3d", "SDNQConvTranspose1d", "SDNQConvTranspose2d", "SDNQConvTranspose3d",
|
|
"WanTransformerBlock",
|
|
]
|
|
accelerate_dtype_byte_size = None
|
|
move_stream = None
|
|
|
|
|
|
def dtype_byte_size(dtype: torch.dtype):
|
|
try:
|
|
if dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]:
|
|
dtype = accelerate.utils.modeling.CustomDtype.FP8
|
|
except Exception: # catch since older torch many not have defined dtypes
|
|
pass
|
|
return accelerate_dtype_byte_size(dtype)
|
|
|
|
|
|
def get_signature(cls):
|
|
signature = inspect.signature(cls.__init__, follow_wrapped=True)
|
|
return signature.parameters
|
|
|
|
|
|
def disable_offload(sd_model):
|
|
if not getattr(sd_model, 'has_accelerate', False):
|
|
return
|
|
for module_name in get_module_names(sd_model):
|
|
module = getattr(sd_model, module_name, None)
|
|
if isinstance(module, torch.nn.Module):
|
|
network_layer_name = getattr(module, "network_layer_name", None)
|
|
try:
|
|
module = accelerate.hooks.remove_hook_from_module(module, recurse=True)
|
|
except Exception as e:
|
|
shared.log.warning(f'Offload remove hook: module={module_name} {e}')
|
|
if network_layer_name:
|
|
module.network_layer_name = network_layer_name
|
|
sd_model.has_accelerate = False
|
|
|
|
|
|
def set_accelerate(sd_model):
|
|
def set_accelerate_to_module(model):
|
|
if hasattr(model, "pipe"):
|
|
set_accelerate_to_module(model.pipe)
|
|
for module_name in get_module_names(model):
|
|
component = getattr(model, module_name, None)
|
|
if isinstance(component, torch.nn.Module):
|
|
component.has_accelerate = True
|
|
|
|
sd_model.has_accelerate = True
|
|
set_accelerate_to_module(sd_model)
|
|
if hasattr(sd_model, "prior_pipe"):
|
|
set_accelerate_to_module(sd_model.prior_pipe)
|
|
if hasattr(sd_model, "decoder_pipe"):
|
|
set_accelerate_to_module(sd_model.decoder_pipe)
|
|
|
|
|
|
def apply_group_offload(sd_model, op:str='model'):
|
|
offload_dct = {
|
|
'onload_device': devices.device,
|
|
'offload_device': devices.cpu,
|
|
'offload_type': shared.opts.group_offload_type,
|
|
'num_blocks_per_group': shared.opts.group_offload_blocks,
|
|
'non_blocking': shared.opts.diffusers_offload_nonblocking,
|
|
'use_stream': shared.opts.group_offload_stream,
|
|
'record_stream': shared.opts.group_offload_record,
|
|
'low_cpu_mem_usage': False,
|
|
}
|
|
if shared.opts.group_offload_type == 'block_level':
|
|
offload_dct['exclude_modules'] = ['vae']
|
|
shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} options={offload_dct}')
|
|
if hasattr(sd_model, "enable_group_offload"):
|
|
sd_model.enable_group_offload(**offload_dct)
|
|
else:
|
|
shared.log.warning(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} not supported')
|
|
set_accelerate(sd_model)
|
|
return sd_model
|
|
|
|
|
|
def apply_model_offload(sd_model, op:str='model', quiet:bool=False):
|
|
try:
|
|
shared.log.quiet(quiet, f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}')
|
|
if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner:
|
|
shared.opts.diffusers_move_base = False
|
|
shared.opts.diffusers_move_unet = False
|
|
shared.opts.diffusers_move_refiner = False
|
|
shared.log.warning(f'Disabling {op} "Move model to CPU" since "Model CPU offload" is enabled')
|
|
if not hasattr(sd_model, "_all_hooks") or len(sd_model._all_hooks) == 0: # pylint: disable=protected-access
|
|
sd_model.enable_model_cpu_offload(device=devices.device)
|
|
else:
|
|
sd_model.maybe_free_model_hooks()
|
|
set_accelerate(sd_model)
|
|
except Exception as e:
|
|
shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}')
|
|
|
|
|
|
def apply_sequential_offload(sd_model, op:str='model', quiet:bool=False):
|
|
try:
|
|
shared.log.quiet(quiet, f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}')
|
|
if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner:
|
|
shared.opts.diffusers_move_base = False
|
|
shared.opts.diffusers_move_unet = False
|
|
shared.opts.diffusers_move_refiner = False
|
|
shared.log.warning(f'Disabling {op} "Move model to CPU" since "Sequential CPU offload" is enabled')
|
|
if sd_model.has_accelerate:
|
|
if op == "vae": # reapply sequential offload to vae
|
|
from accelerate import cpu_offload
|
|
sd_model.vae.to(devices.cpu)
|
|
cpu_offload(sd_model.vae, devices.device, offload_buffers=len(sd_model.vae._parameters) > 0) # pylint: disable=protected-access
|
|
else:
|
|
pass # do nothing if offload is already applied
|
|
else:
|
|
sd_model.enable_sequential_cpu_offload(device=devices.device)
|
|
set_accelerate(sd_model)
|
|
except Exception as e:
|
|
shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}')
|
|
|
|
|
|
def apply_none_offload(sd_model, op:str='model', quiet:bool=False):
|
|
if shared.sd_model_type not in offload_allow_none:
|
|
shared.log.warning(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} type={shared.sd_model.__class__.__name__} large model')
|
|
else:
|
|
shared.log.quiet(quiet, f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}')
|
|
try:
|
|
sd_model.has_accelerate = False
|
|
if hasattr(sd_model, 'maybe_free_model_hooks'):
|
|
sd_model.maybe_free_model_hooks()
|
|
sd_model = accelerate.hooks.remove_hook_from_module(sd_model, recurse=True)
|
|
except Exception:
|
|
pass
|
|
sd_models.move_model(sd_model, devices.device)
|
|
|
|
|
|
def set_diffuser_offload(sd_model, op:str='model', quiet:bool=False, force:bool=False):
|
|
global accelerate_dtype_byte_size # pylint: disable=global-statement
|
|
t0 = time.time()
|
|
if sd_model is None:
|
|
shared.log.warning(f'{op} is not loaded')
|
|
return
|
|
if not (hasattr(sd_model, "has_accelerate") and sd_model.has_accelerate):
|
|
sd_model.has_accelerate = False
|
|
if accelerate_dtype_byte_size is None:
|
|
accelerate_dtype_byte_size = accelerate.utils.modeling.dtype_byte_size
|
|
accelerate.utils.modeling.dtype_byte_size = dtype_byte_size
|
|
|
|
if shared.opts.diffusers_offload_mode == "none":
|
|
apply_none_offload(sd_model, op=op, quiet=quiet)
|
|
|
|
if shared.opts.diffusers_offload_mode == "model" and hasattr(sd_model, "enable_model_cpu_offload"):
|
|
apply_model_offload(sd_model, op=op, quiet=quiet)
|
|
|
|
if shared.opts.diffusers_offload_mode == "sequential" and hasattr(sd_model, "enable_sequential_cpu_offload"):
|
|
apply_sequential_offload(sd_model, op=op, quiet=quiet)
|
|
|
|
if shared.opts.diffusers_offload_mode == "group":
|
|
sd_model = apply_group_offload(sd_model, op=op)
|
|
|
|
if shared.opts.diffusers_offload_mode == "balanced":
|
|
sd_model = apply_balanced_offload(sd_model, force=force)
|
|
|
|
process_timer.add('offload', time.time() - t0)
|
|
|
|
|
|
class OffloadHook(accelerate.hooks.ModelHook):
|
|
def __init__(self, checkpoint_name):
|
|
if shared.opts.diffusers_offload_max_gpu_memory > 1:
|
|
shared.opts.diffusers_offload_max_gpu_memory = 0.75
|
|
if shared.opts.diffusers_offload_max_cpu_memory > 1:
|
|
shared.opts.diffusers_offload_max_cpu_memory = 0.75
|
|
self.checkpoint_name = checkpoint_name
|
|
self.min_watermark = shared.opts.diffusers_offload_min_gpu_memory
|
|
self.max_watermark = shared.opts.diffusers_offload_max_gpu_memory
|
|
self.cpu_watermark = shared.opts.diffusers_offload_max_cpu_memory
|
|
self.offload_always = [m.strip() for m in re.split(';|,| ', shared.opts.diffusers_offload_always) if len(m.strip()) > 2]
|
|
self.offload_never = [m.strip() for m in re.split(';|,| ', shared.opts.diffusers_offload_never) if len(m.strip()) > 2]
|
|
self.gpu = int(shared.gpu_memory * shared.opts.diffusers_offload_max_gpu_memory * 1024*1024*1024)
|
|
self.cpu = int(shared.cpu_memory * shared.opts.diffusers_offload_max_cpu_memory * 1024*1024*1024)
|
|
self.offload_map = {}
|
|
self.param_map = {}
|
|
self.last_pre = None
|
|
self.last_post = None
|
|
self.last_cls = None
|
|
gpu = f'{(shared.gpu_memory * shared.opts.diffusers_offload_min_gpu_memory):.2f}-{(shared.gpu_memory * shared.opts.diffusers_offload_max_gpu_memory):.2f}:{shared.gpu_memory:.2f}'
|
|
shared.log.info(f'Offload: type=balanced op=init watermark={self.min_watermark}-{self.max_watermark} gpu={gpu} cpu={shared.cpu_memory:.3f} limit={shared.opts.cuda_mem_fraction:.2f} always={self.offload_always} never={self.offload_never} pre={shared.opts.diffusers_offload_pre} streams={shared.opts.diffusers_offload_streams}')
|
|
self.validate()
|
|
super().__init__()
|
|
|
|
def validate(self):
|
|
if shared.opts.diffusers_offload_mode != 'balanced':
|
|
return
|
|
if shared.opts.diffusers_offload_min_gpu_memory < 0 or shared.opts.diffusers_offload_min_gpu_memory > 1:
|
|
shared.opts.diffusers_offload_min_gpu_memory = 0.2
|
|
shared.log.warning(f'Offload: type=balanced op=validate: watermark low={shared.opts.diffusers_offload_min_gpu_memory} invalid value')
|
|
if shared.opts.diffusers_offload_max_gpu_memory < 0.1 or shared.opts.diffusers_offload_max_gpu_memory > 1:
|
|
shared.opts.diffusers_offload_max_gpu_memory = 0.7
|
|
shared.log.warning(f'Offload: type=balanced op=validate: watermark high={shared.opts.diffusers_offload_max_gpu_memory} invalid value')
|
|
if shared.opts.diffusers_offload_min_gpu_memory > shared.opts.diffusers_offload_max_gpu_memory:
|
|
shared.opts.diffusers_offload_min_gpu_memory = shared.opts.diffusers_offload_max_gpu_memory
|
|
shared.log.warning(f'Offload: type=balanced op=validate: watermark low={shared.opts.diffusers_offload_min_gpu_memory} reset')
|
|
if shared.opts.diffusers_offload_max_gpu_memory * shared.gpu_memory < 4:
|
|
shared.log.warning(f'Offload: type=balanced op=validate: watermark high={shared.opts.diffusers_offload_max_gpu_memory} low memory')
|
|
|
|
def model_size(self):
|
|
return sum(self.offload_map.values())
|
|
|
|
def init_hook(self, module):
|
|
return module
|
|
|
|
def offload_allowed(self, module):
|
|
if hasattr(module, "offload_never"):
|
|
return False
|
|
if hasattr(module, 'nets') and any(hasattr(n, "offload_never") for n in module.nets):
|
|
return False
|
|
if shared.sd_model_type.lower() in [m.lower().strip() for m in re.split(r'[ ,]+', shared.opts.models_not_to_offload)]:
|
|
return False
|
|
return True
|
|
|
|
@torch.compiler.disable
|
|
def pre_forward(self, module, *args, **kwargs):
|
|
_id = id(module)
|
|
|
|
do_offload = (self.last_pre != _id) or (module.__class__.__name__ != self.last_cls)
|
|
|
|
if do_offload and self.offload_allowed(module): # offload every other module first time when new module starts pre-forward
|
|
if shared.opts.diffusers_offload_pre:
|
|
t0 = time.time()
|
|
debug_move(f'Offload: type=balanced op=pre module={module.__class__.__name__}')
|
|
for pipe in get_pipe_variants():
|
|
for module_name in get_module_names(pipe):
|
|
module_instance = getattr(pipe, module_name, None)
|
|
module_cls = module_instance.__class__.__name__
|
|
if (module_instance is not None) and (_id != id(module_instance)) and (module_cls not in self.offload_never) and (not devices.same_device(module_instance.device, devices.cpu)):
|
|
apply_balanced_offload_to_module(module_instance, op='pre')
|
|
self.last_cls = module.__class__.__name__
|
|
process_timer.add('offload', time.time() - t0)
|
|
|
|
if not devices.same_device(module.device, devices.device): # move-to-device
|
|
t0 = time.time()
|
|
device_index = torch.device(devices.device).index
|
|
if device_index is None:
|
|
device_index = 0
|
|
max_memory = { device_index: self.gpu, "cpu": self.cpu }
|
|
device_map = getattr(module, "balanced_offload_device_map", None)
|
|
if (device_map is None) or (max_memory != getattr(module, "balanced_offload_max_memory", None)):
|
|
device_map = accelerate.infer_auto_device_map(module,
|
|
max_memory=max_memory,
|
|
no_split_module_classes=no_split_module_classes,
|
|
verbose=verbose,
|
|
clean_result=False,
|
|
)
|
|
offload_dir = getattr(module, "offload_dir", os.path.join(shared.opts.accelerate_offload_path, module.__class__.__name__))
|
|
if devices.backend == "directml":
|
|
for k, v in device_map.items():
|
|
if isinstance(v, int):
|
|
device_map[k] = f"{devices.device.type}:{v}" # int implies CUDA or XPU device, but it will break DirectML backend so we add type
|
|
if debug:
|
|
shared.log.trace(f'Offload: type=balanced op=dispatch map={device_map}')
|
|
if device_map is not None:
|
|
skip_keys = getattr(module, "_skip_keys", None)
|
|
module = accelerate.dispatch_model(module,
|
|
main_device=torch.device(devices.device),
|
|
device_map=device_map,
|
|
offload_dir=offload_dir,
|
|
skip_keys=skip_keys,
|
|
force_hooks=True,
|
|
)
|
|
module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access
|
|
module.balanced_offload_device_map = device_map
|
|
module.balanced_offload_max_memory = max_memory
|
|
process_timer.add('onload', time.time() - t0)
|
|
|
|
if debug:
|
|
for _i, pipe in enumerate(get_pipe_variants()):
|
|
for module_name in get_module_names(pipe):
|
|
module_instance = getattr(pipe, module_name, None)
|
|
shared.log.trace(f'Offload: type=balanced op=pre:status forward={module.__class__.__name__} module={module_name} class={module_instance.__class__.__name__} pipe={_i} device={module_instance.device} dtype={module_instance.dtype}')
|
|
|
|
self.last_pre = _id
|
|
return args, kwargs
|
|
|
|
@torch.compiler.disable
|
|
def post_forward(self, module, output):
|
|
if self.last_post != id(module):
|
|
self.last_post = id(module)
|
|
if getattr(module, "offload_post", False) and (module.device != devices.cpu):
|
|
apply_balanced_offload_to_module(module, op='post')
|
|
return output
|
|
|
|
def detach_hook(self, module):
|
|
return module
|
|
|
|
|
|
def get_pipe_variants(pipe=None):
|
|
if pipe is None:
|
|
if shared.sd_loaded:
|
|
pipe = shared.sd_model
|
|
else:
|
|
return [pipe]
|
|
variants = [pipe]
|
|
if hasattr(pipe, "pipe"):
|
|
variants.append(pipe.pipe)
|
|
if hasattr(pipe, "prior_pipe"):
|
|
variants.append(pipe.prior_pipe)
|
|
if hasattr(pipe, "decoder_pipe"):
|
|
variants.append(pipe.decoder_pipe)
|
|
return variants
|
|
|
|
|
|
def get_module_names(pipe=None, exclude=None):
|
|
def is_valid(module):
|
|
if isinstance(getattr(pipe, module, None), torch.nn.ModuleDict):
|
|
return True
|
|
if isinstance(getattr(pipe, module, None), torch.nn.ModuleList):
|
|
return True
|
|
if isinstance(getattr(pipe, module, None), torch.nn.Module):
|
|
return True
|
|
return False
|
|
|
|
if exclude is None:
|
|
exclude = []
|
|
if pipe is None:
|
|
if shared.sd_loaded:
|
|
pipe = shared.sd_model
|
|
else:
|
|
return []
|
|
modules_names = []
|
|
try:
|
|
dict_keys = pipe._internal_dict.keys() # pylint: disable=protected-access
|
|
modules_names.extend(dict_keys)
|
|
except Exception:
|
|
pass
|
|
try:
|
|
dict_keys = get_signature(pipe).keys()
|
|
modules_names.extend(dict_keys)
|
|
except Exception:
|
|
pass
|
|
modules_names = [m for m in modules_names if m not in exclude and not m.startswith('_')]
|
|
modules_names = [m for m in modules_names if is_valid(m)]
|
|
modules_names = sorted(set(modules_names))
|
|
return modules_names
|
|
|
|
|
|
def get_module_sizes(pipe=None, exclude=None):
|
|
if exclude is None:
|
|
exclude = []
|
|
modules = {}
|
|
for module_name in get_module_names(pipe, exclude):
|
|
module_size = offload_hook_instance.offload_map.get(module_name, None)
|
|
if module_size is None:
|
|
module = getattr(pipe, module_name, None)
|
|
if not isinstance(module, torch.nn.Module):
|
|
continue
|
|
try:
|
|
module_size = sum(p.numel() * p.element_size() for p in module.parameters(recurse=True)) / 1024 / 1024 / 1024
|
|
param_num = sum(p.numel() for p in module.parameters(recurse=True)) / 1024 / 1024 / 1024
|
|
except Exception as e:
|
|
shared.log.error(f'Offload: type=balanced op=calc module={module_name} {e}')
|
|
module_size = 0
|
|
offload_hook_instance.offload_map[module_name] = module_size
|
|
offload_hook_instance.param_map[module_name] = param_num
|
|
modules[module_name] = module_size
|
|
modules = sorted(modules.items(), key=lambda x: x[1], reverse=True)
|
|
return modules
|
|
|
|
|
|
def move_module_to_cpu(module, op='unk', force:bool=False):
|
|
def do_move(module):
|
|
if shared.opts.diffusers_offload_streams:
|
|
global move_stream # pylint: disable=global-statement
|
|
if move_stream is None:
|
|
move_stream = torch.cuda.Stream(device=devices.device)
|
|
with torch.cuda.stream(move_stream):
|
|
module = module.to(devices.cpu)
|
|
else:
|
|
module = module.to(devices.cpu)
|
|
return module
|
|
|
|
try:
|
|
module_name = getattr(module, "module_name", module.__class__.__name__)
|
|
module_size = offload_hook_instance.offload_map.get(module_name, offload_hook_instance.model_size())
|
|
used_gpu, used_ram = devices.torch_gc(fast=True)
|
|
perc_gpu = used_gpu / shared.gpu_memory
|
|
prev_gpu = used_gpu
|
|
module_cls = module.__class__.__name__
|
|
op = f'{op}:skip'
|
|
if force:
|
|
op = f'{op}:force'
|
|
module = do_move(module)
|
|
used_gpu -= module_size
|
|
elif module_cls in offload_hook_instance.offload_never:
|
|
op = f'{op}:never'
|
|
elif module_cls in offload_hook_instance.offload_always:
|
|
op = f'{op}:always'
|
|
module = do_move(module)
|
|
used_gpu -= module_size
|
|
elif perc_gpu > shared.opts.diffusers_offload_min_gpu_memory:
|
|
op = f'{op}:mem'
|
|
module = do_move(module)
|
|
used_gpu -= module_size
|
|
if debug:
|
|
quant = getattr(module, "quantization_method", None)
|
|
debug_move(f'Offload: type=balanced op={op} gpu={prev_gpu:.3f}:{used_gpu:.3f} perc={perc_gpu:.2f}:{shared.opts.diffusers_offload_min_gpu_memory} ram={used_ram:.3f} current={module.device} dtype={module.dtype} quant={quant} module={module_cls} size={module_size:.3f}')
|
|
except Exception as e:
|
|
if 'out of memory' in str(e):
|
|
devices.torch_gc(fast=True, force=True, reason='oom')
|
|
elif 'bitsandbytes' in str(e):
|
|
pass
|
|
else:
|
|
shared.log.error(f'Offload: type=balanced op=apply module={getattr(module, "__name__", None)} cls={module.__class__ if inspect.isclass(module) else None} {e}')
|
|
if os.environ.get('SD_MOVE_DEBUG', None):
|
|
errors.display(e, f'Offload: type=balanced op=apply module={getattr(module, "__name__", None)}')
|
|
|
|
|
|
def apply_balanced_offload_to_module(module, op="apply", force:bool=False):
|
|
module_name = getattr(module, "module_name", module.__class__.__name__)
|
|
network_layer_name = getattr(module, "network_layer_name", None)
|
|
device_map = getattr(module, "balanced_offload_device_map", None)
|
|
max_memory = getattr(module, "balanced_offload_max_memory", None)
|
|
try:
|
|
module = accelerate.hooks.remove_hook_from_module(module, recurse=True)
|
|
except Exception as e:
|
|
shared.log.warning(f'Offload remove hook: module={module_name} {e}')
|
|
move_module_to_cpu(module, op=op, force=force)
|
|
try:
|
|
module = accelerate.hooks.add_hook_to_module(module, offload_hook_instance, append=True)
|
|
except Exception as e:
|
|
shared.log.warning(f'Offload add hook: module={module_name} {e}')
|
|
module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access
|
|
if network_layer_name:
|
|
module.network_layer_name = network_layer_name
|
|
if device_map and max_memory:
|
|
module.balanced_offload_device_map = device_map
|
|
module.balanced_offload_max_memory = max_memory
|
|
module.offload_post = shared.sd_model_type in offload_post and module_name.startswith("text_encoder")
|
|
if shared.opts.layerwise_quantization or getattr(module, 'quantization_method', None) == 'LayerWise':
|
|
model_quant.apply_layerwise(module, quiet=True) # need to reapply since hooks were removed/readded
|
|
devices.torch_gc(fast=True, force=True, reason='offload')
|
|
|
|
|
|
def report_model_stats(module_name, module):
|
|
try:
|
|
size = offload_hook_instance.offload_map.get(module_name, 0)
|
|
quant = getattr(module, "quantization_method", None)
|
|
params = sum(p.numel() for p in module.parameters(recurse=True))
|
|
shared.log.debug(f'Module: name={module_name} cls={module.__class__.__name__} size={size:.3f} params={params} quant={quant}')
|
|
except Exception as e:
|
|
shared.log.error(f'Module stats: name={module_name} {e}')
|
|
|
|
|
|
def apply_balanced_offload(sd_model=None, exclude:list[str]=None, force:bool=False, silent:bool=False):
|
|
global offload_hook_instance # pylint: disable=global-statement
|
|
if shared.opts.diffusers_offload_mode != "balanced":
|
|
return sd_model
|
|
if sd_model is None:
|
|
if not shared.sd_loaded:
|
|
return sd_model
|
|
sd_model = shared.sd_model
|
|
if sd_model is None:
|
|
return sd_model
|
|
if exclude is None:
|
|
exclude = []
|
|
t0 = time.time()
|
|
if sd_model.__class__.__name__ in balanced_offload_exclude:
|
|
return sd_model
|
|
cached = True
|
|
checkpoint_name = sd_model.sd_checkpoint_info.name if getattr(sd_model, "sd_checkpoint_info", None) is not None else sd_model.__class__.__name__
|
|
if force or (offload_hook_instance is None) or (offload_hook_instance.min_watermark != shared.opts.diffusers_offload_min_gpu_memory) or (offload_hook_instance.max_watermark != shared.opts.diffusers_offload_max_gpu_memory) or (checkpoint_name != offload_hook_instance.checkpoint_name):
|
|
cached = False
|
|
offload_hook_instance = OffloadHook(checkpoint_name)
|
|
|
|
if cached and shared.opts.diffusers_offload_pre:
|
|
debug_move('Offload: type=balanced op=apply skip')
|
|
return sd_model
|
|
|
|
for pipe in get_pipe_variants(sd_model):
|
|
for module_name, _module_size in get_module_sizes(pipe, exclude):
|
|
module = getattr(pipe, module_name, None)
|
|
if module is None:
|
|
continue
|
|
module.module_name = module_name
|
|
module.offload_dir = os.path.join(shared.opts.accelerate_offload_path, checkpoint_name, module_name)
|
|
apply_balanced_offload_to_module(module, op='apply')
|
|
if not silent:
|
|
report_model_stats(module_name, module)
|
|
|
|
set_accelerate(sd_model)
|
|
t = time.time() - t0
|
|
process_timer.add('offload', t)
|
|
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
|
|
debug_move(f'Apply offload: time={t:.2f} type=balanced fn={fn}')
|
|
if not cached:
|
|
shared.log.info(f'Model class={sd_model.__class__.__name__} modules={len(offload_hook_instance.offload_map)} size={offload_hook_instance.model_size():.3f}')
|
|
return sd_model
|