from pathlib import Path from typing import List, Type, Any, Dict, Union import math from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline from diffusers.image_processor import VaeImageProcessor from diffusers.models.attention_processor import Attention import numpy as np import PIL.Image as Image import torch import torch.nn.functional as F from .utils import cache_dir, auto_autocast from .experiment import GenerationExperiment from .heatmap import RawHeatMapCollection, GlobalHeatMap from .hook import ObjectHooker, AggregateHooker, UNetCrossAttentionLocator __all__ = ['trace', 'DiffusionHeatMapHooker', 'GlobalHeatMap'] class DiffusionHeatMapHooker(AggregateHooker): def __init__( self, pipeline: Union[StableDiffusionPipeline, StableDiffusionXLPipeline], low_memory: bool = False, load_heads: bool = False, save_heads: bool = False, data_dir: str = None ): self.all_heat_maps = RawHeatMapCollection() h = (pipeline.unet.config.sample_size * pipeline.vae_scale_factor) self.latent_hw = 4096 if h == 512 or h == 1024 else 9216 # 64x64 or 96x96 depending on if it's 2.0-v or 2.0 locate_middle = load_heads or save_heads self.locator = UNetCrossAttentionLocator(restrict={0} if low_memory else None, locate_middle_block=locate_middle) self.last_prompt: str = '' self.last_image: Image = None self.time_idx = 0 self._gen_idx = 0 modules = [ UNetCrossAttentionHooker( x, self, layer_idx=idx, latent_hw=self.latent_hw, load_heads=load_heads, save_heads=save_heads, data_dir=data_dir ) for idx, x in enumerate(self.locator.locate(pipeline.unet)) ] modules.append(PipelineHooker(pipeline, self)) if type(pipeline) == StableDiffusionXLPipeline: modules.append(ImageProcessorHooker(pipeline.image_processor, self)) super().__init__(modules) self.pipe = pipeline def time_callback(self, *args, **kwargs): self.time_idx += 1 @property def layer_names(self): return self.locator.layer_names def to_experiment(self, path, seed=None, id='.', subtype='.', **compute_kwargs): # type: (Union[Path, str], int, str, str, Dict[str, Any]) -> GenerationExperiment """Exports the last generation call to a serializable generation experiment.""" return GenerationExperiment( self.last_image, self.compute_global_heat_map(**compute_kwargs).heat_maps, self.last_prompt, seed=seed, id=id, subtype=subtype, path=path, tokenizer=self.pipe.tokenizer, ) def compute_global_heat_map(self, prompt=None, factors=None, head_idx=None, layer_idx=None, normalize=False): # type: (str, List[float], int, int, bool) -> GlobalHeatMap """ Compute the global heat map for the given prompt, aggregating across time (inference steps) and space (different spatial transformer block heat maps). Args: prompt: The prompt to compute the heat map for. If none, uses the last prompt that was used for generation. factors: Restrict the application to heat maps with spatial factors in this set. If `None`, use all sizes. head_idx: Restrict the application to heat maps with this head index. If `None`, use all heads. layer_idx: Restrict the application to heat maps with this layer index. If `None`, use all layers. Returns: A heat map object for computing word-level heat maps. """ heat_maps = self.all_heat_maps if prompt is None: prompt = self.last_prompt if factors is None: factors = {0, 1, 2, 4, 8, 16, 32, 64} else: factors = set(factors) all_merges = [] x = int(np.sqrt(self.latent_hw)) with auto_autocast(dtype=torch.float32): for (factor, layer, head), heat_map in heat_maps: if factor in factors and (head_idx is None or head_idx == head) and (layer_idx is None or layer_idx == layer): heat_map = heat_map.unsqueeze(1) # The clamping fixes undershoot. all_merges.append(F.interpolate(heat_map, size=(x, x), mode='bicubic').clamp_(min=0)) try: maps = torch.stack(all_merges, dim=0) except RuntimeError: if head_idx is not None or layer_idx is not None: raise RuntimeError('No heat maps found for the given parameters.') else: raise RuntimeError('No heat maps found. Did you forget to call `with trace(...)` during generation?') maps = maps.mean(0)[:, 0] maps = maps[:len(self.pipe.tokenizer.tokenize(prompt)) + 2] # 1 for SOS and 1 for padding if normalize: maps = maps / (maps[1:-1].sum(0, keepdim=True) + 1e-6) # drop out [SOS] and [PAD] for proper probabilities return GlobalHeatMap(self.pipe.tokenizer, prompt, maps) class ImageProcessorHooker(ObjectHooker[VaeImageProcessor]): def __init__(self, processor: VaeImageProcessor, parent_trace: 'trace'): super().__init__(processor) self.parent_trace = parent_trace def _hooked_postprocess(hk_self, _: VaeImageProcessor, *args, **kwargs): images = hk_self.monkey_super('postprocess', *args, **kwargs) hk_self.parent_trace.last_image = images[0] return images def _hook_impl(self): self.monkey_patch('postprocess', self._hooked_postprocess) class PipelineHooker(ObjectHooker[StableDiffusionPipeline]): def __init__(self, pipeline: StableDiffusionPipeline, parent_trace: 'trace'): super().__init__(pipeline) self.heat_maps = parent_trace.all_heat_maps self.parent_trace = parent_trace def _hooked_run_safety_checker(hk_self, self: StableDiffusionPipeline, image, *args, **kwargs): image, has_nsfw = hk_self.monkey_super('run_safety_checker', image, *args, **kwargs) if self.image_processor: if torch.is_tensor(image): images = self.image_processor.postprocess(image, output_type='pil') else: images = self.image_processor.numpy_to_pil(image) else: images = self.numpy_to_pil(image) hk_self.parent_trace.last_image = images[len(images)-1] return image, has_nsfw def _hooked_check_inputs(hk_self, _: StableDiffusionPipeline, prompt: Union[str, List[str]], *args, **kwargs): if not isinstance(prompt, str) and len(prompt) > 1: raise ValueError('Only single prompt generation is supported for heat map computation.') elif not isinstance(prompt, str): last_prompt = prompt[0] else: last_prompt = prompt hk_self.heat_maps.clear() hk_self.parent_trace.last_prompt = last_prompt return hk_self.monkey_super('check_inputs', prompt, *args, **kwargs) def _hook_impl(self): self.monkey_patch('run_safety_checker', self._hooked_run_safety_checker, strict=False) # not present in SDXL self.monkey_patch('check_inputs', self._hooked_check_inputs) class UNetCrossAttentionHooker(ObjectHooker[Attention]): def __init__( self, module: Attention, parent_trace: 'trace', context_size: int = 77, layer_idx: int = 0, latent_hw: int = 9216, load_heads: bool = False, save_heads: bool = False, data_dir: Union[str, Path] = None, ): super().__init__(module) self.heat_maps = parent_trace.all_heat_maps self.context_size = context_size self.layer_idx = layer_idx self.latent_hw = latent_hw self.load_heads = load_heads self.save_heads = save_heads self.trace = parent_trace if data_dir is not None: data_dir = Path(data_dir) else: data_dir = cache_dir() / 'heads' self.data_dir = data_dir self.data_dir.mkdir(parents=True, exist_ok=True) @torch.no_grad() def _unravel_attn(self, x): # type: (torch.Tensor) -> torch.Tensor # x shape: (heads, height * width, tokens) """ Unravels the attention, returning it as a collection of heat maps. Args: x (`torch.Tensor`): cross attention slice/map between the words and the tokens. value (`torch.Tensor`): the value tensor. Returns: `List[Tuple[int, torch.Tensor]]`: the list of heat maps across heads. """ h = w = int(math.sqrt(x.size(1))) maps = [] x = x.permute(2, 0, 1) with auto_autocast(dtype=torch.float32): for map_ in x: map_ = map_.view(map_.size(0), h, w) # For Instruct Pix2Pix, divide the map into three parts: text condition, image condition and unconditional, # and only keep the text condition part, which is first of the three parts(as per diffusers implementation). if map_.size(0) == 24: map_ = map_[:((map_.size(0) // 3)+1)] # Filter out unconditional and image condition else: map_ = map_[map_.size(0) // 2:] # # Filter out unconditional maps.append(map_) maps = torch.stack(maps, 0) # shape: (tokens, heads, height, width) return maps.permute(1, 0, 2, 3).contiguous() # shape: (heads, tokens, height, width) def _save_attn(self, attn_slice: torch.Tensor): torch.save(attn_slice, self.data_dir / f'{self.trace._gen_idx}.pt') def _load_attn(self) -> torch.Tensor: return torch.load(self.data_dir / f'{self.trace._gen_idx}.pt') def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, ): """Capture attentions and aggregate them.""" batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross is not None: encoder_hidden_states = attn.norm_cross(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) # DAAM save heads if self.save_heads: self._save_attn(attention_probs) elif self.load_heads: attention_probs = self._load_attn() # compute shape factor factor = int(math.sqrt(self.latent_hw // attention_probs.shape[1])) self.trace._gen_idx += 1 # skip if too large if attention_probs.shape[-1] == self.context_size and factor != 8: # shape: (batch_size, 64 // factor, 64 // factor, 77) maps = self._unravel_attn(attention_probs) for head_idx, heatmap in enumerate(maps): self.heat_maps.update(factor, self.layer_idx, head_idx, heatmap) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states def _hook_impl(self): self.original_processor = self.module.processor self.module.set_processor(self) def _unhook_impl(self): self.module.set_processor(self.original_processor) @property def num_heat_maps(self): return len(next(iter(self.heat_maps.values()))) trace: Type[DiffusionHeatMapHooker] = DiffusionHeatMapHooker