1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Dhruv Nair
2024-10-24 14:26:23 +02:00
parent ccc1b36b09
commit 275041d21e

View File

@@ -139,6 +139,7 @@ def retrieve_timesteps(
class MochiPipeline(
DiffusionPipeline,
TextualInversionLoaderMixin
):
r"""
The Flux pipeline for text-to-image generation.
@@ -187,14 +188,17 @@ class MochiPipeline(
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
)
#TODO: determine these scaling factors from model parameters
self.vae_spatial_scale_factor = 8
self.vae_temporal_scale_factor = 6
self.patch_size = 2
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = 64
self.default_height = 64
self.default_width = 64
def _get_t5_prompt_embeds(
self,
@@ -235,7 +239,7 @@ class MochiPipeline(
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=False)[0]
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=False).last_hidden_state
dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
@@ -246,7 +250,32 @@ class MochiPipeline(
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
return prompt_embeds, prompt_attention_mask
def _pack_indices(self, attention_mask, latent_frames_dim, latent_height_dim, latent_width_dim):
N = latent_frames_dim * latent_height_dim * latent_width_dim // (self.patch_size**2)
# Create an expanded token mask saying which tokens are valid across both visual and text tokens.
assert N > 0 and len(attention_mask) == 1
attention_mask = attention_mask[0]
mask = F.pad(attention_mask, (N, 0), value=True) # (B, N + L)
seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,)
valid_token_indices = torch.nonzero(
mask.flatten(), as_tuple=False
).flatten() # up to (B * (N + L),)
assert valid_token_indices.size(0) >= attention_mask.size(0) * N # At least (B * N,)
cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
)
max_seqlen_in_batch = seqlens_in_batch.max().item()
return {
"cu_seqlens_kv": cu_seqlens,
"max_seqlen_in_batch_kv": max_seqlen_in_batch,
"valid_token_indices_kv": valid_token_indices,
}
def encode_prompt(
self,