mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'main' into cogvideox-lora-and-training
This commit is contained in:
@@ -34,7 +34,7 @@ from utils import ( # noqa: E402
|
||||
|
||||
|
||||
RESOLUTION_MAPPING = {
|
||||
"runwayml/stable-diffusion-v1-5": (512, 512),
|
||||
"Lykon/DreamShaper": (512, 512),
|
||||
"lllyasviel/sd-controlnet-canny": (512, 512),
|
||||
"diffusers/controlnet-canny-sdxl-1.0": (1024, 1024),
|
||||
"TencentARC/t2iadapter_canny_sd14v1": (512, 512),
|
||||
@@ -268,7 +268,7 @@ class IPAdapterTextToImageBenchmark(TextToImageBenchmark):
|
||||
class ControlNetBenchmark(TextToImageBenchmark):
|
||||
pipeline_class = StableDiffusionControlNetPipeline
|
||||
aux_network_class = ControlNetModel
|
||||
root_ckpt = "runwayml/stable-diffusion-v1-5"
|
||||
root_ckpt = "Lykon/DreamShaper"
|
||||
|
||||
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_image_condition.png"
|
||||
image = load_image(url).convert("RGB")
|
||||
@@ -311,7 +311,7 @@ class ControlNetSDXLBenchmark(ControlNetBenchmark):
|
||||
class T2IAdapterBenchmark(ControlNetBenchmark):
|
||||
pipeline_class = StableDiffusionAdapterPipeline
|
||||
aux_network_class = T2IAdapter
|
||||
root_ckpt = "CompVis/stable-diffusion-v1-4"
|
||||
root_ckpt = "Lykon/DreamShaper"
|
||||
|
||||
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter.png"
|
||||
image = load_image(url).convert("L")
|
||||
|
||||
@@ -7,7 +7,8 @@ from base_classes import IPAdapterTextToImageBenchmark # noqa: E402
|
||||
|
||||
|
||||
IP_ADAPTER_CKPTS = {
|
||||
"runwayml/stable-diffusion-v1-5": ("h94/IP-Adapter", "ip-adapter_sd15.bin"),
|
||||
# because original SD v1.5 has been taken down.
|
||||
"Lykon/DreamShaper": ("h94/IP-Adapter", "ip-adapter_sd15.bin"),
|
||||
"stabilityai/stable-diffusion-xl-base-1.0": ("h94/IP-Adapter", "ip-adapter_sdxl.bin"),
|
||||
}
|
||||
|
||||
@@ -17,7 +18,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="runwayml/stable-diffusion-v1-5",
|
||||
default="rstabilityai/stable-diffusion-xl-base-1.0",
|
||||
choices=list(IP_ADAPTER_CKPTS.keys()),
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
|
||||
@@ -11,9 +11,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="runwayml/stable-diffusion-v1-5",
|
||||
default="Lykon/DreamShaper",
|
||||
choices=[
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
"Lykon/DreamShaper",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0",
|
||||
"stabilityai/sdxl-turbo",
|
||||
|
||||
@@ -11,9 +11,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="runwayml/stable-diffusion-v1-5",
|
||||
default="Lykon/DreamShaper",
|
||||
choices=[
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
"Lykon/DreamShaper",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
],
|
||||
|
||||
@@ -7,7 +7,7 @@ from base_classes import TextToImageBenchmark, TurboTextToImageBenchmark # noqa
|
||||
|
||||
|
||||
ALL_T2I_CKPTS = [
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
"Lykon/DreamShaper",
|
||||
"segmind/SSD-1B",
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"kandinsky-community/kandinsky-2-2-decoder",
|
||||
@@ -21,7 +21,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="runwayml/stable-diffusion-v1-5",
|
||||
default="Lykon/DreamShaper",
|
||||
choices=ALL_T2I_CKPTS,
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
|
||||
@@ -161,6 +161,8 @@
|
||||
title: DeepCache
|
||||
- local: optimization/tgate
|
||||
title: TGATE
|
||||
- local: optimization/xdit
|
||||
title: xDiT
|
||||
- sections:
|
||||
- local: using-diffusers/stable_diffusion_jax_how_to
|
||||
title: JAX/Flax
|
||||
|
||||
@@ -29,6 +29,7 @@ The abstract of the paper is the following:
|
||||
| [AnimateDiffSparseControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py) | *Controlled Video-to-Video Generation with AnimateDiff using SparseCtrl* |
|
||||
| [AnimateDiffSDXLPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py) | *Video-to-Video Generation with AnimateDiff* |
|
||||
| [AnimateDiffVideoToVideoPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py) | *Video-to-Video Generation with AnimateDiff* |
|
||||
| [AnimateDiffVideoToVideoControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py) | *Video-to-Video Generation with AnimateDiff using ControlNet* |
|
||||
|
||||
## Available checkpoints
|
||||
|
||||
@@ -518,6 +519,97 @@ Here are some sample outputs:
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
||||
|
||||
### AnimateDiffVideoToVideoControlNetPipeline
|
||||
|
||||
AnimateDiff can be used together with ControlNets to enhance video-to-video generation by allowing for precise control over the output. ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala, and allows you to condition Stable Diffusion with an additional control image to ensure that the spatial information is preserved throughout the video.
|
||||
|
||||
This pipeline allows you to condition your generation both on the original video and on a sequence of control images.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from controlnet_aux.processor import OpenposeDetector
|
||||
from diffusers import AnimateDiffVideoToVideoControlNetPipeline
|
||||
from diffusers.utils import export_to_gif, load_video
|
||||
from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter, LCMScheduler
|
||||
|
||||
# Load the ControlNet
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16)
|
||||
# Load the motion adapter
|
||||
motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
|
||||
# Load SD 1.5 based finetuned model
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
|
||||
pipe = AnimateDiffVideoToVideoControlNetPipeline.from_pretrained(
|
||||
"SG161222/Realistic_Vision_V5.1_noVAE",
|
||||
motion_adapter=motion_adapter,
|
||||
controlnet=controlnet,
|
||||
vae=vae,
|
||||
).to(device="cuda", dtype=torch.float16)
|
||||
|
||||
# Enable LCM to speed up inference
|
||||
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
|
||||
pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora")
|
||||
pipe.set_adapters(["lcm-lora"], [0.8])
|
||||
|
||||
video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/dance.gif")
|
||||
video = [frame.convert("RGB") for frame in video]
|
||||
|
||||
prompt = "astronaut in space, dancing"
|
||||
negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly"
|
||||
|
||||
# Create controlnet preprocessor
|
||||
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
|
||||
|
||||
# Preprocess controlnet images
|
||||
conditioning_frames = []
|
||||
for frame in tqdm(video):
|
||||
conditioning_frames.append(open_pose(frame))
|
||||
|
||||
strength = 0.8
|
||||
with torch.inference_mode():
|
||||
video = pipe(
|
||||
video=video,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=10,
|
||||
guidance_scale=2.0,
|
||||
controlnet_conditioning_scale=0.75,
|
||||
conditioning_frames=conditioning_frames,
|
||||
strength=strength,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
).frames[0]
|
||||
|
||||
video = [frame.resize(conditioning_frames[0].size) for frame in video]
|
||||
export_to_gif(video, f"animatediff_vid2vid_controlnet.gif", fps=8)
|
||||
```
|
||||
|
||||
Here are some sample outputs:
|
||||
|
||||
<table align="center">
|
||||
<tr>
|
||||
<th align="center">Source Video</th>
|
||||
<th align="center">Output Video</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">
|
||||
anime girl, dancing
|
||||
<br />
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/dance.gif" alt="anime girl, dancing" />
|
||||
</td>
|
||||
<td align="center">
|
||||
astronaut in space, dancing
|
||||
<br/>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff_vid2vid_controlnet.gif" alt="astronaut in space, dancing" />
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
**The lights and composition were transferred from the Source Video.**
|
||||
|
||||
## Using Motion LoRAs
|
||||
|
||||
Motion LoRAs are a collection of LoRAs that work with the `guoyww/animatediff-motion-adapter-v1-5-2` checkpoint. These LoRAs are responsible for adding specific types of motion to the animations.
|
||||
@@ -866,6 +958,12 @@ pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapt
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## AnimateDiffVideoToVideoControlNetPipeline
|
||||
|
||||
[[autodoc]] AnimateDiffVideoToVideoControlNetPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## AnimateDiffPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.animatediff.AnimateDiffPipelineOutput
|
||||
|
||||
@@ -105,3 +105,11 @@ image.save("kolors_ipa_sample.png")
|
||||
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## KolorsImg2ImgPipeline
|
||||
|
||||
[[autodoc]] KolorsImg2ImgPipeline
|
||||
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
122
docs/source/en/optimization/xdit.md
Normal file
122
docs/source/en/optimization/xdit.md
Normal file
@@ -0,0 +1,122 @@
|
||||
# xDiT
|
||||
|
||||
[xDiT](https://github.com/xdit-project/xDiT) is an inference engine designed for the large scale parallel deployment of Diffusion Transformers (DiTs). xDiT provides a suite of efficient parallel approaches for Diffusion Models, as well as GPU kernel accelerations.
|
||||
|
||||
There are four parallel methods supported in xDiT, including [Unified Sequence Parallelism](https://arxiv.org/abs/2405.07719), [PipeFusion](https://arxiv.org/abs/2405.14430), CFG parallelism and data parallelism. The four parallel methods in xDiT can be configured in a hybrid manner, optimizing communication patterns to best suit the underlying network hardware.
|
||||
|
||||
Optimization orthogonal to parallelization focuses on accelerating single GPU performance. In addition to utilizing well-known Attention optimization libraries, we leverage compilation acceleration technologies such as torch.compile and onediff.
|
||||
|
||||
The overview of xDiT is shown as follows.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/methods/xdit_overview.png">
|
||||
</div>
|
||||
You can install xDiT using the following command:
|
||||
|
||||
|
||||
```bash
|
||||
pip install xfuser
|
||||
```
|
||||
|
||||
Here's an example of using xDiT to accelerate inference of a Diffusers model.
|
||||
|
||||
```diff
|
||||
import torch
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
|
||||
from xfuser import xFuserArgs, xDiTParallel
|
||||
from xfuser.config import FlexibleArgumentParser
|
||||
from xfuser.core.distributed import get_world_group
|
||||
|
||||
def main():
|
||||
+ parser = FlexibleArgumentParser(description="xFuser Arguments")
|
||||
+ args = xFuserArgs.add_cli_args(parser).parse_args()
|
||||
+ engine_args = xFuserArgs.from_cli_args(args)
|
||||
+ engine_config, input_config = engine_args.create_config()
|
||||
|
||||
local_rank = get_world_group().local_rank
|
||||
pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=engine_config.model_config.model,
|
||||
torch_dtype=torch.float16,
|
||||
).to(f"cuda:{local_rank}")
|
||||
|
||||
# do anything you want with pipeline here
|
||||
|
||||
+ pipe = xDiTParallel(pipe, engine_config, input_config)
|
||||
|
||||
pipe(
|
||||
height=input_config.height,
|
||||
width=input_config.height,
|
||||
prompt=input_config.prompt,
|
||||
num_inference_steps=input_config.num_inference_steps,
|
||||
output_type=input_config.output_type,
|
||||
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
|
||||
)
|
||||
|
||||
+ if input_config.output_type == "pil":
|
||||
+ pipe.save("results", "stable_diffusion_3")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
```
|
||||
|
||||
As you can see, we only need to use xFuserArgs from xDiT to get configuration parameters, and pass these parameters along with the pipeline object from the Diffusers library into xDiTParallel to complete the parallelization of a specific pipeline in Diffusers.
|
||||
|
||||
xDiT runtime parameters can be viewed in the command line using `-h`, and you can refer to this [usage](https://github.com/xdit-project/xDiT?tab=readme-ov-file#2-usage) example for more details.
|
||||
|
||||
xDiT needs to be launched using torchrun to support its multi-node, multi-GPU parallel capabilities. For example, the following command can be used for 8-GPU parallel inference:
|
||||
|
||||
```bash
|
||||
torchrun --nproc_per_node=8 ./inference.py --model models/FLUX.1-dev --data_parallel_degree 2 --ulysses_degree 2 --ring_degree 2 --prompt "A snowy mountain" "A small dog" --num_inference_steps 50
|
||||
```
|
||||
|
||||
## Supported models
|
||||
|
||||
A subset of Diffusers models are supported in xDiT, such as Flux.1, Stable Diffusion 3, etc. The latest supported models can be found [here](https://github.com/xdit-project/xDiT?tab=readme-ov-file#-supported-dits).
|
||||
|
||||
## Benchmark
|
||||
We tested different models on various machines, and here is some of the benchmark data.
|
||||
|
||||
|
||||
### Flux.1-schnell
|
||||
<div class="flex justify-center">
|
||||
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/flux/Flux-2k-L40.png">
|
||||
</div>
|
||||
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/flux/Flux-2K-A100.png">
|
||||
</div>
|
||||
|
||||
### Stable Diffusion 3
|
||||
<div class="flex justify-center">
|
||||
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/sd3/L40-SD3.png">
|
||||
</div>
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/sd3/A100-SD3.png">
|
||||
</div>
|
||||
|
||||
### HunyuanDiT
|
||||
<div class="flex justify-center">
|
||||
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/hunuyuandit/L40-HunyuanDiT.png">
|
||||
</div>
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/hunuyuandit/A100-HunyuanDiT.png">
|
||||
</div>
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://github.com/xdit-project/xDiT/raw/main/assets/performance/hunuyuandit/T4-HunyuanDiT.png">
|
||||
</div>
|
||||
|
||||
More detailed performance metric can be found on our [github page](https://github.com/xdit-project/xDiT?tab=readme-ov-file#perf).
|
||||
|
||||
## Reference
|
||||
|
||||
[xDiT-project](https://github.com/xdit-project/xDiT)
|
||||
|
||||
[USP: A Unified Sequence Parallelism Approach for Long Context Generative AI](https://arxiv.org/abs/2405.07719)
|
||||
|
||||
[PipeFusion: Displaced Patch Pipeline Parallelism for Inference of Diffusion Transformer Models](https://arxiv.org/abs/2405.14430)
|
||||
@@ -245,6 +245,7 @@ else:
|
||||
"AnimateDiffPipeline",
|
||||
"AnimateDiffSDXLPipeline",
|
||||
"AnimateDiffSparseControlNetPipeline",
|
||||
"AnimateDiffVideoToVideoControlNetPipeline",
|
||||
"AnimateDiffVideoToVideoPipeline",
|
||||
"AudioLDM2Pipeline",
|
||||
"AudioLDM2ProjectionModel",
|
||||
@@ -694,6 +695,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AnimateDiffPipeline,
|
||||
AnimateDiffSDXLPipeline,
|
||||
AnimateDiffSparseControlNetPipeline,
|
||||
AnimateDiffVideoToVideoControlNetPipeline,
|
||||
AnimateDiffVideoToVideoPipeline,
|
||||
AudioLDM2Pipeline,
|
||||
AudioLDM2ProjectionModel,
|
||||
|
||||
@@ -242,9 +242,12 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@classmethod
|
||||
def from_transformer(cls, transformer, num_layers=12, load_weights_from_transformer=True):
|
||||
def from_transformer(
|
||||
cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True
|
||||
):
|
||||
config = transformer.config
|
||||
config["num_layers"] = num_layers or config.num_layers
|
||||
config["extra_conditioning_channels"] = num_extra_conditioning_channels
|
||||
controlnet = cls(**config)
|
||||
|
||||
if load_weights_from_transformer:
|
||||
|
||||
@@ -123,6 +123,7 @@ else:
|
||||
"AnimateDiffSDXLPipeline",
|
||||
"AnimateDiffSparseControlNetPipeline",
|
||||
"AnimateDiffVideoToVideoPipeline",
|
||||
"AnimateDiffVideoToVideoControlNetPipeline",
|
||||
]
|
||||
_import_structure["flux"] = [
|
||||
"FluxControlNetPipeline",
|
||||
@@ -449,6 +450,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AnimateDiffPipeline,
|
||||
AnimateDiffSDXLPipeline,
|
||||
AnimateDiffSparseControlNetPipeline,
|
||||
AnimateDiffVideoToVideoControlNetPipeline,
|
||||
AnimateDiffVideoToVideoPipeline,
|
||||
)
|
||||
from .audioldm import AudioLDMPipeline
|
||||
|
||||
@@ -26,6 +26,7 @@ else:
|
||||
_import_structure["pipeline_animatediff_sdxl"] = ["AnimateDiffSDXLPipeline"]
|
||||
_import_structure["pipeline_animatediff_sparsectrl"] = ["AnimateDiffSparseControlNetPipeline"]
|
||||
_import_structure["pipeline_animatediff_video2video"] = ["AnimateDiffVideoToVideoPipeline"]
|
||||
_import_structure["pipeline_animatediff_video2video_controlnet"] = ["AnimateDiffVideoToVideoControlNetPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -40,6 +41,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_animatediff_sdxl import AnimateDiffSDXLPipeline
|
||||
from .pipeline_animatediff_sparsectrl import AnimateDiffSparseControlNetPipeline
|
||||
from .pipeline_animatediff_video2video import AnimateDiffVideoToVideoPipeline
|
||||
from .pipeline_animatediff_video2video_controlnet import AnimateDiffVideoToVideoControlNetPipeline
|
||||
from .pipeline_output import AnimateDiffPipelineOutput
|
||||
|
||||
else:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1024,14 +1024,16 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
if denoising_start is None:
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
else:
|
||||
t_start = 0
|
||||
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
# Strength is irrelevant if we directly request a timestep to start at;
|
||||
# that is, strength is determined by the denoising_start instead.
|
||||
if denoising_start is not None:
|
||||
# Strength is irrelevant if we directly request a timestep to start at;
|
||||
# that is, strength is determined by the denoising_start instead.
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
@@ -1039,7 +1041,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
)
|
||||
)
|
||||
|
||||
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
|
||||
num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
|
||||
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
|
||||
# if the scheduler is a 2nd order scheduler we might have to do +1
|
||||
# because `num_inference_steps` might be even given that every timestep
|
||||
@@ -1050,11 +1052,12 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
num_inference_steps = num_inference_steps + 1
|
||||
|
||||
# because t_n+1 >= t_n, we slice the timesteps starting from the end
|
||||
timesteps = timesteps[-num_inference_steps:]
|
||||
t_start = len(self.scheduler.timesteps) - num_inference_steps
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start)
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def _get_add_time_ids(
|
||||
self,
|
||||
original_size,
|
||||
|
||||
@@ -564,14 +564,16 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
|
||||
if denoising_start is None:
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
else:
|
||||
t_start = 0
|
||||
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
# Strength is irrelevant if we directly request a timestep to start at;
|
||||
# that is, strength is determined by the denoising_start instead.
|
||||
if denoising_start is not None:
|
||||
# Strength is irrelevant if we directly request a timestep to start at;
|
||||
# that is, strength is determined by the denoising_start instead.
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
@@ -579,7 +581,7 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
|
||||
)
|
||||
)
|
||||
|
||||
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
|
||||
num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
|
||||
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
|
||||
# if the scheduler is a 2nd order scheduler we might have to do +1
|
||||
# because `num_inference_steps` might be even given that every timestep
|
||||
@@ -590,11 +592,12 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
|
||||
num_inference_steps = num_inference_steps + 1
|
||||
|
||||
# because t_n+1 >= t_n, we slice the timesteps starting from the end
|
||||
timesteps = timesteps[-num_inference_steps:]
|
||||
t_start = len(self.scheduler.timesteps) - num_inference_steps
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start)
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
|
||||
|
||||
@@ -648,14 +648,16 @@ class StableDiffusionXLPAGImg2ImgPipeline(
|
||||
if denoising_start is None:
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
else:
|
||||
t_start = 0
|
||||
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
# Strength is irrelevant if we directly request a timestep to start at;
|
||||
# that is, strength is determined by the denoising_start instead.
|
||||
if denoising_start is not None:
|
||||
# Strength is irrelevant if we directly request a timestep to start at;
|
||||
# that is, strength is determined by the denoising_start instead.
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
@@ -663,7 +665,7 @@ class StableDiffusionXLPAGImg2ImgPipeline(
|
||||
)
|
||||
)
|
||||
|
||||
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
|
||||
num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
|
||||
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
|
||||
# if the scheduler is a 2nd order scheduler we might have to do +1
|
||||
# because `num_inference_steps` might be even given that every timestep
|
||||
@@ -674,11 +676,12 @@ class StableDiffusionXLPAGImg2ImgPipeline(
|
||||
num_inference_steps = num_inference_steps + 1
|
||||
|
||||
# because t_n+1 >= t_n, we slice the timesteps starting from the end
|
||||
timesteps = timesteps[-num_inference_steps:]
|
||||
t_start = len(self.scheduler.timesteps) - num_inference_steps
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start)
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
|
||||
|
||||
@@ -897,14 +897,16 @@ class StableDiffusionXLPAGInpaintPipeline(
|
||||
if denoising_start is None:
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
else:
|
||||
t_start = 0
|
||||
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
# Strength is irrelevant if we directly request a timestep to start at;
|
||||
# that is, strength is determined by the denoising_start instead.
|
||||
if denoising_start is not None:
|
||||
# Strength is irrelevant if we directly request a timestep to start at;
|
||||
# that is, strength is determined by the denoising_start instead.
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
@@ -912,7 +914,7 @@ class StableDiffusionXLPAGInpaintPipeline(
|
||||
)
|
||||
)
|
||||
|
||||
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
|
||||
num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
|
||||
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
|
||||
# if the scheduler is a 2nd order scheduler we might have to do +1
|
||||
# because `num_inference_steps` might be even given that every timestep
|
||||
@@ -923,11 +925,12 @@ class StableDiffusionXLPAGInpaintPipeline(
|
||||
num_inference_steps = num_inference_steps + 1
|
||||
|
||||
# because t_n+1 >= t_n, we slice the timesteps starting from the end
|
||||
timesteps = timesteps[-num_inference_steps:]
|
||||
t_start = len(self.scheduler.timesteps) - num_inference_steps
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start)
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
|
||||
def _get_add_time_ids(
|
||||
self,
|
||||
|
||||
@@ -640,14 +640,16 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
if denoising_start is None:
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
else:
|
||||
t_start = 0
|
||||
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
# Strength is irrelevant if we directly request a timestep to start at;
|
||||
# that is, strength is determined by the denoising_start instead.
|
||||
if denoising_start is not None:
|
||||
# Strength is irrelevant if we directly request a timestep to start at;
|
||||
# that is, strength is determined by the denoising_start instead.
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
@@ -655,7 +657,7 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
)
|
||||
)
|
||||
|
||||
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
|
||||
num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
|
||||
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
|
||||
# if the scheduler is a 2nd order scheduler we might have to do +1
|
||||
# because `num_inference_steps` might be even given that every timestep
|
||||
@@ -666,11 +668,12 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
num_inference_steps = num_inference_steps + 1
|
||||
|
||||
# because t_n+1 >= t_n, we slice the timesteps starting from the end
|
||||
timesteps = timesteps[-num_inference_steps:]
|
||||
t_start = len(self.scheduler.timesteps) - num_inference_steps
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start)
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_latents(
|
||||
self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
|
||||
):
|
||||
|
||||
@@ -901,14 +901,16 @@ class StableDiffusionXLInpaintPipeline(
|
||||
if denoising_start is None:
|
||||
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
else:
|
||||
t_start = 0
|
||||
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
# Strength is irrelevant if we directly request a timestep to start at;
|
||||
# that is, strength is determined by the denoising_start instead.
|
||||
if denoising_start is not None:
|
||||
# Strength is irrelevant if we directly request a timestep to start at;
|
||||
# that is, strength is determined by the denoising_start instead.
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
@@ -916,7 +918,7 @@ class StableDiffusionXLInpaintPipeline(
|
||||
)
|
||||
)
|
||||
|
||||
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
|
||||
num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
|
||||
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
|
||||
# if the scheduler is a 2nd order scheduler we might have to do +1
|
||||
# because `num_inference_steps` might be even given that every timestep
|
||||
@@ -927,11 +929,12 @@ class StableDiffusionXLInpaintPipeline(
|
||||
num_inference_steps = num_inference_steps + 1
|
||||
|
||||
# because t_n+1 >= t_n, we slice the timesteps starting from the end
|
||||
timesteps = timesteps[-num_inference_steps:]
|
||||
t_start = len(self.scheduler.timesteps) - num_inference_steps
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start)
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
|
||||
def _get_add_time_ids(
|
||||
self,
|
||||
|
||||
@@ -152,6 +152,21 @@ class AnimateDiffSparseControlNetPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AnimateDiffVideoToVideoControlNetPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -0,0 +1,535 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
AnimateDiffVideoToVideoControlNetPipeline,
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
LCMScheduler,
|
||||
MotionAdapter,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
UNetMotionModel,
|
||||
)
|
||||
from diffusers.models.attention import FreeNoiseTransformerBlock
|
||||
from diffusers.utils import is_xformers_available, logging
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS, VIDEO_TO_VIDEO_BATCH_PARAMS
|
||||
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineFromPipeTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
def to_np(tensor):
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
tensor = tensor.detach().cpu().numpy()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
class AnimateDiffVideoToVideoControlNetPipelineFastTests(
|
||||
IPAdapterTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = AnimateDiffVideoToVideoControlNetPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = VIDEO_TO_VIDEO_BATCH_PARAMS.union({"conditioning_frames"})
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
|
||||
def get_dummy_components(self):
|
||||
cross_attention_dim = 8
|
||||
block_out_channels = (8, 8)
|
||||
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=2,
|
||||
sample_size=8,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="linear",
|
||||
clip_sample=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
controlnet = ControlNetModel(
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=2,
|
||||
in_channels=4,
|
||||
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
conditioning_embedding_out_channels=(8, 8),
|
||||
norm_num_groups=1,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=block_out_channels,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
norm_num_groups=2,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=cross_attention_dim,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
torch.manual_seed(0)
|
||||
motion_adapter = MotionAdapter(
|
||||
block_out_channels=block_out_channels,
|
||||
motion_layers_per_block=2,
|
||||
motion_norm_num_groups=2,
|
||||
motion_num_attention_heads=4,
|
||||
)
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"controlnet": controlnet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"motion_adapter": motion_adapter,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"feature_extractor": None,
|
||||
"image_encoder": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0, num_frames: int = 2):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
video_height = 32
|
||||
video_width = 32
|
||||
video = [Image.new("RGB", (video_width, video_height))] * num_frames
|
||||
|
||||
video_height = 32
|
||||
video_width = 32
|
||||
conditioning_frames = [Image.new("RGB", (video_width, video_height))] * num_frames
|
||||
|
||||
inputs = {
|
||||
"video": video,
|
||||
"conditioning_frames": conditioning_frames,
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_from_pipe_consistent_config(self):
|
||||
assert self.original_pipeline_class == StableDiffusionPipeline
|
||||
original_repo = "hf-internal-testing/tinier-stable-diffusion-pipe"
|
||||
original_kwargs = {"requires_safety_checker": False}
|
||||
|
||||
# create original_pipeline_class(sd)
|
||||
pipe_original = self.original_pipeline_class.from_pretrained(original_repo, **original_kwargs)
|
||||
|
||||
# original_pipeline_class(sd) -> pipeline_class
|
||||
pipe_components = self.get_dummy_components()
|
||||
pipe_additional_components = {}
|
||||
for name, component in pipe_components.items():
|
||||
if name not in pipe_original.components:
|
||||
pipe_additional_components[name] = component
|
||||
|
||||
pipe = self.pipeline_class.from_pipe(pipe_original, **pipe_additional_components)
|
||||
|
||||
# pipeline_class -> original_pipeline_class(sd)
|
||||
original_pipe_additional_components = {}
|
||||
for name, component in pipe_original.components.items():
|
||||
if name not in pipe.components or not isinstance(component, pipe.components[name].__class__):
|
||||
original_pipe_additional_components[name] = component
|
||||
|
||||
pipe_original_2 = self.original_pipeline_class.from_pipe(pipe, **original_pipe_additional_components)
|
||||
|
||||
# compare the config
|
||||
original_config = {k: v for k, v in pipe_original.config.items() if not k.startswith("_")}
|
||||
original_config_2 = {k: v for k, v in pipe_original_2.config.items() if not k.startswith("_")}
|
||||
assert original_config_2 == original_config
|
||||
|
||||
def test_motion_unet_loading(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = AnimateDiffVideoToVideoControlNetPipeline(**components)
|
||||
|
||||
assert isinstance(pipe.unet, UNetMotionModel)
|
||||
|
||||
@unittest.skip("Attention slicing is not enabled in this pipeline")
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
pass
|
||||
|
||||
def test_ip_adapter(self):
|
||||
expected_pipe_slice = None
|
||||
if torch_device == "cpu":
|
||||
expected_pipe_slice = np.array(
|
||||
[
|
||||
0.5569,
|
||||
0.6250,
|
||||
0.4144,
|
||||
0.5613,
|
||||
0.5563,
|
||||
0.5213,
|
||||
0.5091,
|
||||
0.4950,
|
||||
0.4950,
|
||||
0.5684,
|
||||
0.3858,
|
||||
0.4863,
|
||||
0.6457,
|
||||
0.4311,
|
||||
0.5517,
|
||||
0.5608,
|
||||
0.4417,
|
||||
0.5377,
|
||||
]
|
||||
)
|
||||
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
|
||||
|
||||
def test_inference_batch_single_identical(
|
||||
self,
|
||||
batch_size=2,
|
||||
expected_max_diff=1e-4,
|
||||
additional_params_copy_to_batched_inputs=["num_inference_steps"],
|
||||
):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for components in pipe.components.values():
|
||||
if hasattr(components, "set_default_attn_processor"):
|
||||
components.set_default_attn_processor()
|
||||
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
# Reset generator in case it is has been used in self.get_dummy_inputs
|
||||
inputs["generator"] = self.get_generator(0)
|
||||
|
||||
logger = logging.get_logger(pipe.__module__)
|
||||
logger.setLevel(level=diffusers.logging.FATAL)
|
||||
|
||||
# batchify inputs
|
||||
batched_inputs = {}
|
||||
batched_inputs.update(inputs)
|
||||
|
||||
for name in self.batch_params:
|
||||
if name not in inputs:
|
||||
continue
|
||||
|
||||
value = inputs[name]
|
||||
if name == "prompt":
|
||||
len_prompt = len(value)
|
||||
batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
|
||||
batched_inputs[name][-1] = 100 * "very long"
|
||||
|
||||
else:
|
||||
batched_inputs[name] = batch_size * [value]
|
||||
|
||||
if "generator" in inputs:
|
||||
batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
|
||||
|
||||
if "batch_size" in inputs:
|
||||
batched_inputs["batch_size"] = batch_size
|
||||
|
||||
for arg in additional_params_copy_to_batched_inputs:
|
||||
batched_inputs[arg] = inputs[arg]
|
||||
|
||||
output = pipe(**inputs)
|
||||
output_batch = pipe(**batched_inputs)
|
||||
|
||||
assert output_batch[0].shape[0] == batch_size
|
||||
|
||||
max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
|
||||
assert max_diff < expected_max_diff
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
|
||||
def test_to_device(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.to("cpu")
|
||||
# pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
self.assertTrue(all(device == "cpu" for device in model_devices))
|
||||
|
||||
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
|
||||
self.assertTrue(np.isnan(output_cpu).sum() == 0)
|
||||
|
||||
pipe.to("cuda")
|
||||
model_devices = [
|
||||
component.device.type for component in pipe.components.values() if hasattr(component, "device")
|
||||
]
|
||||
self.assertTrue(all(device == "cuda" for device in model_devices))
|
||||
|
||||
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
|
||||
self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
|
||||
|
||||
def test_to_dtype(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components
|
||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||
|
||||
pipe.to(dtype=torch.float16)
|
||||
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||
|
||||
def test_prompt_embeds(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs.pop("prompt")
|
||||
inputs["prompt_embeds"] = torch.randn((1, 4, pipe.text_encoder.config.hidden_size), device=torch_device)
|
||||
pipe(**inputs)
|
||||
|
||||
def test_latent_inputs(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
sample_size = pipe.unet.config.sample_size
|
||||
num_frames = len(inputs["conditioning_frames"])
|
||||
inputs["latents"] = torch.randn((1, 4, num_frames, sample_size, sample_size), device=torch_device)
|
||||
inputs.pop("video")
|
||||
pipe(**inputs)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_without_offload = pipe(**inputs).frames[0]
|
||||
output_without_offload = (
|
||||
output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload
|
||||
)
|
||||
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_with_offload = pipe(**inputs).frames[0]
|
||||
output_with_offload = (
|
||||
output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload
|
||||
)
|
||||
|
||||
max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
|
||||
self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")
|
||||
|
||||
def test_free_init(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: AnimateDiffVideoToVideoControlNetPipeline = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
inputs_normal = self.get_dummy_inputs(torch_device)
|
||||
frames_normal = pipe(**inputs_normal).frames[0]
|
||||
|
||||
pipe.enable_free_init(
|
||||
num_iters=2,
|
||||
use_fast_sampling=True,
|
||||
method="butterworth",
|
||||
order=4,
|
||||
spatial_stop_frequency=0.25,
|
||||
temporal_stop_frequency=0.25,
|
||||
)
|
||||
inputs_enable_free_init = self.get_dummy_inputs(torch_device)
|
||||
frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0]
|
||||
|
||||
pipe.disable_free_init()
|
||||
inputs_disable_free_init = self.get_dummy_inputs(torch_device)
|
||||
frames_disable_free_init = pipe(**inputs_disable_free_init).frames[0]
|
||||
|
||||
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
|
||||
max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max()
|
||||
self.assertGreater(
|
||||
sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results"
|
||||
)
|
||||
self.assertLess(
|
||||
max_diff_disabled,
|
||||
1e-4,
|
||||
"Disabling of FreeInit should lead to results similar to the default pipeline results",
|
||||
)
|
||||
|
||||
def test_free_init_with_schedulers(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: AnimateDiffVideoToVideoControlNetPipeline = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
inputs_normal = self.get_dummy_inputs(torch_device)
|
||||
frames_normal = pipe(**inputs_normal).frames[0]
|
||||
|
||||
schedulers_to_test = [
|
||||
DPMSolverMultistepScheduler.from_config(
|
||||
components["scheduler"].config,
|
||||
timestep_spacing="linspace",
|
||||
beta_schedule="linear",
|
||||
algorithm_type="dpmsolver++",
|
||||
steps_offset=1,
|
||||
clip_sample=False,
|
||||
),
|
||||
LCMScheduler.from_config(
|
||||
components["scheduler"].config,
|
||||
timestep_spacing="linspace",
|
||||
beta_schedule="linear",
|
||||
steps_offset=1,
|
||||
clip_sample=False,
|
||||
),
|
||||
]
|
||||
components.pop("scheduler")
|
||||
|
||||
for scheduler in schedulers_to_test:
|
||||
components["scheduler"] = scheduler
|
||||
pipe: AnimateDiffVideoToVideoControlNetPipeline = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
pipe.enable_free_init(num_iters=2, use_fast_sampling=False)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
frames_enable_free_init = pipe(**inputs).frames[0]
|
||||
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
|
||||
|
||||
self.assertGreater(
|
||||
sum_enabled,
|
||||
1e1,
|
||||
"Enabling of FreeInit should lead to results different from the default pipeline results",
|
||||
)
|
||||
|
||||
def test_free_noise_blocks(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: AnimateDiffVideoToVideoControlNetPipeline = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
pipe.enable_free_noise()
|
||||
for block in pipe.unet.down_blocks:
|
||||
for motion_module in block.motion_modules:
|
||||
for transformer_block in motion_module.transformer_blocks:
|
||||
self.assertTrue(
|
||||
isinstance(transformer_block, FreeNoiseTransformerBlock),
|
||||
"Motion module transformer blocks must be an instance of `FreeNoiseTransformerBlock` after enabling FreeNoise.",
|
||||
)
|
||||
|
||||
pipe.disable_free_noise()
|
||||
for block in pipe.unet.down_blocks:
|
||||
for motion_module in block.motion_modules:
|
||||
for transformer_block in motion_module.transformer_blocks:
|
||||
self.assertFalse(
|
||||
isinstance(transformer_block, FreeNoiseTransformerBlock),
|
||||
"Motion module transformer blocks must not be an instance of `FreeNoiseTransformerBlock` after disabling FreeNoise.",
|
||||
)
|
||||
|
||||
def test_free_noise(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: AnimateDiffVideoToVideoControlNetPipeline = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
inputs_normal = self.get_dummy_inputs(torch_device, num_frames=16)
|
||||
inputs_normal["num_inference_steps"] = 2
|
||||
inputs_normal["strength"] = 0.5
|
||||
frames_normal = pipe(**inputs_normal).frames[0]
|
||||
|
||||
for context_length in [8, 9]:
|
||||
for context_stride in [4, 6]:
|
||||
pipe.enable_free_noise(context_length, context_stride)
|
||||
|
||||
inputs_enable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16)
|
||||
inputs_enable_free_noise["num_inference_steps"] = 2
|
||||
inputs_enable_free_noise["strength"] = 0.5
|
||||
frames_enable_free_noise = pipe(**inputs_enable_free_noise).frames[0]
|
||||
|
||||
pipe.disable_free_noise()
|
||||
inputs_disable_free_noise = self.get_dummy_inputs(torch_device, num_frames=16)
|
||||
inputs_disable_free_noise["num_inference_steps"] = 2
|
||||
inputs_disable_free_noise["strength"] = 0.5
|
||||
frames_disable_free_noise = pipe(**inputs_disable_free_noise).frames[0]
|
||||
|
||||
sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_noise)).sum()
|
||||
max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_noise)).max()
|
||||
self.assertGreater(
|
||||
sum_enabled,
|
||||
1e1,
|
||||
"Enabling of FreeNoise should lead to results different from the default pipeline results",
|
||||
)
|
||||
self.assertLess(
|
||||
max_diff_disabled,
|
||||
1e-4,
|
||||
"Disabling of FreeNoise should lead to results similar to the default pipeline results",
|
||||
)
|
||||
|
||||
def test_free_noise_multi_prompt(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe: AnimateDiffVideoToVideoControlNetPipeline = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
context_length = 8
|
||||
context_stride = 4
|
||||
pipe.enable_free_noise(context_length, context_stride)
|
||||
|
||||
# Make sure that pipeline works when prompt indices are within num_frames bounds
|
||||
inputs = self.get_dummy_inputs(torch_device, num_frames=16)
|
||||
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"}
|
||||
inputs["num_inference_steps"] = 2
|
||||
inputs["strength"] = 0.5
|
||||
pipe(**inputs).frames[0]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Ensure that prompt indices are within bounds
|
||||
inputs = self.get_dummy_inputs(torch_device, num_frames=16)
|
||||
inputs["num_inference_steps"] = 2
|
||||
inputs["strength"] = 0.5
|
||||
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"}
|
||||
pipe(**inputs).frames[0]
|
||||
Reference in New Issue
Block a user