mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
119 lines
4.9 KiB
Python
119 lines
4.9 KiB
Python
# Copyright (C) 2024 NVIDIA Corporation. All rights reserved.
|
|
#
|
|
# This work is licensed under the LICENSE file
|
|
# located at the root directory.
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
|
|
|
|
## Attention Utils
|
|
def get_dynamic_threshold(tensor):
|
|
from skimage import filters
|
|
return filters.threshold_otsu(tensor.float().cpu().numpy())
|
|
|
|
|
|
def attn_map_to_binary(attention_map, scaler=1.):
|
|
from skimage import filters
|
|
attention_map_np = attention_map.float().cpu().numpy()
|
|
threshold_value = filters.threshold_otsu(attention_map_np) * scaler
|
|
binary_mask = (attention_map_np > threshold_value).astype(np.uint8)
|
|
|
|
return binary_mask
|
|
|
|
|
|
## Features
|
|
|
|
def gaussian_smooth(input_tensor, kernel_size=3, sigma=1):
|
|
"""
|
|
Function to apply Gaussian smoothing on each 2D slice of a 3D tensor.
|
|
"""
|
|
kernel = np.fromfunction(
|
|
lambda x, y: (1/ (2 * np.pi * sigma ** 2)) *
|
|
np.exp(-((x - (kernel_size - 1) / 2) ** 2 + (y - (kernel_size - 1) / 2) ** 2) / (2 * sigma ** 2)),
|
|
(kernel_size, kernel_size)
|
|
)
|
|
kernel = torch.Tensor(kernel / kernel.sum()).to(input_tensor.dtype).to(input_tensor.device)
|
|
# Add batch and channel dimensions to the kernel
|
|
kernel = kernel.unsqueeze(0).unsqueeze(0)
|
|
# Iterate over each 2D slice and apply convolution
|
|
smoothed_slices = []
|
|
for i in range(input_tensor.size(0)):
|
|
slice_tensor = input_tensor[i, :, :]
|
|
slice_tensor = F.conv2d(slice_tensor.unsqueeze(0).unsqueeze(0), kernel, padding=kernel_size // 2)[0, 0]
|
|
smoothed_slices.append(slice_tensor)
|
|
# Stack the smoothed slices to get the final tensor
|
|
smoothed_tensor = torch.stack(smoothed_slices, dim=0)
|
|
return smoothed_tensor
|
|
|
|
|
|
## Dense correspondence utils
|
|
|
|
def cos_dist(a, b):
|
|
a_norm = F.normalize(a, dim=-1)
|
|
b_norm = F.normalize(b, dim=-1)
|
|
res = a_norm @ b_norm.T
|
|
return 1 - res
|
|
|
|
|
|
def gen_nn_map(src_features, src_mask, tgt_features, tgt_mask, device, batch_size=100, tgt_size=768):
|
|
resized_src_features = F.interpolate(src_features.unsqueeze(0), size=tgt_size, mode='bilinear', align_corners=False).squeeze(0)
|
|
resized_src_features = resized_src_features.permute(1,2,0).view(tgt_size**2, -1)
|
|
resized_tgt_features = F.interpolate(tgt_features.unsqueeze(0), size=tgt_size, mode='bilinear', align_corners=False).squeeze(0)
|
|
resized_tgt_features = resized_tgt_features.permute(1,2,0).view(tgt_size**2, -1)
|
|
nearest_neighbor_indices = torch.zeros(tgt_size**2, dtype=torch.long, device=device)
|
|
nearest_neighbor_distances = torch.zeros(tgt_size**2, dtype=src_features.dtype, device=device)
|
|
if not batch_size:
|
|
batch_size = tgt_size**2
|
|
for i in range(0, tgt_size**2, batch_size):
|
|
distances = cos_dist(resized_src_features, resized_tgt_features[i:i+batch_size])
|
|
distances[~src_mask] = 2.
|
|
min_distances, min_indices = torch.min(distances, dim=0)
|
|
nearest_neighbor_indices[i:i+batch_size] = min_indices
|
|
nearest_neighbor_distances[i:i+batch_size] = min_distances
|
|
return nearest_neighbor_indices, nearest_neighbor_distances
|
|
|
|
|
|
def cyclic_nn_map(features, masks, latent_resolutions, device):
|
|
bsz = features.shape[0]
|
|
nn_map_dict = {}
|
|
nn_distances_dict = {}
|
|
|
|
for tgt_size in latent_resolutions:
|
|
nn_map = torch.empty(bsz, bsz, tgt_size**2, dtype=torch.long, device=device)
|
|
nn_distances = torch.full((bsz, bsz, tgt_size**2), float('inf'), dtype=features.dtype, device=device)
|
|
|
|
for i in range(bsz):
|
|
for j in range(bsz):
|
|
if i != j:
|
|
nearest_neighbor_indices, nearest_neighbor_distances = gen_nn_map(features[j], masks[tgt_size][j], features[i], masks[tgt_size][i], device, batch_size=None, tgt_size=tgt_size)
|
|
nn_map[i,j] = nearest_neighbor_indices
|
|
nn_distances[i,j] = nearest_neighbor_distances
|
|
|
|
nn_map_dict[tgt_size] = nn_map
|
|
nn_distances_dict[tgt_size] = nn_distances
|
|
|
|
return nn_map_dict, nn_distances_dict
|
|
|
|
|
|
def anchor_nn_map(features, anchor_features, masks, anchor_masks, latent_resolutions, device):
|
|
bsz = features.shape[0]
|
|
anchor_bsz = anchor_features.shape[0]
|
|
nn_map_dict = {}
|
|
nn_distances_dict = {}
|
|
|
|
for tgt_size in latent_resolutions:
|
|
nn_map = torch.empty(bsz, anchor_bsz, tgt_size**2, dtype=torch.long, device=device)
|
|
nn_distances = torch.full((bsz, anchor_bsz, tgt_size**2), float('inf'), dtype=features.dtype, device=device)
|
|
|
|
for i in range(bsz):
|
|
for j in range(anchor_bsz):
|
|
nearest_neighbor_indices, nearest_neighbor_distances = gen_nn_map(anchor_features[j], anchor_masks[tgt_size][j], features[i], masks[tgt_size][i], device, batch_size=None, tgt_size=tgt_size)
|
|
nn_map[i,j] = nearest_neighbor_indices
|
|
nn_distances[i,j] = nearest_neighbor_distances
|
|
nn_map_dict[tgt_size] = nn_map
|
|
nn_distances_dict[tgt_size] = nn_distances
|
|
|
|
return nn_map_dict, nn_distances_dict
|