mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
moved sdps inside processor
This commit is contained in:
@@ -174,14 +174,6 @@ def nablaT_v2(
|
||||
)
|
||||
|
||||
|
||||
def sdpa(q, k, v):
|
||||
query = q.transpose(1, 2).contiguous()
|
||||
key = k.transpose(1, 2).contiguous()
|
||||
value = v.transpose(1, 2).contiguous()
|
||||
out = F.scaled_dot_product_attention(query, key, value).transpose(1, 2).contiguous()
|
||||
return out
|
||||
|
||||
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
||||
def apply_scale_shift_norm(norm, x, scale, shift):
|
||||
return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16)
|
||||
@@ -355,7 +347,12 @@ class Kandinsky5SDPAAttentionProcessor(nn.Module):
|
||||
**kwargs,
|
||||
):
|
||||
# Process attention with the given query, key, value tensors
|
||||
out = sdpa(q=query, k=key, v=value).flatten(-2, -1)
|
||||
|
||||
query = query.transpose(1, 2).contiguous()
|
||||
key = key.transpose(1, 2).contiguous()
|
||||
value = value.transpose(1, 2).contiguous()
|
||||
out = F.scaled_dot_product_attention(query, key, value).transpose(1, 2).contiguous().flatten(-2, -1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@@ -260,7 +260,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
|
||||
return sparse_params
|
||||
|
||||
def _get_qwen_prompt_embeds(
|
||||
def _encode_prompt_qwen(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
@@ -314,7 +314,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
|
||||
return embeds.to(dtype), cu_seqlens
|
||||
|
||||
def _get_clip_prompt_embeds(
|
||||
def _encode_prompt_clip(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
|
||||
Reference in New Issue
Block a user