mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Rewrite AuraFlowPatchEmbed.pe_selection_index_based_on_dim to be torch.compile compatible (#11297)
* Update pe_selection_index_based_on_dim * Make pe_selection_index_based_on_dim work with torh.compile * Fix AuraFlowTransformer2DModel's dpcstring default values --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -74,15 +74,23 @@ class AuraFlowPatchEmbed(nn.Module):
|
||||
# PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
|
||||
# because original input are in flattened format, we have to flatten this 2d grid as well.
|
||||
h_p, w_p = h // self.patch_size, w // self.patch_size
|
||||
original_pe_indexes = torch.arange(self.pos_embed.shape[1])
|
||||
h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
|
||||
original_pe_indexes = original_pe_indexes.view(h_max, w_max)
|
||||
|
||||
# Calculate the top-left corner indices for the centered patch grid
|
||||
starth = h_max // 2 - h_p // 2
|
||||
endh = starth + h_p
|
||||
startw = w_max // 2 - w_p // 2
|
||||
endw = startw + w_p
|
||||
original_pe_indexes = original_pe_indexes[starth:endh, startw:endw]
|
||||
return original_pe_indexes.flatten()
|
||||
|
||||
# Generate the row and column indices for the desired patch grid
|
||||
rows = torch.arange(starth, starth + h_p, device=self.pos_embed.device)
|
||||
cols = torch.arange(startw, startw + w_p, device=self.pos_embed.device)
|
||||
|
||||
# Create a 2D grid of indices
|
||||
row_indices, col_indices = torch.meshgrid(rows, cols, indexing="ij")
|
||||
|
||||
# Convert the 2D grid indices to flattened 1D indices
|
||||
selected_indices = (row_indices * w_max + col_indices).flatten()
|
||||
|
||||
return selected_indices
|
||||
|
||||
def forward(self, latent):
|
||||
batch_size, num_channels, height, width = latent.size()
|
||||
@@ -275,17 +283,17 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
sample_size (`int`): The width of the latent images. This is fixed during training since
|
||||
it is used to learn a number of position embeddings.
|
||||
patch_size (`int`): Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
||||
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input.
|
||||
num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use.
|
||||
num_single_dit_layers (`int`, *optional*, defaults to 4):
|
||||
num_single_dit_layers (`int`, *optional*, defaults to 32):
|
||||
The number of layers of Transformer blocks to use. These blocks use concatenated image and text
|
||||
representations.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 256): The number of channels in each head.
|
||||
num_attention_heads (`int`, *optional*, defaults to 12): The number of heads to use for multi-head attention.
|
||||
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
|
||||
out_channels (`int`, defaults to 16): Number of output channels.
|
||||
pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
|
||||
out_channels (`int`, defaults to 4): Number of output channels.
|
||||
pos_embed_max_size (`int`, defaults to 1024): Maximum positions to embed from the image latents.
|
||||
"""
|
||||
|
||||
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
|
||||
|
||||
Reference in New Issue
Block a user