import os import re import sys import time import inspect import torch import accelerate.hooks import accelerate.utils.modeling from installer import log from modules import shared, devices, errors, model_quant, sd_models from modules.timer import process as process_timer debug = os.environ.get('SD_MOVE_DEBUG', None) is not None verbose = os.environ.get('SD_MOVE_VERBOSE', None) is not None debug_move = log.trace if debug else lambda *args, **kwargs: None offload_allow_none = ['sd', 'sdxl'] offload_post = ['h1'] offload_hook_instance = None balanced_offload_exclude = ['CogView4Pipeline', 'MeissonicPipeline'] no_split_module_classes = [ "Linear", "Conv1d", "Conv2d", "Conv3d", "ConvTranspose1d", "ConvTranspose2d", "ConvTranspose3d", "SDNQLinear", "SDNQConv1d", "SDNQConv2d", "SDNQConv3d", "SDNQConvTranspose1d", "SDNQConvTranspose2d", "SDNQConvTranspose3d", "WanTransformerBlock", ] accelerate_dtype_byte_size = None move_stream = None def dtype_byte_size(dtype: torch.dtype): try: if dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]: dtype = accelerate.utils.modeling.CustomDtype.FP8 except Exception: # catch since older torch many not have defined dtypes pass return accelerate_dtype_byte_size(dtype) def get_signature(cls): signature = inspect.signature(cls.__init__, follow_wrapped=True) return signature.parameters def disable_offload(sd_model): if not getattr(sd_model, 'has_accelerate', False): return for module_name in get_module_names(sd_model): module = getattr(sd_model, module_name, None) if isinstance(module, torch.nn.Module): network_layer_name = getattr(module, "network_layer_name", None) try: module = accelerate.hooks.remove_hook_from_module(module, recurse=True) except Exception as e: shared.log.warning(f'Offload remove hook: module={module_name} {e}') if network_layer_name: module.network_layer_name = network_layer_name sd_model.has_accelerate = False def set_accelerate(sd_model): def set_accelerate_to_module(model): if hasattr(model, "pipe"): set_accelerate_to_module(model.pipe) for module_name in get_module_names(model): component = getattr(model, module_name, None) if isinstance(component, torch.nn.Module): component.has_accelerate = True sd_model.has_accelerate = True set_accelerate_to_module(sd_model) if hasattr(sd_model, "prior_pipe"): set_accelerate_to_module(sd_model.prior_pipe) if hasattr(sd_model, "decoder_pipe"): set_accelerate_to_module(sd_model.decoder_pipe) def apply_group_offload(sd_model, op:str='model'): offload_dct = { 'onload_device': devices.device, 'offload_device': devices.cpu, 'offload_type': shared.opts.group_offload_type, 'num_blocks_per_group': shared.opts.group_offload_blocks, 'non_blocking': shared.opts.diffusers_offload_nonblocking, 'use_stream': shared.opts.group_offload_stream, 'record_stream': shared.opts.group_offload_record, 'low_cpu_mem_usage': False, } if shared.opts.group_offload_type == 'block_level': offload_dct['exclude_modules'] = ['vae'] shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} options={offload_dct}') if hasattr(sd_model, "enable_group_offload"): sd_model.enable_group_offload(**offload_dct) else: shared.log.warning(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} not supported') set_accelerate(sd_model) return sd_model def apply_model_offload(sd_model, op:str='model', quiet:bool=False): try: shared.log.quiet(quiet, f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner: shared.opts.diffusers_move_base = False shared.opts.diffusers_move_unet = False shared.opts.diffusers_move_refiner = False shared.log.warning(f'Disabling {op} "Move model to CPU" since "Model CPU offload" is enabled') if not hasattr(sd_model, "_all_hooks") or len(sd_model._all_hooks) == 0: # pylint: disable=protected-access sd_model.enable_model_cpu_offload(device=devices.device) else: sd_model.maybe_free_model_hooks() set_accelerate(sd_model) except Exception as e: shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}') def apply_sequential_offload(sd_model, op:str='model', quiet:bool=False): try: shared.log.quiet(quiet, f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner: shared.opts.diffusers_move_base = False shared.opts.diffusers_move_unet = False shared.opts.diffusers_move_refiner = False shared.log.warning(f'Disabling {op} "Move model to CPU" since "Sequential CPU offload" is enabled') if sd_model.has_accelerate: if op == "vae": # reapply sequential offload to vae from accelerate import cpu_offload sd_model.vae.to(devices.cpu) cpu_offload(sd_model.vae, devices.device, offload_buffers=len(sd_model.vae._parameters) > 0) # pylint: disable=protected-access else: pass # do nothing if offload is already applied else: sd_model.enable_sequential_cpu_offload(device=devices.device) set_accelerate(sd_model) except Exception as e: shared.log.error(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} {e}') def apply_none_offload(sd_model, op:str='model', quiet:bool=False): if shared.sd_model_type not in offload_allow_none: shared.log.warning(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} type={shared.sd_model.__class__.__name__} large model') else: shared.log.quiet(quiet, f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') try: sd_model.has_accelerate = False if hasattr(sd_model, 'maybe_free_model_hooks'): sd_model.maybe_free_model_hooks() sd_model = accelerate.hooks.remove_hook_from_module(sd_model, recurse=True) except Exception: pass sd_models.move_model(sd_model, devices.device) def set_diffuser_offload(sd_model, op:str='model', quiet:bool=False, force:bool=False): global accelerate_dtype_byte_size # pylint: disable=global-statement t0 = time.time() if sd_model is None: shared.log.warning(f'{op} is not loaded') return if not (hasattr(sd_model, "has_accelerate") and sd_model.has_accelerate): sd_model.has_accelerate = False if accelerate_dtype_byte_size is None: accelerate_dtype_byte_size = accelerate.utils.modeling.dtype_byte_size accelerate.utils.modeling.dtype_byte_size = dtype_byte_size if shared.opts.diffusers_offload_mode == "none": apply_none_offload(sd_model, op=op, quiet=quiet) if shared.opts.diffusers_offload_mode == "model" and hasattr(sd_model, "enable_model_cpu_offload"): apply_model_offload(sd_model, op=op, quiet=quiet) if shared.opts.diffusers_offload_mode == "sequential" and hasattr(sd_model, "enable_sequential_cpu_offload"): apply_sequential_offload(sd_model, op=op, quiet=quiet) if shared.opts.diffusers_offload_mode == "group": sd_model = apply_group_offload(sd_model, op=op) if shared.opts.diffusers_offload_mode == "balanced": sd_model = apply_balanced_offload(sd_model, force=force) process_timer.add('offload', time.time() - t0) class OffloadHook(accelerate.hooks.ModelHook): def __init__(self, checkpoint_name): if shared.opts.diffusers_offload_max_gpu_memory > 1: shared.opts.diffusers_offload_max_gpu_memory = 0.75 if shared.opts.diffusers_offload_max_cpu_memory > 1: shared.opts.diffusers_offload_max_cpu_memory = 0.75 self.checkpoint_name = checkpoint_name self.min_watermark = shared.opts.diffusers_offload_min_gpu_memory self.max_watermark = shared.opts.diffusers_offload_max_gpu_memory self.cpu_watermark = shared.opts.diffusers_offload_max_cpu_memory self.offload_always = [m.strip() for m in re.split(';|,| ', shared.opts.diffusers_offload_always) if len(m.strip()) > 2] self.offload_never = [m.strip() for m in re.split(';|,| ', shared.opts.diffusers_offload_never) if len(m.strip()) > 2] self.gpu = int(shared.gpu_memory * shared.opts.diffusers_offload_max_gpu_memory * 1024*1024*1024) self.cpu = int(shared.cpu_memory * shared.opts.diffusers_offload_max_cpu_memory * 1024*1024*1024) self.offload_map = {} self.param_map = {} self.last_pre = None self.last_post = None self.last_cls = None gpu = f'{(shared.gpu_memory * shared.opts.diffusers_offload_min_gpu_memory):.2f}-{(shared.gpu_memory * shared.opts.diffusers_offload_max_gpu_memory):.2f}:{shared.gpu_memory:.2f}' shared.log.info(f'Offload: type=balanced op=init watermark={self.min_watermark}-{self.max_watermark} gpu={gpu} cpu={shared.cpu_memory:.3f} limit={shared.opts.cuda_mem_fraction:.2f} always={self.offload_always} never={self.offload_never} pre={shared.opts.diffusers_offload_pre} streams={shared.opts.diffusers_offload_streams}') self.validate() super().__init__() def validate(self): if shared.opts.diffusers_offload_mode != 'balanced': return if shared.opts.diffusers_offload_min_gpu_memory < 0 or shared.opts.diffusers_offload_min_gpu_memory > 1: shared.opts.diffusers_offload_min_gpu_memory = 0.2 shared.log.warning(f'Offload: type=balanced op=validate: watermark low={shared.opts.diffusers_offload_min_gpu_memory} invalid value') if shared.opts.diffusers_offload_max_gpu_memory < 0.1 or shared.opts.diffusers_offload_max_gpu_memory > 1: shared.opts.diffusers_offload_max_gpu_memory = 0.7 shared.log.warning(f'Offload: type=balanced op=validate: watermark high={shared.opts.diffusers_offload_max_gpu_memory} invalid value') if shared.opts.diffusers_offload_min_gpu_memory > shared.opts.diffusers_offload_max_gpu_memory: shared.opts.diffusers_offload_min_gpu_memory = shared.opts.diffusers_offload_max_gpu_memory shared.log.warning(f'Offload: type=balanced op=validate: watermark low={shared.opts.diffusers_offload_min_gpu_memory} reset') if shared.opts.diffusers_offload_max_gpu_memory * shared.gpu_memory < 4: shared.log.warning(f'Offload: type=balanced op=validate: watermark high={shared.opts.diffusers_offload_max_gpu_memory} low memory') def model_size(self): return sum(self.offload_map.values()) def init_hook(self, module): return module def offload_allowed(self, module): if hasattr(module, "offload_never"): return False if hasattr(module, 'nets') and any(hasattr(n, "offload_never") for n in module.nets): return False if shared.sd_model_type.lower() in [m.lower().strip() for m in re.split(r'[ ,]+', shared.opts.models_not_to_offload)]: return False return True @torch.compiler.disable def pre_forward(self, module, *args, **kwargs): _id = id(module) do_offload = (self.last_pre != _id) or (module.__class__.__name__ != self.last_cls) if do_offload and self.offload_allowed(module): # offload every other module first time when new module starts pre-forward if shared.opts.diffusers_offload_pre: t0 = time.time() debug_move(f'Offload: type=balanced op=pre module={module.__class__.__name__}') for pipe in get_pipe_variants(): for module_name in get_module_names(pipe): module_instance = getattr(pipe, module_name, None) module_cls = module_instance.__class__.__name__ if (module_instance is not None) and (_id != id(module_instance)) and (module_cls not in self.offload_never) and (not devices.same_device(module_instance.device, devices.cpu)): apply_balanced_offload_to_module(module_instance, op='pre') self.last_cls = module.__class__.__name__ process_timer.add('offload', time.time() - t0) if not devices.same_device(module.device, devices.device): # move-to-device t0 = time.time() device_index = torch.device(devices.device).index if device_index is None: device_index = 0 max_memory = { device_index: self.gpu, "cpu": self.cpu } device_map = getattr(module, "balanced_offload_device_map", None) if (device_map is None) or (max_memory != getattr(module, "balanced_offload_max_memory", None)): device_map = accelerate.infer_auto_device_map(module, max_memory=max_memory, no_split_module_classes=no_split_module_classes, verbose=verbose, clean_result=False, ) offload_dir = getattr(module, "offload_dir", os.path.join(shared.opts.accelerate_offload_path, module.__class__.__name__)) if devices.backend == "directml": for k, v in device_map.items(): if isinstance(v, int): device_map[k] = f"{devices.device.type}:{v}" # int implies CUDA or XPU device, but it will break DirectML backend so we add type if debug: shared.log.trace(f'Offload: type=balanced op=dispatch map={device_map}') if device_map is not None: skip_keys = getattr(module, "_skip_keys", None) module = accelerate.dispatch_model(module, main_device=torch.device(devices.device), device_map=device_map, offload_dir=offload_dir, skip_keys=skip_keys, force_hooks=True, ) module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access module.balanced_offload_device_map = device_map module.balanced_offload_max_memory = max_memory process_timer.add('onload', time.time() - t0) if debug: for _i, pipe in enumerate(get_pipe_variants()): for module_name in get_module_names(pipe): module_instance = getattr(pipe, module_name, None) shared.log.trace(f'Offload: type=balanced op=pre:status forward={module.__class__.__name__} module={module_name} class={module_instance.__class__.__name__} pipe={_i} device={module_instance.device} dtype={module_instance.dtype}') self.last_pre = _id return args, kwargs @torch.compiler.disable def post_forward(self, module, output): if self.last_post != id(module): self.last_post = id(module) if getattr(module, "offload_post", False) and (module.device != devices.cpu): apply_balanced_offload_to_module(module, op='post') return output def detach_hook(self, module): return module def get_pipe_variants(pipe=None): if pipe is None: if shared.sd_loaded: pipe = shared.sd_model else: return [pipe] variants = [pipe] if hasattr(pipe, "pipe"): variants.append(pipe.pipe) if hasattr(pipe, "prior_pipe"): variants.append(pipe.prior_pipe) if hasattr(pipe, "decoder_pipe"): variants.append(pipe.decoder_pipe) return variants def get_module_names(pipe=None, exclude=None): def is_valid(module): if isinstance(getattr(pipe, module, None), torch.nn.ModuleDict): return True if isinstance(getattr(pipe, module, None), torch.nn.ModuleList): return True if isinstance(getattr(pipe, module, None), torch.nn.Module): return True return False if exclude is None: exclude = [] if pipe is None: if shared.sd_loaded: pipe = shared.sd_model else: return [] modules_names = [] try: dict_keys = pipe._internal_dict.keys() # pylint: disable=protected-access modules_names.extend(dict_keys) except Exception: pass try: dict_keys = get_signature(pipe).keys() modules_names.extend(dict_keys) except Exception: pass modules_names = [m for m in modules_names if m not in exclude and not m.startswith('_')] modules_names = [m for m in modules_names if is_valid(m)] modules_names = sorted(set(modules_names)) return modules_names def get_module_sizes(pipe=None, exclude=None): if exclude is None: exclude = [] modules = {} for module_name in get_module_names(pipe, exclude): module_size = offload_hook_instance.offload_map.get(module_name, None) if module_size is None: module = getattr(pipe, module_name, None) if not isinstance(module, torch.nn.Module): continue try: module_size = sum(p.numel() * p.element_size() for p in module.parameters(recurse=True)) / 1024 / 1024 / 1024 param_num = sum(p.numel() for p in module.parameters(recurse=True)) / 1024 / 1024 / 1024 except Exception as e: shared.log.error(f'Offload: type=balanced op=calc module={module_name} {e}') module_size = 0 offload_hook_instance.offload_map[module_name] = module_size offload_hook_instance.param_map[module_name] = param_num modules[module_name] = module_size modules = sorted(modules.items(), key=lambda x: x[1], reverse=True) return modules def move_module_to_cpu(module, op='unk', force:bool=False): def do_move(module): if shared.opts.diffusers_offload_streams: global move_stream # pylint: disable=global-statement if move_stream is None: move_stream = torch.cuda.Stream(device=devices.device) with torch.cuda.stream(move_stream): module = module.to(devices.cpu) else: module = module.to(devices.cpu) return module try: module_name = getattr(module, "module_name", module.__class__.__name__) module_size = offload_hook_instance.offload_map.get(module_name, offload_hook_instance.model_size()) used_gpu, used_ram = devices.torch_gc(fast=True) perc_gpu = used_gpu / shared.gpu_memory prev_gpu = used_gpu module_cls = module.__class__.__name__ op = f'{op}:skip' if force: op = f'{op}:force' module = do_move(module) used_gpu -= module_size elif module_cls in offload_hook_instance.offload_never: op = f'{op}:never' elif module_cls in offload_hook_instance.offload_always: op = f'{op}:always' module = do_move(module) used_gpu -= module_size elif perc_gpu > shared.opts.diffusers_offload_min_gpu_memory: op = f'{op}:mem' module = do_move(module) used_gpu -= module_size if debug: quant = getattr(module, "quantization_method", None) debug_move(f'Offload: type=balanced op={op} gpu={prev_gpu:.3f}:{used_gpu:.3f} perc={perc_gpu:.2f}:{shared.opts.diffusers_offload_min_gpu_memory} ram={used_ram:.3f} current={module.device} dtype={module.dtype} quant={quant} module={module_cls} size={module_size:.3f}') except Exception as e: if 'out of memory' in str(e): devices.torch_gc(fast=True, force=True, reason='oom') elif 'bitsandbytes' in str(e): pass else: shared.log.error(f'Offload: type=balanced op=apply module={getattr(module, "__name__", None)} cls={module.__class__ if inspect.isclass(module) else None} {e}') if os.environ.get('SD_MOVE_DEBUG', None): errors.display(e, f'Offload: type=balanced op=apply module={getattr(module, "__name__", None)}') def apply_balanced_offload_to_module(module, op="apply", force:bool=False): module_name = getattr(module, "module_name", module.__class__.__name__) network_layer_name = getattr(module, "network_layer_name", None) device_map = getattr(module, "balanced_offload_device_map", None) max_memory = getattr(module, "balanced_offload_max_memory", None) try: module = accelerate.hooks.remove_hook_from_module(module, recurse=True) except Exception as e: shared.log.warning(f'Offload remove hook: module={module_name} {e}') move_module_to_cpu(module, op=op, force=force) try: module = accelerate.hooks.add_hook_to_module(module, offload_hook_instance, append=True) except Exception as e: shared.log.warning(f'Offload add hook: module={module_name} {e}') module._hf_hook.execution_device = torch.device(devices.device) # pylint: disable=protected-access if network_layer_name: module.network_layer_name = network_layer_name if device_map and max_memory: module.balanced_offload_device_map = device_map module.balanced_offload_max_memory = max_memory module.offload_post = shared.sd_model_type in offload_post and module_name.startswith("text_encoder") if shared.opts.layerwise_quantization or getattr(module, 'quantization_method', None) == 'LayerWise': model_quant.apply_layerwise(module, quiet=True) # need to reapply since hooks were removed/readded devices.torch_gc(fast=True, force=True, reason='offload') def report_model_stats(module_name, module): try: size = offload_hook_instance.offload_map.get(module_name, 0) quant = getattr(module, "quantization_method", None) params = sum(p.numel() for p in module.parameters(recurse=True)) shared.log.debug(f'Module: name={module_name} cls={module.__class__.__name__} size={size:.3f} params={params} quant={quant}') except Exception as e: shared.log.error(f'Module stats: name={module_name} {e}') def apply_balanced_offload(sd_model=None, exclude:list[str]=None, force:bool=False, silent:bool=False): global offload_hook_instance # pylint: disable=global-statement if shared.opts.diffusers_offload_mode != "balanced": return sd_model if sd_model is None: if not shared.sd_loaded: return sd_model sd_model = shared.sd_model if sd_model is None: return sd_model if exclude is None: exclude = [] t0 = time.time() if sd_model.__class__.__name__ in balanced_offload_exclude: return sd_model cached = True checkpoint_name = sd_model.sd_checkpoint_info.name if getattr(sd_model, "sd_checkpoint_info", None) is not None else sd_model.__class__.__name__ if force or (offload_hook_instance is None) or (offload_hook_instance.min_watermark != shared.opts.diffusers_offload_min_gpu_memory) or (offload_hook_instance.max_watermark != shared.opts.diffusers_offload_max_gpu_memory) or (checkpoint_name != offload_hook_instance.checkpoint_name): cached = False offload_hook_instance = OffloadHook(checkpoint_name) if cached and shared.opts.diffusers_offload_pre: debug_move('Offload: type=balanced op=apply skip') return sd_model for pipe in get_pipe_variants(sd_model): for module_name, _module_size in get_module_sizes(pipe, exclude): module = getattr(pipe, module_name, None) if module is None: continue module.module_name = module_name module.offload_dir = os.path.join(shared.opts.accelerate_offload_path, checkpoint_name, module_name) apply_balanced_offload_to_module(module, op='apply') if not silent: report_model_stats(module_name, module) set_accelerate(sd_model) t = time.time() - t0 process_timer.add('offload', t) fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access debug_move(f'Apply offload: time={t:.2f} type=balanced fn={fn}') if not cached: shared.log.info(f'Model class={sd_model.__class__.__name__} modules={len(offload_hook_instance.offload_map)} size={offload_hook_instance.model_size():.3f}') return sd_model