1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[Chore] add set_default_attn_processor to pixart. (#9196)

add set_default_attn_processor to pixart.
This commit is contained in:
Sayak Paul
2024-08-16 13:07:06 +05:30
committed by GitHub
parent e649678bf5
commit e780c05cc3

View File

@@ -19,7 +19,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ..attention import BasicTransformerBlock
from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -247,6 +247,14 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
"""
self.set_attn_processor(AttnProcessor())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""