mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
invert sigmas in scheduler; fix pipeline
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -20,7 +20,6 @@ import torch
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import MochiTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
@@ -152,22 +151,19 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
class MochiPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
The mochi pipeline for text-to-image generation.
|
||||
The mochi pipeline for text-to-video generation.
|
||||
|
||||
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
||||
Reference: https://github.com/genmoai/models
|
||||
|
||||
Args:
|
||||
transformer ([`mochiTransformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
transformer ([`MochiTransformer3DModel`]):
|
||||
Conditional Transformer architecture to denoise the encoded video latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
@@ -181,7 +177,7 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -217,7 +213,7 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 226,
|
||||
max_sequence_length: int = 256,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
@@ -272,7 +268,7 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 226,
|
||||
max_sequence_length: int = 256,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
@@ -483,7 +479,7 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
num_frames: int = 16,
|
||||
num_inference_steps: int = 28,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 3.5,
|
||||
guidance_scale: float = 4.5,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
@@ -517,7 +513,7 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.0):
|
||||
guidance_scale (`float`, defaults to `4.5`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
@@ -655,14 +651,16 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
print(t)
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
timestep=1000 - timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -71,6 +71,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
max_shift: Optional[float] = 1.15,
|
||||
base_image_seq_len: Optional[int] = 256,
|
||||
max_image_seq_len: Optional[int] = 4096,
|
||||
invert_sigmas: bool = False,
|
||||
):
|
||||
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
||||
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
||||
@@ -294,6 +295,10 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
|
||||
if self.config.invert_sigmas:
|
||||
print("inverting")
|
||||
sigma, sigma_next = sigma_next, sigma
|
||||
|
||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
|
||||
# Cast sample back to model compatible dtype
|
||||
|
||||
Reference in New Issue
Block a user