1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[Core] introduce videoprocessor. (#7776)

* introduce videoprocessor.

* fix quality

* address yiyi's feedback

* fix preprocess_video call.

* video_processor -> image_processor

* fix

* fix more.

* quality

* image_processor -> video_processor

* support List[List[PIL.Image.Image]]

* change to video_processor.

* documentation

* Apply suggestions from code review

* changes

* remove print.

* refactor video processor (part # 7776) (#7861)

* update

* update remove deprecate

* Update src/diffusers/video_processor.py

* update

* Apply suggestions from code review

* deprecate list of 5d for video and list of 4d for image + apply other feedbacks

* up

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* add doc.

* tensor2vid -> postprocess_video.

* refactor preprocess with preprocess_video

* set default values.

* empty commit

* more refactoring of prepare_latents in animatediff vid2vid

* checking documentation

* remove documentation for now.

* fix animatediff sdxl

* fix test failure [part of video processor PR] (#7905)

up

* remove preceed_with_frames.

* doc

* fix

* fix

* remove video input as a single-frame video.

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
Sayak Paul
2024-05-10 21:02:36 +02:00
committed by GitHub
parent 82be58c512
commit 04f4bd54ea
13 changed files with 395 additions and 283 deletions

View File

@@ -29,15 +29,34 @@ from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
PipelineImageInput = Union[
PIL.Image.Image,
np.ndarray,
torch.FloatTensor,
torch.Tensor,
List[PIL.Image.Image],
List[np.ndarray],
List[torch.FloatTensor],
List[torch.Tensor],
]
PipelineDepthInput = PipelineImageInput
def is_valid_image(image):
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
def is_valid_image_imagelist(images):
# check if the image input is one of the supported formats for image and image list:
# it can be either one of below 3
# (1) a 4d pytorch tensor or numpy array,
# (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
# (3) a list of valid image
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
return True
elif is_valid_image(images):
return True
elif isinstance(images, list):
return all(is_valid_image(image) for image in images)
return False
class VaeImageProcessor(ConfigMixin):
"""
Image processor for VAE.
@@ -110,7 +129,7 @@ class VaeImageProcessor(ConfigMixin):
return images
@staticmethod
def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
"""
Convert a NumPy image to a PyTorch tensor.
"""
@@ -121,7 +140,7 @@ class VaeImageProcessor(ConfigMixin):
return images
@staticmethod
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
"""
Convert a PyTorch tensor to a NumPy image.
"""
@@ -497,12 +516,27 @@ class VaeImageProcessor(ConfigMixin):
else:
image = np.expand_dims(image, axis=-1)
if isinstance(image, supported_formats):
image = [image]
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
raise ValueError(
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
warnings.warn(
"Passing `image` as a list of 4d np.ndarray is deprecated."
"Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
FutureWarning,
)
image = np.concatenate(image, axis=0)
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
warnings.warn(
"Passing `image` as a list of 4d torch.Tensor is deprecated."
"Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
FutureWarning,
)
image = torch.cat(image, axis=0)
if not is_valid_image_imagelist(image):
raise ValueError(
f"Input is in incorrect format. Currently, we only support {', '.join(supported_formats)}"
)
if not isinstance(image, list):
image = [image]
if isinstance(image[0], PIL.Image.Image):
if crops_coords is not None:
@@ -561,15 +595,15 @@ class VaeImageProcessor(ConfigMixin):
def postprocess(
self,
image: torch.FloatTensor,
image: torch.Tensor,
output_type: str = "pil",
do_denormalize: Optional[List[bool]] = None,
) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
"""
Postprocess the image output from tensor to `output_type`.
Args:
image (`torch.FloatTensor`):
image (`torch.Tensor`):
The image input, should be a pytorch tensor with shape `B x C x H x W`.
output_type (`str`, *optional*, defaults to `pil`):
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
@@ -578,7 +612,7 @@ class VaeImageProcessor(ConfigMixin):
`VaeImageProcessor` config.
Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
The postprocessed image.
"""
if not isinstance(image, torch.Tensor):
@@ -738,15 +772,15 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
def postprocess(
self,
image: torch.FloatTensor,
image: torch.Tensor,
output_type: str = "pil",
do_denormalize: Optional[List[bool]] = None,
) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
"""
Postprocess the image output from tensor to `output_type`.
Args:
image (`torch.FloatTensor`):
image (`torch.Tensor`):
The image input, should be a pytorch tensor with shape `B x C x H x W`.
output_type (`str`, *optional*, defaults to `pil`):
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
@@ -755,7 +789,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
`VaeImageProcessor` config.
Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
The postprocessed image.
"""
if not isinstance(image, torch.Tensor):
@@ -793,8 +827,8 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
def preprocess(
self,
rgb: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
depth: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
height: Optional[int] = None,
width: Optional[int] = None,
target_res: Optional[int] = None,
@@ -933,13 +967,13 @@ class IPAdapterMaskProcessor(VaeImageProcessor):
)
@staticmethod
def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int):
def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int):
"""
Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the
aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
Args:
mask (`torch.FloatTensor`):
mask (`torch.Tensor`):
The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
batch_size (`int`):
The batch size.
@@ -949,7 +983,7 @@ class IPAdapterMaskProcessor(VaeImageProcessor):
The dimensionality of the value embeddings.
Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The downsampled mask tensor.
"""

View File

@@ -15,11 +15,10 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...image_processor import PipelineImageInput
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder
@@ -41,6 +40,7 @@ from ...utils import (
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
@@ -65,27 +65,6 @@ EXAMPLE_DOC_STRING = """
"""
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
class AnimateDiffPipeline(
DiffusionPipeline,
StableDiffusionMixin,
@@ -159,7 +138,7 @@ class AnimateDiffPipeline(
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
def encode_prompt(
@@ -836,7 +815,7 @@ class AnimateDiffPipeline(
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 10. Offload all models
self.maybe_free_model_hooks()

View File

@@ -15,7 +15,6 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import (
CLIPImageProcessor,
@@ -25,7 +24,7 @@ from transformers import (
CLIPVisionModelWithProjection,
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...image_processor import PipelineImageInput
from ...loaders import (
FromSingleFileMixin,
IPAdapterMixin,
@@ -57,6 +56,7 @@ from ...utils import (
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
@@ -113,28 +113,6 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
@@ -320,7 +298,7 @@ class AnimateDiffSDXLPipeline(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
@@ -1291,7 +1269,7 @@ class AnimateDiffSDXLPipeline(
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# cast back to fp16 if needed
if needs_upcasting:

View File

@@ -15,11 +15,10 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...image_processor import PipelineImageInput
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder
@@ -34,6 +33,7 @@ from ...schedulers import (
)
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
@@ -95,28 +95,6 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -264,7 +242,7 @@ class AnimateDiffVideoToVideoPipeline(
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
def encode_prompt(
@@ -650,16 +628,7 @@ class AnimateDiffVideoToVideoPipeline(
generator,
latents=None,
):
# video must be a list of list of images
# the outer list denotes having multiple videos as input, whereas inner list means the frames of the video
# as a list of images
if video and not isinstance(video[0], list):
video = [video]
if latents is None:
video = torch.cat(
[self.image_processor.preprocess(vid, height=height, width=width).unsqueeze(0) for vid in video], dim=0
)
video = video.to(device=device, dtype=dtype)
num_frames = video.shape[1]
else:
num_frames = latents.shape[2]
@@ -943,6 +912,11 @@ class AnimateDiffVideoToVideoPipeline(
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
# 5. Prepare latent variables
if latents is None:
video = self.video_processor.preprocess_video(video, height=height, width=width)
# Move the number of frames before the number of channels.
video = video.permute(0, 2, 1, 3, 4)
video = video.to(device=device, dtype=prompt_embeds.dtype)
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
video=video,
@@ -1023,7 +997,7 @@ class AnimateDiffVideoToVideoPipeline(
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 10. Offload all models
self.maybe_free_model_hooks()

View File

@@ -31,6 +31,7 @@ from ...utils import (
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
@@ -70,28 +71,6 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
@dataclass
class I2VGenXLPipelineOutput(BaseOutput):
r"""
@@ -156,7 +135,7 @@ class I2VGenXLPipeline(
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# `do_resize=False` as we do custom resizing.
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False)
@property
def guidance_scale(self):
@@ -342,8 +321,8 @@ class I2VGenXLPipeline(
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.image_processor.pil_to_numpy(image)
image = self.image_processor.numpy_to_pt(image)
image = self.video_processor.pil_to_numpy(image)
image = self.video_processor.numpy_to_pt(image)
# Normalize the image with CLIP training stats.
image = self.feature_extractor(
@@ -657,7 +636,7 @@ class I2VGenXLPipeline(
# 3.2.2 Image latents.
resized_image = _center_crop_wide(image, (width, height))
image = self.image_processor.preprocess(resized_image).to(device=device, dtype=image_embeddings.dtype)
image = self.video_processor.preprocess(resized_image).to(device=device, dtype=image_embeddings.dtype)
image_latents = self.prepare_image_latents(
image,
device=device,
@@ -737,7 +716,7 @@ class I2VGenXLPipeline(
video = latents
else:
video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 9. Offload all models
self.maybe_free_model_hooks()

View File

@@ -21,7 +21,7 @@ import PIL
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...image_processor import PipelineImageInput
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder
@@ -43,6 +43,7 @@ from ...utils import (
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
@@ -89,28 +90,6 @@ RANGE_LIST = [
]
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, motion_scale: int):
assert num_frames > 0, "video_length should be greater than 0"
@@ -218,7 +197,7 @@ class PIAPipeline(
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
def encode_prompt(
@@ -621,7 +600,7 @@ class PIAPipeline(
)
_, _, _, scaled_height, scaled_width = shape
image = self.image_processor.preprocess(image)
image = self.video_processor.preprocess(image)
image = image.to(device, dtype)
if isinstance(generator, list):
@@ -959,7 +938,7 @@ class PIAPipeline(
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 10. Offload all models
self.maybe_free_model_hooks()

View File

@@ -21,11 +21,12 @@ import PIL.Image
import torch
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
from ...schedulers import EulerDiscreteScheduler
from ...utils import BaseOutput, logging, replace_example_docstring
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
@@ -61,28 +62,6 @@ def _append_dims(x, target_dims):
return x[(...,) + (None,) * dims_to_append]
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor: VaeImageProcessor, output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
@@ -199,7 +178,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
def _encode_image(
self,
@@ -211,8 +190,8 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.image_processor.pil_to_numpy(image)
image = self.image_processor.numpy_to_pt(image)
image = self.video_processor.pil_to_numpy(image)
image = self.video_processor.numpy_to_pt(image)
# We normalize the image before resizing to match with the original implementation.
# Then we unnormalize it after resizing.
@@ -520,7 +499,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
fps = fps - 1
# 4. Encode input image using VAE
image = self.image_processor.preprocess(image, height=height, width=width).to(device)
image = self.video_processor.preprocess(image, height=height, width=width).to(device)
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
image = image + noise_aug_strength * noise
@@ -626,7 +605,7 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
if needs_upcasting:
self.vae.to(dtype=torch.float16)
frames = self.decode_latents(latents, num_frames, decode_chunk_size)
frames = tensor2vid(frames, self.image_processor, output_type=output_type)
frames = self.video_processor.postprocess_video(video=frames, output_type=output_type)
else:
frames = latents

View File

@@ -15,11 +15,9 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
@@ -33,6 +31,7 @@ from ...utils import (
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import TextToVideoSDPipelineOutput
@@ -59,28 +58,6 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
class TextToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin):
r"""
Pipeline for text-to-video generation.
@@ -127,7 +104,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInve
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
@@ -652,7 +629,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInve
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 9. Offload all models
self.maybe_free_model_hooks()

View File

@@ -16,11 +16,9 @@ import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
@@ -34,6 +32,7 @@ from ...utils import (
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import TextToVideoSDPipelineOutput
@@ -94,69 +93,6 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
def preprocess_video(video):
supported_formats = (np.ndarray, torch.Tensor, PIL.Image.Image)
if isinstance(video, supported_formats):
video = [video]
elif not (isinstance(video, list) and all(isinstance(i, supported_formats) for i in video)):
raise ValueError(
f"Input is in incorrect format: {[type(i) for i in video]}. Currently, we only support {', '.join(supported_formats)}"
)
if isinstance(video[0], PIL.Image.Image):
video = [np.array(frame) for frame in video]
if isinstance(video[0], np.ndarray):
video = np.concatenate(video, axis=0) if video[0].ndim == 5 else np.stack(video, axis=0)
if video.dtype == np.uint8:
video = np.array(video).astype(np.float32) / 255.0
if video.ndim == 4:
video = video[None, ...]
video = torch.from_numpy(video.transpose(0, 4, 1, 2, 3))
elif isinstance(video[0], torch.Tensor):
video = torch.cat(video, axis=0) if video[0].ndim == 5 else torch.stack(video, axis=0)
# don't need any preprocess if the video is latents
channel = video.shape[1]
if channel == 4:
return video
# move channels before num_frames
video = video.permute(0, 2, 1, 3, 4)
# normalize video
video = 2.0 * video - 1.0
return video
class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin):
r"""
Pipeline for text-guided video-to-video generation.
@@ -203,7 +139,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInv
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
@@ -687,7 +623,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInv
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Preprocess video
video = preprocess_video(video)
video = self.video_processor.preprocess_video(video)
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -749,7 +685,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInv
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 10. Offload all models
self.maybe_free_model_hooks()

View File

@@ -0,0 +1,111 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import List, Optional, Union
import numpy as np
import PIL
import torch
from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist
class VideoProcessor(VaeImageProcessor):
r"""Simple video processor."""
def preprocess_video(self, video, height: Optional[int] = None, width: Optional[int] = None) -> torch.Tensor:
r"""
Preprocesses input video(s).
Args:
video: The input video. It can be one of the following:
* List of the PIL images.
* List of list of PIL images.
* 4D Torch tensors (expected shape for each tensor: (num_frames, num_channels, height, width)).
* 4D NumPy arrays (expected shape for each array: (num_frames, height, width, num_channels)).
* List of 4D Torch tensors (expected shape for each tensor: (num_frames, num_channels, height, width)).
* List of 4D NumPy arrays (expected shape for each array: (num_frames, height, width, num_channels)).
* 5D NumPy arrays: expected shape for each array: (batch_size, num_frames, height, width,
num_channels).
* 5D Torch tensors: expected shape for each array: (batch_size, num_frames, num_channels, height,
width).
height (`int`, *optional*, defaults to `None`):
The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to
get default height.
width (`int`, *optional*`, defaults to `None`):
The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get
the default width.
"""
if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5:
warnings.warn(
"Passing `video` as a list of 5d np.ndarray is deprecated."
"Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray",
FutureWarning,
)
video = np.concatenate(video, axis=0)
if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5:
warnings.warn(
"Passing `video` as a list of 5d torch.Tensor is deprecated."
"Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor",
FutureWarning,
)
video = torch.cat(video, axis=0)
# ensure the input is a list of videos:
# - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray)
# - if it is is a single video, it is convereted to a list of one video.
if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5:
video = list(video)
elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video):
video = [video]
elif isinstance(video, list) and is_valid_image_imagelist(video[0]):
video = video
else:
raise ValueError(
"Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image"
)
video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0)
# move the number of channels before the number of frames.
video = video.permute(0, 2, 1, 3, 4)
return video
def postprocess_video(
self, video: torch.Tensor, output_type: str = "np"
) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]:
r"""
Converts a video tensor to a list of frames for export.
Args:
video (`torch.Tensor`): The video as a tensor.
output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor.
"""
batch_size = video.shape[0]
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = self.postprocess(batch_vid, output_type)
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs