diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 3bbb9421f7..45d4ccdf9a 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -64,8 +64,8 @@ def get_freqs(dim, max_period=10000.0): def fractal_flatten(x, rope, shape, block_mask=False): if block_mask: pixel_size = 8 - x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=0) - rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=0) + x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1) + rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1) x = x.flatten(1, 2) rope = rope.flatten(1, 2) else: @@ -77,15 +77,15 @@ def fractal_flatten(x, rope, shape, block_mask=False): def fractal_unflatten(x, shape, block_mask=False): if block_mask: pixel_size = 8 - x = x.reshape(-1, pixel_size**2, *x.shape[1:]) - x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=0) + x = x.reshape(x.shape[0], -1, pixel_size**2, *x.shape[2:]) + x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1) else: x = x.reshape(*shape, *x.shape[2:]) return x def local_patching(x, shape, group_size, dim=0): - duration, height, width = shape + batch_size, duration, height, width = shape g1, g2, g3 = group_size x = x.reshape( *x.shape[:dim], @@ -112,7 +112,7 @@ def local_patching(x, shape, group_size, dim=0): def local_merge(x, shape, group_size, dim=0): - duration, height, width = shape + batch_size, duration, height, width = shape g1, g2, g3 = group_size x = x.reshape( *x.shape[:dim], @@ -138,6 +138,36 @@ def local_merge(x, shape, group_size, dim=0): return x +def nablaT_v2( + q: Tensor, + k: Tensor, + sta: Tensor, + thr: float = 0.9, +) -> BlockMask: + # Map estimation + B, h, S, D = q.shape + s1 = S // 64 + qa = q.reshape(B, h, s1, 64, D).mean(-2) + ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1) + map = qa @ ka + + map = torch.softmax(map / math.sqrt(D), dim=-1) + # Map binarization + vals, inds = map.sort(-1) + cvals = vals.cumsum_(-1) + mask = (cvals >= 1 - thr).int() + mask = mask.gather(-1, inds.argsort(-1)) + + mask = torch.logical_or(mask, sta) + + # BlockMask creation + kv_nb = mask.sum(-1).to(torch.int32) + kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32) + return BlockMask.from_kv_blocks( + torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None + ) + + def sdpa(q, k, v): query = q.transpose(1, 2).contiguous() key = k.transpose(1, 2).contiguous() @@ -392,6 +422,29 @@ class MultiheadSelfAttentionDec(nn.Module): def attention(self, query, key, value): out = sdpa(q=query, k=key, v=value).flatten(-2, -1) return out + + def nabla(self, query, key, value, sparse_params=None): + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + block_mask = nablaT_v2( + query, + key, + sparse_params["sta_mask"], + thr=sparse_params["P"], + ) + out = ( + flex_attention( + query, + key, + value, + block_mask=block_mask + ) + .transpose(1, 2) + .contiguous() + ) + out = out.flatten(-2, -1) + return out def out_l(self, x): return self.out_layer(x) @@ -402,7 +455,10 @@ class MultiheadSelfAttentionDec(nn.Module): query = apply_rotary(query, rope).type_as(query) key = apply_rotary(key, rope).type_as(key) - out = self.attention(query, key, value) + if sparse_params is not None: + out = self.nabla(query, key, value, sparse_params=sparse_params) + else: + out = self.attention(query, key, value) out = self.out_l(out) return out @@ -587,7 +643,18 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin): num_text_blocks=2, num_visual_blocks=32, axes_dims=(16, 24, 24), - visual_cond=False, + visual_cond=False, + attention_type: str = "regular", + attention_causal: bool = None, #Deffault for Nabla: false, + attention_local: bool = None, #Deffault for Nabla: false, + attention_glob:bool = None, #Deffault for Nabla: false, + attention_window: int = None, #Deffault for Nabla: 3 + attention_P: float = None, #Deffault for Nabla: 0.9 + attention_wT: int = None, #Deffault for Nabla: 11 + attention_wW:int = None, #Deffault for Nabla: 3 + attention_wH:int = None, #Deffault for Nabla: 3 + attention_add_sta: bool = None, #Deffault for Nabla: true + attention_method: str = None, #Deffault for Nabla: "topcdf" ): super().__init__() @@ -596,6 +663,7 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin): self.model_dim = model_dim self.patch_size = patch_size self.visual_cond = visual_cond + self.attention_type = attention_type visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim self.time_embeddings = TimeEmbeddings(model_dim, time_dim) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 5d1eb7d605..05230a604f 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -223,6 +223,66 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + @staticmethod + def fast_sta_nabla( + T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda" + ) -> torch.Tensor: + l = torch.Tensor([T, H, W]).amax() + r = torch.arange(0, l, 1, dtype=torch.int16, device=device) + mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs() + sta_t, sta_h, sta_w = ( + mat[:T, :T].flatten(), + mat[:H, :H].flatten(), + mat[:W, :W].flatten(), + ) + sta_t = sta_t <= wT // 2 + sta_h = sta_h <= wH // 2 + sta_w = sta_w <= wW // 2 + sta_hw = ( + (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)) + .reshape(H, H, W, W) + .transpose(1, 2) + .flatten() + ) + sta = ( + (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)) + .reshape(T, T, H * W, H * W) + .transpose(1, 2) + ) + return sta.reshape(T * H * W, T * H * W) + + def get_sparse_params(self, sample, device): + assert self.transformer.config.patch_size[0] == 1 + B, T, H, W, _ = sample.shape + T, H, W = ( + T // self.transformer.config.patch_size[0], + H // self.transformer.config.patch_size[1], + W // self.transformer.config.patch_size[2], + ) + if self.transformer.config.attention_type == "nabla": + sta_mask = self.fast_sta_nabla( + T, H // 8, W // 8, + self.transformer.config.attention_wT, self.transformer.config.attention_wH, self.transformer.config.attention_wW, + device=device + ) + + sparse_params = { + "sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0), + "attention_type": self.transformer.config.attention_type, + "to_fractal": True, + "P": self.transformer.config.attention_P, + "wT": self.transformer.config.attention_wT, + "wW": self.transformer.config.attention_wW, + "wH": self.transformer.config.attention_wH, + "add_sta": self.transformer.config.attention_add_sta, + "visual_shape": (T, H, W), + "method": self.transformer.config.attention_method, + } + else: + sparse_params = None + + return sparse_params def _encode_prompt_qwen( self, @@ -681,8 +741,11 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): if negative_cu_seqlens is not None else None ) + + # 7. Sparse Params + sparse_params = self.get_sparse_params(latents, device) - # 7. Denoising loop + # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) @@ -702,7 +765,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): visual_rope_pos=visual_rope_pos, text_rope_pos=text_rope_pos, scale_factor=(1, 2, 2), - sparse_params=None, + sparse_params=sparse_params, return_dict=True ).sample @@ -715,7 +778,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): visual_rope_pos=visual_rope_pos, text_rope_pos=negative_text_rope_pos, scale_factor=(1, 2, 2), - sparse_params=None, + sparse_params=sparse_params, return_dict=True ).sample