mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Core] introduce videoprocessor. (#7776)
* introduce videoprocessor. * fix quality * address yiyi's feedback * fix preprocess_video call. * video_processor -> image_processor * fix * fix more. * quality * image_processor -> video_processor * support List[List[PIL.Image.Image]] * change to video_processor. * documentation * Apply suggestions from code review * changes * remove print. * refactor video processor (part # 7776) (#7861) * update * update remove deprecate * Update src/diffusers/video_processor.py * update * Apply suggestions from code review * deprecate list of 5d for video and list of 4d for image + apply other feedbacks * up --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * add doc. * tensor2vid -> postprocess_video. * refactor preprocess with preprocess_video * set default values. * empty commit * more refactoring of prepare_latents in animatediff vid2vid * checking documentation * remove documentation for now. * fix animatediff sdxl * fix test failure [part of video processor PR] (#7905) up * remove preceed_with_frames. * doc * fix * fix * remove video input as a single-frame video. --------- Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -439,6 +439,8 @@
|
||||
title: Utilities
|
||||
- local: api/image_processor
|
||||
title: VAE Image Processor
|
||||
- local: api/video_processor
|
||||
title: Video Processor
|
||||
title: Internal classes
|
||||
isExpanded: false
|
||||
title: API
|
||||
|
||||
15
docs/source/en/api/video_processor.md
Normal file
15
docs/source/en/api/video_processor.md
Normal file
@@ -0,0 +1,15 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Video Processor
|
||||
|
||||
The `VideoProcessor` provides a unified API for video pipelines to prepare inputs for VAE encoding and post-processing outputs once they're decoded. The class inherits [`VaeImageProcessor`] so it includes transformations such as resizing, normalization, and conversion between PIL Image, PyTorch, and NumPy arrays.
|
||||
@@ -29,15 +29,34 @@ from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
|
||||
PipelineImageInput = Union[
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
torch.FloatTensor,
|
||||
torch.Tensor,
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
List[torch.FloatTensor],
|
||||
List[torch.Tensor],
|
||||
]
|
||||
|
||||
PipelineDepthInput = PipelineImageInput
|
||||
|
||||
|
||||
def is_valid_image(image):
|
||||
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
|
||||
|
||||
|
||||
def is_valid_image_imagelist(images):
|
||||
# check if the image input is one of the supported formats for image and image list:
|
||||
# it can be either one of below 3
|
||||
# (1) a 4d pytorch tensor or numpy array,
|
||||
# (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
|
||||
# (3) a list of valid image
|
||||
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
|
||||
return True
|
||||
elif is_valid_image(images):
|
||||
return True
|
||||
elif isinstance(images, list):
|
||||
return all(is_valid_image(image) for image in images)
|
||||
return False
|
||||
|
||||
|
||||
class VaeImageProcessor(ConfigMixin):
|
||||
"""
|
||||
Image processor for VAE.
|
||||
@@ -110,7 +129,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
return images
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
|
||||
def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
|
||||
"""
|
||||
Convert a NumPy image to a PyTorch tensor.
|
||||
"""
|
||||
@@ -121,7 +140,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
return images
|
||||
|
||||
@staticmethod
|
||||
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
|
||||
def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
|
||||
"""
|
||||
Convert a PyTorch tensor to a NumPy image.
|
||||
"""
|
||||
@@ -497,12 +516,27 @@ class VaeImageProcessor(ConfigMixin):
|
||||
else:
|
||||
image = np.expand_dims(image, axis=-1)
|
||||
|
||||
if isinstance(image, supported_formats):
|
||||
image = [image]
|
||||
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
|
||||
raise ValueError(
|
||||
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
|
||||
if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
|
||||
warnings.warn(
|
||||
"Passing `image` as a list of 4d np.ndarray is deprecated."
|
||||
"Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
|
||||
FutureWarning,
|
||||
)
|
||||
image = np.concatenate(image, axis=0)
|
||||
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
|
||||
warnings.warn(
|
||||
"Passing `image` as a list of 4d torch.Tensor is deprecated."
|
||||
"Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
|
||||
FutureWarning,
|
||||
)
|
||||
image = torch.cat(image, axis=0)
|
||||
|
||||
if not is_valid_image_imagelist(image):
|
||||
raise ValueError(
|
||||
f"Input is in incorrect format. Currently, we only support {', '.join(supported_formats)}"
|
||||
)
|
||||
if not isinstance(image, list):
|
||||
image = [image]
|
||||
|
||||
if isinstance(image[0], PIL.Image.Image):
|
||||
if crops_coords is not None:
|
||||
@@ -561,15 +595,15 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
image: torch.FloatTensor,
|
||||
image: torch.Tensor,
|
||||
output_type: str = "pil",
|
||||
do_denormalize: Optional[List[bool]] = None,
|
||||
) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
|
||||
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
Postprocess the image output from tensor to `output_type`.
|
||||
|
||||
Args:
|
||||
image (`torch.FloatTensor`):
|
||||
image (`torch.Tensor`):
|
||||
The image input, should be a pytorch tensor with shape `B x C x H x W`.
|
||||
output_type (`str`, *optional*, defaults to `pil`):
|
||||
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
||||
@@ -578,7 +612,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
`VaeImageProcessor` config.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
|
||||
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
||||
The postprocessed image.
|
||||
"""
|
||||
if not isinstance(image, torch.Tensor):
|
||||
@@ -738,15 +772,15 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
|
||||
def postprocess(
|
||||
self,
|
||||
image: torch.FloatTensor,
|
||||
image: torch.Tensor,
|
||||
output_type: str = "pil",
|
||||
do_denormalize: Optional[List[bool]] = None,
|
||||
) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
|
||||
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
||||
"""
|
||||
Postprocess the image output from tensor to `output_type`.
|
||||
|
||||
Args:
|
||||
image (`torch.FloatTensor`):
|
||||
image (`torch.Tensor`):
|
||||
The image input, should be a pytorch tensor with shape `B x C x H x W`.
|
||||
output_type (`str`, *optional*, defaults to `pil`):
|
||||
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
||||
@@ -755,7 +789,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
`VaeImageProcessor` config.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
|
||||
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
||||
The postprocessed image.
|
||||
"""
|
||||
if not isinstance(image, torch.Tensor):
|
||||
@@ -793,8 +827,8 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
rgb: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
||||
depth: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
|
||||
rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
|
||||
depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
target_res: Optional[int] = None,
|
||||
@@ -933,13 +967,13 @@ class IPAdapterMaskProcessor(VaeImageProcessor):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int):
|
||||
def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int):
|
||||
"""
|
||||
Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the
|
||||
aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
|
||||
|
||||
Args:
|
||||
mask (`torch.FloatTensor`):
|
||||
mask (`torch.Tensor`):
|
||||
The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
|
||||
batch_size (`int`):
|
||||
The batch size.
|
||||
@@ -949,7 +983,7 @@ class IPAdapterMaskProcessor(VaeImageProcessor):
|
||||
The dimensionality of the value embeddings.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`:
|
||||
`torch.Tensor`:
|
||||
The downsampled mask tensor.
|
||||
|
||||
"""
|
||||
|
||||
@@ -15,11 +15,10 @@
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
@@ -41,6 +40,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
@@ -65,27 +65,6 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class AnimateDiffPipeline(
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
@@ -159,7 +138,7 @@ class AnimateDiffPipeline(
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
def encode_prompt(
|
||||
@@ -836,7 +815,7 @@ class AnimateDiffPipeline(
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
@@ -25,7 +24,7 @@ from transformers import (
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import (
|
||||
FromSingleFileMixin,
|
||||
IPAdapterMixin,
|
||||
@@ -57,6 +56,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
@@ -113,28 +113,6 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
@@ -320,7 +298,7 @@ class AnimateDiffSDXLPipeline(
|
||||
)
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
self.default_sample_size = self.unet.config.sample_size
|
||||
|
||||
@@ -1291,7 +1269,7 @@ class AnimateDiffSDXLPipeline(
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
|
||||
@@ -15,11 +15,10 @@
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
@@ -34,6 +33,7 @@ from ...schedulers import (
|
||||
)
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
@@ -95,28 +95,6 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
@@ -264,7 +242,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
def encode_prompt(
|
||||
@@ -650,16 +628,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
# video must be a list of list of images
|
||||
# the outer list denotes having multiple videos as input, whereas inner list means the frames of the video
|
||||
# as a list of images
|
||||
if video and not isinstance(video[0], list):
|
||||
video = [video]
|
||||
if latents is None:
|
||||
video = torch.cat(
|
||||
[self.image_processor.preprocess(vid, height=height, width=width).unsqueeze(0) for vid in video], dim=0
|
||||
)
|
||||
video = video.to(device=device, dtype=dtype)
|
||||
num_frames = video.shape[1]
|
||||
else:
|
||||
num_frames = latents.shape[2]
|
||||
@@ -943,6 +912,11 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
if latents is None:
|
||||
video = self.video_processor.preprocess_video(video, height=height, width=width)
|
||||
# Move the number of frames before the number of channels.
|
||||
video = video.permute(0, 2, 1, 3, 4)
|
||||
video = video.to(device=device, dtype=prompt_embeds.dtype)
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
video=video,
|
||||
@@ -1023,7 +997,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
@@ -31,6 +31,7 @@ from ...utils import (
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
|
||||
|
||||
@@ -70,28 +71,6 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class I2VGenXLPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
@@ -156,7 +135,7 @@ class I2VGenXLPipeline(
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
# `do_resize=False` as we do custom resizing.
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False)
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
@@ -342,8 +321,8 @@ class I2VGenXLPipeline(
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.image_processor.pil_to_numpy(image)
|
||||
image = self.image_processor.numpy_to_pt(image)
|
||||
image = self.video_processor.pil_to_numpy(image)
|
||||
image = self.video_processor.numpy_to_pt(image)
|
||||
|
||||
# Normalize the image with CLIP training stats.
|
||||
image = self.feature_extractor(
|
||||
@@ -657,7 +636,7 @@ class I2VGenXLPipeline(
|
||||
|
||||
# 3.2.2 Image latents.
|
||||
resized_image = _center_crop_wide(image, (width, height))
|
||||
image = self.image_processor.preprocess(resized_image).to(device=device, dtype=image_embeddings.dtype)
|
||||
image = self.video_processor.preprocess(resized_image).to(device=device, dtype=image_embeddings.dtype)
|
||||
image_latents = self.prepare_image_latents(
|
||||
image,
|
||||
device=device,
|
||||
@@ -737,7 +716,7 @@ class I2VGenXLPipeline(
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||
|
||||
# 9. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
@@ -21,7 +21,7 @@ import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
@@ -43,6 +43,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
|
||||
@@ -89,28 +90,6 @@ RANGE_LIST = [
|
||||
]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, motion_scale: int):
|
||||
assert num_frames > 0, "video_length should be greater than 0"
|
||||
|
||||
@@ -218,7 +197,7 @@ class PIAPipeline(
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
|
||||
def encode_prompt(
|
||||
@@ -621,7 +600,7 @@ class PIAPipeline(
|
||||
)
|
||||
_, _, _, scaled_height, scaled_width = shape
|
||||
|
||||
image = self.image_processor.preprocess(image)
|
||||
image = self.video_processor.preprocess(image)
|
||||
image = image.to(device, dtype)
|
||||
|
||||
if isinstance(generator, list):
|
||||
@@ -959,7 +938,7 @@ class PIAPipeline(
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
@@ -21,11 +21,12 @@ import PIL.Image
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
|
||||
from ...schedulers import EulerDiscreteScheduler
|
||||
from ...utils import BaseOutput, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
@@ -61,28 +62,6 @@ def _append_dims(x, target_dims):
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor: VaeImageProcessor, output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
@@ -199,7 +178,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
def _encode_image(
|
||||
self,
|
||||
@@ -211,8 +190,8 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
image = self.image_processor.pil_to_numpy(image)
|
||||
image = self.image_processor.numpy_to_pt(image)
|
||||
image = self.video_processor.pil_to_numpy(image)
|
||||
image = self.video_processor.numpy_to_pt(image)
|
||||
|
||||
# We normalize the image before resizing to match with the original implementation.
|
||||
# Then we unnormalize it after resizing.
|
||||
@@ -520,7 +499,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
fps = fps - 1
|
||||
|
||||
# 4. Encode input image using VAE
|
||||
image = self.image_processor.preprocess(image, height=height, width=width).to(device)
|
||||
image = self.video_processor.preprocess(image, height=height, width=width).to(device)
|
||||
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
|
||||
image = image + noise_aug_strength * noise
|
||||
|
||||
@@ -626,7 +605,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
frames = self.decode_latents(latents, num_frames, decode_chunk_size)
|
||||
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
|
||||
frames = self.video_processor.postprocess_video(video=frames, output_type=output_type)
|
||||
else:
|
||||
frames = latents
|
||||
|
||||
|
||||
@@ -15,11 +15,9 @@
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet3DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
@@ -33,6 +31,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from . import TextToVideoSDPipelineOutput
|
||||
|
||||
@@ -59,28 +58,6 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class TextToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
@@ -127,7 +104,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInve
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
@@ -652,7 +629,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInve
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type)
|
||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||
|
||||
# 9. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
@@ -16,11 +16,9 @@ import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet3DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
@@ -34,6 +32,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from . import TextToVideoSDPipelineOutput
|
||||
|
||||
@@ -94,69 +93,6 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def preprocess_video(video):
|
||||
supported_formats = (np.ndarray, torch.Tensor, PIL.Image.Image)
|
||||
|
||||
if isinstance(video, supported_formats):
|
||||
video = [video]
|
||||
elif not (isinstance(video, list) and all(isinstance(i, supported_formats) for i in video)):
|
||||
raise ValueError(
|
||||
f"Input is in incorrect format: {[type(i) for i in video]}. Currently, we only support {', '.join(supported_formats)}"
|
||||
)
|
||||
|
||||
if isinstance(video[0], PIL.Image.Image):
|
||||
video = [np.array(frame) for frame in video]
|
||||
|
||||
if isinstance(video[0], np.ndarray):
|
||||
video = np.concatenate(video, axis=0) if video[0].ndim == 5 else np.stack(video, axis=0)
|
||||
|
||||
if video.dtype == np.uint8:
|
||||
video = np.array(video).astype(np.float32) / 255.0
|
||||
|
||||
if video.ndim == 4:
|
||||
video = video[None, ...]
|
||||
|
||||
video = torch.from_numpy(video.transpose(0, 4, 1, 2, 3))
|
||||
|
||||
elif isinstance(video[0], torch.Tensor):
|
||||
video = torch.cat(video, axis=0) if video[0].ndim == 5 else torch.stack(video, axis=0)
|
||||
|
||||
# don't need any preprocess if the video is latents
|
||||
channel = video.shape[1]
|
||||
if channel == 4:
|
||||
return video
|
||||
|
||||
# move channels before num_frames
|
||||
video = video.permute(0, 2, 1, 3, 4)
|
||||
|
||||
# normalize video
|
||||
video = 2.0 * video - 1.0
|
||||
|
||||
return video
|
||||
|
||||
|
||||
class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-guided video-to-video generation.
|
||||
@@ -203,7 +139,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInv
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
@@ -687,7 +623,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInv
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# 4. Preprocess video
|
||||
video = preprocess_video(video)
|
||||
video = self.video_processor.preprocess_video(video)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
@@ -749,7 +685,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInv
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type)
|
||||
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
111
src/diffusers/video_processor.py
Normal file
111
src/diffusers/video_processor.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist
|
||||
|
||||
|
||||
class VideoProcessor(VaeImageProcessor):
|
||||
r"""Simple video processor."""
|
||||
|
||||
def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> torch.Tensor:
|
||||
r"""
|
||||
Preprocesses input video(s).
|
||||
|
||||
Args:
|
||||
video: The input video. It can be one of the following:
|
||||
* List of the PIL images.
|
||||
* List of list of PIL images.
|
||||
* 4D Torch tensors (expected shape for each tensor: (num_frames, num_channels, height, width)).
|
||||
* 4D NumPy arrays (expected shape for each array: (num_frames, height, width, num_channels)).
|
||||
* List of 4D Torch tensors (expected shape for each tensor: (num_frames, num_channels, height, width)).
|
||||
* List of 4D NumPy arrays (expected shape for each array: (num_frames, height, width, num_channels)).
|
||||
* 5D NumPy arrays: expected shape for each array: (batch_size, num_frames, height, width,
|
||||
num_channels).
|
||||
* 5D Torch tensors: expected shape for each array: (batch_size, num_frames, num_channels, height,
|
||||
width).
|
||||
height (`int`, *optional*, defaults to `None`):
|
||||
The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to
|
||||
get default height.
|
||||
width (`int`, *optional*`, defaults to `None`):
|
||||
The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get
|
||||
the default width.
|
||||
"""
|
||||
if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5:
|
||||
warnings.warn(
|
||||
"Passing `video` as a list of 5d np.ndarray is deprecated."
|
||||
"Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray",
|
||||
FutureWarning,
|
||||
)
|
||||
video = np.concatenate(video, axis=0)
|
||||
if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5:
|
||||
warnings.warn(
|
||||
"Passing `video` as a list of 5d torch.Tensor is deprecated."
|
||||
"Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor",
|
||||
FutureWarning,
|
||||
)
|
||||
video = torch.cat(video, axis=0)
|
||||
|
||||
# ensure the input is a list of videos:
|
||||
# - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray)
|
||||
# - if it is is a single video, it is convereted to a list of one video.
|
||||
if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5:
|
||||
video = list(video)
|
||||
elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video):
|
||||
video = [video]
|
||||
elif isinstance(video, list) and is_valid_image_imagelist(video[0]):
|
||||
video = video
|
||||
else:
|
||||
raise ValueError(
|
||||
"Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image"
|
||||
)
|
||||
|
||||
video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0)
|
||||
|
||||
# move the number of channels before the number of frames.
|
||||
video = video.permute(0, 2, 1, 3, 4)
|
||||
|
||||
return video
|
||||
|
||||
def postprocess_video(
|
||||
self, video: torch.Tensor, output_type: str = "np"
|
||||
) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]:
|
||||
r"""
|
||||
Converts a video tensor to a list of frames for export.
|
||||
|
||||
Args:
|
||||
video (`torch.Tensor`): The video as a tensor.
|
||||
output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor.
|
||||
"""
|
||||
batch_size = video.shape[0]
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = self.postprocess(batch_vid, output_type)
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
169
tests/others/test_video_processor.py
Normal file
169
tests/others/test_video_processor.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
class VideoProcessorTest(unittest.TestCase):
|
||||
def get_dummy_sample(self, input_type):
|
||||
batch_size = 1
|
||||
num_frames = 5
|
||||
num_channels = 3
|
||||
height = 8
|
||||
width = 8
|
||||
|
||||
def generate_image():
|
||||
return PIL.Image.fromarray(np.random.randint(0, 256, size=(height, width, num_channels)).astype("uint8"))
|
||||
|
||||
def generate_4d_array():
|
||||
return np.random.rand(num_frames, height, width, num_channels)
|
||||
|
||||
def generate_5d_array():
|
||||
return np.random.rand(batch_size, num_frames, height, width, num_channels)
|
||||
|
||||
def generate_4d_tensor():
|
||||
return torch.rand(num_frames, num_channels, height, width)
|
||||
|
||||
def generate_5d_tensor():
|
||||
return torch.rand(batch_size, num_frames, num_channels, height, width)
|
||||
|
||||
if input_type == "list_images":
|
||||
sample = [generate_image() for _ in range(num_frames)]
|
||||
elif input_type == "list_list_images":
|
||||
sample = [[generate_image() for _ in range(num_frames)] for _ in range(num_frames)]
|
||||
elif input_type == "list_4d_np":
|
||||
sample = [generate_4d_array() for _ in range(num_frames)]
|
||||
elif input_type == "list_list_4d_np":
|
||||
sample = [[generate_4d_array() for _ in range(num_frames)] for _ in range(num_frames)]
|
||||
elif input_type == "list_5d_np":
|
||||
sample = [generate_5d_array() for _ in range(num_frames)]
|
||||
elif input_type == "5d_np":
|
||||
sample = generate_5d_array()
|
||||
elif input_type == "list_4d_pt":
|
||||
sample = [generate_4d_tensor() for _ in range(num_frames)]
|
||||
elif input_type == "list_list_4d_pt":
|
||||
sample = [[generate_4d_tensor() for _ in range(num_frames)] for _ in range(num_frames)]
|
||||
elif input_type == "list_5d_pt":
|
||||
sample = [generate_5d_tensor() for _ in range(num_frames)]
|
||||
elif input_type == "5d_pt":
|
||||
sample = generate_5d_tensor()
|
||||
|
||||
return sample
|
||||
|
||||
def to_np(self, video):
|
||||
# List of images.
|
||||
if isinstance(video[0], PIL.Image.Image):
|
||||
video = np.stack([np.array(i) for i in video], axis=0)
|
||||
|
||||
# List of list of images.
|
||||
elif isinstance(video, list) and isinstance(video[0][0], PIL.Image.Image):
|
||||
frames = []
|
||||
for vid in video:
|
||||
all_current_frames = np.stack([np.array(i) for i in vid], axis=0)
|
||||
frames.append(all_current_frames)
|
||||
video = np.stack([np.array(frame) for frame in frames], axis=0)
|
||||
|
||||
# List of 4d/5d {ndarrays, torch tensors}.
|
||||
elif isinstance(video, list) and isinstance(video[0], (torch.Tensor, np.ndarray)):
|
||||
if isinstance(video[0], np.ndarray):
|
||||
video = np.stack(video, axis=0) if video[0].ndim == 4 else np.concatenate(video, axis=0)
|
||||
else:
|
||||
if video[0].ndim == 4:
|
||||
video = np.stack([i.cpu().numpy().transpose(0, 2, 3, 1) for i in video], axis=0)
|
||||
elif video[0].ndim == 5:
|
||||
video = np.concatenate([i.cpu().numpy().transpose(0, 1, 3, 4, 2) for i in video], axis=0)
|
||||
|
||||
# List of list of 4d/5d {ndarrays, torch tensors}.
|
||||
elif (
|
||||
isinstance(video, list)
|
||||
and isinstance(video[0], list)
|
||||
and isinstance(video[0][0], (torch.Tensor, np.ndarray))
|
||||
):
|
||||
all_frames = []
|
||||
for list_of_videos in video:
|
||||
temp_frames = []
|
||||
for vid in list_of_videos:
|
||||
if vid.ndim == 4:
|
||||
current_vid_frames = np.stack(
|
||||
[i if isinstance(i, np.ndarray) else i.cpu().numpy().transpose(1, 2, 0) for i in vid],
|
||||
axis=0,
|
||||
)
|
||||
elif vid.ndim == 5:
|
||||
current_vid_frames = np.concatenate(
|
||||
[i if isinstance(i, np.ndarray) else i.cpu().numpy().transpose(0, 2, 3, 1) for i in vid],
|
||||
axis=0,
|
||||
)
|
||||
temp_frames.append(current_vid_frames)
|
||||
temp_frames = np.stack(temp_frames, axis=0)
|
||||
all_frames.append(temp_frames)
|
||||
|
||||
video = np.concatenate(all_frames, axis=0)
|
||||
|
||||
# Just 5d {ndarrays, torch tensors}.
|
||||
elif isinstance(video, (torch.Tensor, np.ndarray)) and video.ndim == 5:
|
||||
video = video if isinstance(video, np.ndarray) else video.cpu().numpy().transpose(0, 1, 3, 4, 2)
|
||||
|
||||
return video
|
||||
|
||||
@parameterized.expand(["list_images", "list_list_images"])
|
||||
def test_video_processor_pil(self, input_type):
|
||||
video_processor = VideoProcessor(do_resize=False, do_normalize=True)
|
||||
|
||||
input = self.get_dummy_sample(input_type=input_type)
|
||||
|
||||
for output_type in ["pt", "np", "pil"]:
|
||||
out = video_processor.postprocess_video(video_processor.preprocess_video(input), output_type=output_type)
|
||||
out_np = self.to_np(out)
|
||||
input_np = self.to_np(input).astype("float32") / 255.0 if output_type != "pil" else self.to_np(input)
|
||||
assert np.abs(input_np - out_np).max() < 1e-6, f"Decoded output does not match input for {output_type=}"
|
||||
|
||||
@parameterized.expand(["list_4d_np", "list_5d_np", "5d_np"])
|
||||
def test_video_processor_np(self, input_type):
|
||||
video_processor = VideoProcessor(do_resize=False, do_normalize=True)
|
||||
|
||||
input = self.get_dummy_sample(input_type=input_type)
|
||||
|
||||
for output_type in ["pt", "np", "pil"]:
|
||||
out = video_processor.postprocess_video(video_processor.preprocess_video(input), output_type=output_type)
|
||||
out_np = self.to_np(out)
|
||||
input_np = (
|
||||
(self.to_np(input) * 255.0).round().astype("uint8") if output_type == "pil" else self.to_np(input)
|
||||
)
|
||||
assert np.abs(input_np - out_np).max() < 1e-6, f"Decoded output does not match input for {output_type=}"
|
||||
|
||||
@parameterized.expand(["list_4d_pt", "list_5d_pt", "5d_pt"])
|
||||
def test_video_processor_pt(self, input_type):
|
||||
video_processor = VideoProcessor(do_resize=False, do_normalize=True)
|
||||
|
||||
input = self.get_dummy_sample(input_type=input_type)
|
||||
|
||||
for output_type in ["pt", "np", "pil"]:
|
||||
out = video_processor.postprocess_video(video_processor.preprocess_video(input), output_type=output_type)
|
||||
out_np = self.to_np(out)
|
||||
input_np = (
|
||||
(self.to_np(input) * 255.0).round().astype("uint8") if output_type == "pil" else self.to_np(input)
|
||||
)
|
||||
assert np.abs(input_np - out_np).max() < 1e-6, f"Decoded output does not match input for {output_type=}"
|
||||
Reference in New Issue
Block a user