1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

moved sdps inside processor

This commit is contained in:
leffff
2025-10-16 07:35:17 +00:00
parent b9a3be2a15
commit 56b90b10ef
2 changed files with 8 additions and 11 deletions

View File

@@ -174,14 +174,6 @@ def nablaT_v2(
)
def sdpa(q, k, v):
query = q.transpose(1, 2).contiguous()
key = k.transpose(1, 2).contiguous()
value = v.transpose(1, 2).contiguous()
out = F.scaled_dot_product_attention(query, key, value).transpose(1, 2).contiguous()
return out
@torch.autocast(device_type="cuda", dtype=torch.float32)
def apply_scale_shift_norm(norm, x, scale, shift):
return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16)
@@ -355,7 +347,12 @@ class Kandinsky5SDPAAttentionProcessor(nn.Module):
**kwargs,
):
# Process attention with the given query, key, value tensors
out = sdpa(q=query, k=key, v=value).flatten(-2, -1)
query = query.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
out = F.scaled_dot_product_attention(query, key, value).transpose(1, 2).contiguous().flatten(-2, -1)
return out

View File

@@ -260,7 +260,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
return sparse_params
def _get_qwen_prompt_embeds(
def _encode_prompt_qwen(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
@@ -314,7 +314,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
return embeds.to(dtype), cu_seqlens
def _get_clip_prompt_embeds(
def _encode_prompt_clip(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,