mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Standardise outputs for video pipelines (#6626)
* update * update * update * update * update * update * update * clean up * clean up
This commit is contained in:
@@ -67,10 +67,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
# Based on:
|
||||
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
@@ -79,6 +76,15 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -805,11 +811,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
||||
return AnimateDiffPipelineOutput(frames=latents)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
else:
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
@@ -40,10 +40,8 @@ def _append_dims(x, target_dims):
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
# Based on:
|
||||
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
@@ -53,7 +51,13 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
return np.stack(outputs)
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet3DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
@@ -58,22 +59,26 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
|
||||
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
# reshape to ncfhw
|
||||
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
# unnormalize back to [0,1]
|
||||
video = video.mul_(std).add_(mean)
|
||||
video.clamp_(0, 1)
|
||||
# prepare the final outputs
|
||||
i, c, f, h, w = video.shape
|
||||
images = video.permute(2, 3, 0, 4, 1).reshape(
|
||||
f, h, i * w, c
|
||||
) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
|
||||
images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
|
||||
images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
|
||||
return images
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
@@ -122,6 +127,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
@@ -717,11 +723,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
||||
return TextToVideoSDPipelineOutput(frames=latents)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
else:
|
||||
video = tensor2vid(video_tensor)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
@@ -20,6 +20,7 @@ import PIL.Image
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet3DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
@@ -93,22 +94,26 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
|
||||
# This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
# reshape to ncfhw
|
||||
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
|
||||
# unnormalize back to [0,1]
|
||||
video = video.mul_(std).add_(mean)
|
||||
video.clamp_(0, 1)
|
||||
# prepare the final outputs
|
||||
i, c, f, h, w = video.shape
|
||||
images = video.permute(2, 3, 0, 4, 1).reshape(
|
||||
f, h, i * w, c
|
||||
) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
|
||||
images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
|
||||
images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
|
||||
return images
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
||||
batch_output = processor.postprocess(batch_vid, output_type)
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def preprocess_video(video):
|
||||
@@ -198,6 +203,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
@@ -812,12 +818,11 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.unet.to("cpu")
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
if output_type == "latent":
|
||||
return TextToVideoSDPipelineOutput(frames=latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
else:
|
||||
video = tensor2vid(video_tensor)
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
Reference in New Issue
Block a user