1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Dhruv Nair
2024-10-24 16:31:10 +02:00
parent c12ce7d32b
commit 44987ad98c

View File

@@ -15,22 +15,18 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import T5EncoderModel, T5TokenizerFast
from ...image_processor import VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...loaders import TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import MochiTransformer3D
from ...models.transformers import MochiTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
@@ -53,13 +49,13 @@ EXAMPLE_DOC_STRING = """
>>> import torch
>>> from diffusers import MochiPipeline
>>> pipe = MochiPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
>>> pipe = MochiPipeline.from_pretrained("black-forest-labs/mochi.1-schnell", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A cat holding a sign that says hello world"
>>> # Depending on the variant being used, the pipeline call will slightly vary.
>>> # Refer to the pipeline documentation for more details.
>>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
>>> image.save("flux.png")
>>> image.save("mochi.png")
```
"""
@@ -77,6 +73,24 @@ def calculate_shift(
return mu
# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
if linear_steps is None:
linear_steps = num_steps // 2
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
quadratic_steps = num_steps - linear_steps
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
const = quadratic_coef * (linear_steps**2)
quadratic_sigma_schedule = [
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
]
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
sigma_schedule = [1.0 - x for x in sigma_schedule]
return sigma_schedule
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
@@ -137,17 +151,14 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class MochiPipeline(
DiffusionPipeline,
TextualInversionLoaderMixin
):
class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
The Flux pipeline for text-to-image generation.
The mochi pipeline for text-to-image generation.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Args:
transformer ([`FluxTransformer2DModel`]):
transformer ([`mochiTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
@@ -177,7 +188,7 @@ class MochiPipeline(
vae: AutoencoderKL,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: MochiTransformer3D,
transformer: MochiTransformer3DModel,
):
super().__init__()
@@ -188,22 +199,22 @@ class MochiPipeline(
transformer=transformer,
scheduler=scheduler,
)
#TODO: determine these scaling factors from model parameters
# TODO: determine these scaling factors from model parameters
self.vae_spatial_scale_factor = 8
self.vae_temporal_scale_factor = 6
self.patch_size = 2
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_height = 64
self.default_width = 64
self.default_height = 480
self.default_width = 848
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
@@ -227,10 +238,8 @@ class MochiPipeline(
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
@@ -239,7 +248,9 @@ class MochiPipeline(
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=False).last_hidden_state
prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device), output_hidden_states=False
).last_hidden_state
dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
@@ -247,43 +258,17 @@ class MochiPipeline(
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds, prompt_attention_mask
def _pack_indices(self, attention_mask, latent_frames_dim, latent_height_dim, latent_width_dim):
N = latent_frames_dim * latent_height_dim * latent_width_dim // (self.patch_size**2)
# Create an expanded token mask saying which tokens are valid across both visual and text tokens.
assert N > 0 and len(attention_mask) == 1
attention_mask = attention_mask[0]
mask = F.pad(attention_mask, (N, 0), value=True) # (B, N + L)
seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,)
valid_token_indices = torch.nonzero(
mask.flatten(), as_tuple=False
).flatten() # up to (B * (N + L),)
assert valid_token_indices.size(0) >= attention_mask.size(0) * N # At least (B * N,)
cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
)
max_seqlen_in_batch = seqlens_in_batch.max().item()
return {
"cu_seqlens_kv": cu_seqlens,
"max_seqlen_in_batch_kv": max_seqlen_in_batch,
"valid_token_indices_kv": valid_token_indices,
}
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
):
@@ -297,7 +282,7 @@ class MochiPipeline(
used in all text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
num_videos_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
@@ -308,30 +293,29 @@ class MochiPipeline(
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device ori self._execution_device
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt_2,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
# TODO: Add negative prompts back
return prompt_embeds
def check_inputs(
self,
prompt,
prompt_2,
height,
width,
prompt_embeds=None,
pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
@@ -350,25 +334,12 @@ class MochiPipeline(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt_2 is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@@ -407,11 +378,16 @@ class MochiPipeline(
num_channels_latents,
height,
width,
num_frames,
dtype,
device,
generator,
latents=None,
):
height = height // self.vae_spatial_scale_factor
width = width // self.vae_spatial_scale_factor
num_frames = (num_frames - 1) // (self.vae_temporal_scale_factor + 1)
shape = (batch_size, num_channels_latents, num_frames, height, width)
if latents is not None:
@@ -429,6 +405,10 @@ class MochiPipeline(
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@@ -470,13 +450,12 @@ class MochiPipeline(
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer` and `text_encoder`. If not defined, `prompt` is
will be used instead
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_frames (`int`, defaults to 16):
The number of video frames to generate
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
@@ -490,7 +469,7 @@ class MochiPipeline(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
@@ -509,7 +488,7 @@ class MochiPipeline(
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.MochiPipelineOutput`] instead of a plain tuple.
Whether or not to return a [`~pipelines.mochi.MochiPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
@@ -528,13 +507,12 @@ class MochiPipeline(
Examples:
Returns:
[`~pipelines.flux.MochiPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
[`~pipelines.mochi.MochiPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
height = height or self.default_height
width = width or self.default_width
# 1. Check inputs. Raise error if not correct
self.check_inputs(
@@ -542,7 +520,6 @@ class MochiPipeline(
height,
width,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
@@ -564,22 +541,18 @@ class MochiPipeline(
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = self.encode_prompt(
(prompt_embeds) = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents(
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
@@ -591,8 +564,12 @@ class MochiPipeline(
latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
# 5. Prepare timestep
# from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
threshold_noise = 0.025
sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
@@ -624,18 +601,14 @@ class MochiPipeline(
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]