1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
Files
Steven Liu cc5b31ffc9 [docs] Migrate syntax (#12390)
* change syntax

* make style
2025-09-30 10:11:19 -07:00

305 lines
15 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# 加速推理
Diffusion模型在推理时速度较慢因为生成是一个迭代过程需要经过一定数量的"步数"逐步将噪声细化为图像或视频。要加速这一过程,您可以尝试使用不同的[调度器](../api/schedulers/overview)、降低模型权重的精度以加快计算、使用更高效的内存注意力机制等方法。
将这些技术组合使用,可以比单独使用任何一种技术获得更快的推理速度。
本指南将介绍如何加速推理。
## 模型数据类型
模型权重的精度和数据类型会影响推理速度因为更高的精度需要更多内存来加载也需要更多时间进行计算。PyTorch默认以float32或全精度加载模型权重因此更改数据类型是快速获得更快推理速度的简单方法。
<hfoptions id="dtypes">
<hfoption id="bfloat16">
bfloat16与float16类似但对数值误差更稳健。硬件对bfloat16的支持各不相同但大多数现代GPU都能支持bfloat16。
```py
import torch
from diffusers import StableDiffusionXLPipeline
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
pipeline(prompt, num_inference_steps=30).images[0]
```
</hfoption>
<hfoption id="float16">
float16与bfloat16类似但可能更容易出现数值误差。
```py
import torch
from diffusers import StableDiffusionXLPipeline
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
pipeline(prompt, num_inference_steps=30).images[0]
```
</hfoption>
<hfoption id="TensorFloat-32">
[TensorFloat-32 (tf32)](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/)模式在NVIDIA Ampere GPU上受支持它以tf32计算卷积和矩阵乘法运算。存储和其他操作保持在float32。与bfloat16或float16结合使用时可以显著加快计算速度。
PyTorch默认仅对卷积启用tf32模式您需要显式启用矩阵乘法的tf32模式。
```py
import torch
from diffusers import StableDiffusionXLPipeline
torch.backends.cuda.matmul.allow_tf32 = True
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
pipeline(prompt, num_inference_steps=30).images[0]
```
更多详情请参阅[混合精度训练](https://huggingface.co/docs/transformers/en/perf_train_gpu_one#mixed-precision)文档。
</hfoption>
</hfoptions>
## 缩放点积注意力
> [!TIP]
> 内存高效注意力优化了推理速度*和*[内存使用](./memory#memory-efficient-attention)
[缩放点积注意力SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)实现了多种注意力后端,包括[FlashAttention](https://github.com/Dao-AILab/flash-attention)、[xFormers](https://github.com/facebookresearch/xformers)和原生C++实现。它会根据您的硬件自动选择最优的后端。
如果您使用的是PyTorch >= 2.0SDPA默认启用无需对代码进行任何额外更改。不过您也可以尝试使用其他注意力后端来自行选择。下面的示例使用[torch.nn.attention.sdpa_kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html)上下文管理器来启用高效注意力。
```py
from torch.nn.attention import SDPBackend, sdpa_kernel
import torch
from diffusers import StableDiffusionXLPipeline
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
image = pipeline(prompt, num_inference_steps=30).images[0]
```
## torch.compile
[torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)通过将PyTorch代码和操作编译为优化的内核来加速推理。Diffusers通常会编译计算密集型的模型如UNet、transformer或VAE。
启用以下编译器设置以获得最大速度(更多选项请参阅[完整列表](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py))。
```py
import torch
from diffusers import StableDiffusionXLPipeline
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
```
加载并编译UNet和VAE。有几种不同的模式可供选择`"max-autotune"`通过编译为CUDA图来优化速度。CUDA图通过单个CPU操作启动多个GPU操作有效减少了开销。
> [!TIP]
> 在PyTorch 2.3.1中您可以控制torch.compile的缓存行为。这对于像`"max-autotune"`这样的编译模式特别有用,它会通过网格搜索多个编译标志来找到最优配置。更多详情请参阅[torch.compile中的编译时间缓存](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html)教程。
将内存布局更改为[channels_last](./memory#torchchannels_last)也可以优化内存和推理速度。
```py
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.unet.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
pipeline.unet = torch.compile(
pipeline.unet, mode="max-autotune", fullgraph=True
)
pipeline.vae.decode = torch.compile(
pipeline.vae.decode,
mode="max-autotune",
fullgraph=True
)
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
pipeline(prompt, num_inference_steps=30).images[0]
```
第一次编译时速度较慢,但一旦编译完成,速度会显著提升。尽量只在相同类型的推理操作上使用编译后的管道。在不同尺寸的图像上调用编译后的管道会重新触发编译,这会很慢且效率低下。
### 动态形状编译
> [!TIP]
> 确保始终使用PyTorch的nightly版本以获得更好的支持。
`torch.compile`会跟踪输入形状和条件如果这些不同它会重新编译模型。例如如果模型是在1024x1024分辨率的图像上编译的而在不同分辨率的图像上使用就会触发重新编译。
为避免重新编译,添加`dynamic=True`以尝试生成更动态的内核,避免条件变化时重新编译。
```diff
+ torch.fx.experimental._config.use_duck_shape = False
+ pipeline.unet = torch.compile(
pipeline.unet, fullgraph=True, dynamic=True
)
```
指定`use_duck_shape=False`会指示编译器是否应使用相同的符号变量来表示相同大小的输入。更多详情请参阅此[评论](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790)。
并非所有模型都能开箱即用地从动态编译中受益,可能需要更改。参考此[PR](https://github.com/huggingface/diffusers/pull/11297/),它改进了[`AuraFlowPipeline`]的实现以受益于动态编译。
如果动态编译对Diffusers模型的效果不如预期请随时提出问题。
### 区域编译
[区域编译](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html)通过仅编译模型中*小而频繁重复的块*通常是transformer层来减少冷启动延迟并为每个后续出现的块重用编译后的工件。对于许多diffusion架构这提供了与全图编译相同的运行时加速并将编译时间减少了8-10倍。
使用[`~ModelMixin.compile_repeated_blocks`]方法(一个包装`torch.compile`的辅助函数在任何组件如transformer模型如下所示。
```py
# pip install -U diffusers
import torch
from diffusers import StableDiffusionXLPipeline
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
).to("cuda")
# 仅编译UNet中重复的transformer层
pipeline.unet.compile_repeated_blocks(fullgraph=True)
```
要为新模型启用区域编译,请在模型类中添加一个`_repeated_blocks`属性,包含您想要编译的块的类名(作为字符串)。
```py
class MyUNet(ModelMixin):
_repeated_blocks = ("Transformer2DModel",) # ← 默认编译
```
> [!TIP]
> 更多区域编译示例,请参阅参考[PR](https://github.com/huggingface/diffusers/pull/11705)。
[Accelerate](https://huggingface.co/docs/accelerate/index)中还有一个[compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78)方法,可以自动选择模型中的候选块进行编译。其余图会单独编译。这对于快速实验很有用,因为您不需要设置哪些块要编译或调整编译标志。
```py
# pip install -U accelerate
import torch
from diffusers import StableDiffusionXLPipeline
from accelerate.utils import compile regions
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
```
[`~ModelMixin.compile_repeated_blocks`]是故意显式的。在`_repeated_blocks`中列出要重复的块,辅助函数仅编译这些块。它提供了可预测的行为,并且只需一行代码即可轻松推理缓存重用。
### 图中断
在torch.compile中指定`fullgraph=True`非常重要以确保底层模型中没有图中断。这使您可以充分利用torch.compile而不会降低性能。对于UNet和VAE这会改变您访问返回变量的方式。
```diff
- latents = unet(
- latents, timestep=timestep, encoder_hidden_states=prompt_embeds
-).sample
+ latents = unet(
+ latents, timestep=timestep, encoder_hidden_states=prompt_embeds, return_dict=False
+)[0]
```
### GPU同步
每次去噪器做出预测后,调度器的`step()`函数会被[调用](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L1228),并且`sigmas`变量会被[索引](https://github.com/huggingface/diffusers/blob/1d686bac8146037e97f3fd8c56e4063230f71751/src/diffusers/schedulers/scheduling_euler_discrete.py#L476)。当放在GPU上时这会引入延迟因为CPU和GPU之间需要进行通信同步。当去噪器已经编译时这一点会更加明显。
一般来说,`sigmas`应该[保持在CPU上](https://github.com/huggingface/diffusers/blob/35a969d297cba69110d175ee79c59312b9f49e1e/src/diffusers/schedulers/scheduling_euler_discrete.py#L240),以避免通信同步和延迟。
> [!TIP]
> 参阅[torch.compile和Diffusers峰值性能实践指南](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/)博客文章,了解如何为扩散模型最大化`torch.compile`的性能。
### 基准测试
参阅[diffusers/benchmarks](https://huggingface.co/datasets/diffusers/benchmarks)数据集,查看编译管道的推理延迟和内存使用数据。
[diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao#benchmarking-results)仓库还包含Flux和CogVideoX编译版本的基准测试结果。
## 动态量化
[动态量化](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html)通过降低精度以加快数学运算来提高推理速度。这种特定类型的量化在运行时根据数据确定如何缩放激活,而不是使用固定的缩放因子。因此,缩放因子与数据更准确地匹配。
以下示例使用[torchao](../quantization/torchao)库对UNet和VAE应用[动态int8量化](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html)。
> [!TIP]
> 参阅我们的[torchao](../quantization/torchao)文档了解更多关于如何使用Diffusers torchao集成的信息。
配置编译器标志以获得最大速度。
```py
import torch
from torchao import apply_dynamic_quant
from diffusers import StableDiffusionXLPipeline
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True
```
使用[dynamic_quant_filter_fn](https://github.com/huggingface/diffusion-fast/blob/0f169640b1db106fe6a479f78c1ed3bfaeba3386/utils/pipeline_utils.py#L16)过滤掉UNet和VAE中一些不会从动态量化中受益的线性层。
```py
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")
apply_dynamic_quant(pipeline.unet, dynamic_quant_filter_fn)
apply_dynamic_quant(pipeline.vae, dynamic_quant_filter_fn)
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
pipeline(prompt, num_inference_steps=30).images[0]
```
## 融合投影矩阵
> [!WARNING]
> [fuse_qkv_projections](https://github.com/huggingface/diffusers/blob/58431f102cf39c3c8a569f32d71b2ea8caa461e1/src/diffusers/pipelines/pipeline_utils.py#L2034)方法是实验性的目前主要支持Stable Diffusion管道。参阅此[PR](https://github.com/huggingface/diffusers/pull/6179)了解如何为其他管道启用它。
在注意力块中输入被投影到三个子空间分别由投影矩阵Q、K和V表示。这些投影通常单独计算但您可以水平组合这些矩阵为一个矩阵并在单步中执行投影。这会增加输入投影的矩阵乘法大小并提高量化的效果。
```py
pipeline.fuse_qkv_projections()
```
## 资源
- 阅读[Presenting Flux Fast: Making Flux go brrr on H100s](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/)博客文章,了解如何结合所有这些优化与[TorchInductor](https://docs.pytorch.org/docs/stable/torch.compiler.html)和[AOTInductor](https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html),使用[flux-fast](https://github.com/huggingface/flux-fast)的配方获得约2.5倍的加速。
这些配方支持AMD硬件和[Flux.1 Kontext Dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev)。
- 阅读[torch.compile和Diffusers峰值性能实践指南](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/)博客文章,了解如何在使用`torch.compile`时最大化性能。