1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[docs/nits] Fix return values based on return_dict and minor doc updates (#7105)

* fix returns and docs

* handle latent output_type correctly

* revert to old tensor2vid impl

* make fix-copies

* fix return in community animatediff pipes

* fix return docstring

* fix return docs

* add missing quote

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
Aryan
2024-03-09 10:17:24 +05:30
committed by GitHub
parent 6f2b310a17
commit cd6e1f1171
9 changed files with 87 additions and 82 deletions

View File

@@ -13,7 +13,6 @@
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
@@ -27,6 +26,7 @@ from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionL
from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.models.unets.unet_motion_model import MotionAdapter
from diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from diffusers.schedulers import (
@@ -37,7 +37,7 @@ from diffusers.schedulers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
@@ -91,10 +91,8 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
# Based on:
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
@@ -103,14 +101,18 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
@dataclass
class AnimateDiffControlNetPipelineOutput(BaseOutput):
frames: Union[torch.Tensor, np.ndarray]
class AnimateDiffControlNetPipeline(
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin
):
@@ -843,8 +845,8 @@ class AnimateDiffControlNetPipeline(
Examples:
Returns:
[`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""
@@ -1020,7 +1022,7 @@ class AnimateDiffControlNetPipeline(
]
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
# Denoising loop
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -1096,21 +1098,17 @@ class AnimateDiffControlNetPipeline(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 9. Post processing
if output_type == "latent":
return AnimateDiffControlNetPipelineOutput(frames=latents)
# Post-processing
video_tensor = self.decode_latents(latents)
if output_type == "pt":
video = video_tensor
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
# Offload all models
# 10. Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return AnimateDiffControlNetPipelineOutput(frames=video)
return AnimateDiffPipelineOutput(frames=video)

View File

@@ -158,10 +158,8 @@ def slerp(
return v2
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
# Based on:
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
@@ -170,6 +168,15 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
outputs.append(batch_output)
if output_type == "np":
outputs = np.stack(outputs)
elif output_type == "pt":
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
@@ -826,8 +833,8 @@ class AnimateDiffImgToVideoPipeline(
Examples:
Returns:
[`AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""
# 0. Default height and width to unet
@@ -958,11 +965,10 @@ class AnimateDiffImgToVideoPipeline(
return AnimateDiffPipelineOutput(frames=latents)
# 10. Post-processing
video_tensor = self.decode_latents(latents)
if output_type == "pt":
video = video_tensor
if output_type == "latent":
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
# 11. Offload all models

View File

@@ -81,7 +81,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
@@ -668,8 +668,8 @@ class AnimateDiffPipeline(
Examples:
Returns:
[`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""
@@ -790,6 +790,8 @@ class AnimateDiffPipeline(
self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
# 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
@@ -829,13 +831,14 @@ class AnimateDiffPipeline(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# 9. Post processing
if output_type == "latent":
return AnimateDiffPipelineOutput(frames=latents)
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
# 9. Offload all models
# 10. Offload all models
self.maybe_free_model_hooks()
if not return_dict:

View File

@@ -100,7 +100,7 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
@@ -828,8 +828,8 @@ class AnimateDiffVideoToVideoPipeline(
Examples:
Returns:
[`AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is
[`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""
@@ -942,6 +942,7 @@ class AnimateDiffVideoToVideoPipeline(
self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
# 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -980,15 +981,11 @@ class AnimateDiffVideoToVideoPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if output_type == "latent":
return AnimateDiffPipelineOutput(frames=latents)
# 9. Post-processing
video_tensor = self.decode_latents(latents)
if output_type == "pt":
video = video_tensor
if output_type == "latent":
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
# 10. Offload all models

View File

@@ -83,7 +83,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
@@ -726,13 +726,14 @@ class I2VGenXLPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# 8. Post processing
if output_type == "latent":
return I2VGenXLPipelineOutput(frames=latents)
video = latents
else:
video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
# Offload all models
# 9. Offload all models
self.maybe_free_model_hooks()
if not return_dict:

View File

@@ -107,7 +107,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
@@ -860,8 +860,8 @@ class PIAPipeline(
Examples:
Returns:
[`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
[`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
"""
# 0. Default height and width to unet
@@ -1018,13 +1018,14 @@ class PIAPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# 9. Post processing
if output_type == "latent":
return PIAPipelineOutput(frames=latents)
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
# 9. Offload all models
# 10. Offload all models
self.maybe_free_model_hooks()
if not return_dict:

View File

@@ -74,7 +74,7 @@ def tensor2vid(video: torch.Tensor, processor: VaeImageProcessor, output_type: s
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs

View File

@@ -76,7 +76,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
@@ -647,13 +647,14 @@ class TextToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInve
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# 8. Post processing
if output_type == "latent":
return TextToVideoSDPipelineOutput(frames=latents)
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type)
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type)
# Offload all models
# 9. Offload all models
self.maybe_free_model_hooks()
if not return_dict:

View File

@@ -111,7 +111,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
outputs = torch.stack(outputs)
elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
@@ -694,13 +694,13 @@ class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInv
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 5. Prepare latent variables
# 6. Prepare latent variables
latents = self.prepare_latents(video, latent_timestep, batch_size, prompt_embeds.dtype, device, generator)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -740,20 +740,18 @@ class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInv
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if output_type == "latent":
return TextToVideoSDPipelineOutput(frames=latents)
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
# 9. Post processing
if output_type == "latent":
return TextToVideoSDPipelineOutput(frames=latents)
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type)
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type)
# Offload all models
# 10. Offload all models
self.maybe_free_model_hooks()
if not return_dict: