mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
712 lines
37 KiB
Python
712 lines
37 KiB
Python
from typing import Type, Dict, Any, Tuple, Optional
|
|
import math
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from diffusers.pipelines import auto_pipeline
|
|
|
|
|
|
current_steps = 50
|
|
def sd15_hidiffusion_key():
|
|
modified_key = {}
|
|
modified_key['down_module_key'] = ['down_blocks.0.downsamplers.0.conv']
|
|
modified_key['down_module_key_extra'] = ['down_blocks.1']
|
|
modified_key['up_module_key'] = ['up_blocks.2.upsamplers.0.conv']
|
|
modified_key['up_module_key_extra'] = ['up_blocks.2']
|
|
modified_key['windown_attn_module_key'] = [
|
|
'down_blocks.0.attentions.0.transformer_blocks.0',
|
|
'down_blocks.0.attentions.1.transformer_blocks.0',
|
|
'up_blocks.3.attentions.0.transformer_blocks.0',
|
|
'up_blocks.3.attentions.1.transformer_blocks.0',
|
|
'up_blocks.3.attentions.2.transformer_blocks.0']
|
|
return modified_key
|
|
|
|
def sdxl_hidiffusion_key():
|
|
modified_key = {}
|
|
modified_key['down_module_key'] = ['down_blocks.1']
|
|
modified_key['down_module_key_extra'] = ['down_blocks.1.downsamplers.0.conv']
|
|
modified_key['up_module_key'] = ['up_blocks.1']
|
|
modified_key['up_module_key_extra'] = ['up_blocks.0.upsamplers.0.conv']
|
|
modified_key['windown_attn_module_key'] = [
|
|
'down_blocks.1.attentions.0.transformer_blocks.0',
|
|
'down_blocks.1.attentions.0.transformer_blocks.1',
|
|
'down_blocks.1.attentions.1.transformer_blocks.0',
|
|
'down_blocks.1.attentions.1.transformer_blocks.1',
|
|
'up_blocks.1.attentions.0.transformer_blocks.0',
|
|
'up_blocks.1.attentions.0.transformer_blocks.1',
|
|
'up_blocks.1.attentions.1.transformer_blocks.0',
|
|
'up_blocks.1.attentions.1.transformer_blocks.1',
|
|
'up_blocks.1.attentions.2.transformer_blocks.0',
|
|
'up_blocks.1.attentions.2.transformer_blocks.1']
|
|
return modified_key
|
|
|
|
|
|
def sdxl_turbo_hidiffusion_key():
|
|
modified_key = {}
|
|
modified_key['down_module_key'] = ['down_blocks.1']
|
|
modified_key['up_module_key'] = ['up_blocks.1']
|
|
modified_key['windown_attn_module_key'] = [
|
|
'down_blocks.1.attentions.0.transformer_blocks.0',
|
|
'down_blocks.1.attentions.0.transformer_blocks.1',
|
|
'down_blocks.1.attentions.1.transformer_blocks.0',
|
|
'down_blocks.1.attentions.1.transformer_blocks.1',
|
|
'up_blocks.1.attentions.0.transformer_blocks.0',
|
|
'up_blocks.1.attentions.0.transformer_blocks.1',
|
|
'up_blocks.1.attentions.1.transformer_blocks.0',
|
|
'up_blocks.1.attentions.1.transformer_blocks.1',
|
|
'up_blocks.1.attentions.2.transformer_blocks.0',
|
|
'up_blocks.1.attentions.2.transformer_blocks.1']
|
|
return modified_key
|
|
|
|
|
|
# T1_ratio: see T1 introduced in the main paper. T1 = number_inference_step * T1_ratio. A higher T1_ratio can better mitigate object duplication. We set T1_ratio=0.4 by default. You'd better adjust it to fit your prompt. Only active when apply_raunet=True.
|
|
# T2_ratio: see T2 introduced in the appendix, used in extreme resolution image generation. T2 = number_inference_step * T2_ratio. A higher T2_ratio can better mitigate object duplication. Only active when apply_raunet=True
|
|
switching_threshold_ratio_dict = {
|
|
'sd15_1024': {'T1_ratio': 0.4, 'T2_ratio': 0.0},
|
|
'sd15_2048': {'T1_ratio': 0.7, 'T2_ratio': 0.3},
|
|
'sdxl_2048': {'T1_ratio': 0.4, 'T2_ratio': 0.0},
|
|
'sdxl_4096': {'T1_ratio': 0.7, 'T2_ratio': 0.3},
|
|
'sdxl_turbo_1024': {'T1_ratio': 0.5, 'T2_ratio': 0.0},
|
|
}
|
|
|
|
|
|
text_to_img_controlnet_switching_threshold_ratio_dict = {
|
|
'sdxl_2048': {'T1_ratio': 0.5, 'T2_ratio': 0.0},
|
|
}
|
|
controlnet_apply_steps_rate = 0.6
|
|
is_aggressive_raunet = True
|
|
aggressive_step = 8
|
|
inpainting_is_aggressive_raunet = False
|
|
playground_is_aggressive_raunet = False
|
|
|
|
|
|
def make_diffusers_transformer_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
|
|
# replace global self-attention with MSW-MSA
|
|
class transformer_block(block_class):
|
|
# Save for unpatching later
|
|
_parent = block_class
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.FloatTensor,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
timestep: Optional[torch.LongTensor] = None,
|
|
cross_attention_kwargs: Dict[str, Any] = None,
|
|
class_labels: Optional[torch.LongTensor] = None,
|
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
|
) -> torch.FloatTensor:
|
|
# reference: https://github.com/microsoft/Swin-Transformer
|
|
def window_partition(x, window_size, shift_size, H, W):
|
|
B, _N, C = x.shape
|
|
x = x.view(B,H,W,C)
|
|
if H % 2 != 0 or W % 2 != 0:
|
|
from modules.errors import log
|
|
log.warning('HiDiffusion: The feature size is not divisible by 2')
|
|
x = F.interpolate(x.permute(0,3,1,2).contiguous(), size=(window_size[0]*2, window_size[1]*2), mode='bicubic').permute(0,2,3,1).contiguous()
|
|
if type(shift_size) == list or type(shift_size) == tuple:
|
|
if shift_size[0] > 0:
|
|
x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
|
|
else:
|
|
if shift_size > 0:
|
|
x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2))
|
|
x = x.view(B, 2, window_size[0], 2, window_size[1], C)
|
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
|
|
windows = windows.view(-1, window_size[0] * window_size[1], C)
|
|
return windows
|
|
|
|
def window_reverse(windows, window_size, H, W, shift_size):
|
|
B, _N, C = windows.shape
|
|
windows = windows.view(-1, window_size[0], window_size[1], C)
|
|
B = int(windows.shape[0] / 4) # 2x2
|
|
x = windows.view(B, 2, 2, window_size[0], window_size[1], -1)
|
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, window_size[0]*2, window_size[1]*2, -1)
|
|
if type(shift_size) == list or type(shift_size) == tuple:
|
|
if shift_size[0] > 0:
|
|
x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2))
|
|
else:
|
|
if shift_size > 0:
|
|
x = torch.roll(x, shifts=(shift_size, shift_size), dims=(1, 2))
|
|
if H % 2 != 0 or W % 2 != 0:
|
|
x = F.interpolate(x.permute(0,3,1,2).contiguous(), size=(H, W), mode='bicubic').permute(0,2,3,1).contiguous()
|
|
x = x.view(B, H*W, C)
|
|
return x
|
|
|
|
batch_size = hidden_states.shape[0]
|
|
if self.use_ada_layer_norm:
|
|
norm_hidden_states = self.norm1(hidden_states, timestep)
|
|
elif self.use_ada_layer_norm_zero:
|
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype)
|
|
elif self.use_layer_norm:
|
|
norm_hidden_states = self.norm1(hidden_states)
|
|
elif self.use_ada_layer_norm_continuous:
|
|
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
|
elif self.use_ada_layer_norm_single:
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)).chunk(6, dim=1)
|
|
norm_hidden_states = self.norm1(hidden_states)
|
|
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
|
norm_hidden_states = norm_hidden_states.squeeze(1)
|
|
else:
|
|
raise ValueError("HiDiffusion: Incorrect norm used")
|
|
|
|
if self.pos_embed is not None:
|
|
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
|
|
|
# MSW-MSA
|
|
rand_num = torch.rand(1)
|
|
_B, N, _C = hidden_states.shape
|
|
try:
|
|
ori_H, ori_W = self.info['size']
|
|
except Exception as e:
|
|
raise RuntimeError(f'HiDiffusion: cls={self.__class__.__name__} info={hasattr(self, "info")} parent={hasattr(self, "_parent")} orphaned call') from e
|
|
|
|
downsample_ratio = round(((ori_H*ori_W) / N)**0.5)
|
|
H, W = (math.ceil(ori_H/downsample_ratio), math.ceil(ori_W/downsample_ratio))
|
|
widow_size = (math.ceil(H/2), math.ceil(W/2))
|
|
if rand_num <= 0.25:
|
|
shift_size = (0,0)
|
|
elif rand_num > 0.25 and rand_num <= 0.5:
|
|
shift_size = (widow_size[0]//4, widow_size[1]//4)
|
|
elif rand_num > 0.5 and rand_num <= 0.75:
|
|
shift_size = (widow_size[0]//4*2, widow_size[1]//4*2)
|
|
elif rand_num > 0.75 and rand_num <= 1:
|
|
shift_size = (widow_size[0]//4*3, widow_size[1]//4*3)
|
|
else:
|
|
shift_size = (0,0)
|
|
norm_hidden_states = window_partition(norm_hidden_states, widow_size, shift_size, H, W)
|
|
|
|
# 1. Retrieve lora scale.
|
|
# cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
|
|
|
# 2. Prepare GLIGEN inputs
|
|
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
|
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
|
|
|
attn_output = self.attn1(
|
|
norm_hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
|
attention_mask=attention_mask,
|
|
**cross_attention_kwargs,
|
|
)
|
|
if self.use_ada_layer_norm_zero:
|
|
attn_output = gate_msa.unsqueeze(1) * attn_output
|
|
elif self.use_ada_layer_norm_single:
|
|
attn_output = gate_msa * attn_output
|
|
|
|
attn_output = window_reverse(attn_output, widow_size, H, W, shift_size)
|
|
|
|
hidden_states = attn_output + hidden_states
|
|
if hidden_states.ndim == 4:
|
|
hidden_states = hidden_states.squeeze(1)
|
|
|
|
# 2.5 GLIGEN Control
|
|
if gligen_kwargs is not None:
|
|
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
|
|
|
# 3. Cross-Attention
|
|
if self.attn2 is not None:
|
|
if self.use_ada_layer_norm:
|
|
norm_hidden_states = self.norm2(hidden_states, timestep)
|
|
elif self.use_ada_layer_norm_zero or self.use_layer_norm:
|
|
norm_hidden_states = self.norm2(hidden_states)
|
|
elif self.use_ada_layer_norm_single:
|
|
norm_hidden_states = hidden_states
|
|
elif self.use_ada_layer_norm_continuous:
|
|
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
|
else:
|
|
raise ValueError("HiDiffusion: Incorrect norm")
|
|
|
|
if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
|
|
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
|
|
|
attn_output = self.attn2(
|
|
norm_hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
attention_mask=encoder_attention_mask,
|
|
**cross_attention_kwargs,
|
|
)
|
|
hidden_states = attn_output + hidden_states
|
|
|
|
# 4. Feed-forward
|
|
if self.use_ada_layer_norm_continuous:
|
|
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
|
elif not self.use_ada_layer_norm_single:
|
|
norm_hidden_states = self.norm3(hidden_states)
|
|
if self.use_ada_layer_norm_zero:
|
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
|
if self.use_ada_layer_norm_single:
|
|
norm_hidden_states = self.norm2(hidden_states)
|
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
|
if self._chunk_size is not None:
|
|
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) # pylint: disable=undefined-variable
|
|
else:
|
|
ff_output = self.ff(norm_hidden_states)
|
|
if self.use_ada_layer_norm_zero:
|
|
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
|
elif self.use_ada_layer_norm_single:
|
|
ff_output = gate_mlp * ff_output
|
|
hidden_states = ff_output + hidden_states
|
|
if hidden_states.ndim == 4:
|
|
hidden_states = hidden_states.squeeze(1)
|
|
return hidden_states
|
|
|
|
return transformer_block
|
|
|
|
|
|
def make_diffusers_cross_attn_down_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
|
|
# replace conventional downsampler with resolution-aware downsampler
|
|
class cross_attn_down_block(block_class):
|
|
_parent = block_class # Save for unpatching later
|
|
timestep = 0
|
|
aggressive_raunet = False
|
|
T1_ratio = 0
|
|
T1_start = 0
|
|
T1_end = 0
|
|
T1 = 0 # to avoid confict with sdxl-turbo
|
|
max_timestep = current_steps
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.FloatTensor,
|
|
temb: Optional[torch.FloatTensor] = None,
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
additional_residuals: Optional[torch.FloatTensor] = None,
|
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
|
if not hasattr(self.info['pipeline'], '_num_timesteps'):
|
|
self.info['pipeline']._num_timesteps = self.max_timestep # pylint: disable=protected-access
|
|
self.max_timestep = self.info['pipeline']._num_timesteps # pylint: disable=protected-access
|
|
# self.max_timestep = len(self.info['scheduler'].timesteps)
|
|
try:
|
|
ori_H, ori_W = self.info['size']
|
|
except Exception as e:
|
|
raise RuntimeError(f'HiDiffusion: cls={self.__class__.__name__} info={hasattr(self, "info")} parent={hasattr(self, "_parent")} orphaned call') from e
|
|
if self.model == 'sd15':
|
|
if ori_H < 256 or ori_W < 256:
|
|
self.T1_ratio = switching_threshold_ratio_dict['sd15_1024'][self.switching_threshold_ratio]
|
|
else:
|
|
self.T1_ratio = switching_threshold_ratio_dict['sd15_2048'][self.switching_threshold_ratio]
|
|
elif self.model == 'sdxl':
|
|
if ori_H < 512 or ori_W < 512:
|
|
if self.info['text_to_img_controlnet']:
|
|
self.T1_ratio = text_to_img_controlnet_switching_threshold_ratio_dict['sdxl_2048'][self.switching_threshold_ratio]
|
|
else:
|
|
self.T1_ratio = switching_threshold_ratio_dict['sdxl_2048'][self.switching_threshold_ratio]
|
|
if self.info['is_inpainting_task']:
|
|
self.aggressive_raunet = inpainting_is_aggressive_raunet
|
|
else:
|
|
self.aggressive_raunet = is_aggressive_raunet
|
|
else:
|
|
self.T1_ratio = switching_threshold_ratio_dict['sdxl_4096'][self.switching_threshold_ratio]
|
|
elif self.model == 'sdxl_turbo':
|
|
self.T1_ratio = switching_threshold_ratio_dict['sdxl_turbo_1024'][self.switching_threshold_ratio]
|
|
else:
|
|
raise RuntimeError('HiDiffusion: unsupported model type')
|
|
|
|
if self.aggressive_raunet:
|
|
self.T1_start = int(aggressive_step/50 * self.max_timestep)
|
|
self.T1_end = int(self.max_timestep * self.T1_ratio)
|
|
self.T1 = 0 # to avoid confict with sdxl-turbo
|
|
else:
|
|
self.T1 = int(self.max_timestep * self.T1_ratio)
|
|
|
|
output_states = ()
|
|
_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
|
|
|
blocks = list(zip(self.resnets, self.attentions))
|
|
|
|
for i, (resnet, attn) in enumerate(blocks):
|
|
if self.training and self.gradient_checkpointing:
|
|
|
|
def create_custom_forward(module, return_dict=None):
|
|
def custom_forward(*inputs):
|
|
if return_dict is not None:
|
|
return module(*inputs, return_dict=return_dict)
|
|
else:
|
|
return module(*inputs)
|
|
|
|
return custom_forward
|
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False}
|
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(resnet),
|
|
hidden_states,
|
|
temb,
|
|
**ckpt_kwargs,
|
|
)
|
|
hidden_states = attn(
|
|
hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
attention_mask=attention_mask,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
return_dict=False,
|
|
)[0]
|
|
else:
|
|
# hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
|
hidden_states = resnet(hidden_states, temb)
|
|
hidden_states = attn(
|
|
hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
attention_mask=attention_mask,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
return_dict=False,
|
|
)[0]
|
|
|
|
# apply additional residuals to the output of the last pair of resnet and attention blocks
|
|
if i == len(blocks) - 1 and additional_residuals is not None:
|
|
hidden_states = hidden_states + additional_residuals
|
|
|
|
if i == 0:
|
|
if self.aggressive_raunet and self.timestep >= self.T1_start and self.timestep < self.T1_end:
|
|
self.info["upsample_size"] = (hidden_states.shape[2], hidden_states.shape[3])
|
|
hidden_states = F.avg_pool2d(hidden_states, kernel_size=(2,2),ceil_mode=True)
|
|
elif self.timestep < self.T1:
|
|
self.info["upsample_size"] = (hidden_states.shape[2], hidden_states.shape[3])
|
|
hidden_states = F.avg_pool2d(hidden_states, kernel_size=(2,2),ceil_mode=True)
|
|
output_states = output_states + (hidden_states,)
|
|
|
|
if self.downsamplers is not None:
|
|
for downsampler in self.downsamplers:
|
|
hidden_states = downsampler(hidden_states)
|
|
# hidden_states = downsampler(hidden_states, scale=lora_scale)
|
|
|
|
output_states = output_states + (hidden_states,)
|
|
|
|
self.timestep += 1
|
|
if self.timestep == self.max_timestep:
|
|
self.timestep = 0
|
|
|
|
return hidden_states, output_states
|
|
|
|
return cross_attn_down_block
|
|
|
|
|
|
def make_diffusers_cross_attn_up_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
|
|
# replace conventional downsampler with resolution-aware downsampler
|
|
class cross_attn_up_block(block_class):
|
|
# Save for unpatching later
|
|
_parent = block_class
|
|
timestep = 0
|
|
aggressive_raunet = False
|
|
T1_ratio = 0
|
|
T1_start = 0
|
|
T1_end = 0
|
|
T1 = 0 # to avoid confict with sdxl-turbo
|
|
max_timestep = 50
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.FloatTensor,
|
|
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
|
temb: Optional[torch.FloatTensor] = None,
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
upsample_size: Optional[int] = None,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
) -> torch.FloatTensor:
|
|
def fix_scale(first, second):
|
|
if (first.shape[-1] != second.shape[-1] or first.shape[-2] != second.shape[-2]):
|
|
rescale = min(second.shape[-2] / first.shape[-2], second.shape[-1] / first.shape[-1])
|
|
# log.debug(f"HiDiffusion rescale: {hidden_states.shape} => {res_hidden_states_tuple[0].shape} scale={rescale}")
|
|
return F.interpolate(first, scale_factor=rescale, mode='bicubic')
|
|
return first
|
|
|
|
self.max_timestep = self.info['pipeline']._num_timesteps # pylint: disable=protected-access
|
|
try:
|
|
ori_H, ori_W = self.info['size']
|
|
except Exception as e:
|
|
raise RuntimeError(f'HiDiffusion: cls={self.__class__.__name__} info={hasattr(self, "info")} parent={hasattr(self, "_parent")} orphaned call') from e
|
|
if self.model == 'sd15':
|
|
if ori_H < 256 or ori_W < 256:
|
|
self.T1_ratio = switching_threshold_ratio_dict['sd15_1024'][self.switching_threshold_ratio]
|
|
else:
|
|
self.T1_ratio = switching_threshold_ratio_dict['sd15_2048'][self.switching_threshold_ratio]
|
|
elif self.model == 'sdxl':
|
|
if ori_H < 512 or ori_W < 512:
|
|
if self.info['text_to_img_controlnet']:
|
|
self.T1_ratio = text_to_img_controlnet_switching_threshold_ratio_dict['sdxl_2048'][self.switching_threshold_ratio]
|
|
else:
|
|
self.T1_ratio = switching_threshold_ratio_dict['sdxl_2048'][self.switching_threshold_ratio]
|
|
|
|
if self.info['is_inpainting_task']:
|
|
self.aggressive_raunet = inpainting_is_aggressive_raunet
|
|
else:
|
|
self.aggressive_raunet = is_aggressive_raunet
|
|
|
|
else:
|
|
self.T1_ratio = switching_threshold_ratio_dict['sdxl_4096'][self.switching_threshold_ratio]
|
|
elif self.model == 'sdxl_turbo':
|
|
self.T1_ratio = switching_threshold_ratio_dict['sdxl_turbo_1024'][self.switching_threshold_ratio]
|
|
else:
|
|
raise RuntimeError('HiDiffusion: unsupported model type')
|
|
|
|
if self.aggressive_raunet:
|
|
self.T1_start = int(aggressive_step/50 * self.max_timestep)
|
|
self.T1_end = int(self.max_timestep * self.T1_ratio)
|
|
self.T1 = 0 # to avoid confict with sdxl-turbo
|
|
else:
|
|
self.T1 = int(self.max_timestep * self.T1_ratio)
|
|
|
|
for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
|
|
# pop res hidden states
|
|
res_hidden_states = res_hidden_states_tuple[-1]
|
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
|
hidden_states = fix_scale(hidden_states, res_hidden_states)
|
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
|
hidden_states = resnet(hidden_states, temb)
|
|
hidden_states = attn(
|
|
hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
cross_attention_kwargs=cross_attention_kwargs,
|
|
attention_mask=attention_mask,
|
|
encoder_attention_mask=encoder_attention_mask,
|
|
return_dict=False,
|
|
)[0]
|
|
if i == 1:
|
|
if self.aggressive_raunet and self.timestep >= self.T1_start and self.timestep < self.T1_end:
|
|
hidden_states = F.interpolate(hidden_states, size=self.info["upsample_size"], mode='bicubic')
|
|
elif self.timestep < self.T1:
|
|
hidden_states = F.interpolate(hidden_states, size=self.info["upsample_size"], mode='bicubic')
|
|
|
|
if self.upsamplers is not None:
|
|
for upsampler in self.upsamplers:
|
|
hidden_states = upsampler(hidden_states, upsample_size)
|
|
self.timestep += 1
|
|
if self.timestep == self.max_timestep:
|
|
self.timestep = 0
|
|
return hidden_states
|
|
|
|
return cross_attn_up_block
|
|
|
|
|
|
def make_diffusers_downsampler_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
|
|
# replace conventional downsampler with resolution-aware downsampler
|
|
class downsampler_block(block_class):
|
|
# Save for unpatching later
|
|
_parent = block_class
|
|
T1_ratio = 0
|
|
T1 = 0
|
|
timestep = 0
|
|
aggressive_raunet = False
|
|
max_timestep = 50
|
|
|
|
def forward(self, hidden_states: torch.Tensor, scale = 1.0) -> torch.Tensor: # pylint: disable=unused-argument
|
|
self.max_timestep = self.info['pipeline']._num_timesteps # pylint: disable=protected-access
|
|
# self.max_timestep = len(self.info['scheduler'].timesteps)
|
|
try:
|
|
ori_H, ori_W = self.info['size']
|
|
except Exception as e:
|
|
raise RuntimeError(f'HiDiffusion: cls={self.__class__.__name__} info={hasattr(self, "info")} parent={hasattr(self, "_parent")} orphaned call') from e
|
|
if self.model == 'sd15':
|
|
if ori_H < 256 or ori_W < 256:
|
|
self.T1_ratio = switching_threshold_ratio_dict['sd15_1024'][self.switching_threshold_ratio]
|
|
else:
|
|
self.T1_ratio = switching_threshold_ratio_dict['sd15_2048'][self.switching_threshold_ratio]
|
|
elif self.model == 'sdxl':
|
|
if ori_H < 512 or ori_W < 512:
|
|
if self.info['text_to_img_controlnet']:
|
|
self.T1_ratio = text_to_img_controlnet_switching_threshold_ratio_dict['sdxl_2048'][self.switching_threshold_ratio]
|
|
else:
|
|
self.T1_ratio = switching_threshold_ratio_dict['sdxl_2048'][self.switching_threshold_ratio]
|
|
if self.info['is_inpainting_task']:
|
|
self.aggressive_raunet = inpainting_is_aggressive_raunet
|
|
else:
|
|
self.aggressive_raunet = is_aggressive_raunet
|
|
else:
|
|
self.T1_ratio = switching_threshold_ratio_dict['sdxl_4096'][self.switching_threshold_ratio]
|
|
elif self.model == 'sdxl_turbo':
|
|
self.T1_ratio = switching_threshold_ratio_dict['sdxl_turbo_1024'][self.switching_threshold_ratio]
|
|
else:
|
|
raise RuntimeError('HiDiffusion: unsupported model type')
|
|
|
|
if self.aggressive_raunet:
|
|
self.T1 = int(aggressive_step/50 * self.max_timestep)
|
|
else:
|
|
self.T1 = int(self.max_timestep * self.T1_ratio)
|
|
if self.timestep < self.T1:
|
|
self.ori_stride = self.stride # pylint: disable=access-member-before-definition, attribute-defined-outside-init
|
|
self.ori_padding = self.padding # pylint: disable=access-member-before-definition, attribute-defined-outside-init
|
|
self.ori_dilation = self.dilation # pylint: disable=access-member-before-definition, attribute-defined-outside-init
|
|
self.stride = (4,4) # pylint: disable=access-member-before-definition, attribute-defined-outside-init
|
|
self.padding = (2,2) # pylint: disable=access-member-before-definition, attribute-defined-outside-init
|
|
self.dilation = (2,2) # pylint: disable=access-member-before-definition, attribute-defined-outside-init
|
|
|
|
hidden_states = F.conv2d(
|
|
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
|
)
|
|
if self.timestep < self.T1:
|
|
self.stride = self.ori_stride # pylint: disable=access-member-before-definition, attribute-defined-outside-init
|
|
self.padding = self.ori_padding # pylint: disable=access-member-before-definition, attribute-defined-outside-init
|
|
self.dilation = self.ori_dilation # pylint: disable=access-member-before-definition, attribute-defined-outside-init
|
|
self.timestep += 1
|
|
if self.timestep == self.max_timestep:
|
|
self.timestep = 0
|
|
return hidden_states
|
|
|
|
return downsampler_block
|
|
|
|
|
|
def make_diffusers_upsampler_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
|
|
# replace conventional upsampler with resolution-aware downsampler
|
|
class upsampler_block(block_class):
|
|
# Save for unpatching later
|
|
_parent = block_class
|
|
T1_ratio = 0
|
|
T1 = 0
|
|
timestep = 0
|
|
aggressive_raunet = False
|
|
max_timestep = 50
|
|
|
|
def forward(self, hidden_states: torch.Tensor, scale = 1.0) -> torch.Tensor: # pylint: disable=unused-argument
|
|
self.max_timestep = self.info['pipeline']._num_timesteps # pylint: disable=protected-access
|
|
# self.max_timestep = len(self.info['scheduler'].timesteps)
|
|
try:
|
|
ori_H, ori_W = self.info['size']
|
|
except Exception as e:
|
|
raise RuntimeError(f'HiDiffusion: cls={self.__class__.__name__} info={hasattr(self, "info")} parent={hasattr(self, "_parent")} orphaned call') from e
|
|
if self.model == 'sd15':
|
|
if ori_H < 256 or ori_W < 256:
|
|
self.T1_ratio = switching_threshold_ratio_dict['sd15_1024'][self.switching_threshold_ratio]
|
|
else:
|
|
self.T1_ratio = switching_threshold_ratio_dict['sd15_2048'][self.switching_threshold_ratio]
|
|
elif self.model == 'sdxl':
|
|
if ori_H < 512 or ori_W < 512:
|
|
if self.info['text_to_img_controlnet']:
|
|
self.T1_ratio = text_to_img_controlnet_switching_threshold_ratio_dict['sdxl_2048'][self.switching_threshold_ratio]
|
|
else:
|
|
self.T1_ratio = switching_threshold_ratio_dict['sdxl_2048'][self.switching_threshold_ratio]
|
|
|
|
if self.info['is_inpainting_task']:
|
|
self.aggressive_raunet = inpainting_is_aggressive_raunet
|
|
else:
|
|
self.aggressive_raunet = is_aggressive_raunet
|
|
else:
|
|
self.T1_ratio = switching_threshold_ratio_dict['sdxl_4096'][self.switching_threshold_ratio]
|
|
elif self.model == 'sdxl_turbo':
|
|
self.T1_ratio = switching_threshold_ratio_dict['sdxl_turbo_1024'][self.switching_threshold_ratio]
|
|
else:
|
|
raise RuntimeError('HiDiffusion: unsupported model type')
|
|
|
|
if self.aggressive_raunet:
|
|
self.T1 = int(aggressive_step/50 * self.max_timestep)
|
|
else:
|
|
self.T1 = int(self.max_timestep * self.T1_ratio)
|
|
self.timestep += 1
|
|
if self.timestep == self.max_timestep:
|
|
self.timestep = 0
|
|
|
|
return F.conv2d(hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
|
|
|
return upsampler_block
|
|
|
|
|
|
def hook_diffusion_model(model: torch.nn.Module):
|
|
""" Adds a forward pre hook to get the image size. This hook can be removed with remove_hidiffusion. """
|
|
def hook(module, args):
|
|
module.info["size"] = (args[0].shape[2], args[0].shape[3])
|
|
return None
|
|
|
|
model.info["hooks"].append(model.register_forward_pre_hook(hook))
|
|
|
|
|
|
def apply_hidiffusion(
|
|
model: torch.nn.Module,
|
|
apply_raunet: bool = True,
|
|
apply_window_attn: bool = True,
|
|
model_type: str = 'None',
|
|
steps: int = 50):
|
|
"""
|
|
model: diffusers model. We support SD 1.5, 2.1, XL, XL Turbo.
|
|
apply_raunet: whether to apply RAU-Net
|
|
apply_window_attn: whether to apply MSW-MSA.
|
|
"""
|
|
global current_steps # pylint: disable=global-statement
|
|
current_steps = steps
|
|
if hasattr(model, 'controlnet') and (model_type == 'sd' or model_type == 'sdxl'):
|
|
from .hidiffusion_controlnet import make_diffusers_sdxl_contrtolnet_ppl, make_diffusers_unet_2d_condition
|
|
make_ppl_fn = make_diffusers_sdxl_contrtolnet_ppl
|
|
model.__class__ = make_ppl_fn(model.__class__)
|
|
make_block_fn = make_diffusers_unet_2d_condition
|
|
model.unet.__class__ = make_block_fn(model.unet.__class__)
|
|
|
|
diffusion_model = model.unet if hasattr(model, "unet") else model
|
|
diffusion_model.num_upsamplers += 12
|
|
diffusion_model.info = {
|
|
'size': None,
|
|
'upsample_size': None,
|
|
'hooks': [],
|
|
'text_to_img_controlnet': hasattr(model, 'controlnet'),
|
|
'is_inpainting_task': model.__class__ in auto_pipeline.AUTO_INPAINT_PIPELINES_MAPPING.values(),
|
|
'pipeline': model}
|
|
|
|
if model_type == 'sd':
|
|
modified_key = sd15_hidiffusion_key()
|
|
for key, module in diffusion_model.named_modules():
|
|
if hasattr(module, "_parent"):
|
|
raise RuntimeError(f'HiDiffusion: key={key} module={module.__class__} already patched')
|
|
if apply_raunet and key in modified_key['down_module_key']:
|
|
module.__class__ = make_diffusers_downsampler_block(module.__class__)
|
|
module.switching_threshold_ratio = 'T1_ratio'
|
|
if apply_raunet and key in modified_key['down_module_key_extra']:
|
|
module.__class__ = make_diffusers_cross_attn_down_block(module.__class__)
|
|
module.switching_threshold_ratio = 'T2_ratio'
|
|
if apply_raunet and key in modified_key['up_module_key']:
|
|
module.__class__ = make_diffusers_upsampler_block(module.__class__)
|
|
module.switching_threshold_ratio = 'T1_ratio'
|
|
if apply_raunet and key in modified_key['up_module_key_extra']:
|
|
module.__class__ = make_diffusers_cross_attn_up_block(module.__class__)
|
|
module.switching_threshold_ratio = 'T2_ratio'
|
|
if apply_window_attn and key in modified_key['windown_attn_module_key']:
|
|
module.__class__ = make_diffusers_transformer_block(module.__class__)
|
|
if hasattr(module, "_parent"):
|
|
module.model = 'sd15'
|
|
module.info = diffusion_model.info
|
|
|
|
elif model_type == 'sdxl':
|
|
modified_key = sdxl_hidiffusion_key()
|
|
for key, module in diffusion_model.named_modules():
|
|
if hasattr(module, "_parent"):
|
|
raise RuntimeError(f'HiDiffusion: key={key} module={module.__class__} already patched')
|
|
if apply_raunet and key in modified_key['down_module_key']:
|
|
module.__class__ = make_diffusers_cross_attn_down_block(module.__class__)
|
|
module.switching_threshold_ratio = 'T1_ratio'
|
|
if apply_raunet and key in modified_key['down_module_key_extra']:
|
|
module.__class__ = make_diffusers_downsampler_block(module.__class__)
|
|
module.switching_threshold_ratio = 'T2_ratio'
|
|
if apply_raunet and key in modified_key['up_module_key']:
|
|
module.__class__ = make_diffusers_cross_attn_up_block(module.__class__)
|
|
module.switching_threshold_ratio = 'T1_ratio'
|
|
if apply_raunet and key in modified_key['up_module_key_extra']:
|
|
module.__class__ = make_diffusers_upsampler_block(module.__class__)
|
|
module.switching_threshold_ratio = 'T2_ratio'
|
|
if apply_window_attn and key in modified_key['windown_attn_module_key']:
|
|
module.__class__ = make_diffusers_transformer_block(module.__class__)
|
|
if hasattr(module, "_parent"):
|
|
module.model = 'sdxl'
|
|
module.info = diffusion_model.info
|
|
else:
|
|
raise RuntimeError('HiDiffusion: unsupported model type')
|
|
|
|
model.info = diffusion_model.info
|
|
model.hidiffusion = True
|
|
hook_diffusion_model(diffusion_model)
|
|
|
|
|
|
def remove_hidiffusion(model: torch.nn.Module):
|
|
""" Removes hidiffusion from a Diffusion module if it was already patched. """
|
|
model = model.unet if hasattr(model, "unet") else model
|
|
for _, module in model.named_modules():
|
|
while hasattr(module, "_parent"):
|
|
model.hidiffusion = True
|
|
module.__class__ = module._parent # pylint: disable=protected-access
|
|
if hasattr(module, "info"):
|
|
for hook in module.info.get("hooks", []):
|
|
hook.remove()
|
|
module.info["hooks"].clear()
|
|
del module.info
|