diff --git a/src/diffusers/models/attention2d.py b/src/diffusers/models/attention2d.py index e69de29bb2..2b19ccc49c 100644 --- a/src/diffusers/models/attention2d.py +++ b/src/diffusers/models/attention2d.py @@ -0,0 +1,245 @@ +# unet_grad_tts.py +class LinearAttention(torch.nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super(LinearAttention, self).__init__() + self.heads = heads + self.dim_head = dim_head + hidden_dim = dim_head * heads + self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + # q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) + q, k, v = ( + qkv.reshape(b, 3, self.heads, self.dim_head, h, w) + .permute(1, 0, 2, 3, 4, 5) + .reshape(3, b, self.heads, self.dim_head, -1) + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + # out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w) + out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w) + return self.to_out(out) + +# unet.py +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + +# unet_glide.py +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + encoder_channels=None, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels, swish=0.0) + self.qkv = conv_nd(1, channels, channels * 3, 1) + self.attention = QKVAttention(self.num_heads) + + if encoder_channels is not None: + self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1) + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x, encoder_out=None): + b, c, *spatial = x.shape + qkv = self.qkv(self.norm(x).view(b, c, -1)) + if encoder_out is not None: + encoder_out = self.encoder_kv(encoder_out) + h = self.attention(qkv, encoder_out) + else: + h = self.attention(qkv) + h = self.proj_out(h) + return x + h.reshape(b, c, *spatial) + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv, encoder_kv=None): + """ + Apply QKV attention. + + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after + attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + if encoder_kv is not None: + assert encoder_kv.shape[1] == self.n_heads * ch * 2 + ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) + k = torch.cat([ek, k], dim=-1) + v = torch.cat([ev, v], dim=-1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + +# unet_ldm.py +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted + to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x + T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + +# unet_score_estimation.py +class AttnBlockpp(nn.Module): + """Channel-wise self-attention block. Modified from DDPM.""" + + def __init__(self, channels, skip_rescale=False, init_scale=0.0): + super().__init__() + self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, eps=1e-6) + self.NIN_0 = NIN(channels, channels) + self.NIN_1 = NIN(channels, channels) + self.NIN_2 = NIN(channels, channels) + self.NIN_3 = NIN(channels, channels, init_scale=init_scale) + self.skip_rescale = skip_rescale + + def forward(self, x): + B, C, H, W = x.shape + h = self.GroupNorm_0(x) + q = self.NIN_0(h) + k = self.NIN_1(h) + v = self.NIN_2(h) + + w = torch.einsum("bchw,bcij->bhwij", q, k) * (int(C) ** (-0.5)) + w = torch.reshape(w, (B, H, W, H * W)) + w = F.softmax(w, dim=-1) + w = torch.reshape(w, (B, H, W, H, W)) + h = torch.einsum("bhwij,bcij->bchw", w, v) + h = self.NIN_3(h) + if not self.skip_rescale: + return x + h + else: + return (x + h) / np.sqrt(2.0)