From e780c05cc3e7816dca976622ef03bc734c486866 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 16 Aug 2024 13:07:06 +0530 Subject: [PATCH] [Chore] add set_default_attn_processor to pixart. (#9196) add set_default_attn_processor to pixart. --- .../models/transformers/pixart_transformer_2d.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py index 5c9c61243c..1e5cd57945 100644 --- a/src/diffusers/models/transformers/pixart_transformer_2d.py +++ b/src/diffusers/models/transformers/pixart_transformer_2d.py @@ -19,7 +19,7 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import is_torch_version, logging from ..attention import BasicTransformerBlock -from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0 +from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0 from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -247,6 +247,14 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + + Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model. + """ + self.set_attn_processor(AttnProcessor()) + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections def fuse_qkv_projections(self): """