1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Dhruv Nair
2024-10-24 18:07:43 +02:00
parent 44987ad98c
commit ebcbad2f38

View File

@@ -261,15 +261,18 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds, prompt_attention_mask
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
do_classifier_free_guidance=True,
lora_scale: Optional[float] = None,
):
r"""
@@ -277,9 +280,6 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer` and `text_encoder`. If not defined, `prompt` is
used in all text-encoders
device: (`torch.device`):
torch device
num_videos_per_prompt (`int`):
@@ -287,14 +287,15 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
@@ -307,8 +308,32 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
# TODO: Add negative prompts back
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
)
return prompt_embeds
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
return prompt_embeds, negative_prompt_embeds
def check_inputs(
self,
@@ -541,7 +566,7 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
(prompt_embeds) = self.encode_prompt(
(prompt_embeds, negative_prompt_embeds) = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
device=device,
@@ -589,12 +614,8 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar: