mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[docs/nits] Fix return values based on return_dict and minor doc updates (#7105)
* fix returns and docs * handle latent output_type correctly * revert to old tensor2vid impl * make fix-copies * fix return in community animatediff pipes * fix return docstring * fix return docs * add missing quote --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -27,6 +26,7 @@ from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionL
|
||||
from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.models.unets.unet_motion_model import MotionAdapter
|
||||
from diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput
|
||||
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from diffusers.schedulers import (
|
||||
@@ -37,7 +37,7 @@ from diffusers.schedulers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
|
||||
|
||||
|
||||
@@ -91,10 +91,8 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
# Based on:
|
||||
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
@@ -103,14 +101,18 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnimateDiffControlNetPipelineOutput(BaseOutput):
|
||||
frames: Union[torch.Tensor, np.ndarray]
|
||||
|
||||
|
||||
class AnimateDiffControlNetPipeline(
|
||||
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin
|
||||
):
|
||||
@@ -843,8 +845,8 @@ class AnimateDiffControlNetPipeline(
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
|
||||
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
|
||||
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
||||
"""
|
||||
|
||||
@@ -1020,7 +1022,7 @@ class AnimateDiffControlNetPipeline(
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
||||
|
||||
# Denoising loop
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -1096,21 +1098,17 @@ class AnimateDiffControlNetPipeline(
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 9. Post processing
|
||||
if output_type == "latent":
|
||||
return AnimateDiffControlNetPipelineOutput(frames=latents)
|
||||
|
||||
# Post-processing
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
# 10. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return AnimateDiffControlNetPipelineOutput(frames=video)
|
||||
return AnimateDiffPipelineOutput(frames=video)
|
||||
|
||||
@@ -158,10 +158,8 @@ def slerp(
|
||||
return v2
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
|
||||
def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
# Based on:
|
||||
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
||||
|
||||
batch_size, channels, num_frames, height, width = video.shape
|
||||
outputs = []
|
||||
for batch_idx in range(batch_size):
|
||||
@@ -170,6 +168,15 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
|
||||
outputs.append(batch_output)
|
||||
|
||||
if output_type == "np":
|
||||
outputs = np.stack(outputs)
|
||||
|
||||
elif output_type == "pt":
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -826,8 +833,8 @@ class AnimateDiffImgToVideoPipeline(
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`AnimateDiffPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is
|
||||
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
|
||||
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
@@ -958,11 +965,10 @@ class AnimateDiffImgToVideoPipeline(
|
||||
return AnimateDiffPipelineOutput(frames=latents)
|
||||
|
||||
# 10. Post-processing
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
# 11. Offload all models
|
||||
|
||||
@@ -81,7 +81,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -668,8 +668,8 @@ class AnimateDiffPipeline(
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
|
||||
[`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
|
||||
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
||||
"""
|
||||
|
||||
@@ -790,6 +790,8 @@ class AnimateDiffPipeline(
|
||||
|
||||
self._num_timesteps = len(timesteps)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
# 8. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
@@ -829,13 +831,14 @@ class AnimateDiffPipeline(
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 9. Post processing
|
||||
if output_type == "latent":
|
||||
return AnimateDiffPipelineOutput(frames=latents)
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
# 9. Offload all models
|
||||
# 10. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -100,7 +100,7 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -828,8 +828,8 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`AnimateDiffPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is
|
||||
[`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is
|
||||
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
||||
"""
|
||||
|
||||
@@ -942,6 +942,7 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
|
||||
self._num_timesteps = len(timesteps)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
# 8. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -980,15 +981,11 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
return AnimateDiffPipelineOutput(frames=latents)
|
||||
|
||||
# 9. Post-processing
|
||||
video_tensor = self.decode_latents(latents)
|
||||
|
||||
if output_type == "pt":
|
||||
video = video_tensor
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
# 10. Offload all models
|
||||
|
||||
@@ -83,7 +83,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -726,13 +726,14 @@ class I2VGenXLPipeline(
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
# 8. Post processing
|
||||
if output_type == "latent":
|
||||
return I2VGenXLPipelineOutput(frames=latents)
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
# 9. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -107,7 +107,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -860,8 +860,8 @@ class PIAPipeline(
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
|
||||
[`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] is
|
||||
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
@@ -1018,13 +1018,14 @@ class PIAPipeline(
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
# 9. Post processing
|
||||
if output_type == "latent":
|
||||
return PIAPipelineOutput(frames=latents)
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
|
||||
|
||||
# 9. Offload all models
|
||||
# 10. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -74,7 +74,7 @@ def tensor2vid(video: torch.Tensor, processor: VaeImageProcessor, output_type: s
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -647,13 +647,14 @@ class TextToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInve
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
# 8. Post processing
|
||||
if output_type == "latent":
|
||||
return TextToVideoSDPipelineOutput(frames=latents)
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type)
|
||||
|
||||
# Offload all models
|
||||
# 9. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -111,7 +111,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type:
|
||||
outputs = torch.stack(outputs)
|
||||
|
||||
elif not output_type == "pil":
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]")
|
||||
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -694,13 +694,13 @@ class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInv
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
# 6. Prepare latent variables
|
||||
latents = self.prepare_latents(video, latent_timestep, batch_size, prompt_embeds.dtype, device, generator)
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -740,20 +740,18 @@ class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInv
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if output_type == "latent":
|
||||
return TextToVideoSDPipelineOutput(frames=latents)
|
||||
|
||||
# manually for max memory savings
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.unet.to("cpu")
|
||||
|
||||
# 9. Post processing
|
||||
if output_type == "latent":
|
||||
return TextToVideoSDPipelineOutput(frames=latents)
|
||||
video = latents
|
||||
else:
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type)
|
||||
|
||||
video_tensor = self.decode_latents(latents)
|
||||
video = tensor2vid(video_tensor, self.image_processor, output_type)
|
||||
|
||||
# Offload all models
|
||||
# 10. Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
|
||||
Reference in New Issue
Block a user