From 275041d21e6bf786708eae66dc2aad1e7758e499 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 24 Oct 2024 14:26:23 +0200 Subject: [PATCH] update --- .../pipelines/mochi/pipeline_mochi.py | 41 ++++++++++++++++--- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index f9a026440c..dcfed214c5 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -139,6 +139,7 @@ def retrieve_timesteps( class MochiPipeline( DiffusionPipeline, + TextualInversionLoaderMixin ): r""" The Flux pipeline for text-to-image generation. @@ -187,14 +188,17 @@ class MochiPipeline( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 - ) + #TODO: determine these scaling factors from model parameters + self.vae_spatial_scale_factor = 8 + self.vae_temporal_scale_factor = 6 + self.patch_size = 2 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) - self.default_sample_size = 64 + self.default_height = 64 + self.default_width = 64 def _get_t5_prompt_embeds( self, @@ -235,7 +239,7 @@ class MochiPipeline( f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=False)[0] + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=False).last_hidden_state dtype = self.text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -246,7 +250,32 @@ class MochiPipeline( prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - return prompt_embeds + return prompt_embeds, prompt_attention_mask + + def _pack_indices(self, attention_mask, latent_frames_dim, latent_height_dim, latent_width_dim): + N = latent_frames_dim * latent_height_dim * latent_width_dim // (self.patch_size**2) + + # Create an expanded token mask saying which tokens are valid across both visual and text tokens. + assert N > 0 and len(attention_mask) == 1 + attention_mask = attention_mask[0] + + mask = F.pad(attention_mask, (N, 0), value=True) # (B, N + L) + seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,) + valid_token_indices = torch.nonzero( + mask.flatten(), as_tuple=False + ).flatten() # up to (B * (N + L),) + + assert valid_token_indices.size(0) >= attention_mask.size(0) * N # At least (B * N,) + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + max_seqlen_in_batch = seqlens_in_batch.max().item() + + return { + "cu_seqlens_kv": cu_seqlens, + "max_seqlen_in_batch_kv": max_seqlen_in_batch, + "valid_token_indices_kv": valid_token_indices, + } def encode_prompt( self,