1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Dhruv Nair
2024-10-24 22:44:54 +02:00
parent ebcbad2f38
commit 8700d64d62

View File

@@ -104,6 +104,7 @@ def retrieve_timesteps(
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
@@ -272,8 +273,7 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
do_classifier_free_guidance=True,
lora_scale: Optional[float] = None,
do_classifier_free_guidance: bool = True,
):
r"""
@@ -305,14 +305,12 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
device=device,
)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
prompt_embeds = prompt_embeds.to(self.text_encoder.dtype)
# TODO: Add negative prompts back
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
)
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
@@ -332,6 +330,7 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
max_sequence_length=max_sequence_length,
device=device,
)
negative_prompt_embeds = negative_prompt_embeds.to(self.text_encoder.dtype)
return prompt_embeds, negative_prompt_embeds
@@ -532,9 +531,9 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
Examples:
Returns:
[`~pipelines.mochi.MochiPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
[`~pipelines.mochi.MochiPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
generated images.
"""
height = height or self.default_height
width = width or self.default_width
@@ -595,21 +594,12 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
threshold_noise = 0.025
sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
@@ -628,12 +618,16 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
timestep=timestep / 1000,
encoder_hidden_states=prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
@@ -660,18 +654,16 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
xm.mark_step()
if output_type == "latent":
image = latents
video = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess(video, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return (video,)
return MochiPipelineOutput(images=image)
return MochiPipelineOutput(frames=video)