mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
make style
This commit is contained in:
1077
src/diffusers/models/transformers/transformer_mochi_original.py
Normal file
1077
src/diffusers/models/transformers/transformer_mochi_original.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -255,7 +255,7 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
|
||||
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
|
||||
|
||||
@@ -377,7 +377,7 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
@@ -397,7 +397,7 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
@@ -563,13 +563,13 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
|
||||
Returns:
|
||||
[`~pipelines.mochi.MochiPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.mochi.MochiPipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is a list with the generated images.
|
||||
If `return_dict` is `True`, [`~pipelines.mochi.MochiPipelineOutput`] is returned, otherwise a `tuple`
|
||||
is returned where the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
|
||||
height = height or self.default_height
|
||||
width = width or self.default_width
|
||||
|
||||
@@ -599,7 +599,12 @@ class MochiPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Prepare text embeddings
|
||||
(prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask) = self.encode_prompt(
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
|
||||
Reference in New Issue
Block a user