diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index a057cc13cc..cca83988a7 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -35,6 +35,23 @@ from ..normalization import FP32LayerNorm logger = logging.get_logger(__name__) +if torch.cuda.get_device_capability()[0] >= 9: + try: + from flash_attn_interface import flash_attn_func as FA + except: + FA = None + + try: + from flash_attn import flash_attn_func as FA + except: + FA = None +else: + try: + from flash_attn import flash_attn_func as FA + except: + FA = None + + # @torch.compile() @torch.autocast(device_type="cuda", dtype=torch.float32) def apply_scale_shift_norm(norm, x, scale, shift): @@ -99,7 +116,7 @@ class VisualEmbeddings(nn.Module): super().__init__() self.patch_size = patch_size self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim) - + def forward(self, x): batch_size, duration, height, width, dim = x.shape x = ( @@ -107,7 +124,7 @@ class VisualEmbeddings(nn.Module): batch_size, duration // self.patch_size[0], self.patch_size[0], - height // self.patch_size[1], + height // self.patch_size[1], self.patch_size[1], width // self.patch_size[2], self.patch_size[2], @@ -169,24 +186,23 @@ class RoPE3D(nn.Module): @torch.autocast(device_type="cuda", enabled=False) def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)): batch_size, duration, height, width = shape + args_t = self.args_0[pos[0]] / scale_factor[0] args_h = self.args_1[pos[1]] / scale_factor[1] args_w = self.args_2[pos[2]] / scale_factor[2] - # Replicate the original logic with batch dimension args_t_expanded = args_t.view(1, duration, 1, 1, -1).expand(batch_size, -1, height, width, -1) args_h_expanded = args_h.view(1, 1, height, 1, -1).expand(batch_size, duration, -1, width, -1) args_w_expanded = args_w.view(1, 1, 1, width, -1).expand(batch_size, duration, height, -1, -1) - # Concatenate along the last dimension - args = torch.cat([args_t_expanded, args_h_expanded, args_w_expanded], dim=-1) # [B, D, H, W, F] + args = torch.cat([args_t_expanded, args_h_expanded, args_w_expanded], dim=-1) cosine = torch.cos(args) sine = torch.sin(args) - rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) # [B, D, H, W, F, 4] - rope = rope.view(*rope.shape[:-1], 2, 2) # [B, D, H, W, F, 2, 2] - return rope.unsqueeze(-4) # [B, D, H, 1, W, F, 2, 2] - + rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) + rope = rope.view(*rope.shape[:-1], 2, 2) + return rope.unsqueeze(-4) + class Modulation(nn.Module): def __init__(self, time_dim, model_dim, num_params): @@ -230,11 +246,14 @@ class MultiheadSelfAttentionEnc(nn.Module): key = apply_rotary(key, rope).type_as(key) # Use torch's scaled_dot_product_attention - out = F.scaled_dot_product_attention( - query, - key, - value, - ).flatten(-2, -1) + # print(query.shape, key.shape, value.shape, "QKV MultiheadSelfAttentionEnc SHAPE") + # out = F.scaled_dot_product_attention( + # query.permute(0, 2, 1, 3), + # key.permute(0, 2, 1, 3), + # value.permute(0, 2, 1, 3), + # ).permute(0, 2, 1, 3).flatten(-2, -1) + + out = FA(q=query, k=key, v=value).flatten(-2, -1) out = self.out_layer(out) return out @@ -270,11 +289,15 @@ class MultiheadSelfAttentionDec(nn.Module): key = apply_rotary(key, rope).type_as(key) # Use standard attention (can be extended with sparse attention) - out = F.scaled_dot_product_attention( - query, - key, - value, - ).flatten(-2, -1) + # out = F.scaled_dot_product_attention( + # query.permute(0, 2, 1, 3), + # key.permute(0, 2, 1, 3), + # value.permute(0, 2, 1, 3), + # ).permute(0, 2, 1, 3).flatten(-2, -1) + + # print(query.shape, key.shape, value.shape, "QKV MultiheadSelfAttentionDec SHAPE") + + out = FA(q=query, k=key, v=value).flatten(-2, -1) out = self.out_layer(out) return out @@ -306,11 +329,15 @@ class MultiheadCrossAttention(nn.Module): query = self.query_norm(query.float()).type_as(query) key = self.key_norm(key.float()).type_as(key) - out = F.scaled_dot_product_attention( - query.permute(0, 2, 1, 3), - key.permute(0, 2, 1, 3), - value.permute(0, 2, 1, 3), - ).permute(0, 2, 1, 3).flatten(-2, -1) + # out = F.scaled_dot_product_attention( + # query.permute(0, 2, 1, 3), + # key.permute(0, 2, 1, 3), + # value.permute(0, 2, 1, 3), + # ).permute(0, 2, 1, 3).flatten(-2, -1) + + # print(query.shape, key.shape, value.shape, "QKV MultiheadCrossAttention SHAPE") + + out = FA(q=query, k=key, v=value).flatten(-2, -1) out = self.out_layer(out) return out @@ -339,19 +366,18 @@ class TransformerEncoderBlock(nn.Module): self.feed_forward = FeedForward(model_dim, ff_dim) def forward(self, x, time_embed, rope): - self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1) + self_attn_params, ff_params = torch.chunk( + self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 + ) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) - - out = self.self_attention_norm(x) - out = out * (scale + 1.0) + shift + out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift) out = self.self_attention(out, rope) - x = x + gate * out + x = apply_gate_sum(x, out, gate) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) - out = self.feed_forward_norm(x) - out = out * (scale + 1.0) + shift + out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift) out = self.feed_forward(out) - x = x + gate * out + x = apply_gate_sum(x, out, gate) return x @@ -371,26 +397,22 @@ class TransformerDecoderBlock(nn.Module): def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): self_attn_params, cross_attn_params, ff_params = torch.chunk( - self.visual_modulation(time_embed), 3, dim=-1 + self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1 ) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) - - visual_out = self.self_attention_norm(visual_embed) - visual_out = visual_out * (scale + 1.0) + shift + visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift) visual_out = self.self_attention(visual_out, rope, sparse_params) - visual_embed = visual_embed + gate * visual_out + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) - visual_out = self.cross_attention_norm(visual_embed) - visual_out = visual_out * (scale + 1.0) + shift + visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift) visual_out = self.cross_attention(visual_out, text_embed) - visual_embed = visual_embed + gate * visual_out + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) - visual_out = self.feed_forward_norm(visual_embed) - visual_out = visual_out * (scale + 1.0) + shift + visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift) visual_out = self.feed_forward(visual_out) - visual_embed = visual_embed + gate * visual_out + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) return visual_embed @@ -575,7 +597,7 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin): # 1. Process text embeddings text_embed = self.text_embeddings(encoder_hidden_states) time_embed = self.time_embeddings(timestep) - + # Add pooled text embedding to time embedding pooled_embed = self.pooled_text_embeddings(pooled_text_embed) time_embed = time_embed + pooled_embed @@ -587,22 +609,29 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin): text_rope = self.text_rope_embeddings(text_rope_pos) # 4. Text transformer blocks + i = 0 for text_block in self.text_transformer_blocks: if self.gradient_checkpointing and self.training: text_embed = torch.utils.checkpoint.checkpoint( text_block, text_embed, time_embed, text_rope, use_reentrant=False ) + else: text_embed = text_block(text_embed, time_embed, text_rope) + i += 1 + # 5. Prepare visual rope visual_shape = visual_embed.shape[:-1] visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor) + + # visual_embed = visual_embed.reshape(visual_embed.shape[0], -1, visual_embed.shape[-1]) + # visual_rope = visual_rope.view(visual_rope.shape[0], -1, *list(visual_rope.shape[-4:])) + visual_embed = visual_embed.flatten(1, 3) + visual_rope = visual_rope.flatten(1, 3) - visual_embed = visual_embed.reshape(visual_embed.shape[0], -1, visual_embed.shape[-1]) - visual_rope = visual_rope.view(visual_rope.shape[0], -1, *list(visual_rope.shape[-4:])) - # 6. Visual transformer blocks + i = 0 for visual_block in self.visual_transformer_blocks: if self.gradient_checkpointing and self.training: visual_embed = torch.utils.checkpoint.checkpoint( @@ -619,6 +648,8 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin): visual_embed = visual_block( visual_embed, text_embed, time_embed, visual_rope, sparse_params ) + + i += 1 # 7. Output projection visual_embed = visual_embed.reshape(batch_size, num_frames, height // 2, width // 2, -1) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 02eae13633..9dbf31fea9 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -220,19 +220,14 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): ): device = device or self._execution_device - # Encode with Qwen2.5-VL - prompt_embeds, prompt_cu_seqlens = self._encode_prompt_qwen( - prompt, device, num_videos_per_prompt - ) + prompt_embeds, prompt_cu_seqlens = self._encode_prompt_qwen(prompt, device, num_videos_per_prompt) pooled_embed = self._encode_prompt_clip(prompt, device, num_videos_per_prompt) if do_classifier_free_guidance: negative_prompt = negative_prompt or "" negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_embeds, negative_cu_seqlens = self._encode_prompt_qwen( - negative_prompt, device, num_videos_per_prompt - ) + negative_prompt_embeds, negative_cu_seqlens = self._encode_prompt_qwen(negative_prompt, device, num_videos_per_prompt) negative_pooled_embed = self._encode_prompt_clip(negative_prompt, device, num_videos_per_prompt) else: negative_prompt_embeds = None @@ -264,23 +259,25 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if latents is not None: - return latents.to(device=device, dtype=dtype) + num_latent_frames = latents.shape[1] + latents = latents.to(device=device, dtype=dtype) - num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - shape = ( - batch_size, - num_latent_frames, - int(height) // self.vae_scale_factor_spatial, - int(width) // self.vae_scale_factor_spatial, - num_channels_latents, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." + else: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + num_channels_latents, ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) if visual_cond: # For visual conditioning, concatenate with zeros and mask @@ -294,50 +291,6 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): return latents - def get_velocity( - self, - latents: torch.Tensor, - timestep: torch.Tensor, - text_embeds: Dict[str, torch.Tensor], - negative_text_embeds: Optional[Dict[str, torch.Tensor]], - visual_rope_pos: List[torch.Tensor], - text_rope_pos: torch.Tensor, - negative_text_rope_pos: torch.Tensor, - guidance_scale: float, - sparse_params: Optional[Dict] = None, - ): - # print(latents.shape, text_embeds["text_embeds"].shape, text_embeds["pooled_embed"].shape, timestep.shape, [el.shape for el in visual_rope_pos], text_rope_pos, sparse_params) - - pred_velocity = self.transformer( - latents, - text_embeds["text_embeds"], - text_embeds["pooled_embed"], - timestep * 1000, # Scale to match training - visual_rope_pos, - text_rope_pos, - scale_factor=(1, 2, 2), # From Kandinsky config - sparse_params=sparse_params, - return_dict=False - )[0] - - if guidance_scale > 1.0 and negative_text_embeds is not None: - uncond_pred_velocity = self.transformer( - latents, - negative_text_embeds["text_embeds"], - negative_text_embeds["pooled_embed"], - timestep * 1000, - visual_rope_pos, - negative_text_rope_pos, - scale_factor=(1, 2, 2), - sparse_params=sparse_params, - return_dict=False - )[0] - - pred_velocity = uncond_pred_velocity + guidance_scale * ( - pred_velocity - uncond_pred_velocity - ) - - return pred_velocity @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -402,11 +355,9 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ - # 1. Check inputs if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") - # 2. Define call parameters if isinstance(prompt, str): batch_size = 1 else: @@ -415,16 +366,18 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 - if num_frames % self.vae_scale_factor_temporal != 1: logger.warning( f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." ) num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - - # 3. Encode input prompt text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, @@ -433,11 +386,6 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): device=device, ) - # 4. Prepare timesteps (Kandinsky uses custom flow matching) - timesteps = torch.linspace(1, 0, num_inference_steps + 1, device=device) - timesteps = scheduler_scale * timesteps / (1 + (scheduler_scale - 1) * timesteps) - - # 5. Prepare latent variables num_channels_latents = 16 latents = self.prepare_latents( batch_size=batch_size * num_videos_per_prompt, @@ -451,11 +399,12 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): generator=generator, latents=latents, ) + + visual_cond = latents[:, :, :, :, 16:] - # 6. Prepare rope positions visual_rope_pos = [ torch.arange(num_frames // 4 + 1, device=device), - torch.arange(height // 8 // 2, device=device), # patch size 2 + torch.arange(height // 8 // 2, device=device), torch.arange(width // 8 // 2, device=device), ] @@ -467,31 +416,43 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): else None ) - # 7. Prepare sparse attention params if needed - sparse_params = None # Can be extended based on Kandinsky attention config - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, (timestep, timestep_diff) in enumerate(zip(timesteps[:-1], torch.diff(timesteps))): - # Expand timestep to match batch size - time = timestep.unsqueeze(0) + for i, t in enumerate(timesteps): + timestep = t.unsqueeze(0) - pred_velocity = self.get_velocity( - latents, - time, - text_embeds, - negative_text_embeds, - visual_rope_pos, - text_rope_pos, - negative_text_rope_pos, - guidance_scale, - sparse_params, - ) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + # print(latents.shape) + pred_velocity = self.transformer( + latents, + text_embeds["text_embeds"], + text_embeds["pooled_embed"], + timestep, + visual_rope_pos, + text_rope_pos, + scale_factor=(1, 2, 2), + sparse_params=None, + return_dict=False + )[0] + + if guidance_scale > 1.0 and negative_text_embeds is not None: + uncond_pred_velocity = self.transformer( + latents, + negative_text_embeds["text_embeds"], + negative_text_embeds["pooled_embed"], + timestep, + visual_rope_pos, + negative_text_rope_pos, + scale_factor=(1, 2, 2), + sparse_params=None, + return_dict=False + )[0] - # Update latents using flow matching - latents[:, :, :, :, :16] = latents[:, :, :, :, :16] + timestep_diff * pred_velocity + pred_velocity = uncond_pred_velocity + guidance_scale * ( + pred_velocity - uncond_pred_velocity + ) + + latents = self.scheduler.step(pred_velocity, t, latents[:, :, :, :, :16], return_dict=False)[0] + latents = torch.cat([latents, visual_cond], dim=-1) if callback_on_step_end is not None: callback_kwargs = {} @@ -499,8 +460,8 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs) latents = callback_outputs.pop("latents", latents) - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % 1 == 0): + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() latents = latents[:, :, :, :, :16] @@ -524,7 +485,6 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): video = video / self.vae.config.scaling_factor video = self.vae.decode(video).sample video = ((video.clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8) - # Convert to output format if output_type == "pil": if num_frames == 1: @@ -533,6 +493,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): else: # Video frames video = [video[i] for i in range(video.shape[0])] + else: video = latents