1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

add nabla attention

This commit is contained in:
leffff
2025-10-12 21:59:23 +00:00
parent 22e14bdac8
commit 70fa62baea
2 changed files with 142 additions and 11 deletions

View File

@@ -64,8 +64,8 @@ def get_freqs(dim, max_period=10000.0):
def fractal_flatten(x, rope, shape, block_mask=False):
if block_mask:
pixel_size = 8
x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=0)
rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=0)
x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1)
rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1)
x = x.flatten(1, 2)
rope = rope.flatten(1, 2)
else:
@@ -77,15 +77,15 @@ def fractal_flatten(x, rope, shape, block_mask=False):
def fractal_unflatten(x, shape, block_mask=False):
if block_mask:
pixel_size = 8
x = x.reshape(-1, pixel_size**2, *x.shape[1:])
x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=0)
x = x.reshape(x.shape[0], -1, pixel_size**2, *x.shape[2:])
x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1)
else:
x = x.reshape(*shape, *x.shape[2:])
return x
def local_patching(x, shape, group_size, dim=0):
duration, height, width = shape
batch_size, duration, height, width = shape
g1, g2, g3 = group_size
x = x.reshape(
*x.shape[:dim],
@@ -112,7 +112,7 @@ def local_patching(x, shape, group_size, dim=0):
def local_merge(x, shape, group_size, dim=0):
duration, height, width = shape
batch_size, duration, height, width = shape
g1, g2, g3 = group_size
x = x.reshape(
*x.shape[:dim],
@@ -138,6 +138,36 @@ def local_merge(x, shape, group_size, dim=0):
return x
def nablaT_v2(
q: Tensor,
k: Tensor,
sta: Tensor,
thr: float = 0.9,
) -> BlockMask:
# Map estimation
B, h, S, D = q.shape
s1 = S // 64
qa = q.reshape(B, h, s1, 64, D).mean(-2)
ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1)
map = qa @ ka
map = torch.softmax(map / math.sqrt(D), dim=-1)
# Map binarization
vals, inds = map.sort(-1)
cvals = vals.cumsum_(-1)
mask = (cvals >= 1 - thr).int()
mask = mask.gather(-1, inds.argsort(-1))
mask = torch.logical_or(mask, sta)
# BlockMask creation
kv_nb = mask.sum(-1).to(torch.int32)
kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32)
return BlockMask.from_kv_blocks(
torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None
)
def sdpa(q, k, v):
query = q.transpose(1, 2).contiguous()
key = k.transpose(1, 2).contiguous()
@@ -392,6 +422,29 @@ class MultiheadSelfAttentionDec(nn.Module):
def attention(self, query, key, value):
out = sdpa(q=query, k=key, v=value).flatten(-2, -1)
return out
def nabla(self, query, key, value, sparse_params=None):
query = query.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
block_mask = nablaT_v2(
query,
key,
sparse_params["sta_mask"],
thr=sparse_params["P"],
)
out = (
flex_attention(
query,
key,
value,
block_mask=block_mask
)
.transpose(1, 2)
.contiguous()
)
out = out.flatten(-2, -1)
return out
def out_l(self, x):
return self.out_layer(x)
@@ -402,7 +455,10 @@ class MultiheadSelfAttentionDec(nn.Module):
query = apply_rotary(query, rope).type_as(query)
key = apply_rotary(key, rope).type_as(key)
out = self.attention(query, key, value)
if sparse_params is not None:
out = self.nabla(query, key, value, sparse_params=sparse_params)
else:
out = self.attention(query, key, value)
out = self.out_l(out)
return out
@@ -587,7 +643,18 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin):
num_text_blocks=2,
num_visual_blocks=32,
axes_dims=(16, 24, 24),
visual_cond=False,
visual_cond=False,
attention_type: str = "regular",
attention_causal: bool = None, #Deffault for Nabla: false,
attention_local: bool = None, #Deffault for Nabla: false,
attention_glob:bool = None, #Deffault for Nabla: false,
attention_window: int = None, #Deffault for Nabla: 3
attention_P: float = None, #Deffault for Nabla: 0.9
attention_wT: int = None, #Deffault for Nabla: 11
attention_wW:int = None, #Deffault for Nabla: 3
attention_wH:int = None, #Deffault for Nabla: 3
attention_add_sta: bool = None, #Deffault for Nabla: true
attention_method: str = None, #Deffault for Nabla: "topcdf"
):
super().__init__()
@@ -596,6 +663,7 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin):
self.model_dim = model_dim
self.patch_size = patch_size
self.visual_cond = visual_cond
self.attention_type = attention_type
visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim
self.time_embeddings = TimeEmbeddings(model_dim, time_dim)

View File

@@ -223,6 +223,66 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio
self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@staticmethod
def fast_sta_nabla(
T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda"
) -> torch.Tensor:
l = torch.Tensor([T, H, W]).amax()
r = torch.arange(0, l, 1, dtype=torch.int16, device=device)
mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs()
sta_t, sta_h, sta_w = (
mat[:T, :T].flatten(),
mat[:H, :H].flatten(),
mat[:W, :W].flatten(),
)
sta_t = sta_t <= wT // 2
sta_h = sta_h <= wH // 2
sta_w = sta_w <= wW // 2
sta_hw = (
(sta_h.unsqueeze(1) * sta_w.unsqueeze(0))
.reshape(H, H, W, W)
.transpose(1, 2)
.flatten()
)
sta = (
(sta_t.unsqueeze(1) * sta_hw.unsqueeze(0))
.reshape(T, T, H * W, H * W)
.transpose(1, 2)
)
return sta.reshape(T * H * W, T * H * W)
def get_sparse_params(self, sample, device):
assert self.transformer.config.patch_size[0] == 1
B, T, H, W, _ = sample.shape
T, H, W = (
T // self.transformer.config.patch_size[0],
H // self.transformer.config.patch_size[1],
W // self.transformer.config.patch_size[2],
)
if self.transformer.config.attention_type == "nabla":
sta_mask = self.fast_sta_nabla(
T, H // 8, W // 8,
self.transformer.config.attention_wT, self.transformer.config.attention_wH, self.transformer.config.attention_wW,
device=device
)
sparse_params = {
"sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0),
"attention_type": self.transformer.config.attention_type,
"to_fractal": True,
"P": self.transformer.config.attention_P,
"wT": self.transformer.config.attention_wT,
"wW": self.transformer.config.attention_wW,
"wH": self.transformer.config.attention_wH,
"add_sta": self.transformer.config.attention_add_sta,
"visual_shape": (T, H, W),
"method": self.transformer.config.attention_method,
}
else:
sparse_params = None
return sparse_params
def _encode_prompt_qwen(
self,
@@ -681,8 +741,11 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
if negative_cu_seqlens is not None
else None
)
# 7. Sparse Params
sparse_params = self.get_sparse_params(latents, device)
# 7. Denoising loop
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
@@ -702,7 +765,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
visual_rope_pos=visual_rope_pos,
text_rope_pos=text_rope_pos,
scale_factor=(1, 2, 2),
sparse_params=None,
sparse_params=sparse_params,
return_dict=True
).sample
@@ -715,7 +778,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
visual_rope_pos=visual_rope_pos,
text_rope_pos=negative_text_rope_pos,
scale_factor=(1, 2, 2),
sparse_params=None,
sparse_params=sparse_params,
return_dict=True
).sample