import typing import os import time import torch import numpy as np from modules import shared, devices, processing_correction, timer, prompt_parser_diffusers p = None debug = os.environ.get('SD_CALLBACK_DEBUG', None) is not None debug_callback = shared.log.trace if debug else lambda *args, **kwargs: None warned = False def set_callbacks_p(processing): global p, warned # pylint: disable=global-statement p = processing warned = False def prompt_callback(step, kwargs): if prompt_parser_diffusers.embedder is None or 'prompt_embeds' not in kwargs: return kwargs try: prompt_embeds = prompt_parser_diffusers.embedder('prompt_embeds', step + 1) negative_prompt_embeds = prompt_parser_diffusers.embedder('negative_prompt_embeds', step + 1) if p.cfg_scale > 1: # Perform guidance prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # Combined embeds assert prompt_embeds.shape == kwargs['prompt_embeds'].shape, f"prompt_embed shape mismatch {kwargs['prompt_embeds'].shape} {prompt_embeds.shape}" kwargs['prompt_embeds'] = prompt_embeds except Exception as e: debug_callback(f"Callback: type=prompt {e}") return kwargs def diffusers_callback_legacy(step: int, timestep: int, latents: typing.Union[torch.FloatTensor, np.ndarray]): if p is None: return if isinstance(latents, np.ndarray): # latents from Onnx pipelines is ndarray. latents = torch.from_numpy(latents) shared.state.sampling_step = step shared.state.current_latent = latents latents = processing_correction.correction_callback(p, timestep, {'latents': latents}) if shared.state.interrupted or shared.state.skipped: raise AssertionError('Interrupted...') if shared.state.paused: shared.log.debug('Sampling paused') while shared.state.paused: if shared.state.interrupted or shared.state.skipped: raise AssertionError('Interrupted...') time.sleep(0.1) def diffusers_callback(pipe, step: int = 0, timestep: int = 0, kwargs: dict = {}): t0 = time.time() if devices.backend == "ipex": torch.xpu.synchronize(devices.device) elif devices.backend in {"cuda", "zluda", "rocm"}: torch.cuda.synchronize(devices.device) latents = kwargs.get('latents', None) if debug: debug_callback(f'Callback: step={step} timestep={timestep} latents={latents.shape if latents is not None else None} kwargs={list(kwargs)}') if shared.state.sampling_steps == 0 and getattr(pipe, 'num_timesteps', 0) > 0: shared.state.sampling_steps = pipe.num_timesteps shared.state.step() if shared.state.interrupted or shared.state.skipped: raise AssertionError('Interrupted...') if shared.state.paused: shared.log.debug('Sampling paused') while shared.state.paused: if shared.state.interrupted or shared.state.skipped: raise AssertionError('Interrupted...') time.sleep(0.1) if latents is None: return kwargs elif shared.opts.nan_skip: assert not torch.isnan(latents[..., 0, 0]).all(), f'NaN detected at step {step}: Skipping...' if p is None: return kwargs if len(getattr(p, 'ip_adapter_names', [])) > 0 and p.ip_adapter_names[0] != 'None': ip_adapter_scales = list(p.ip_adapter_scales) ip_adapter_starts = list(p.ip_adapter_starts) ip_adapter_ends = list(p.ip_adapter_ends) if any(end != 1 for end in ip_adapter_ends) or any(start != 0 for start in ip_adapter_starts): if 'Flux' in pipe.__class__.__name__: ip_adapter_scales = [(ip_adapter_starts[0] + (ip_adapter_ends[0] - ip_adapter_starts[0]) * (i / (19 - 1))) for i in range(19)] else: for i in range(len(ip_adapter_scales)): ip_adapter_scales[i] *= float(step >= pipe.num_timesteps * ip_adapter_starts[i]) ip_adapter_scales[i] *= float(step <= pipe.num_timesteps * ip_adapter_ends[i]) debug_callback(f"Callback: IP Adapter scales={ip_adapter_scales}") pipe.set_ip_adapter_scale(ip_adapter_scales) if step != getattr(pipe, 'num_timesteps', 0): kwargs = processing_correction.correction_callback(p, timestep, kwargs, initial=step == 0) kwargs = prompt_callback(step, kwargs) # monkey patch for diffusers callback issues if step == 0: pipe._cfg_end_applied = False # pylint: disable=protected-access cfg_end = getattr(p, "cfg_end", 1.0) or 1.0 total_steps = getattr(pipe, "num_timesteps", 0) target_step = int(total_steps * cfg_end) if total_steps else 0 if (cfg_end < 1.0) and not getattr(pipe, "_cfg_end_applied", False) and (step >= target_step): pipe._cfg_end_applied = True # pylint: disable=protected-access if "PAG" in shared.sd_model.__class__.__name__: pipe._guidance_scale = 1.001 if pipe._guidance_scale > 1 else pipe._guidance_scale # pylint: disable=protected-access pipe._pag_scale = 0.001 # pylint: disable=protected-access else: pipe._guidance_scale = 0.0 # pylint: disable=protected-access for key in ["prompt_embeds", "negative_prompt_embeds", "add_text_embeds", "add_time_ids"]: tensor = kwargs.get(key, None) if tensor is not None and hasattr(tensor, "chunk") and tensor.shape[0] % 2 == 0: kwargs[key] = tensor.chunk(2)[-1] try: current_noise_pred = kwargs.get("noise_pred", None) if current_noise_pred is None: current_noise_pred = kwargs.get("predicted_image_embedding", None) if hasattr(pipe, "_unpack_latents") and hasattr(pipe, "vae_scale_factor"): # FLUX.1 if p.hr_resize_mode > 0 and (p.hr_upscaler != 'None' or p.hr_resize_mode == 5) and p.is_hr_pass: width = max(getattr(p, 'width', 0), getattr(p, 'hr_upscale_to_x', 0)) height = max(getattr(p, 'height', 0), getattr(p, 'hr_upscale_to_y', 0)) else: width = getattr(p, 'width', 1024) height = getattr(p, 'height', 1024) shared.state.current_latent = pipe._unpack_latents(kwargs['latents'], height, width, pipe.vae_scale_factor) # pylint: disable=protected-access if current_noise_pred is not None: shared.state.current_noise_pred = pipe._unpack_latents(current_noise_pred, height, width, pipe.vae_scale_factor) # pylint: disable=protected-access else: shared.state.current_noise_pred = current_noise_pred elif hasattr(pipe, "_unpatchify_latents"): # FLUX.2 - unpack [B, seq, patch_ch] to [B, ch, H, W] vae_scale = getattr(pipe, 'vae_scale_factor', 8) if p.hr_resize_mode > 0 and (p.hr_upscaler != 'None' or p.hr_resize_mode == 5) and p.is_hr_pass: width = max(getattr(p, 'width', 0), getattr(p, 'hr_upscale_to_x', 0)) height = max(getattr(p, 'height', 0), getattr(p, 'hr_upscale_to_y', 0)) else: width = getattr(p, 'width', 1024) height = getattr(p, 'height', 1024) latents = kwargs['latents'] if len(latents.shape) == 3: # packed format [B, seq_len, patch_channels] b, seq_len, patch_ch = latents.shape channels = patch_ch // 4 # 4 = 2x2 patch h_patches = height // vae_scale // 2 w_patches = width // vae_scale // 2 if h_patches * w_patches != seq_len: # fallback to square assumption h_patches = w_patches = int(seq_len ** 0.5) # [B, h*w, C*4] -> [B, h, w, C, 2, 2] -> [B, C, h, 2, w, 2] -> [B, C, H, W] latents = latents.view(b, h_patches, w_patches, channels, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5).reshape(b, channels, h_patches * 2, w_patches * 2) shared.state.current_latent = latents if current_noise_pred is not None and len(current_noise_pred.shape) == 3: b, seq_len, patch_ch = current_noise_pred.shape channels = patch_ch // 4 h_patches = height // vae_scale // 2 w_patches = width // vae_scale // 2 if h_patches * w_patches != seq_len: h_patches = w_patches = int(seq_len ** 0.5) current_noise_pred = current_noise_pred.view(b, h_patches, w_patches, channels, 2, 2) current_noise_pred = current_noise_pred.permute(0, 3, 1, 4, 2, 5).reshape(b, channels, h_patches * 2, w_patches * 2) shared.state.current_noise_pred = current_noise_pred else: shared.state.current_latent = kwargs['latents'] shared.state.current_noise_pred = current_noise_pred if hasattr(pipe, "scheduler") and hasattr(pipe.scheduler, "sigmas") and hasattr(pipe.scheduler, "step_index") and pipe.scheduler.step_index is not None: try: shared.state.current_sigma = pipe.scheduler.sigmas[pipe.scheduler.step_index-1] shared.state.current_sigma_next = pipe.scheduler.sigmas[pipe.scheduler.step_index] if (shared.opts.schedulers_sigma_adjust != 1.0) and (timestep > 1000 * shared.opts.schedulers_sigma_adjust_min) and (timestep < 1000 * shared.opts.schedulers_sigma_adjust_max): pipe.scheduler.sigmas[pipe.scheduler.step_index+1] = pipe.scheduler.sigmas[pipe.scheduler.step_index+1] * shared.opts.schedulers_sigma_adjust p.extra_generation_params["Sigma adjust"] = shared.opts.schedulers_sigma_adjust except Exception: pass except Exception as e: global warned # pylint: disable=global-statement if not warned: shared.log.error(f'Callback: {e}') warned = True # from modules import errors # errors.display(e, 'Callback') if shared.cmd_opts.profile and shared.profiler is not None: shared.profiler.step() t1 = time.time() timer.process.add('callback', t1 - t0) return kwargs