mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
add nabla attention
This commit is contained in:
@@ -64,8 +64,8 @@ def get_freqs(dim, max_period=10000.0):
|
||||
def fractal_flatten(x, rope, shape, block_mask=False):
|
||||
if block_mask:
|
||||
pixel_size = 8
|
||||
x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=0)
|
||||
rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=0)
|
||||
x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1)
|
||||
rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1)
|
||||
x = x.flatten(1, 2)
|
||||
rope = rope.flatten(1, 2)
|
||||
else:
|
||||
@@ -77,15 +77,15 @@ def fractal_flatten(x, rope, shape, block_mask=False):
|
||||
def fractal_unflatten(x, shape, block_mask=False):
|
||||
if block_mask:
|
||||
pixel_size = 8
|
||||
x = x.reshape(-1, pixel_size**2, *x.shape[1:])
|
||||
x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=0)
|
||||
x = x.reshape(x.shape[0], -1, pixel_size**2, *x.shape[2:])
|
||||
x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1)
|
||||
else:
|
||||
x = x.reshape(*shape, *x.shape[2:])
|
||||
return x
|
||||
|
||||
|
||||
def local_patching(x, shape, group_size, dim=0):
|
||||
duration, height, width = shape
|
||||
batch_size, duration, height, width = shape
|
||||
g1, g2, g3 = group_size
|
||||
x = x.reshape(
|
||||
*x.shape[:dim],
|
||||
@@ -112,7 +112,7 @@ def local_patching(x, shape, group_size, dim=0):
|
||||
|
||||
|
||||
def local_merge(x, shape, group_size, dim=0):
|
||||
duration, height, width = shape
|
||||
batch_size, duration, height, width = shape
|
||||
g1, g2, g3 = group_size
|
||||
x = x.reshape(
|
||||
*x.shape[:dim],
|
||||
@@ -138,6 +138,36 @@ def local_merge(x, shape, group_size, dim=0):
|
||||
return x
|
||||
|
||||
|
||||
def nablaT_v2(
|
||||
q: Tensor,
|
||||
k: Tensor,
|
||||
sta: Tensor,
|
||||
thr: float = 0.9,
|
||||
) -> BlockMask:
|
||||
# Map estimation
|
||||
B, h, S, D = q.shape
|
||||
s1 = S // 64
|
||||
qa = q.reshape(B, h, s1, 64, D).mean(-2)
|
||||
ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1)
|
||||
map = qa @ ka
|
||||
|
||||
map = torch.softmax(map / math.sqrt(D), dim=-1)
|
||||
# Map binarization
|
||||
vals, inds = map.sort(-1)
|
||||
cvals = vals.cumsum_(-1)
|
||||
mask = (cvals >= 1 - thr).int()
|
||||
mask = mask.gather(-1, inds.argsort(-1))
|
||||
|
||||
mask = torch.logical_or(mask, sta)
|
||||
|
||||
# BlockMask creation
|
||||
kv_nb = mask.sum(-1).to(torch.int32)
|
||||
kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32)
|
||||
return BlockMask.from_kv_blocks(
|
||||
torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None
|
||||
)
|
||||
|
||||
|
||||
def sdpa(q, k, v):
|
||||
query = q.transpose(1, 2).contiguous()
|
||||
key = k.transpose(1, 2).contiguous()
|
||||
@@ -392,6 +422,29 @@ class MultiheadSelfAttentionDec(nn.Module):
|
||||
def attention(self, query, key, value):
|
||||
out = sdpa(q=query, k=key, v=value).flatten(-2, -1)
|
||||
return out
|
||||
|
||||
def nabla(self, query, key, value, sparse_params=None):
|
||||
query = query.transpose(1, 2).contiguous()
|
||||
key = key.transpose(1, 2).contiguous()
|
||||
value = value.transpose(1, 2).contiguous()
|
||||
block_mask = nablaT_v2(
|
||||
query,
|
||||
key,
|
||||
sparse_params["sta_mask"],
|
||||
thr=sparse_params["P"],
|
||||
)
|
||||
out = (
|
||||
flex_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
block_mask=block_mask
|
||||
)
|
||||
.transpose(1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
out = out.flatten(-2, -1)
|
||||
return out
|
||||
|
||||
def out_l(self, x):
|
||||
return self.out_layer(x)
|
||||
@@ -402,7 +455,10 @@ class MultiheadSelfAttentionDec(nn.Module):
|
||||
query = apply_rotary(query, rope).type_as(query)
|
||||
key = apply_rotary(key, rope).type_as(key)
|
||||
|
||||
out = self.attention(query, key, value)
|
||||
if sparse_params is not None:
|
||||
out = self.nabla(query, key, value, sparse_params=sparse_params)
|
||||
else:
|
||||
out = self.attention(query, key, value)
|
||||
|
||||
out = self.out_l(out)
|
||||
return out
|
||||
@@ -587,7 +643,18 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin):
|
||||
num_text_blocks=2,
|
||||
num_visual_blocks=32,
|
||||
axes_dims=(16, 24, 24),
|
||||
visual_cond=False,
|
||||
visual_cond=False,
|
||||
attention_type: str = "regular",
|
||||
attention_causal: bool = None, #Deffault for Nabla: false,
|
||||
attention_local: bool = None, #Deffault for Nabla: false,
|
||||
attention_glob:bool = None, #Deffault for Nabla: false,
|
||||
attention_window: int = None, #Deffault for Nabla: 3
|
||||
attention_P: float = None, #Deffault for Nabla: 0.9
|
||||
attention_wT: int = None, #Deffault for Nabla: 11
|
||||
attention_wW:int = None, #Deffault for Nabla: 3
|
||||
attention_wH:int = None, #Deffault for Nabla: 3
|
||||
attention_add_sta: bool = None, #Deffault for Nabla: true
|
||||
attention_method: str = None, #Deffault for Nabla: "topcdf"
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -596,6 +663,7 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin):
|
||||
self.model_dim = model_dim
|
||||
self.patch_size = patch_size
|
||||
self.visual_cond = visual_cond
|
||||
self.attention_type = attention_type
|
||||
|
||||
visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim
|
||||
self.time_embeddings = TimeEmbeddings(model_dim, time_dim)
|
||||
|
||||
@@ -223,6 +223,66 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio
|
||||
self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
@staticmethod
|
||||
def fast_sta_nabla(
|
||||
T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda"
|
||||
) -> torch.Tensor:
|
||||
l = torch.Tensor([T, H, W]).amax()
|
||||
r = torch.arange(0, l, 1, dtype=torch.int16, device=device)
|
||||
mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs()
|
||||
sta_t, sta_h, sta_w = (
|
||||
mat[:T, :T].flatten(),
|
||||
mat[:H, :H].flatten(),
|
||||
mat[:W, :W].flatten(),
|
||||
)
|
||||
sta_t = sta_t <= wT // 2
|
||||
sta_h = sta_h <= wH // 2
|
||||
sta_w = sta_w <= wW // 2
|
||||
sta_hw = (
|
||||
(sta_h.unsqueeze(1) * sta_w.unsqueeze(0))
|
||||
.reshape(H, H, W, W)
|
||||
.transpose(1, 2)
|
||||
.flatten()
|
||||
)
|
||||
sta = (
|
||||
(sta_t.unsqueeze(1) * sta_hw.unsqueeze(0))
|
||||
.reshape(T, T, H * W, H * W)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
return sta.reshape(T * H * W, T * H * W)
|
||||
|
||||
def get_sparse_params(self, sample, device):
|
||||
assert self.transformer.config.patch_size[0] == 1
|
||||
B, T, H, W, _ = sample.shape
|
||||
T, H, W = (
|
||||
T // self.transformer.config.patch_size[0],
|
||||
H // self.transformer.config.patch_size[1],
|
||||
W // self.transformer.config.patch_size[2],
|
||||
)
|
||||
if self.transformer.config.attention_type == "nabla":
|
||||
sta_mask = self.fast_sta_nabla(
|
||||
T, H // 8, W // 8,
|
||||
self.transformer.config.attention_wT, self.transformer.config.attention_wH, self.transformer.config.attention_wW,
|
||||
device=device
|
||||
)
|
||||
|
||||
sparse_params = {
|
||||
"sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0),
|
||||
"attention_type": self.transformer.config.attention_type,
|
||||
"to_fractal": True,
|
||||
"P": self.transformer.config.attention_P,
|
||||
"wT": self.transformer.config.attention_wT,
|
||||
"wW": self.transformer.config.attention_wW,
|
||||
"wH": self.transformer.config.attention_wH,
|
||||
"add_sta": self.transformer.config.attention_add_sta,
|
||||
"visual_shape": (T, H, W),
|
||||
"method": self.transformer.config.attention_method,
|
||||
}
|
||||
else:
|
||||
sparse_params = None
|
||||
|
||||
return sparse_params
|
||||
|
||||
def _encode_prompt_qwen(
|
||||
self,
|
||||
@@ -681,8 +741,11 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
if negative_cu_seqlens is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# 7. Sparse Params
|
||||
sparse_params = self.get_sparse_params(latents, device)
|
||||
|
||||
# 7. Denoising loop
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
@@ -702,7 +765,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
visual_rope_pos=visual_rope_pos,
|
||||
text_rope_pos=text_rope_pos,
|
||||
scale_factor=(1, 2, 2),
|
||||
sparse_params=None,
|
||||
sparse_params=sparse_params,
|
||||
return_dict=True
|
||||
).sample
|
||||
|
||||
@@ -715,7 +778,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
visual_rope_pos=visual_rope_pos,
|
||||
text_rope_pos=negative_text_rope_pos,
|
||||
scale_factor=(1, 2, 2),
|
||||
sparse_params=None,
|
||||
sparse_params=sparse_params,
|
||||
return_dict=True
|
||||
).sample
|
||||
|
||||
|
||||
Reference in New Issue
Block a user