mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Transformer: move all methods to forward
This commit is contained in:
@@ -616,47 +616,6 @@ class Kandinsky5TransformerEncoderBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# class Kandinsky5TransformerDecoderBlock(nn.Module):
|
||||
# def __init__(self, model_dim, time_dim, ff_dim, head_dim):
|
||||
# super().__init__()
|
||||
# self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9)
|
||||
|
||||
# self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
# self.self_attention = Kandinsky5MultiheadSelfAttentionDec(model_dim, head_dim)
|
||||
|
||||
# self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
# self.cross_attention = Kandinsky5MultiheadCrossAttention(model_dim, head_dim)
|
||||
|
||||
# self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
# self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim)
|
||||
|
||||
# def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params):
|
||||
# self_attn_params, cross_attn_params, ff_params = torch.chunk(
|
||||
# self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1
|
||||
# )
|
||||
# shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
|
||||
# visual_out = apply_scale_shift_norm(
|
||||
# self.self_attention_norm, visual_embed, scale, shift
|
||||
# )
|
||||
# visual_out = self.self_attention(visual_out, rope, sparse_params)
|
||||
# visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
|
||||
# shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1)
|
||||
# visual_out = apply_scale_shift_norm(
|
||||
# self.cross_attention_norm, visual_embed, scale, shift
|
||||
# )
|
||||
# visual_out = self.cross_attention(visual_out, text_embed)
|
||||
# visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
|
||||
# shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
|
||||
# visual_out = apply_scale_shift_norm(
|
||||
# self.feed_forward_norm, visual_embed, scale, shift
|
||||
# )
|
||||
# visual_out = self.feed_forward(visual_out)
|
||||
# visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
# return visual_embed
|
||||
|
||||
|
||||
class Kandinsky5TransformerDecoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim):
|
||||
super().__init__()
|
||||
@@ -724,16 +683,16 @@ class Kandinsky5Transformer3DModel(
|
||||
axes_dims=(16, 24, 24),
|
||||
visual_cond=False,
|
||||
attention_type: str = "regular",
|
||||
attention_causal: bool = None, # Default for Nabla: false
|
||||
attention_local: bool = None, # Default for Nabla: false
|
||||
attention_glob: bool = None, # Default for Nabla: false
|
||||
attention_window: int = None, # Default for Nabla: 3
|
||||
attention_P: float = None, # Default for Nabla: 0.9
|
||||
attention_wT: int = None, # Default for Nabla: 11
|
||||
attention_wW: int = None, # Default for Nabla: 3
|
||||
attention_wH: int = None, # Default for Nabla: 3
|
||||
attention_add_sta: bool = None, # Default for Nabla: true
|
||||
attention_method: str = None, # Default for Nabla: "topcdf"
|
||||
attention_causal: bool = None,
|
||||
attention_local: bool = None,
|
||||
attention_glob: bool = None,
|
||||
attention_window: int = None,
|
||||
attention_P: float = None,
|
||||
attention_wT: int = None,
|
||||
attention_wW: int = None,
|
||||
attention_wH: int = None,
|
||||
attention_add_sta: bool = None,
|
||||
attention_method: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -779,73 +738,6 @@ class Kandinsky5Transformer3DModel(
|
||||
)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def prepare_text_embeddings(
|
||||
self, text_embed, time, pooled_text_embed, x, text_rope_pos
|
||||
):
|
||||
"""Prepare text embeddings and related components"""
|
||||
text_embed = self.text_embeddings(text_embed)
|
||||
time_embed = self.time_embeddings(time)
|
||||
time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed)
|
||||
visual_embed = self.visual_embeddings(x)
|
||||
text_rope = self.text_rope_embeddings(text_rope_pos)
|
||||
text_rope = text_rope.unsqueeze(dim=0)
|
||||
return text_embed, time_embed, text_rope, visual_embed
|
||||
|
||||
def prepare_visual_embeddings(
|
||||
self, visual_embed, visual_rope_pos, scale_factor, sparse_params
|
||||
):
|
||||
"""Prepare visual embeddings and related components"""
|
||||
visual_shape = visual_embed.shape[:-1]
|
||||
visual_rope = self.visual_rope_embeddings(
|
||||
visual_shape, visual_rope_pos, scale_factor
|
||||
)
|
||||
to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False
|
||||
visual_embed, visual_rope = fractal_flatten(
|
||||
visual_embed, visual_rope, visual_shape, block_mask=to_fractal
|
||||
)
|
||||
return visual_embed, visual_shape, to_fractal, visual_rope
|
||||
|
||||
def process_text_transformer_blocks(self, text_embed, time_embed, text_rope):
|
||||
"""Process text through transformer blocks"""
|
||||
for text_transformer_block in self.text_transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
text_embed = self._gradient_checkpointing_func(
|
||||
text_transformer_block, text_embed, time_embed, text_rope
|
||||
)
|
||||
else:
|
||||
text_embed = text_transformer_block(text_embed, time_embed, text_rope)
|
||||
return text_embed
|
||||
|
||||
def process_visual_transformer_blocks(
|
||||
self, visual_embed, text_embed, time_embed, visual_rope, sparse_params
|
||||
):
|
||||
"""Process visual through transformer blocks"""
|
||||
for visual_transformer_block in self.visual_transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
visual_embed = self._gradient_checkpointing_func(
|
||||
visual_transformer_block,
|
||||
visual_embed,
|
||||
text_embed,
|
||||
time_embed,
|
||||
visual_rope,
|
||||
sparse_params,
|
||||
)
|
||||
else:
|
||||
visual_embed = visual_transformer_block(
|
||||
visual_embed, text_embed, time_embed, visual_rope, sparse_params
|
||||
)
|
||||
return visual_embed
|
||||
|
||||
def prepare_output(
|
||||
self, visual_embed, visual_shape, to_fractal, text_embed, time_embed
|
||||
):
|
||||
"""Prepare the final output"""
|
||||
visual_embed = fractal_unflatten(
|
||||
visual_embed, visual_shape, block_mask=to_fractal
|
||||
)
|
||||
x = self.out_layer(visual_embed, text_embed, time_embed)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor, # x
|
||||
@@ -881,32 +773,49 @@ class Kandinsky5Transformer3DModel(
|
||||
time = timestep
|
||||
pooled_text_embed = pooled_projections
|
||||
|
||||
# Prepare text embeddings and related components
|
||||
text_embed, time_embed, text_rope, visual_embed = self.prepare_text_embeddings(
|
||||
text_embed, time, pooled_text_embed, x, text_rope_pos
|
||||
text_embed = self.text_embeddings(text_embed)
|
||||
time_embed = self.time_embeddings(time)
|
||||
time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed)
|
||||
visual_embed = self.visual_embeddings(x)
|
||||
text_rope = self.text_rope_embeddings(text_rope_pos)
|
||||
text_rope = text_rope.unsqueeze(dim=0)
|
||||
|
||||
for text_transformer_block in self.text_transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
text_embed = self._gradient_checkpointing_func(
|
||||
text_transformer_block, text_embed, time_embed, text_rope
|
||||
)
|
||||
else:
|
||||
text_embed = text_transformer_block(text_embed, time_embed, text_rope)
|
||||
|
||||
visual_shape = visual_embed.shape[:-1]
|
||||
visual_rope = self.visual_rope_embeddings(
|
||||
visual_shape, visual_rope_pos, scale_factor
|
||||
)
|
||||
to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False
|
||||
visual_embed, visual_rope = fractal_flatten(
|
||||
visual_embed, visual_rope, visual_shape, block_mask=to_fractal
|
||||
)
|
||||
|
||||
# Process text through transformer blocks
|
||||
text_embed = self.process_text_transformer_blocks(
|
||||
text_embed, time_embed, text_rope
|
||||
)
|
||||
for visual_transformer_block in self.visual_transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
visual_embed = self._gradient_checkpointing_func(
|
||||
visual_transformer_block,
|
||||
visual_embed,
|
||||
text_embed,
|
||||
time_embed,
|
||||
visual_rope,
|
||||
sparse_params,
|
||||
)
|
||||
else:
|
||||
visual_embed = visual_transformer_block(
|
||||
visual_embed, text_embed, time_embed, visual_rope, sparse_params
|
||||
)
|
||||
|
||||
# Prepare visual embeddings and related components
|
||||
visual_embed, visual_shape, to_fractal, visual_rope = (
|
||||
self.prepare_visual_embeddings(
|
||||
visual_embed, visual_rope_pos, scale_factor, sparse_params
|
||||
)
|
||||
)
|
||||
|
||||
# Process visual through transformer blocks
|
||||
visual_embed = self.process_visual_transformer_blocks(
|
||||
visual_embed, text_embed, time_embed, visual_rope, sparse_params
|
||||
)
|
||||
|
||||
# Prepare final output
|
||||
x = self.prepare_output(
|
||||
visual_embed, visual_shape, to_fractal, text_embed, time_embed
|
||||
visual_embed = fractal_unflatten(
|
||||
visual_embed, visual_shape, block_mask=to_fractal
|
||||
)
|
||||
x = self.out_layer(visual_embed, text_embed, time_embed)
|
||||
|
||||
if not return_dict:
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user