mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
271 lines
11 KiB
Markdown
271 lines
11 KiB
Markdown
## CacheDiT
|
||
|
||
CacheDiT is a unified, flexible, and training-free cache acceleration framework designed to support nearly all Diffusers' DiT-based pipelines. It provides a unified cache API that supports automatic block adapter, DBCache, and more.
|
||
|
||
To learn more, refer to the [CacheDiT](https://github.com/vipshop/cache-dit) repository.
|
||
|
||
Install a stable release of CacheDiT from PyPI or you can install the latest version from GitHub.
|
||
|
||
<hfoptions id="install">
|
||
<hfoption id="PyPI">
|
||
|
||
```bash
|
||
pip3 install -U cache-dit
|
||
```
|
||
|
||
</hfoption>
|
||
<hfoption id="source">
|
||
|
||
```bash
|
||
pip3 install git+https://github.com/vipshop/cache-dit.git
|
||
```
|
||
|
||
</hfoption>
|
||
</hfoptions>
|
||
|
||
Run the command below to view supported DiT pipelines.
|
||
|
||
```python
|
||
>>> import cache_dit
|
||
>>> cache_dit.supported_pipelines()
|
||
(30, ['Flux*', 'Mochi*', 'CogVideoX*', 'Wan*', 'HunyuanVideo*', 'QwenImage*', 'LTX*', 'Allegro*',
|
||
'CogView3Plus*', 'CogView4*', 'Cosmos*', 'EasyAnimate*', 'SkyReelsV2*', 'StableDiffusion3*',
|
||
'ConsisID*', 'DiT*', 'Amused*', 'Bria*', 'Lumina*', 'OmniGen*', 'PixArt*', 'Sana*', 'StableAudio*',
|
||
'VisualCloze*', 'AuraFlow*', 'Chroma*', 'ShapE*', 'HiDream*', 'HunyuanDiT*', 'HunyuanDiTPAG*'])
|
||
```
|
||
|
||
For a complete benchmark, please refer to [Benchmarks](https://github.com/vipshop/cache-dit/blob/main/bench/).
|
||
|
||
|
||
## Unified Cache API
|
||
|
||
CacheDiT works by matching specific input/output patterns as shown below.
|
||
|
||

|
||
|
||
Call the `enable_cache()` function on a pipeline to enable cache acceleration. This function is the entry point to many of CacheDiT's features.
|
||
|
||
```python
|
||
import cache_dit
|
||
from diffusers import DiffusionPipeline
|
||
|
||
# Can be any diffusion pipeline
|
||
pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image")
|
||
|
||
# One-line code with default cache options.
|
||
cache_dit.enable_cache(pipe)
|
||
|
||
# Just call the pipe as normal.
|
||
output = pipe(...)
|
||
|
||
# Disable cache and run original pipe.
|
||
cache_dit.disable_cache(pipe)
|
||
```
|
||
|
||
## Automatic Block Adapter
|
||
|
||
For custom or modified pipelines or transformers not included in Diffusers, use the `BlockAdapter` in `auto` mode or via manual configuration. Please check the [BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md#automatic-block-adapter) docs for more details. Refer to [Qwen-Image w/ BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_qwen_image_adapter.py) as an example.
|
||
|
||
|
||
```python
|
||
from cache_dit import ForwardPattern, BlockAdapter
|
||
|
||
# Use 🔥BlockAdapter with `auto` mode.
|
||
cache_dit.enable_cache(
|
||
BlockAdapter(
|
||
# Any DiffusionPipeline, Qwen-Image, etc.
|
||
pipe=pipe, auto=True,
|
||
# Check `📚Forward Pattern Matching` documentation and hack the code of
|
||
# of Qwen-Image, you will find that it has satisfied `FORWARD_PATTERN_1`.
|
||
forward_pattern=ForwardPattern.Pattern_1,
|
||
),
|
||
)
|
||
|
||
# Or, manually setup transformer configurations.
|
||
cache_dit.enable_cache(
|
||
BlockAdapter(
|
||
pipe=pipe, # Qwen-Image, etc.
|
||
transformer=pipe.transformer,
|
||
blocks=pipe.transformer.transformer_blocks,
|
||
forward_pattern=ForwardPattern.Pattern_1,
|
||
),
|
||
)
|
||
```
|
||
|
||
Sometimes, a Transformer class will contain more than one transformer `blocks`. For example, FLUX.1 (HiDream, Chroma, etc) contains `transformer_blocks` and `single_transformer_blocks` (with different forward patterns). The BlockAdapter is able to detect this hybrid pattern type as well.
|
||
Refer to [FLUX.1](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_flux_adapter.py) as an example.
|
||
|
||
```python
|
||
# For diffusers <= 0.34.0, FLUX.1 transformer_blocks and
|
||
# single_transformer_blocks have different forward patterns.
|
||
cache_dit.enable_cache(
|
||
BlockAdapter(
|
||
pipe=pipe, # FLUX.1, etc.
|
||
transformer=pipe.transformer,
|
||
blocks=[
|
||
pipe.transformer.transformer_blocks,
|
||
pipe.transformer.single_transformer_blocks,
|
||
],
|
||
forward_pattern=[
|
||
ForwardPattern.Pattern_1,
|
||
ForwardPattern.Pattern_3,
|
||
],
|
||
),
|
||
)
|
||
```
|
||
|
||
This also works if there is more than one transformer (namely `transformer` and `transformer_2`) in its structure. Refer to [Wan 2.2 MoE](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_wan_2.2.py) as an example.
|
||
|
||
## Patch Functor
|
||
|
||
For any pattern not included in CacheDiT, use the Patch Functor to convert the pattern into a known pattern. You need to subclass the Patch Functor and may also need to fuse the operations within the blocks for loop into block `forward`. After implementing a Patch Functor, set the `patch_functor` property in `BlockAdapter`.
|
||
|
||

|
||
|
||
Some Patch Functors are already provided in CacheDiT, [HiDreamPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_hidream.py), [ChromaPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/patch_functors/functor_chroma.py), etc.
|
||
|
||
```python
|
||
@BlockAdapterRegistry.register("HiDream")
|
||
def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
|
||
from diffusers import HiDreamImageTransformer2DModel
|
||
from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor
|
||
|
||
assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
|
||
return BlockAdapter(
|
||
pipe=pipe,
|
||
transformer=pipe.transformer,
|
||
blocks=[
|
||
pipe.transformer.double_stream_blocks,
|
||
pipe.transformer.single_stream_blocks,
|
||
],
|
||
forward_pattern=[
|
||
ForwardPattern.Pattern_0,
|
||
ForwardPattern.Pattern_3,
|
||
],
|
||
# NOTE: Setup your custom patch functor here.
|
||
patch_functor=HiDreamPatchFunctor(),
|
||
**kwargs,
|
||
)
|
||
```
|
||
|
||
Finally, you can call the `cache_dit.summary()` function on a pipeline after its completed inference to get the cache acceleration details.
|
||
|
||
```python
|
||
stats = cache_dit.summary(pipe)
|
||
```
|
||
|
||
```python
|
||
⚡️Cache Steps and Residual Diffs Statistics: QwenImagePipeline
|
||
|
||
| Cache Steps | Diffs Min | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Max |
|
||
|-------------|-----------|-----------|-----------|-----------|-----------|-----------|
|
||
| 23 | 0.045 | 0.084 | 0.114 | 0.147 | 0.241 | 0.297 |
|
||
```
|
||
|
||
## DBCache: Dual Block Cache
|
||
|
||

|
||
|
||
DBCache (Dual Block Caching) supports different configurations of compute blocks (F8B12, etc.) to enable a balanced trade-off between performance and precision.
|
||
- Fn_compute_blocks: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
|
||
- Bn_compute_blocks: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
|
||
|
||
|
||
```python
|
||
import cache_dit
|
||
from diffusers import FluxPipeline
|
||
|
||
pipe_or_adapter = FluxPipeline.from_pretrained(
|
||
"black-forest-labs/FLUX.1-dev",
|
||
torch_dtype=torch.bfloat16,
|
||
).to("cuda")
|
||
|
||
# Default options, F8B0, 8 warmup steps, and unlimited cached
|
||
# steps for good balance between performance and precision
|
||
cache_dit.enable_cache(pipe_or_adapter)
|
||
|
||
# Custom options, F8B8, higher precision
|
||
from cache_dit import BasicCacheConfig
|
||
|
||
cache_dit.enable_cache(
|
||
pipe_or_adapter,
|
||
cache_config=BasicCacheConfig(
|
||
max_warmup_steps=8, # steps do not cache
|
||
max_cached_steps=-1, # -1 means no limit
|
||
Fn_compute_blocks=8, # Fn, F8, etc.
|
||
Bn_compute_blocks=8, # Bn, B8, etc.
|
||
residual_diff_threshold=0.12,
|
||
),
|
||
)
|
||
```
|
||
Check the [DBCache](https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md) and [User Guide](https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md#dbcache) docs for more design details.
|
||
|
||
## TaylorSeer Calibrator
|
||
|
||
The [TaylorSeers](https://huggingface.co/papers/2503.06923) algorithm further improves the precision of DBCache in cases where the cached steps are large (Hybrid TaylorSeer + DBCache). At timesteps with significant intervals, the feature similarity in diffusion models decreases substantially, significantly harming the generation quality.
|
||
|
||
TaylorSeer employs a differential method to approximate the higher-order derivatives of features and predict features in future timesteps with Taylor series expansion. The TaylorSeer implemented in CacheDiT supports both hidden states and residual cache types. F_pred can be a residual cache or a hidden-state cache.
|
||
|
||
```python
|
||
from cache_dit import BasicCacheConfig, TaylorSeerCalibratorConfig
|
||
|
||
cache_dit.enable_cache(
|
||
pipe_or_adapter,
|
||
# Basic DBCache w/ FnBn configurations
|
||
cache_config=BasicCacheConfig(
|
||
max_warmup_steps=8, # steps do not cache
|
||
max_cached_steps=-1, # -1 means no limit
|
||
Fn_compute_blocks=8, # Fn, F8, etc.
|
||
Bn_compute_blocks=8, # Bn, B8, etc.
|
||
residual_diff_threshold=0.12,
|
||
),
|
||
# Then, you can use the TaylorSeer Calibrator to approximate
|
||
# the values in cached steps, taylorseer_order default is 1.
|
||
calibrator_config=TaylorSeerCalibratorConfig(
|
||
taylorseer_order=1,
|
||
),
|
||
)
|
||
```
|
||
|
||
> [!TIP]
|
||
> The `Bn_compute_blocks` parameter of DBCache can be set to `0` if you use TaylorSeer as the calibrator for approximate hidden states. DBCache's `Bn_compute_blocks` also acts as a calibrator, so you can choose either `Bn_compute_blocks` > 0 or TaylorSeer. We recommend using the configuration scheme of TaylorSeer + DBCache FnB0.
|
||
|
||
## Hybrid Cache CFG
|
||
|
||
CacheDiT supports caching for CFG (classifier-free guidance). For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG in the forward step, please set `enable_separate_cfg` parameter to `False (default, None)`. Otherwise, set it to `True`.
|
||
|
||
```python
|
||
from cache_dit import BasicCacheConfig
|
||
|
||
cache_dit.enable_cache(
|
||
pipe_or_adapter,
|
||
cache_config=BasicCacheConfig(
|
||
...,
|
||
# For example, set it as True for Wan 2.1, Qwen-Image
|
||
# and set it as False for FLUX.1, HunyuanVideo, etc.
|
||
enable_separate_cfg=True,
|
||
),
|
||
)
|
||
```
|
||
|
||
## torch.compile
|
||
|
||
CacheDiT is designed to work with torch.compile for even better performance. Call `torch.compile` after enabling the cache.
|
||
|
||
|
||
```python
|
||
cache_dit.enable_cache(pipe)
|
||
|
||
# Compile the Transformer module
|
||
pipe.transformer = torch.compile(pipe.transformer)
|
||
```
|
||
|
||
If you're using CacheDiT with dynamic input shapes, consider increasing the `recompile_limit` of `torch._dynamo`. Otherwise, the `recompile_limit` error may be triggered, causing the module to fall back to eager mode.
|
||
|
||
```python
|
||
torch._dynamo.config.recompile_limit = 96 # default is 8
|
||
torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
|
||
```
|
||
|
||
Please check [perf.py](https://github.com/vipshop/cache-dit/blob/main/bench/perf.py) for more details.
|