mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add video img2img (#3900)
* Add image to image video * Improve * better naming * make fix copies * add docs * finish tests * trigger tests * make style * correct * finish * Fix more * make style * finish
This commit is contained in:
committed by
GitHub
parent
5439e917ca
commit
62825064bf
@@ -37,9 +37,12 @@ Resources:
|
||||
| Pipeline | Tasks | Demo
|
||||
|---|---|:---:|
|
||||
| [TextToVideoSDPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py) | *Text-to-Video Generation* | [🤗 Spaces](https://huggingface.co/spaces/damo-vilab/modelscope-text-to-video-synthesis)
|
||||
| [VideoToVideoSDPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py) | *Text-Guided Video-to-Video Generation* | [(TODO)🤗 Spaces]()
|
||||
|
||||
## Usage example
|
||||
|
||||
### `text-to-video-ms-1.7b`
|
||||
|
||||
Let's start by generating a short video with the default length of 16 frames (2s at 8 fps):
|
||||
|
||||
```python
|
||||
@@ -119,12 +122,72 @@ Here are some sample outputs:
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### `cerspense/zeroscope_v2_576w` & `cerspense/zeroscope_v2_XL`
|
||||
|
||||
Zeroscope are watermark-free model and have been trained on specific sizes such as `576x320` and `1024x576`.
|
||||
One should first generate a video using the lower resolution checkpoint [`cerspense/zeroscope_v2_576w`](https://huggingface.co/cerspense/zeroscope_v2_576w) with [`TextToVideoSDPipeline`],
|
||||
which can then be upscaled using [`VideoToVideoSDPipeline`] and [`cerspense/zeroscope_v2_XL`](https://huggingface.co/cerspense/zeroscope_v2_XL).
|
||||
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# memory optimization
|
||||
pipe.enable_vae_slicing()
|
||||
|
||||
prompt = "Darth Vader surfing a wave"
|
||||
video_frames = pipe(prompt, num_frames=24).frames
|
||||
video_path = export_to_video(video_frames)
|
||||
video_path
|
||||
```
|
||||
|
||||
Now the video can be upscaled:
|
||||
|
||||
```py
|
||||
pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_XL", torch_dtype=torch.float16)
|
||||
pipe.vae.enable_slicing()
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
video = [Image.fromarray(frame).resize((1024, 576)) for frame in video_frames]
|
||||
|
||||
video_frames = pipe(prompt, video=video, strength=0.6).frames
|
||||
video_path = export_to_video(video_frames)
|
||||
video_path
|
||||
```
|
||||
|
||||
Here are some sample outputs:
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td ><center>
|
||||
Darth vader surfing in waves.
|
||||
<br>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/darthvader_cerpense.gif"
|
||||
alt="Darth vader surfing in waves."
|
||||
style="width: 576px;" />
|
||||
</center></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Available checkpoints
|
||||
|
||||
* [damo-vilab/text-to-video-ms-1.7b](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b/)
|
||||
* [damo-vilab/text-to-video-ms-1.7b-legacy](https://huggingface.co/damo-vilab/text-to-video-ms-1.7b-legacy)
|
||||
* [cerspense/zeroscope_v2_576w](https://huggingface.co/cerspense/zeroscope_v2_576w)
|
||||
* [cerspense/zeroscope_v2_XL](https://huggingface.co/cerspense/zeroscope_v2_XL)
|
||||
|
||||
## TextToVideoSDPipeline
|
||||
[[autodoc]] TextToVideoSDPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## VideoToVideoSDPipeline
|
||||
[[autodoc]] VideoToVideoSDPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -173,6 +173,7 @@ else:
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
VersatileDiffusionPipeline,
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
VideoToVideoSDPipeline,
|
||||
VQDiffusionPipeline,
|
||||
)
|
||||
|
||||
|
||||
@@ -229,7 +229,12 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
||||
return self.tiled_encode(x, return_dict=return_dict)
|
||||
|
||||
h = self.encoder(x)
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self.encoder(x)
|
||||
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ else:
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline
|
||||
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline, VideoToVideoSDPipeline
|
||||
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
|
||||
from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder
|
||||
from .versatile_diffusion import (
|
||||
|
||||
@@ -28,5 +28,6 @@ try:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_text_to_video_synth import TextToVideoSDPipeline # noqa: F401
|
||||
from .pipeline_text_to_video_synth import TextToVideoSDPipeline
|
||||
from .pipeline_text_to_video_synth_img2img import VideoToVideoSDPipeline # noqa: F401
|
||||
from .pipeline_text_to_video_zero import TextToVideoZeroPipeline
|
||||
|
||||
@@ -672,6 +672,9 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
if output_type == "latent":
|
||||
return TextToVideoSDPipelineOutput(frames=latents)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
|
||||
@@ -0,0 +1,770 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet3DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import TextToVideoSDPipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16)
|
||||
>>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "spiderman running in the desert"
|
||||
>>> video_frames = pipe(prompt, num_inference_steps=40, height=320, width=576, num_frames=24).frames
|
||||
>>> # safe low-res video
|
||||
>>> video_path = export_to_video(video_frames, output_video_path="./video_576_spiderman.mp4")
|
||||
|
||||
>>> # let's offload the text-to-image model
|
||||
>>> pipe.to("cpu")
|
||||
|
||||
>>> # and load the image-to-image model
|
||||
>>> pipe = DiffusionPipeline.from_pretrained(
|
||||
... "cerspense/zeroscope_v2_XL", torch_dtype=torch.float16, revision="refs/pr/15"
|
||||
... )
|
||||
>>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> # The VAE consumes A LOT of memory, let's make sure we run it in sliced mode
|
||||
>>> pipe.vae.enable_slicing()
|
||||
|
||||
>>> # now let's upscale it
|
||||
>>> video = [Image.fromarray(frame).resize((1024, 576)) for frame in video_frames]
|
||||
|
||||
>>> # and denoise it
|
||||
>>> video_frames = pipe(prompt, video=video, strength=0.6).frames
|
||||
>>> video_path = export_to_video(video_frames, output_video_path="./video_1024_spiderman.mp4")
|
||||
>>> video_path
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
|
||||
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
# reshape to ncfhw
|
||||
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
# unnormalize back to [0,1]
|
||||
video = video.mul_(std).add_(mean)
|
||||
video.clamp_(0, 1)
|
||||
# prepare the final outputs
|
||||
i, c, f, h, w = video.shape
|
||||
images = video.permute(2, 3, 0, 4, 1).reshape(
|
||||
f, h, i * w, c
|
||||
) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
|
||||
images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
|
||||
images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
|
||||
return images
|
||||
|
||||
|
||||
def preprocess_video(video):
|
||||
supported_formats = (np.ndarray, torch.Tensor, PIL.Image.Image)
|
||||
|
||||
if isinstance(video, supported_formats):
|
||||
video = [video]
|
||||
elif not (isinstance(video, list) and all(isinstance(i, supported_formats) for i in video)):
|
||||
raise ValueError(
|
||||
f"Input is in incorrect format: {[type(i) for i in video]}. Currently, we only support {', '.join(supported_formats)}"
|
||||
)
|
||||
|
||||
if isinstance(video[0], PIL.Image.Image):
|
||||
video = [np.array(frame) for frame in video]
|
||||
|
||||
if isinstance(video[0], np.ndarray):
|
||||
video = np.concatenate(video, axis=0) if video[0].ndim == 5 else np.stack(video, axis=0)
|
||||
|
||||
if video.dtype == np.uint8:
|
||||
video = np.array(video).astype(np.float32) / 255.0
|
||||
|
||||
if video.ndim == 4:
|
||||
video = video[None, ...]
|
||||
|
||||
video = torch.from_numpy(video.transpose(0, 4, 1, 2, 3))
|
||||
|
||||
elif isinstance(video[0], torch.Tensor):
|
||||
video = torch.cat(video, axis=0) if video[0].ndim == 5 else torch.stack(video, axis=0)
|
||||
|
||||
# don't need any preprocess if the video is latents
|
||||
channel = video.shape[1]
|
||||
if channel == 4:
|
||||
return video
|
||||
|
||||
# move channels before num_frames
|
||||
video = video.permute(0, 2, 1, 3, 4)
|
||||
|
||||
# normalize video
|
||||
video = 2.0 * video - 1.0
|
||||
|
||||
return video
|
||||
|
||||
|
||||
class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Same as Stable Diffusion 2.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded video latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet3DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
||||
steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
|
||||
several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
text_encoder, vae have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded
|
||||
to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a
|
||||
submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in [self.text_encoder, self.vae, self.unet]:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
@property
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
|
||||
def _execution_device(self):
|
||||
r"""
|
||||
Returns the device on which the pipeline's models will be executed. After calling
|
||||
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
||||
hooks.
|
||||
"""
|
||||
if not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
"""
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
batch_size, channels, num_frames, height, width = latents.shape
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
|
||||
|
||||
image = self.vae.decode(latents).sample
|
||||
video = (
|
||||
image[None, :]
|
||||
.reshape(
|
||||
(
|
||||
batch_size,
|
||||
num_frames,
|
||||
-1,
|
||||
)
|
||||
+ image.shape[2:]
|
||||
)
|
||||
.permute(0, 2, 1, 3, 4)
|
||||
)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
video = video.float()
|
||||
return video
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_latents(self, video, timestep, batch_size, dtype, device, generator=None):
|
||||
video = video.to(device=device, dtype=dtype)
|
||||
|
||||
# change from (b, c, f, h, w) -> (b * f, c, w, h)
|
||||
bsz, channel, frames, width, height = video.shape
|
||||
video = video.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
|
||||
|
||||
if video.shape[1] == 4:
|
||||
init_latents = video
|
||||
else:
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
elif isinstance(generator, list):
|
||||
init_latents = [
|
||||
self.vae.encode(video[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = self.vae.encode(video).latent_dist.sample(generator)
|
||||
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `video` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
shape = init_latents.shape
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
# get latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
latents = init_latents
|
||||
|
||||
latents = latents[None, :].reshape((bsz, frames, latents.shape[1]) + latents.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
video: Union[List[np.ndarray], torch.FloatTensor] = None,
|
||||
strength: float = 0.6,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 15.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
video: (`List[np.ndarray]` or `torch.FloatTensor`):
|
||||
`video` frames or tensor representing a video batch, that will be used as the starting point for the
|
||||
process. Can also accpet video latents as `image`, if passing latents directly, it will not be encoded
|
||||
again.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
|
||||
will be used as a starting point, adding more noise to it the larger the `strength`. The number of
|
||||
denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
|
||||
be maximum and the denoising process will run for the full number of iterations specified in
|
||||
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower video quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the video generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape
|
||||
`(batch_size, num_channel, num_frames, height, width)`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"np"`):
|
||||
The output format of the generate video. Choose between `torch.FloatTensor` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated frames.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
num_images_per_prompt = 1
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
|
||||
# 4. Preprocess video
|
||||
video = preprocess_video(video)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
latents = self.prepare_latents(video, latent_timestep, batch_size, prompt_embeds.dtype, device, generator)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# reshape latents
|
||||
bsz, channel, frames, width, height = latents.shape
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
|
||||
noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# reshape latents back
|
||||
latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
if output_type == "latent":
|
||||
return TextToVideoSDPipelineOutput(frames=latents)
|
||||
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.unet.to("cpu")
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
else:
|
||||
video = tensor2vid(video_tensor)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return TextToVideoSDPipelineOutput(frames=video)
|
||||
@@ -782,6 +782,21 @@ class VersatileDiffusionTextToImagePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class VideoToVideoSDPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class VQDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -652,11 +652,11 @@ class PipelineTesterMixin:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_without_offload = pipe(**inputs)[0]
|
||||
output_without_offload = pipe(**inputs)[0].cpu()
|
||||
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_with_offload = pipe(**inputs)[0]
|
||||
output_with_offload = pipe(**inputs)[0].cpu()
|
||||
|
||||
if test_max_difference:
|
||||
max_diff = np.abs(output_with_offload - output_without_offload).max()
|
||||
|
||||
195
tests/pipelines/text_to_video/test_video_to_video.py
Normal file
195
tests/pipelines/text_to_video/test_video_to_video.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
UNet3DConditionModel,
|
||||
VideoToVideoSDPipeline,
|
||||
)
|
||||
from diffusers.utils import floats_tensor, is_xformers_available, skip_mps
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, slow, torch_device
|
||||
|
||||
from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@skip_mps
|
||||
class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = VideoToVideoSDPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS.union({"video"}) - {"image", "width", "height"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"video"}) - {"image"}
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
||||
test_attention_slicing = False
|
||||
|
||||
# No `output_type`.
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback",
|
||||
"callback_steps",
|
||||
]
|
||||
)
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet3DConditionModel(
|
||||
block_out_channels=(32, 64, 64, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D"),
|
||||
up_block_types=("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
|
||||
cross_attention_dim=32,
|
||||
attention_head_dim=4,
|
||||
)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
sample_size=128,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
hidden_act="gelu",
|
||||
projection_dim=512,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
# 3 frames
|
||||
video = floats_tensor((1, 3, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"video": video,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_text_to_video_default_case(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = VideoToVideoSDPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["output_type"] = "np"
|
||||
frames = sd_pipe(**inputs).frames
|
||||
image_slice = frames[0][-3:, -3:, -1]
|
||||
|
||||
assert frames[0].shape == (32, 32, 3)
|
||||
expected_slice = np.array([106, 117, 113, 174, 137, 112, 148, 151, 131])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False, expected_max_diff=5e-3)
|
||||
|
||||
# (todo): sayakpaul
|
||||
@unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
|
||||
def test_inference_batch_consistent(self):
|
||||
pass
|
||||
|
||||
# (todo): sayakpaul
|
||||
@unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
|
||||
def test_inference_batch_single_identical(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="`num_images_per_prompt` argument is not supported for this pipeline.")
|
||||
def test_num_images_per_prompt(self):
|
||||
pass
|
||||
|
||||
def test_progress_bar(self):
|
||||
return super().test_progress_bar()
|
||||
|
||||
|
||||
@slow
|
||||
@skip_mps
|
||||
class VideoToVideoSDPipelineSlowTests(unittest.TestCase):
|
||||
def test_two_step_model(self):
|
||||
pipe = VideoToVideoSDPipeline.from_pretrained("cerspense/zeroscope_v2_XL", torch_dtype=torch.float16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
# 10 frames
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
video = torch.randn((1, 10, 3, 1024, 576), generator=generator)
|
||||
video = video.to("cuda")
|
||||
|
||||
prompt = "Spiderman is surfing"
|
||||
|
||||
video_frames = pipe(prompt, video=video, generator=generator, num_inference_steps=3, output_type="pt").frames
|
||||
|
||||
expected_array = np.array([-1.0458984, -1.1279297, -0.9663086, -0.91503906, -0.75097656])
|
||||
assert np.abs(video_frames.cpu().numpy()[0, 0, 0, 0, -5:] - expected_array).sum() < 1e-2
|
||||
Reference in New Issue
Block a user