1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/scripts/consistory/utils/ptp_utils.py
Vladimir Mandic 8c676d0a61 lint formatting
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-07-05 17:13:31 -04:00

195 lines
8.9 KiB
Python

# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# MIT License
#
# Copyright (c) 2023 AttendAndExcite
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Not a contribution
# Changes made by NVIDIA CORPORATION & AFFILIATES enabling ConsiStory or otherwise documented as NVIDIA-proprietary
# are not a contribution and subject to the license under the LICENSE file located at the root directory.
import torch
from collections import defaultdict
import numpy as np
from typing import Union, List
from PIL import Image
from modules.consistory.utils.general_utils import attn_map_to_binary
import torch.nn.functional as F
class AttentionStore:
def __init__(self, attention_store_kwargs):
"""
Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion
process
"""
self.attn_res = attention_store_kwargs.get('attn_res', (32,32))
self.token_indices = attention_store_kwargs['token_indices']
bsz = self.token_indices.size(1)
self.mask_background_query = attention_store_kwargs.get('mask_background_query', False)
self.original_attn_masks = attention_store_kwargs.get('original_attn_masks', None)
self.extended_mapping = attention_store_kwargs.get('extended_mapping', torch.ones(bsz, bsz).bool())
self.mask_dropout = attention_store_kwargs.get('mask_dropout', 0.0)
torch.manual_seed(0) # For dropout mask reproducibility
self.curr_iter = 0
self.ALL_RES = [32, 64]
self.step_store = defaultdict(list)
self.attn_masks = {res: None for res in self.ALL_RES}
self.last_mask = {res: None for res in self.ALL_RES}
self.last_mask_dropout = {res: None for res in self.ALL_RES}
def __call__(self, attn, is_cross: bool, place_in_unet: str, attn_heads: int):
if is_cross and attn.shape[1] == np.prod(self.attn_res):
guidance_attention = attn[attn.size(0)//2:]
batched_guidance_attention = guidance_attention.reshape([guidance_attention.shape[0]//attn_heads, attn_heads, *guidance_attention.shape[1:]])
batched_guidance_attention = batched_guidance_attention.mean(dim=1)
self.step_store[place_in_unet].append(batched_guidance_attention)
def reset(self):
self.step_store = defaultdict(list)
self.attn_masks = {res: None for res in self.ALL_RES}
self.last_mask = {res: None for res in self.ALL_RES}
self.last_mask_dropout = {res: None for res in self.ALL_RES}
torch.cuda.empty_cache()
def aggregate_last_steps_attention(self) -> torch.Tensor:
"""Aggregates the attention across the different layers and heads at the specified resolution."""
attention_maps = torch.cat([torch.stack(x[-20:]) for x in self.step_store.values()]).mean(dim=0)
bsz, wh, _ = attention_maps.shape
# Create attention maps for each concept token, for each batch item
agg_attn_maps = []
for i in range(bsz):
curr_prompt_indices = []
for concept_token_indices in self.token_indices:
if concept_token_indices[i] != -1:
curr_prompt_indices.append(attention_maps[i, :, concept_token_indices[i]].view(*self.attn_res))
agg_attn_maps.append(torch.stack(curr_prompt_indices))
# Upsample the attention maps to the target resolution
# and create the attention masks, unifying masks across the different concepts
for tgt_size in self.ALL_RES:
pixels = tgt_size ** 2
tgt_agg_attn_maps = [F.interpolate(x.unsqueeze(1), size=tgt_size, mode='bilinear').squeeze(1) for x in agg_attn_maps]
attn_masks = []
for batch_item_map in tgt_agg_attn_maps:
concept_attn_masks = []
for concept_maps in batch_item_map:
concept_attn_masks.append(torch.from_numpy(attn_map_to_binary(concept_maps, 1.)).to(attention_maps.device).bool().view(-1))
concept_attn_masks = torch.stack(concept_attn_masks, dim=0).max(dim=0).values
attn_masks.append(concept_attn_masks)
attn_masks = torch.stack(attn_masks)
self.last_mask[tgt_size] = attn_masks.clone()
# Add mask dropout
if self.curr_iter < 1000:
rand_mask = (torch.rand_like(attn_masks.float()) < self.mask_dropout)
attn_masks[rand_mask] = False
self.last_mask_dropout[tgt_size] = attn_masks.clone()
# # Create subject driven extended self attention masks
# output_attn_mask = torch.zeros((bsz, tgt_size**2, attn_masks.view(-1).size(0)), device=attn_masks.device).bool()
# for i in range(bsz):
# for j in range(bsz):
# if i==j:
# output_attn_mask[i, :, j*pixels:(j+1)*pixels] = 1
# else:
# if self.extended_mapping[i,j]:
# if not self.mask_background_query:
# output_attn_mask[i, :, j*pixels:(j+1)*pixels] = attn_masks[j].unsqueeze(0).expand(pixels, -1)
# else:
# output_attn_mask[i, attn_masks[i], j*pixels:(j+1)*pixels] = attn_masks[j].unsqueeze(0).expand(attn_masks[i].sum(), -1)
# self.attn_masks[tgt_size] = output_attn_mask
def get_attn_mask_bias(self, tgt_size, bsz=None):
attn_mask = self.attn_masks[tgt_size] if self.original_attn_masks is None else self.original_attn_masks[tgt_size]
if attn_mask is None:
return None
attn_bias = torch.zeros_like(attn_mask, dtype=torch.float16)
attn_bias[~attn_mask] = float('-inf')
if bsz and bsz != attn_bias.shape[0]:
attn_bias = attn_bias.repeat(bsz // attn_bias.shape[0], 1, 1)
return attn_bias
def get_extended_attn_mask_instance(self, width, i):
attn_mask = self.last_mask_dropout[width]
if attn_mask is None:
return None
n_patches = width**2
output_attn_mask = torch.zeros((attn_mask.shape[0] * attn_mask.shape[1],), device=attn_mask.device, dtype=torch.bool)
for j in range(attn_mask.shape[0]):
if i==j:
output_attn_mask[j*n_patches:(j+1)*n_patches] = 1
else:
if self.extended_mapping[i,j]:
if not self.mask_background_query:
output_attn_mask[j*n_patches:(j+1)*n_patches] = attn_mask[j].unsqueeze(0) #.expand(n_patches, -1)
else:
raise NotImplementedError('mask_background_query is not supported anymore')
output_attn_mask[0, attn_mask[i], k*n_patches:(k+1)*n_patches] = attn_mask[j].unsqueeze(0).expand(attn_mask[i].sum(), -1)
return output_attn_mask