mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-29 05:02:09 +03:00
390 lines
17 KiB
Python
390 lines
17 KiB
Python
import os
|
|
import inspect
|
|
from statistics import stdev, mean
|
|
from rich import progress
|
|
import torch
|
|
from torch import einsum
|
|
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
|
from einops import rearrange, repeat
|
|
from modules import devices, shared, hashes, errors, files_cache
|
|
|
|
|
|
loaded_hypernetworks = []
|
|
|
|
class HypernetworkModule(torch.nn.Module):
|
|
activation_dict = {
|
|
"linear": torch.nn.Identity,
|
|
"relu": torch.nn.ReLU,
|
|
"leakyrelu": torch.nn.LeakyReLU,
|
|
"elu": torch.nn.ELU,
|
|
"swish": torch.nn.Hardswish,
|
|
"tanh": torch.nn.Tanh,
|
|
"sigmoid": torch.nn.Sigmoid,
|
|
}
|
|
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
|
|
|
|
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
|
|
add_layer_norm=False, activate_output=False, dropout_structure=None):
|
|
super().__init__()
|
|
self.multiplier = 1.0
|
|
assert layer_structure is not None, "layer_structure must not be None"
|
|
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
|
|
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
|
|
linears = []
|
|
for i in range(len(layer_structure) - 1):
|
|
# Add a fully-connected layer
|
|
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
|
|
# Add an activation func except last layer
|
|
if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
|
|
pass
|
|
elif activation_func in self.activation_dict:
|
|
linears.append(self.activation_dict[activation_func]())
|
|
else:
|
|
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
|
|
# Add layer normalization
|
|
if add_layer_norm:
|
|
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
|
|
# Everything should be now parsed into dropout structure, and applied here.
|
|
# Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
|
|
if dropout_structure is not None and dropout_structure[i+1] > 0:
|
|
assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
|
|
linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
|
|
# Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
|
|
self.linear = torch.nn.Sequential(*linears)
|
|
if state_dict is not None:
|
|
self.fix_old_state_dict(state_dict)
|
|
self.load_state_dict(state_dict)
|
|
else:
|
|
for layer in self.linear:
|
|
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
|
w, b = layer.weight.data, layer.bias.data
|
|
if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
|
|
normal_(w, mean=0.0, std=0.01)
|
|
normal_(b, mean=0.0, std=0)
|
|
elif weight_init == 'XavierUniform':
|
|
xavier_uniform_(w)
|
|
zeros_(b)
|
|
elif weight_init == 'XavierNormal':
|
|
xavier_normal_(w)
|
|
zeros_(b)
|
|
elif weight_init == 'KaimingUniform':
|
|
kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
|
zeros_(b)
|
|
elif weight_init == 'KaimingNormal':
|
|
kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
|
|
zeros_(b)
|
|
else:
|
|
raise KeyError(f"Key {weight_init} is not defined as initialization!")
|
|
self.to(devices.device)
|
|
|
|
def fix_old_state_dict(self, state_dict):
|
|
changes = {
|
|
'linear1.bias': 'linear.0.bias',
|
|
'linear1.weight': 'linear.0.weight',
|
|
'linear2.bias': 'linear.1.bias',
|
|
'linear2.weight': 'linear.1.weight',
|
|
}
|
|
for fr, to in changes.items():
|
|
x = state_dict.get(fr, None)
|
|
if x is None:
|
|
continue
|
|
del state_dict[fr]
|
|
state_dict[to] = x
|
|
|
|
def forward(self, x):
|
|
return x + self.linear(x) * (self.multiplier if not self.training else 1)
|
|
|
|
def trainables(self):
|
|
layer_structure = []
|
|
for layer in self.linear:
|
|
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
|
|
layer_structure += [layer.weight, layer.bias]
|
|
return layer_structure
|
|
|
|
|
|
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
|
|
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
|
|
if layer_structure is None:
|
|
layer_structure = [1, 2, 1]
|
|
if not use_dropout:
|
|
return [0] * len(layer_structure)
|
|
dropout_values = [0]
|
|
dropout_values.extend([0.3] * (len(layer_structure) - 3))
|
|
if last_layer_dropout:
|
|
dropout_values.append(0.3)
|
|
else:
|
|
dropout_values.append(0)
|
|
dropout_values.append(0)
|
|
return dropout_values
|
|
|
|
|
|
class Hypernetwork:
|
|
filename = None
|
|
name = None
|
|
|
|
def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
|
|
self.filename = None
|
|
self.name = name
|
|
self.layers = {}
|
|
self.step = 0
|
|
self.sd_checkpoint = None
|
|
self.sd_checkpoint_name = None
|
|
self.layer_structure = layer_structure
|
|
self.activation_func = activation_func
|
|
self.weight_init = weight_init
|
|
self.add_layer_norm = add_layer_norm
|
|
self.use_dropout = use_dropout
|
|
self.activate_output = activate_output
|
|
self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
|
|
self.dropout_structure = kwargs.get('dropout_structure', None)
|
|
if self.dropout_structure is None:
|
|
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
|
self.optimizer_name = None
|
|
self.optimizer_state_dict = None
|
|
self.optional_info = None
|
|
for size in enable_sizes or []:
|
|
self.layers[size] = (
|
|
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
|
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
|
|
HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
|
|
self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
|
|
)
|
|
self.eval()
|
|
|
|
def weights(self):
|
|
res = []
|
|
for layers in self.layers.values():
|
|
for layer in layers:
|
|
res += layer.parameters()
|
|
return res
|
|
|
|
def train(self, mode=True):
|
|
for layers in self.layers.values():
|
|
for layer in layers:
|
|
layer.train(mode=mode)
|
|
for param in layer.parameters():
|
|
param.requires_grad = mode
|
|
|
|
def to(self, device):
|
|
for layers in self.layers.values():
|
|
for layer in layers:
|
|
layer.to(device)
|
|
|
|
return self
|
|
|
|
def set_multiplier(self, multiplier):
|
|
for layers in self.layers.values():
|
|
for layer in layers:
|
|
layer.multiplier = multiplier
|
|
|
|
return self
|
|
|
|
def eval(self):
|
|
for layers in self.layers.values():
|
|
for layer in layers:
|
|
layer.eval()
|
|
for param in layer.parameters():
|
|
param.requires_grad = False
|
|
|
|
def save(self, filename):
|
|
state_dict = {}
|
|
optimizer_saved_dict = {}
|
|
for k, v in self.layers.items():
|
|
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
|
|
state_dict['step'] = self.step
|
|
state_dict['name'] = self.name
|
|
state_dict['layer_structure'] = self.layer_structure
|
|
state_dict['activation_func'] = self.activation_func
|
|
state_dict['is_layer_norm'] = self.add_layer_norm
|
|
state_dict['weight_initialization'] = self.weight_init
|
|
state_dict['sd_checkpoint'] = self.sd_checkpoint
|
|
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
|
|
state_dict['activate_output'] = self.activate_output
|
|
state_dict['use_dropout'] = self.use_dropout
|
|
state_dict['dropout_structure'] = self.dropout_structure
|
|
state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
|
|
state_dict['optional_info'] = self.optional_info if self.optional_info else None
|
|
if self.optimizer_name is not None:
|
|
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
|
|
torch.save(state_dict, filename)
|
|
if shared.opts.save_optimizer_state and self.optimizer_state_dict:
|
|
optimizer_saved_dict['hash'] = self.shorthash()
|
|
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
|
|
torch.save(optimizer_saved_dict, f"{filename}.optim")
|
|
|
|
def load(self, filename):
|
|
self.filename = filename if os.path.exists(filename) else os.path.join(shared.opts.hypernetwork_dir, filename)
|
|
if self.name is None:
|
|
self.name = os.path.splitext(os.path.basename(self.filename))[0]
|
|
with progress.open(self.filename, 'rb', description=f'Load hypernetwork: [cyan]{self.filename}', auto_refresh=True, console=shared.console) as f:
|
|
state_dict = torch.load(f, map_location='cpu')
|
|
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
|
|
self.optional_info = state_dict.get('optional_info', None)
|
|
self.activation_func = state_dict.get('activation_func', None)
|
|
self.weight_init = state_dict.get('weight_initialization', 'Normal')
|
|
self.add_layer_norm = state_dict.get('is_layer_norm', False)
|
|
self.dropout_structure = state_dict.get('dropout_structure', None)
|
|
self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
|
|
self.activate_output = state_dict.get('activate_output', True)
|
|
self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
|
|
# Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
|
|
if self.dropout_structure is None:
|
|
self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
|
|
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}
|
|
if self.shorthash() == optimizer_saved_dict.get('hash', None):
|
|
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
|
else:
|
|
self.optimizer_state_dict = None
|
|
if self.optimizer_state_dict:
|
|
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
|
|
else:
|
|
self.optimizer_name = "AdamW"
|
|
for size, sd in state_dict.items():
|
|
if type(size) == int:
|
|
self.layers[size] = (
|
|
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
|
|
self.add_layer_norm, self.activate_output, self.dropout_structure),
|
|
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
|
|
self.add_layer_norm, self.activate_output, self.dropout_structure),
|
|
)
|
|
self.name = state_dict.get('name', self.name)
|
|
self.step = state_dict.get('step', 0)
|
|
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
|
|
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
|
|
self.eval()
|
|
|
|
def shorthash(self):
|
|
sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')
|
|
return sha256[0:10] if sha256 else None
|
|
|
|
|
|
def list_hypernetworks(path):
|
|
hypernetworks = {
|
|
os.path.splitext(os.path.basename(hypernetwork_path))[0]: hypernetwork_path
|
|
for hypernetwork_path
|
|
in files_cache.list_files(path, ext_filter=['.pt'], recursive=files_cache.not_hidden)
|
|
}
|
|
return hypernetworks
|
|
|
|
|
|
def load_hypernetwork(name):
|
|
path = shared.hypernetworks.get(name, None)
|
|
if path is None:
|
|
return None
|
|
hypernetwork = Hypernetwork()
|
|
try:
|
|
hypernetwork.load(path)
|
|
except Exception as e:
|
|
errors.display(e, f'hypernetwork load: {path}')
|
|
return None
|
|
return hypernetwork
|
|
|
|
|
|
def load_hypernetworks(names, multipliers=None):
|
|
already_loaded = {}
|
|
for hn in loaded_hypernetworks:
|
|
if hn.name in names:
|
|
already_loaded[hn.name] = hn
|
|
loaded_hypernetworks.clear()
|
|
for i, name in enumerate(names):
|
|
hypernetwork = already_loaded.get(name, None)
|
|
if hypernetwork is None:
|
|
hypernetwork = load_hypernetwork(name)
|
|
if hypernetwork is None:
|
|
continue
|
|
hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
|
|
loaded_hypernetworks.append(hypernetwork)
|
|
|
|
|
|
def find_closest_hypernetwork_name(search: str):
|
|
if not search:
|
|
return None
|
|
search = search.lower()
|
|
applicable = [name for name in shared.hypernetworks if search in name.lower()]
|
|
if not applicable:
|
|
return None
|
|
applicable = sorted(applicable, key=lambda name: len(name))
|
|
return applicable[0]
|
|
|
|
|
|
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
|
|
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
|
|
if hypernetwork_layers is None:
|
|
return context_k, context_v
|
|
if layer is not None:
|
|
layer.hyper_k = hypernetwork_layers[0]
|
|
layer.hyper_v = hypernetwork_layers[1]
|
|
context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))
|
|
context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))
|
|
return context_k, context_v
|
|
|
|
|
|
def apply_hypernetworks(hypernetworks, context, layer=None):
|
|
context_k = context
|
|
context_v = context
|
|
for hypernetwork in hypernetworks:
|
|
context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)
|
|
return context_k, context_v
|
|
|
|
|
|
def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
|
from ldm.util import default
|
|
h = self.heads
|
|
q = self.to_q(x)
|
|
context = default(context, x)
|
|
context_k, context_v = apply_hypernetworks(loaded_hypernetworks, context, self)
|
|
k = self.to_k(context_k)
|
|
v = self.to_v(context_v)
|
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
|
if mask is not None:
|
|
mask = rearrange(mask, 'b ... -> b (...)')
|
|
max_neg_value = -torch.finfo(sim.dtype).max
|
|
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
|
sim.masked_fill_(~mask, max_neg_value)
|
|
# attention, what we cannot get enough of
|
|
attn = sim.softmax(dim=-1)
|
|
out = einsum('b i j, b j d -> b i d', attn, v)
|
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
|
return self.to_out(out)
|
|
|
|
|
|
def stack_conds(conds):
|
|
if len(conds) == 1:
|
|
return torch.stack(conds)
|
|
# same as in reconstruct_multicond_batch
|
|
token_count = max([x.shape[0] for x in conds])
|
|
for i in range(len(conds)):
|
|
if conds[i].shape[0] != token_count:
|
|
last_vector = conds[i][-1:]
|
|
last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
|
|
conds[i] = torch.vstack([conds[i], last_vector_repeated])
|
|
return torch.stack(conds)
|
|
|
|
|
|
def statistics(data):
|
|
if len(data) < 2:
|
|
std = 0
|
|
else:
|
|
std = stdev(data)
|
|
total_information = f"loss:{mean(data):.3f}" + "\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
|
|
recent_data = data[-32:]
|
|
if len(recent_data) < 2:
|
|
std = 0
|
|
else:
|
|
std = stdev(recent_data)
|
|
recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + "\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
|
|
return total_information, recent_information
|
|
|
|
|
|
def report_statistics(loss_info:dict):
|
|
keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
|
|
for key in keys:
|
|
try:
|
|
print("Loss statistics for file " + key)
|
|
info, recent = statistics(list(loss_info[key]))
|
|
print(info)
|
|
print(recent)
|
|
except Exception as e:
|
|
print(e)
|