1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Animesh Jain
2025-06-22 21:42:44 -07:00
committed by Animesh Jain
parent f794d66f1e
commit 932914f45d
9 changed files with 73 additions and 25 deletions

View File

@@ -152,9 +152,39 @@ Compilation is slow the first time, but once compiled, it is significantly faste
### Regional compilation
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) reduces the cold start compilation time by only compiling a specific repeated region (or block) of the model instead of the entire model. The compiler reuses the cached and compiled code for the other blocks.
[Accelerate](https://huggingface.co/docs/accelerate/index) provides the [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method for automatically compiling the repeated blocks of a `nn.Module` sequentially. The rest of the model is compiled separately.
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence.
For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **810 ×**.
To make this effortless, `ModelMixin` exposes **`compile_repeated_blocks`** API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable:
```py
# pip install -U diffusers
import torch
from diffusers import StableDiffusionXLPipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
).to("cuda")
# Compile only the repeated Transformer layers inside the UNet
pipe.unet.compile_repeated_blocks(fullgraph=True)
```
To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled:
```py
class MyUNet(ModelMixin):
_repeated_blocks = ("Transformer2DModel",) # ← compiled by default
```
For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags.
```py
# pip install -U accelerate
@@ -167,6 +197,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
).to("cuda")
pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
```
`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users.
### Graph breaks
@@ -241,4 +273,4 @@ An input is projected into three subspaces, represented by the projection matric
```py
pipeline.fuse_qkv_projections()
```
```

View File

@@ -1414,39 +1414,28 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
can reduce end-to-end compile time substantially, while preserving the
runtime speed-ups you would expect from a full `torch.compile`.
The set of sub-modules to compile is discovered in one of two ways:
1. **`_repeated_blocks`** Preferred. Define this attribute on your
subclass as a list/tuple of class names (strings). Every module whose
class name matches will be compiled.
2. **`_no_split_modules`** Fallback. If the preferred attribute is
missing or empty, we fall back to the legacy Diffusers attribute
`_no_split_modules`.
The set of sub-modules to compile is discovered by the presence of
**`_repeated_blocks`** attribute in the model definition. Define this
attribute on your model subclass as a list/tuple of class names
(strings). Every module whose class name matches will be compiled.
Once discovered, each matching sub-module is compiled by calling
``submodule.compile(*args, **kwargs)``. Any positional or keyword
arguments you supply to :meth:`compile_repeated_blocks` are forwarded
`submodule.compile(*args, **kwargs)`. Any positional or keyword
arguments you supply to `compile_repeated_blocks` are forwarded
verbatim to `torch.compile`.
"""
repeated_blocks = getattr(self, "_repeated_blocks", None)
if not repeated_blocks:
logger.warning("_repeated_blocks attribute is empty. Using _no_split_modules to find compile regions.")
repeated_blocks = getattr(self, "_no_split_modules", None)
if not repeated_blocks:
raise ValueError(
"Both _repeated_blocks and _no_split_modules attribute are empty. "
"Set _repeated_blocks for the model to benefit from faster compilation. "
"`_repeated_blocks` attribute is empty. "
f"Set `_repeated_blocks` for the class `{self.__class__.__name__}` to benefit from faster compilation. "
)
has_compiled_region = False
for submod in self.modules():
if submod.__class__.__name__ in repeated_blocks:
has_compiled_region = True
submod.compile(*args, **kwargs)
has_compiled_region = True
if not has_compiled_region:
raise ValueError(

View File

@@ -407,6 +407,7 @@ class ChromaTransformer2DModel(
_supports_gradient_checkpointing = True
_no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
_repeated_blocks = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
@register_to_config

View File

@@ -227,7 +227,7 @@ class FluxTransformer2DModel(
_supports_gradient_checkpointing = True
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = _no_split_modules
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
@register_to_config
def __init__(

View File

@@ -870,6 +870,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
"HunyuanVideoPatchEmbed",
"HunyuanVideoTokenRefiner",
]
_repeated_blocks = [
"HunyuanVideoTransformerBlock",
"HunyuanVideoSingleTransformerBlock",
"HunyuanVideoPatchEmbed",
"HunyuanVideoTokenRefiner",
]
@register_to_config
def __init__(

View File

@@ -328,6 +328,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
_repeated_blocks = ["LTXVideoTransformerBlock"]
@register_to_config
def __init__(

View File

@@ -345,7 +345,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
_no_split_modules = ["WanTransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = _no_split_modules
_repeated_blocks = ["WanTransformerBlock"]
@register_to_config
def __init__(

View File

@@ -167,6 +167,7 @@ class UNet2DConditionModel(
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
_skip_layerwise_casting_patterns = ["norm"]
_repeated_blocks = ["BasicTransformerBlock"]
@register_to_config
def __init__(

View File

@@ -1936,6 +1936,24 @@ class TorchCompileTesterMixin:
_ = model(**inputs_dict)
_ = model(**inputs_dict)
def test_torch_compile_repeated_blocks(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.compile_repeated_blocks(fullgraph=True)
recompile_limit = 1
if self.model_class.__name__ == "UNet2DConditionModel":
recompile_limit = 2
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(recompile_limit=recompile_limit),
torch.no_grad(),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)
def test_compile_with_group_offloading(self):
torch._dynamo.config.cache_size_limit = 10000