diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index c8b8c0af98..e2b3196a00 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -104,6 +104,7 @@ def retrieve_timesteps( Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. @@ -272,8 +273,7 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 512, - do_classifier_free_guidance=True, - lora_scale: Optional[float] = None, + do_classifier_free_guidance: bool = True, ): r""" @@ -305,14 +305,12 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): device=device, ) - dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + prompt_embeds = prompt_embeds.to(self.text_encoder.dtype) - # TODO: Add negative prompts back if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" # normalize str to list negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - ) if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( @@ -332,6 +330,7 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): max_sequence_length=max_sequence_length, device=device, ) + negative_prompt_embeds = negative_prompt_embeds.to(self.text_encoder.dtype) return prompt_embeds, negative_prompt_embeds @@ -532,9 +531,9 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): Examples: Returns: - [`~pipelines.mochi.MochiPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` - is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated - images. + [`~pipelines.mochi.MochiPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. """ height = height or self.default_height width = width or self.default_width @@ -595,21 +594,12 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): threshold_noise = 0.025 sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise) - image_seq_len = latents.shape[1] - mu = calculate_shift( - image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, - ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas, - mu=mu, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) @@ -628,12 +618,16 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): noise_pred = self.transformer( hidden_states=latent_model_input, - timestep=timestep, + timestep=timestep / 1000, encoder_hidden_states=prompt_embeds, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] @@ -660,18 +654,16 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin): xm.mark_step() if output_type == "latent": - image = latents + video = latents else: - latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) - latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor - image = self.vae.decode(latents, return_dict=False)[0] - image = self.image_processor.postprocess(image, output_type=output_type) + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess(video, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: - return (image,) + return (video,) - return MochiPipelineOutput(images=image) + return MochiPipelineOutput(frames=video)