mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update
This commit is contained in:
@@ -139,6 +139,7 @@ def retrieve_timesteps(
|
||||
|
||||
class MochiPipeline(
|
||||
DiffusionPipeline,
|
||||
TextualInversionLoaderMixin
|
||||
):
|
||||
r"""
|
||||
The Flux pipeline for text-to-image generation.
|
||||
@@ -187,14 +188,17 @@ class MochiPipeline(
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
||||
)
|
||||
#TODO: determine these scaling factors from model parameters
|
||||
self.vae_spatial_scale_factor = 8
|
||||
self.vae_temporal_scale_factor = 6
|
||||
self.patch_size = 2
|
||||
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
self.default_sample_size = 64
|
||||
self.default_height = 64
|
||||
self.default_width = 64
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
@@ -235,7 +239,7 @@ class MochiPipeline(
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=False)[0]
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=False).last_hidden_state
|
||||
|
||||
dtype = self.text_encoder.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
@@ -246,7 +250,32 @@ class MochiPipeline(
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
def _pack_indices(self, attention_mask, latent_frames_dim, latent_height_dim, latent_width_dim):
|
||||
N = latent_frames_dim * latent_height_dim * latent_width_dim // (self.patch_size**2)
|
||||
|
||||
# Create an expanded token mask saying which tokens are valid across both visual and text tokens.
|
||||
assert N > 0 and len(attention_mask) == 1
|
||||
attention_mask = attention_mask[0]
|
||||
|
||||
mask = F.pad(attention_mask, (N, 0), value=True) # (B, N + L)
|
||||
seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,)
|
||||
valid_token_indices = torch.nonzero(
|
||||
mask.flatten(), as_tuple=False
|
||||
).flatten() # up to (B * (N + L),)
|
||||
|
||||
assert valid_token_indices.size(0) >= attention_mask.size(0) * N # At least (B * N,)
|
||||
cu_seqlens = F.pad(
|
||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
||||
)
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
|
||||
return {
|
||||
"cu_seqlens_kv": cu_seqlens,
|
||||
"max_seqlen_in_batch_kv": max_seqlen_in_batch,
|
||||
"valid_token_indices_kv": valid_token_indices,
|
||||
}
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user