diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 8d2bae11cb..680b456df3 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -174,14 +174,6 @@ def nablaT_v2( ) -def sdpa(q, k, v): - query = q.transpose(1, 2).contiguous() - key = k.transpose(1, 2).contiguous() - value = v.transpose(1, 2).contiguous() - out = F.scaled_dot_product_attention(query, key, value).transpose(1, 2).contiguous() - return out - - @torch.autocast(device_type="cuda", dtype=torch.float32) def apply_scale_shift_norm(norm, x, scale, shift): return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16) @@ -355,7 +347,12 @@ class Kandinsky5SDPAAttentionProcessor(nn.Module): **kwargs, ): # Process attention with the given query, key, value tensors - out = sdpa(q=query, k=key, v=value).flatten(-2, -1) + + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + out = F.scaled_dot_product_attention(query, key, value).transpose(1, 2).contiguous().flatten(-2, -1) + return out diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 3e61ae0bf2..bdf7e41df9 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -260,7 +260,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): return sparse_params - def _get_qwen_prompt_embeds( + def _encode_prompt_qwen( self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, @@ -314,7 +314,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): return embeds.to(dtype), cu_seqlens - def _get_clip_prompt_embeds( + def _encode_prompt_clip( self, prompt: Union[str, List[str]], device: Optional[torch.device] = None,