From 44987ad98cd92a2d91a8cb8dba8d7503c57711d7 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 24 Oct 2024 16:31:10 +0200 Subject: [PATCH] update --- .../pipelines/mochi/pipeline_mochi.py | 175 ++++++++---------- 1 file changed, 74 insertions(+), 101 deletions(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index dcfed214c5..3d140b8864 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -15,22 +15,18 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union -import numpy as np import torch from transformers import T5EncoderModel, T5TokenizerFast from ...image_processor import VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...loaders import TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL -from ...models.transformers import MochiTransformer3D +from ...models.transformers import MochiTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( - USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, - scale_lora_layers, - unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -53,13 +49,13 @@ EXAMPLE_DOC_STRING = """ >>> import torch >>> from diffusers import MochiPipeline - >>> pipe = MochiPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe = MochiPipeline.from_pretrained("black-forest-labs/mochi.1-schnell", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> prompt = "A cat holding a sign that says hello world" >>> # Depending on the variant being used, the pipeline call will slightly vary. >>> # Refer to the pipeline documentation for more details. >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] - >>> image.save("flux.png") + >>> image.save("mochi.png") ``` """ @@ -77,6 +73,24 @@ def calculate_shift( return mu +# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 +def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): + if linear_steps is None: + linear_steps = num_steps // 2 + linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] + threshold_noise_step_diff = linear_steps - threshold_noise * num_steps + quadratic_steps = num_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_sigma_schedule = [ + quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0] + sigma_schedule = [1.0 - x for x in sigma_schedule] + return sigma_schedule + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -137,17 +151,14 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class MochiPipeline( - DiffusionPipeline, - TextualInversionLoaderMixin -): +class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" - The Flux pipeline for text-to-image generation. + The mochi pipeline for text-to-image generation. Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ Args: - transformer ([`FluxTransformer2DModel`]): + transformer ([`mochiTransformer2DModel`]): Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. @@ -177,7 +188,7 @@ class MochiPipeline( vae: AutoencoderKL, text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast, - transformer: MochiTransformer3D, + transformer: MochiTransformer3DModel, ): super().__init__() @@ -188,22 +199,22 @@ class MochiPipeline( transformer=transformer, scheduler=scheduler, ) - #TODO: determine these scaling factors from model parameters + # TODO: determine these scaling factors from model parameters self.vae_spatial_scale_factor = 8 self.vae_temporal_scale_factor = 6 self.patch_size = 2 - + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) - self.default_height = 64 - self.default_width = 64 + self.default_height = 480 + self.default_width = 848 def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, - num_images_per_prompt: int = 1, + num_videos_per_prompt: int = 1, max_sequence_length: int = 512, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, @@ -227,10 +238,8 @@ class MochiPipeline( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - prompt_attention_mask = text_inputs.attention_mask - prompt_attention_mask = prompt_attention_mask.to(device) + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) @@ -239,7 +248,9 @@ class MochiPipeline( f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=False).last_hidden_state + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device), output_hidden_states=False + ).last_hidden_state dtype = self.text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -247,43 +258,17 @@ class MochiPipeline( _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) return prompt_embeds, prompt_attention_mask - def _pack_indices(self, attention_mask, latent_frames_dim, latent_height_dim, latent_width_dim): - N = latent_frames_dim * latent_height_dim * latent_width_dim // (self.patch_size**2) - - # Create an expanded token mask saying which tokens are valid across both visual and text tokens. - assert N > 0 and len(attention_mask) == 1 - attention_mask = attention_mask[0] - - mask = F.pad(attention_mask, (N, 0), value=True) # (B, N + L) - seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,) - valid_token_indices = torch.nonzero( - mask.flatten(), as_tuple=False - ).flatten() # up to (B * (N + L),) - - assert valid_token_indices.size(0) >= attention_mask.size(0) * N # At least (B * N,) - cu_seqlens = F.pad( - torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) - ) - max_seqlen_in_batch = seqlens_in_batch.max().item() - - return { - "cu_seqlens_kv": cu_seqlens, - "max_seqlen_in_batch_kv": max_seqlen_in_batch, - "valid_token_indices_kv": valid_token_indices, - } - def encode_prompt( self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, + num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 512, lora_scale: Optional[float] = None, ): @@ -297,7 +282,7 @@ class MochiPipeline( used in all text-encoders device: (`torch.device`): torch device - num_images_per_prompt (`int`): + num_videos_per_prompt (`int`): number of images that should be generated per prompt prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not @@ -308,30 +293,29 @@ class MochiPipeline( lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ - device = device ori self._execution_device + device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt if prompt_embeds is None: - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, device=device, ) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + # TODO: Add negative prompts back return prompt_embeds def check_inputs( self, prompt, - prompt_2, height, width, prompt_embeds=None, - pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): @@ -350,25 +334,12 @@ class MochiPipeline( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) - elif prompt_2 is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): - raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") - - if prompt_embeds is not None and pooled_prompt_embeds is None: - raise ValueError( - "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." - ) - if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") @@ -407,11 +378,16 @@ class MochiPipeline( num_channels_latents, height, width, + num_frames, dtype, device, generator, latents=None, ): + height = height // self.vae_spatial_scale_factor + width = width // self.vae_spatial_scale_factor + num_frames = (num_frames - 1) // (self.vae_temporal_scale_factor + 1) + shape = (batch_size, num_channels_latents, num_frames, height, width) if latents is not None: @@ -429,6 +405,10 @@ class MochiPipeline( def guidance_scale(self): return self._guidance_scale + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + @property def joint_attention_kwargs(self): return self._joint_attention_kwargs @@ -470,13 +450,12 @@ class MochiPipeline( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to `tokenizer` and `text_encoder`. If not defined, `prompt` is - will be used instead height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to 16): + The number of video frames to generate num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -490,7 +469,7 @@ class MochiPipeline( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): + num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) @@ -509,7 +488,7 @@ class MochiPipeline( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.flux.MochiPipelineOutput`] instead of a plain tuple. + Whether or not to return a [`~pipelines.mochi.MochiPipelineOutput`] instead of a plain tuple. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in @@ -528,13 +507,12 @@ class MochiPipeline( Examples: Returns: - [`~pipelines.flux.MochiPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + [`~pipelines.mochi.MochiPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ - - height = height or self.default_sample_size * self.vae_scale_factor - width = width or self.default_sample_size * self.vae_scale_factor + height = height or self.default_height + width = width or self.default_width # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -542,7 +520,6 @@ class MochiPipeline( height, width, prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) @@ -564,22 +541,18 @@ class MochiPipeline( lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) - ( - prompt_embeds, - pooled_prompt_embeds, - text_ids, - ) = self.encode_prompt( + (prompt_embeds) = self.encode_prompt( prompt=prompt, prompt_embeds=prompt_embeds, device=device, - num_images_per_prompt=num_images_per_prompt, + num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) # 4. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels // 4 - latents, latent_image_ids = self.prepare_latents( + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, @@ -591,8 +564,12 @@ class MochiPipeline( latents, ) - # 5. Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + # 5. Prepare timestep + + # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 + threshold_noise = 0.025 + sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise) + image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, @@ -624,18 +601,14 @@ class MochiPipeline( for i, t in enumerate(timesteps): if self.interrupt: continue - + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, + hidden_states=latent_model_input, + timestep=timestep, encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0]