1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
leffff
2025-10-06 12:43:04 +00:00
parent d53f848720
commit 7db6093c53
2 changed files with 141 additions and 149 deletions

View File

@@ -35,6 +35,23 @@ from ..normalization import FP32LayerNorm
logger = logging.get_logger(__name__)
if torch.cuda.get_device_capability()[0] >= 9:
try:
from flash_attn_interface import flash_attn_func as FA
except:
FA = None
try:
from flash_attn import flash_attn_func as FA
except:
FA = None
else:
try:
from flash_attn import flash_attn_func as FA
except:
FA = None
# @torch.compile()
@torch.autocast(device_type="cuda", dtype=torch.float32)
def apply_scale_shift_norm(norm, x, scale, shift):
@@ -99,7 +116,7 @@ class VisualEmbeddings(nn.Module):
super().__init__()
self.patch_size = patch_size
self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim)
def forward(self, x):
batch_size, duration, height, width, dim = x.shape
x = (
@@ -107,7 +124,7 @@ class VisualEmbeddings(nn.Module):
batch_size,
duration // self.patch_size[0],
self.patch_size[0],
height // self.patch_size[1],
height // self.patch_size[1],
self.patch_size[1],
width // self.patch_size[2],
self.patch_size[2],
@@ -169,24 +186,23 @@ class RoPE3D(nn.Module):
@torch.autocast(device_type="cuda", enabled=False)
def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)):
batch_size, duration, height, width = shape
args_t = self.args_0[pos[0]] / scale_factor[0]
args_h = self.args_1[pos[1]] / scale_factor[1]
args_w = self.args_2[pos[2]] / scale_factor[2]
# Replicate the original logic with batch dimension
args_t_expanded = args_t.view(1, duration, 1, 1, -1).expand(batch_size, -1, height, width, -1)
args_h_expanded = args_h.view(1, 1, height, 1, -1).expand(batch_size, duration, -1, width, -1)
args_w_expanded = args_w.view(1, 1, 1, width, -1).expand(batch_size, duration, height, -1, -1)
# Concatenate along the last dimension
args = torch.cat([args_t_expanded, args_h_expanded, args_w_expanded], dim=-1) # [B, D, H, W, F]
args = torch.cat([args_t_expanded, args_h_expanded, args_w_expanded], dim=-1)
cosine = torch.cos(args)
sine = torch.sin(args)
rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) # [B, D, H, W, F, 4]
rope = rope.view(*rope.shape[:-1], 2, 2) # [B, D, H, W, F, 2, 2]
return rope.unsqueeze(-4) # [B, D, H, 1, W, F, 2, 2]
rope = torch.stack([cosine, -sine, sine, cosine], dim=-1)
rope = rope.view(*rope.shape[:-1], 2, 2)
return rope.unsqueeze(-4)
class Modulation(nn.Module):
def __init__(self, time_dim, model_dim, num_params):
@@ -230,11 +246,14 @@ class MultiheadSelfAttentionEnc(nn.Module):
key = apply_rotary(key, rope).type_as(key)
# Use torch's scaled_dot_product_attention
out = F.scaled_dot_product_attention(
query,
key,
value,
).flatten(-2, -1)
# print(query.shape, key.shape, value.shape, "QKV MultiheadSelfAttentionEnc SHAPE")
# out = F.scaled_dot_product_attention(
# query.permute(0, 2, 1, 3),
# key.permute(0, 2, 1, 3),
# value.permute(0, 2, 1, 3),
# ).permute(0, 2, 1, 3).flatten(-2, -1)
out = FA(q=query, k=key, v=value).flatten(-2, -1)
out = self.out_layer(out)
return out
@@ -270,11 +289,15 @@ class MultiheadSelfAttentionDec(nn.Module):
key = apply_rotary(key, rope).type_as(key)
# Use standard attention (can be extended with sparse attention)
out = F.scaled_dot_product_attention(
query,
key,
value,
).flatten(-2, -1)
# out = F.scaled_dot_product_attention(
# query.permute(0, 2, 1, 3),
# key.permute(0, 2, 1, 3),
# value.permute(0, 2, 1, 3),
# ).permute(0, 2, 1, 3).flatten(-2, -1)
# print(query.shape, key.shape, value.shape, "QKV MultiheadSelfAttentionDec SHAPE")
out = FA(q=query, k=key, v=value).flatten(-2, -1)
out = self.out_layer(out)
return out
@@ -306,11 +329,15 @@ class MultiheadCrossAttention(nn.Module):
query = self.query_norm(query.float()).type_as(query)
key = self.key_norm(key.float()).type_as(key)
out = F.scaled_dot_product_attention(
query.permute(0, 2, 1, 3),
key.permute(0, 2, 1, 3),
value.permute(0, 2, 1, 3),
).permute(0, 2, 1, 3).flatten(-2, -1)
# out = F.scaled_dot_product_attention(
# query.permute(0, 2, 1, 3),
# key.permute(0, 2, 1, 3),
# value.permute(0, 2, 1, 3),
# ).permute(0, 2, 1, 3).flatten(-2, -1)
# print(query.shape, key.shape, value.shape, "QKV MultiheadCrossAttention SHAPE")
out = FA(q=query, k=key, v=value).flatten(-2, -1)
out = self.out_layer(out)
return out
@@ -339,19 +366,18 @@ class TransformerEncoderBlock(nn.Module):
self.feed_forward = FeedForward(model_dim, ff_dim)
def forward(self, x, time_embed, rope):
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
self_attn_params, ff_params = torch.chunk(
self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1
)
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
out = self.self_attention_norm(x)
out = out * (scale + 1.0) + shift
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
out = self.self_attention(out, rope)
x = x + gate * out
x = apply_gate_sum(x, out, gate)
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
out = self.feed_forward_norm(x)
out = out * (scale + 1.0) + shift
out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
out = self.feed_forward(out)
x = x + gate * out
x = apply_gate_sum(x, out, gate)
return x
@@ -371,26 +397,22 @@ class TransformerDecoderBlock(nn.Module):
def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params):
self_attn_params, cross_attn_params, ff_params = torch.chunk(
self.visual_modulation(time_embed), 3, dim=-1
self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1
)
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
visual_out = self.self_attention_norm(visual_embed)
visual_out = visual_out * (scale + 1.0) + shift
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
visual_out = self.self_attention(visual_out, rope, sparse_params)
visual_embed = visual_embed + gate * visual_out
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1)
visual_out = self.cross_attention_norm(visual_embed)
visual_out = visual_out * (scale + 1.0) + shift
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
visual_out = self.cross_attention(visual_out, text_embed)
visual_embed = visual_embed + gate * visual_out
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
visual_out = self.feed_forward_norm(visual_embed)
visual_out = visual_out * (scale + 1.0) + shift
visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
visual_out = self.feed_forward(visual_out)
visual_embed = visual_embed + gate * visual_out
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
return visual_embed
@@ -575,7 +597,7 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin):
# 1. Process text embeddings
text_embed = self.text_embeddings(encoder_hidden_states)
time_embed = self.time_embeddings(timestep)
# Add pooled text embedding to time embedding
pooled_embed = self.pooled_text_embeddings(pooled_text_embed)
time_embed = time_embed + pooled_embed
@@ -587,22 +609,29 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin):
text_rope = self.text_rope_embeddings(text_rope_pos)
# 4. Text transformer blocks
i = 0
for text_block in self.text_transformer_blocks:
if self.gradient_checkpointing and self.training:
text_embed = torch.utils.checkpoint.checkpoint(
text_block, text_embed, time_embed, text_rope, use_reentrant=False
)
else:
text_embed = text_block(text_embed, time_embed, text_rope)
i += 1
# 5. Prepare visual rope
visual_shape = visual_embed.shape[:-1]
visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor)
# visual_embed = visual_embed.reshape(visual_embed.shape[0], -1, visual_embed.shape[-1])
# visual_rope = visual_rope.view(visual_rope.shape[0], -1, *list(visual_rope.shape[-4:]))
visual_embed = visual_embed.flatten(1, 3)
visual_rope = visual_rope.flatten(1, 3)
visual_embed = visual_embed.reshape(visual_embed.shape[0], -1, visual_embed.shape[-1])
visual_rope = visual_rope.view(visual_rope.shape[0], -1, *list(visual_rope.shape[-4:]))
# 6. Visual transformer blocks
i = 0
for visual_block in self.visual_transformer_blocks:
if self.gradient_checkpointing and self.training:
visual_embed = torch.utils.checkpoint.checkpoint(
@@ -619,6 +648,8 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin):
visual_embed = visual_block(
visual_embed, text_embed, time_embed, visual_rope, sparse_params
)
i += 1
# 7. Output projection
visual_embed = visual_embed.reshape(batch_size, num_frames, height // 2, width // 2, -1)

View File

@@ -220,19 +220,14 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
):
device = device or self._execution_device
# Encode with Qwen2.5-VL
prompt_embeds, prompt_cu_seqlens = self._encode_prompt_qwen(
prompt, device, num_videos_per_prompt
)
prompt_embeds, prompt_cu_seqlens = self._encode_prompt_qwen(prompt, device, num_videos_per_prompt)
pooled_embed = self._encode_prompt_clip(prompt, device, num_videos_per_prompt)
if do_classifier_free_guidance:
negative_prompt = negative_prompt or ""
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_embeds, negative_cu_seqlens = self._encode_prompt_qwen(
negative_prompt, device, num_videos_per_prompt
)
negative_prompt_embeds, negative_cu_seqlens = self._encode_prompt_qwen(negative_prompt, device, num_videos_per_prompt)
negative_pooled_embed = self._encode_prompt_clip(negative_prompt, device, num_videos_per_prompt)
else:
negative_prompt_embeds = None
@@ -264,23 +259,25 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
num_latent_frames = latents.shape[1]
latents = latents.to(device=device, dtype=dtype)
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
shape = (
batch_size,
num_latent_frames,
int(height) // self.vae_scale_factor_spatial,
int(width) // self.vae_scale_factor_spatial,
num_channels_latents,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
else:
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
shape = (
batch_size,
num_latent_frames,
int(height) // self.vae_scale_factor_spatial,
int(width) // self.vae_scale_factor_spatial,
num_channels_latents,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if visual_cond:
# For visual conditioning, concatenate with zeros and mask
@@ -294,50 +291,6 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
return latents
def get_velocity(
self,
latents: torch.Tensor,
timestep: torch.Tensor,
text_embeds: Dict[str, torch.Tensor],
negative_text_embeds: Optional[Dict[str, torch.Tensor]],
visual_rope_pos: List[torch.Tensor],
text_rope_pos: torch.Tensor,
negative_text_rope_pos: torch.Tensor,
guidance_scale: float,
sparse_params: Optional[Dict] = None,
):
# print(latents.shape, text_embeds["text_embeds"].shape, text_embeds["pooled_embed"].shape, timestep.shape, [el.shape for el in visual_rope_pos], text_rope_pos, sparse_params)
pred_velocity = self.transformer(
latents,
text_embeds["text_embeds"],
text_embeds["pooled_embed"],
timestep * 1000, # Scale to match training
visual_rope_pos,
text_rope_pos,
scale_factor=(1, 2, 2), # From Kandinsky config
sparse_params=sparse_params,
return_dict=False
)[0]
if guidance_scale > 1.0 and negative_text_embeds is not None:
uncond_pred_velocity = self.transformer(
latents,
negative_text_embeds["text_embeds"],
negative_text_embeds["pooled_embed"],
timestep * 1000,
visual_rope_pos,
negative_text_rope_pos,
scale_factor=(1, 2, 2),
sparse_params=sparse_params,
return_dict=False
)[0]
pred_velocity = uncond_pred_velocity + guidance_scale * (
pred_velocity - uncond_pred_velocity
)
return pred_velocity
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
@@ -402,11 +355,9 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
# 1. Check inputs
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
# 2. Define call parameters
if isinstance(prompt, str):
batch_size = 1
else:
@@ -415,16 +366,18 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
device = self._execution_device
do_classifier_free_guidance = guidance_scale > 1.0
if num_frames % self.vae_scale_factor_temporal != 1:
logger.warning(
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
)
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
# 3. Encode input prompt
text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
@@ -433,11 +386,6 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
device=device,
)
# 4. Prepare timesteps (Kandinsky uses custom flow matching)
timesteps = torch.linspace(1, 0, num_inference_steps + 1, device=device)
timesteps = scheduler_scale * timesteps / (1 + (scheduler_scale - 1) * timesteps)
# 5. Prepare latent variables
num_channels_latents = 16
latents = self.prepare_latents(
batch_size=batch_size * num_videos_per_prompt,
@@ -451,11 +399,12 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
generator=generator,
latents=latents,
)
visual_cond = latents[:, :, :, :, 16:]
# 6. Prepare rope positions
visual_rope_pos = [
torch.arange(num_frames // 4 + 1, device=device),
torch.arange(height // 8 // 2, device=device), # patch size 2
torch.arange(height // 8 // 2, device=device),
torch.arange(width // 8 // 2, device=device),
]
@@ -467,31 +416,43 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
else None
)
# 7. Prepare sparse attention params if needed
sparse_params = None # Can be extended based on Kandinsky attention config
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, (timestep, timestep_diff) in enumerate(zip(timesteps[:-1], torch.diff(timesteps))):
# Expand timestep to match batch size
time = timestep.unsqueeze(0)
for i, t in enumerate(timesteps):
timestep = t.unsqueeze(0)
pred_velocity = self.get_velocity(
latents,
time,
text_embeds,
negative_text_embeds,
visual_rope_pos,
text_rope_pos,
negative_text_rope_pos,
guidance_scale,
sparse_params,
)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
# print(latents.shape)
pred_velocity = self.transformer(
latents,
text_embeds["text_embeds"],
text_embeds["pooled_embed"],
timestep,
visual_rope_pos,
text_rope_pos,
scale_factor=(1, 2, 2),
sparse_params=None,
return_dict=False
)[0]
if guidance_scale > 1.0 and negative_text_embeds is not None:
uncond_pred_velocity = self.transformer(
latents,
negative_text_embeds["text_embeds"],
negative_text_embeds["pooled_embed"],
timestep,
visual_rope_pos,
negative_text_rope_pos,
scale_factor=(1, 2, 2),
sparse_params=None,
return_dict=False
)[0]
# Update latents using flow matching
latents[:, :, :, :, :16] = latents[:, :, :, :, :16] + timestep_diff * pred_velocity
pred_velocity = uncond_pred_velocity + guidance_scale * (
pred_velocity - uncond_pred_velocity
)
latents = self.scheduler.step(pred_velocity, t, latents[:, :, :, :, :16], return_dict=False)[0]
latents = torch.cat([latents, visual_cond], dim=-1)
if callback_on_step_end is not None:
callback_kwargs = {}
@@ -499,8 +460,8 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % 1 == 0):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
latents = latents[:, :, :, :, :16]
@@ -524,7 +485,6 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
video = video / self.vae.config.scaling_factor
video = self.vae.decode(video).sample
video = ((video.clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8)
# Convert to output format
if output_type == "pil":
if num_frames == 1:
@@ -533,6 +493,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
else:
# Video frames
video = [video[i] for i in range(video.shape[0])]
else:
video = latents