1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/scripts/consistory/utils/general_utils.py
Vladimir Mandic c4d9338d2e major refactoring of modules
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-07-03 09:18:38 -04:00

119 lines
4.9 KiB
Python

# Copyright (C) 2024 NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the LICENSE file
# located at the root directory.
import torch
import torch.nn.functional as F
import numpy as np
## Attention Utils
def get_dynamic_threshold(tensor):
from skimage import filters
return filters.threshold_otsu(tensor.float().cpu().numpy())
def attn_map_to_binary(attention_map, scaler=1.):
from skimage import filters
attention_map_np = attention_map.float().cpu().numpy()
threshold_value = filters.threshold_otsu(attention_map_np) * scaler
binary_mask = (attention_map_np > threshold_value).astype(np.uint8)
return binary_mask
## Features
def gaussian_smooth(input_tensor, kernel_size=3, sigma=1):
"""
Function to apply Gaussian smoothing on each 2D slice of a 3D tensor.
"""
kernel = np.fromfunction(
lambda x, y: (1/ (2 * np.pi * sigma ** 2)) *
np.exp(-((x - (kernel_size - 1) / 2) ** 2 + (y - (kernel_size - 1) / 2) ** 2) / (2 * sigma ** 2)),
(kernel_size, kernel_size)
)
kernel = torch.Tensor(kernel / kernel.sum()).to(input_tensor.dtype).to(input_tensor.device)
# Add batch and channel dimensions to the kernel
kernel = kernel.unsqueeze(0).unsqueeze(0)
# Iterate over each 2D slice and apply convolution
smoothed_slices = []
for i in range(input_tensor.size(0)):
slice_tensor = input_tensor[i, :, :]
slice_tensor = F.conv2d(slice_tensor.unsqueeze(0).unsqueeze(0), kernel, padding=kernel_size // 2)[0, 0]
smoothed_slices.append(slice_tensor)
# Stack the smoothed slices to get the final tensor
smoothed_tensor = torch.stack(smoothed_slices, dim=0)
return smoothed_tensor
## Dense correspondence utils
def cos_dist(a, b):
a_norm = F.normalize(a, dim=-1)
b_norm = F.normalize(b, dim=-1)
res = a_norm @ b_norm.T
return 1 - res
def gen_nn_map(src_features, src_mask, tgt_features, tgt_mask, device, batch_size=100, tgt_size=768):
resized_src_features = F.interpolate(src_features.unsqueeze(0), size=tgt_size, mode='bilinear', align_corners=False).squeeze(0)
resized_src_features = resized_src_features.permute(1,2,0).view(tgt_size**2, -1)
resized_tgt_features = F.interpolate(tgt_features.unsqueeze(0), size=tgt_size, mode='bilinear', align_corners=False).squeeze(0)
resized_tgt_features = resized_tgt_features.permute(1,2,0).view(tgt_size**2, -1)
nearest_neighbor_indices = torch.zeros(tgt_size**2, dtype=torch.long, device=device)
nearest_neighbor_distances = torch.zeros(tgt_size**2, dtype=src_features.dtype, device=device)
if not batch_size:
batch_size = tgt_size**2
for i in range(0, tgt_size**2, batch_size):
distances = cos_dist(resized_src_features, resized_tgt_features[i:i+batch_size])
distances[~src_mask] = 2.
min_distances, min_indices = torch.min(distances, dim=0)
nearest_neighbor_indices[i:i+batch_size] = min_indices
nearest_neighbor_distances[i:i+batch_size] = min_distances
return nearest_neighbor_indices, nearest_neighbor_distances
def cyclic_nn_map(features, masks, latent_resolutions, device):
bsz = features.shape[0]
nn_map_dict = {}
nn_distances_dict = {}
for tgt_size in latent_resolutions:
nn_map = torch.empty(bsz, bsz, tgt_size**2, dtype=torch.long, device=device)
nn_distances = torch.full((bsz, bsz, tgt_size**2), float('inf'), dtype=features.dtype, device=device)
for i in range(bsz):
for j in range(bsz):
if i != j:
nearest_neighbor_indices, nearest_neighbor_distances = gen_nn_map(features[j], masks[tgt_size][j], features[i], masks[tgt_size][i], device, batch_size=None, tgt_size=tgt_size)
nn_map[i,j] = nearest_neighbor_indices
nn_distances[i,j] = nearest_neighbor_distances
nn_map_dict[tgt_size] = nn_map
nn_distances_dict[tgt_size] = nn_distances
return nn_map_dict, nn_distances_dict
def anchor_nn_map(features, anchor_features, masks, anchor_masks, latent_resolutions, device):
bsz = features.shape[0]
anchor_bsz = anchor_features.shape[0]
nn_map_dict = {}
nn_distances_dict = {}
for tgt_size in latent_resolutions:
nn_map = torch.empty(bsz, anchor_bsz, tgt_size**2, dtype=torch.long, device=device)
nn_distances = torch.full((bsz, anchor_bsz, tgt_size**2), float('inf'), dtype=features.dtype, device=device)
for i in range(bsz):
for j in range(anchor_bsz):
nearest_neighbor_indices, nearest_neighbor_distances = gen_nn_map(anchor_features[j], anchor_masks[tgt_size][j], features[i], masks[tgt_size][i], device, batch_size=None, tgt_size=tgt_size)
nn_map[i,j] = nearest_neighbor_indices
nn_distances[i,j] = nearest_neighbor_distances
nn_map_dict[tgt_size] = nn_map
nn_distances_dict[tgt_size] = nn_distances
return nn_map_dict, nn_distances_dict