1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

remove unused imports

This commit is contained in:
leffff
2025-10-14 20:38:40 +00:00
parent d62dffcb21
commit 7084106eaa

View File

@@ -19,21 +19,27 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import BoolTensor, IntTensor, Tensor, nn
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
flex_attention)
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from torch.nn.attention.flex_attention import (
BlockMask,
flex_attention,
)
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import (USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers,
unscale_lora_layers)
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..attention import AttentionMixin, FeedForward
from ..cache_utils import CacheMixin
from ..embeddings import (PixArtAlphaTextProjection, TimestepEmbedding,
Timesteps, get_1d_rotary_pos_embed)
from ..embeddings import (
TimestepEmbedding,
get_1d_rotary_pos_embed,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
@@ -95,7 +101,7 @@ def local_patching(x, shape, group_size, dim=0):
g2,
width // g3,
g3,
*x.shape[dim + 3 :]
*x.shape[dim + 3 :],
)
x = x.permute(
*range(len(x.shape[:dim])),
@@ -105,7 +111,7 @@ def local_patching(x, shape, group_size, dim=0):
dim + 1,
dim + 3,
dim + 5,
*range(dim + 6, len(x.shape))
*range(dim + 6, len(x.shape)),
)
x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3)
return x
@@ -122,7 +128,7 @@ def local_merge(x, shape, group_size, dim=0):
g1,
g2,
g3,
*x.shape[dim + 2 :]
*x.shape[dim + 2 :],
)
x = x.permute(
*range(len(x.shape[:dim])),
@@ -132,7 +138,7 @@ def local_merge(x, shape, group_size, dim=0):
dim + 4,
dim + 2,
dim + 5,
*range(dim + 6, len(x.shape))
*range(dim + 6, len(x.shape)),
)
x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3)
return x
@@ -172,15 +178,7 @@ def sdpa(q, k, v):
query = q.transpose(1, 2).contiguous()
key = k.transpose(1, 2).contiguous()
value = v.transpose(1, 2).contiguous()
out = (
F.scaled_dot_product_attention(
query,
key,
value
)
.transpose(1, 2)
.contiguous()
)
out = F.scaled_dot_product_attention(query, key, value).transpose(1, 2).contiguous()
return out
@@ -279,7 +277,7 @@ class Kandinsky5RoPE1D(nn.Module):
rope = torch.stack([cosine, -sine, sine, cosine], dim=-1)
rope = rope.view(*rope.shape[:-1], 2, 2)
return rope.unsqueeze(-4)
def reset_dtype(self):
freq = get_freqs(self.dim // 2, self.max_period).to(self.args.device)
pos = torch.arange(self.max_pos, dtype=freq.dtype, device=freq.device)
@@ -307,9 +305,15 @@ class Kandinsky5RoPE3D(nn.Module):
args = torch.cat(
[
args_t.view(1, duration, 1, 1, -1).repeat(batch_size, 1, height, width, 1),
args_h.view(1, 1, height, 1, -1).repeat(batch_size, duration, 1, width, 1),
args_w.view(1, 1, 1, width, -1).repeat(batch_size, duration, height, 1, 1),
args_t.view(1, duration, 1, 1, -1).repeat(
batch_size, 1, height, width, 1
),
args_h.view(1, 1, height, 1, -1).repeat(
batch_size, duration, 1, width, 1
),
args_w.view(1, 1, 1, width, -1).repeat(
batch_size, duration, height, 1, 1
),
],
dim=-1,
)
@@ -318,12 +322,12 @@ class Kandinsky5RoPE3D(nn.Module):
rope = torch.stack([cosine, -sine, sine, cosine], dim=-1)
rope = rope.view(*rope.shape[:-1], 2, 2)
return rope.unsqueeze(-4)
def reset_dtype(self):
for i, (axes_dim, ax_max_pos) in enumerate(zip(self.axes_dims, self.max_pos)):
freq = get_freqs(axes_dim // 2, self.max_period).to(self.args_0.device)
pos = torch.arange(ax_max_pos, dtype=freq.dtype, device=freq.device)
setattr(self, f'args_{i}', torch.outer(pos, freq))
setattr(self, f"args_{i}", torch.outer(pos, freq))
class Kandinsky5Modulation(nn.Module):
@@ -341,7 +345,7 @@ class Kandinsky5Modulation(nn.Module):
class Kandinsky5SDPAAttentionProcessor(nn.Module):
"""Custom attention processor for standard SDPA attention"""
def __call__(
self,
attn,
@@ -357,7 +361,7 @@ class Kandinsky5SDPAAttentionProcessor(nn.Module):
class Kandinsky5NablaAttentionProcessor(nn.Module):
"""Custom attention processor for Nabla attention"""
def __call__(
self,
attn,
@@ -369,11 +373,11 @@ class Kandinsky5NablaAttentionProcessor(nn.Module):
):
if sparse_params is None:
raise ValueError("sparse_params is required for Nabla attention")
query = query.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
block_mask = nablaT_v2(
query,
key,
@@ -381,12 +385,7 @@ class Kandinsky5NablaAttentionProcessor(nn.Module):
thr=sparse_params["P"],
)
out = (
flex_attention(
query,
key,
value,
block_mask=block_mask
)
flex_attention(query, key, value, block_mask=block_mask)
.transpose(1, 2)
.contiguous()
)
@@ -407,7 +406,7 @@ class Kandinsky5MultiheadSelfAttentionEnc(nn.Module):
self.key_norm = nn.RMSNorm(head_dim)
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
# Initialize attention processor
self.sdpa_processor = Kandinsky5SDPAAttentionProcessor()
@@ -430,13 +429,7 @@ class Kandinsky5MultiheadSelfAttentionEnc(nn.Module):
def scaled_dot_product_attention(self, query, key, value):
# Use the processor
return self.sdpa_processor(
attn=self,
query=query,
key=key,
value=value,
**{}
)
return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{})
def out_l(self, x):
return self.out_layer(x)
@@ -466,7 +459,7 @@ class Kandinsky5MultiheadSelfAttentionDec(nn.Module):
self.key_norm = nn.RMSNorm(head_dim)
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
# Initialize attention processors
self.sdpa_processor = Kandinsky5SDPAAttentionProcessor()
self.nabla_processor = Kandinsky5NablaAttentionProcessor()
@@ -490,14 +483,8 @@ class Kandinsky5MultiheadSelfAttentionDec(nn.Module):
def attention(self, query, key, value):
# Use the processor
return self.sdpa_processor(
attn=self,
query=query,
key=key,
value=value,
**{}
)
return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{})
def nabla(self, query, key, value, sparse_params=None):
# Use the processor
return self.nabla_processor(
@@ -506,7 +493,7 @@ class Kandinsky5MultiheadSelfAttentionDec(nn.Module):
key=key,
value=value,
sparse_params=sparse_params,
**{}
**{},
)
def out_l(self, x):
@@ -540,7 +527,7 @@ class Kandinsky5MultiheadCrossAttention(nn.Module):
self.key_norm = nn.RMSNorm(head_dim)
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
# Initialize attention processor
self.sdpa_processor = Kandinsky5SDPAAttentionProcessor()
@@ -563,13 +550,7 @@ class Kandinsky5MultiheadCrossAttention(nn.Module):
def attention(self, query, key, value):
# Use the processor
return self.sdpa_processor(
attn=self,
query=query,
key=key,
value=value,
**{}
)
return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{})
def out_l(self, x):
return self.out_layer(x)
@@ -605,7 +586,9 @@ class Kandinsky5OutLayer(nn.Module):
)
def forward(self, visual_embed, text_embed, time_embed):
shift, scale = torch.chunk(self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1)
shift, scale = torch.chunk(
self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1
)
visual_embed = apply_scale_shift_norm(
self.norm,
visual_embed,
@@ -646,7 +629,9 @@ class Kandinsky5TransformerEncoderBlock(nn.Module):
self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim)
def forward(self, x, time_embed, rope):
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1)
self_attn_params, ff_params = torch.chunk(
self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1
)
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
out = self.self_attention(out, rope)
@@ -678,26 +663,40 @@ class Kandinsky5TransformerDecoderBlock(nn.Module):
self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1
)
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
visual_out = apply_scale_shift_norm(
self.self_attention_norm, visual_embed, scale, shift
)
visual_out = self.self_attention(visual_out, rope, sparse_params)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1)
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
visual_out = apply_scale_shift_norm(
self.cross_attention_norm, visual_embed, scale, shift
)
visual_out = self.cross_attention(visual_out, text_embed)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
visual_out = apply_scale_shift_norm(
self.feed_forward_norm, visual_embed, scale, shift
)
visual_out = self.feed_forward(visual_out)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
return visual_embed
class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin):
class Kandinsky5Transformer3DModel(
ModelMixin,
ConfigMixin,
PeftAdapterMixin,
FromOriginalModelMixin,
CacheMixin,
AttentionMixin,
):
"""
A 3D Diffusion Transformer model for video-like data.
"""
_supports_gradient_checkpointing = True
@register_to_config
@@ -714,21 +713,21 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr
num_text_blocks=2,
num_visual_blocks=32,
axes_dims=(16, 24, 24),
visual_cond=False,
visual_cond=False,
attention_type: str = "regular",
attention_causal: bool = None, # Default for Nabla: false
attention_local: bool = None, # Default for Nabla: false
attention_glob: bool = None, # Default for Nabla: false
attention_window: int = None, # Default for Nabla: 3
attention_P: float = None, # Default for Nabla: 0.9
attention_wT: int = None, # Default for Nabla: 11
attention_wW: int = None, # Default for Nabla: 3
attention_wH: int = None, # Default for Nabla: 3
attention_add_sta: bool = None, # Default for Nabla: true
attention_method: str = None, # Default for Nabla: "topcdf"
attention_causal: bool = None, # Default for Nabla: false
attention_local: bool = None, # Default for Nabla: false
attention_glob: bool = None, # Default for Nabla: false
attention_window: int = None, # Default for Nabla: 3
attention_P: float = None, # Default for Nabla: 0.9
attention_wT: int = None, # Default for Nabla: 11
attention_wW: int = None, # Default for Nabla: 3
attention_wH: int = None, # Default for Nabla: 3
attention_add_sta: bool = None, # Default for Nabla: true
attention_method: str = None, # Default for Nabla: "topcdf"
):
super().__init__()
head_dim = sum(axes_dims)
self.in_visual_dim = in_visual_dim
self.model_dim = model_dim
@@ -737,12 +736,14 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr
self.attention_type = attention_type
visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim
# Initialize embeddings
self.time_embeddings = Kandinsky5TimeEmbeddings(model_dim, time_dim)
self.text_embeddings = Kandinsky5TextEmbeddings(in_text_dim, model_dim)
self.pooled_text_embeddings = Kandinsky5TextEmbeddings(in_text_dim2, time_dim)
self.visual_embeddings = Kandinsky5VisualEmbeddings(visual_embed_dim, model_dim, patch_size)
self.visual_embeddings = Kandinsky5VisualEmbeddings(
visual_embed_dim, model_dim, patch_size
)
# Initialize positional embeddings
self.text_rope_embeddings = Kandinsky5RoPE1D(head_dim)
@@ -764,10 +765,14 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr
)
# Initialize output layer
self.out_layer = Kandinsky5OutLayer(model_dim, time_dim, out_visual_dim, patch_size)
self.out_layer = Kandinsky5OutLayer(
model_dim, time_dim, out_visual_dim, patch_size
)
self.gradient_checkpointing = False
def prepare_text_embeddings(self, text_embed, time, pooled_text_embed, x, text_rope_pos):
def prepare_text_embeddings(
self, text_embed, time, pooled_text_embed, x, text_rope_pos
):
"""Prepare text embeddings and related components"""
text_embed = self.text_embeddings(text_embed)
time_embed = self.time_embeddings(time)
@@ -777,38 +782,58 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr
text_rope = text_rope.unsqueeze(dim=0)
return text_embed, time_embed, text_rope, visual_embed
def prepare_visual_embeddings(self, visual_embed, visual_rope_pos, scale_factor, sparse_params):
def prepare_visual_embeddings(
self, visual_embed, visual_rope_pos, scale_factor, sparse_params
):
"""Prepare visual embeddings and related components"""
visual_shape = visual_embed.shape[:-1]
visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor)
visual_rope = self.visual_rope_embeddings(
visual_shape, visual_rope_pos, scale_factor
)
to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False
visual_embed, visual_rope = fractal_flatten(visual_embed, visual_rope, visual_shape,
block_mask=to_fractal)
visual_embed, visual_rope = fractal_flatten(
visual_embed, visual_rope, visual_shape, block_mask=to_fractal
)
return visual_embed, visual_shape, to_fractal, visual_rope
def process_text_transformer_blocks(self, text_embed, time_embed, text_rope):
"""Process text through transformer blocks"""
for text_transformer_block in self.text_transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
text_embed = self._gradient_checkpointing_func(text_transformer_block, text_embed, time_embed, text_rope)
text_embed = self._gradient_checkpointing_func(
text_transformer_block, text_embed, time_embed, text_rope
)
else:
text_embed = text_transformer_block(text_embed, time_embed, text_rope)
return text_embed
def process_visual_transformer_blocks(self, visual_embed, text_embed, time_embed, visual_rope, sparse_params):
def process_visual_transformer_blocks(
self, visual_embed, text_embed, time_embed, visual_rope, sparse_params
):
"""Process visual through transformer blocks"""
for visual_transformer_block in self.visual_transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
visual_embed = self._gradient_checkpointing_func(visual_transformer_block, visual_embed, text_embed, time_embed,
visual_rope, sparse_params)
visual_embed = self._gradient_checkpointing_func(
visual_transformer_block,
visual_embed,
text_embed,
time_embed,
visual_rope,
sparse_params,
)
else:
visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed,
visual_rope, sparse_params)
visual_embed = visual_transformer_block(
visual_embed, text_embed, time_embed, visual_rope, sparse_params
)
return visual_embed
def prepare_output(self, visual_embed, visual_shape, to_fractal, text_embed, time_embed):
def prepare_output(
self, visual_embed, visual_shape, to_fractal, text_embed, time_embed
):
"""Prepare the final output"""
visual_embed = fractal_unflatten(visual_embed, visual_shape, block_mask=to_fractal)
visual_embed = fractal_unflatten(
visual_embed, visual_shape, block_mask=to_fractal
)
x = self.out_layer(visual_embed, text_embed, time_embed)
return x
@@ -846,25 +871,34 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr
text_embed = encoder_hidden_states
time = timestep
pooled_text_embed = pooled_projections
# Prepare text embeddings and related components
text_embed, time_embed, text_rope, visual_embed = self.prepare_text_embeddings(
text_embed, time, pooled_text_embed, x, text_rope_pos)
text_embed, time, pooled_text_embed, x, text_rope_pos
)
# Process text through transformer blocks
text_embed = self.process_text_transformer_blocks(text_embed, time_embed, text_rope)
text_embed = self.process_text_transformer_blocks(
text_embed, time_embed, text_rope
)
# Prepare visual embeddings and related components
visual_embed, visual_shape, to_fractal, visual_rope = self.prepare_visual_embeddings(
visual_embed, visual_rope_pos, scale_factor, sparse_params)
visual_embed, visual_shape, to_fractal, visual_rope = (
self.prepare_visual_embeddings(
visual_embed, visual_rope_pos, scale_factor, sparse_params
)
)
# Process visual through transformer blocks
visual_embed = self.process_visual_transformer_blocks(
visual_embed, text_embed, time_embed, visual_rope, sparse_params)
visual_embed, text_embed, time_embed, visual_rope, sparse_params
)
# Prepare final output
x = self.prepare_output(visual_embed, visual_shape, to_fractal, text_embed, time_embed)
x = self.prepare_output(
visual_embed, visual_shape, to_fractal, text_embed, time_embed
)
if not return_dict:
return x