mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Utils] add utilities for checking if certain utilities are properly documented (#7763)
* add; utility to check if attn_procs,norms,acts are properly documented. * add support listing to the workflows. * change to 2024. * small fixes. * does adding detailed docstrings help? * uncomment image processor check * quality * fix, thanks to @mishig. * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * style * JointAttnProcessor2_0 * fixes * fixes * fixes * fixes * fixes * fixes * Update docs/source/en/api/normalization.md Co-authored-by: hlky <hlky@hlky.ac> --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: hlky <hlky@hlky.ac>
This commit is contained in:
1
.github/workflows/pr_tests.yml
vendored
1
.github/workflows/pr_tests.yml
vendored
@@ -64,6 +64,7 @@ jobs:
|
||||
run: |
|
||||
python utils/check_copies.py
|
||||
python utils/check_dummies.py
|
||||
python utils/check_support_list.py
|
||||
make deps_table_check_updated
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
|
||||
@@ -25,3 +25,16 @@ Customized activation functions for supporting various models in 🤗 Diffusers.
|
||||
## ApproximateGELU
|
||||
|
||||
[[autodoc]] models.activations.ApproximateGELU
|
||||
|
||||
|
||||
## SwiGLU
|
||||
|
||||
[[autodoc]] models.activations.SwiGLU
|
||||
|
||||
## FP32SiLU
|
||||
|
||||
[[autodoc]] models.activations.FP32SiLU
|
||||
|
||||
## LinearActivation
|
||||
|
||||
[[autodoc]] models.activations.LinearActivation
|
||||
|
||||
@@ -147,3 +147,20 @@ An attention processor is a class for applying different types of attention mech
|
||||
## XLAFlashAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.XLAFlashAttnProcessor2_0
|
||||
|
||||
## XFormersJointAttnProcessor
|
||||
|
||||
[[autodoc]] models.attention_processor.XFormersJointAttnProcessor
|
||||
|
||||
## IPAdapterXFormersAttnProcessor
|
||||
|
||||
[[autodoc]] models.attention_processor.IPAdapterXFormersAttnProcessor
|
||||
|
||||
## FluxIPAdapterJointAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.FluxIPAdapterJointAttnProcessor2_0
|
||||
|
||||
|
||||
## XLAFluxFlashAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.XLAFluxFlashAttnProcessor2_0
|
||||
@@ -29,3 +29,43 @@ Customized normalization layers for supporting various models in 🤗 Diffusers.
|
||||
## AdaGroupNorm
|
||||
|
||||
[[autodoc]] models.normalization.AdaGroupNorm
|
||||
|
||||
## AdaLayerNormContinuous
|
||||
|
||||
[[autodoc]] models.normalization.AdaLayerNormContinuous
|
||||
|
||||
## RMSNorm
|
||||
|
||||
[[autodoc]] models.normalization.RMSNorm
|
||||
|
||||
## GlobalResponseNorm
|
||||
|
||||
[[autodoc]] models.normalization.GlobalResponseNorm
|
||||
|
||||
|
||||
## LuminaLayerNormContinuous
|
||||
[[autodoc]] models.normalization.LuminaLayerNormContinuous
|
||||
|
||||
## SD35AdaLayerNormZeroX
|
||||
[[autodoc]] models.normalization.SD35AdaLayerNormZeroX
|
||||
|
||||
## AdaLayerNormZeroSingle
|
||||
[[autodoc]] models.normalization.AdaLayerNormZeroSingle
|
||||
|
||||
## LuminaRMSNormZero
|
||||
[[autodoc]] models.normalization.LuminaRMSNormZero
|
||||
|
||||
## LpNorm
|
||||
[[autodoc]] models.normalization.LpNorm
|
||||
|
||||
## CogView3PlusAdaLayerNormZeroTextImage
|
||||
[[autodoc]] models.normalization.CogView3PlusAdaLayerNormZeroTextImage
|
||||
|
||||
## CogVideoXLayerNormZero
|
||||
[[autodoc]] models.normalization.CogVideoXLayerNormZero
|
||||
|
||||
## MochiRMSNormZero
|
||||
[[autodoc]] models.transformers.transformer_mochi.MochiRMSNormZero
|
||||
|
||||
## MochiRMSNorm
|
||||
[[autodoc]] models.normalization.MochiRMSNorm
|
||||
@@ -306,6 +306,20 @@ class AdaGroupNorm(nn.Module):
|
||||
|
||||
|
||||
class AdaLayerNormContinuous(nn.Module):
|
||||
r"""
|
||||
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
|
||||
|
||||
Args:
|
||||
embedding_dim (`int`): Embedding dimension to use during projection.
|
||||
conditioning_embedding_dim (`int`): Dimension of the input condition.
|
||||
elementwise_affine (`bool`, defaults to `True`):
|
||||
Boolean flag to denote if affine transformation should be applied.
|
||||
eps (`float`, defaults to 1e-5): Epsilon factor.
|
||||
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
|
||||
norm_type (`str`, defaults to `"layer_norm"`):
|
||||
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
@@ -462,6 +476,17 @@ else:
|
||||
# Has optional bias parameter compared to torch layer norm
|
||||
# TODO: replace with torch layernorm once min required torch version >= 2.1
|
||||
class LayerNorm(nn.Module):
|
||||
r"""
|
||||
LayerNorm with the bias parameter.
|
||||
|
||||
Args:
|
||||
dim (`int`): Dimensionality to use for the parameters.
|
||||
eps (`float`, defaults to 1e-5): Epsilon factor.
|
||||
elementwise_affine (`bool`, defaults to `True`):
|
||||
Boolean flag to denote if affine transformation should be applied.
|
||||
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
|
||||
super().__init__()
|
||||
|
||||
@@ -484,6 +509,17 @@ else:
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
r"""
|
||||
RMS Norm as introduced in https://arxiv.org/abs/1910.07467 by Zhang et al.
|
||||
|
||||
Args:
|
||||
dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True.
|
||||
eps (`float`): Small value to use when calculating the reciprocal of the square-root.
|
||||
elementwise_affine (`bool`, defaults to `True`):
|
||||
Boolean flag to denote if affine transformation should be applied.
|
||||
bias (`bool`, defaults to False): If also training the `bias` param.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
|
||||
super().__init__()
|
||||
|
||||
@@ -573,6 +609,13 @@ class MochiRMSNorm(nn.Module):
|
||||
|
||||
|
||||
class GlobalResponseNorm(nn.Module):
|
||||
r"""
|
||||
Global response normalization as introduced in ConvNeXt-v2 (https://arxiv.org/abs/2301.00808).
|
||||
|
||||
Args:
|
||||
dim (`int`): Number of dimensions to use for the `gamma` and `beta`.
|
||||
"""
|
||||
|
||||
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
|
||||
68
tests/others/test_check_support_list.py
Normal file
68
tests/others/test_check_support_list.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
|
||||
git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
sys.path.append(os.path.join(git_repo_path, "utils"))
|
||||
|
||||
from check_support_list import check_documentation # noqa: E402
|
||||
|
||||
|
||||
class TestCheckSupportList(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Mock doc and source contents that we can reuse
|
||||
self.doc_content = """# Documentation
|
||||
## FooProcessor
|
||||
|
||||
[[autodoc]] module.FooProcessor
|
||||
|
||||
## BarProcessor
|
||||
|
||||
[[autodoc]] module.BarProcessor
|
||||
"""
|
||||
self.source_content = """
|
||||
class FooProcessor(nn.Module):
|
||||
pass
|
||||
|
||||
class BarProcessor(nn.Module):
|
||||
pass
|
||||
"""
|
||||
|
||||
def test_check_documentation_all_documented(self):
|
||||
# In this test, both FooProcessor and BarProcessor are documented
|
||||
with patch("builtins.open", mock_open(read_data=self.doc_content)) as doc_file:
|
||||
doc_file.side_effect = [
|
||||
mock_open(read_data=self.doc_content).return_value,
|
||||
mock_open(read_data=self.source_content).return_value,
|
||||
]
|
||||
|
||||
undocumented = check_documentation(
|
||||
doc_path="fake_doc.md",
|
||||
src_path="fake_source.py",
|
||||
doc_regex=r"\[\[autodoc\]\]\s([^\n]+)",
|
||||
src_regex=r"class\s+(\w+Processor)\(.*?nn\.Module.*?\):",
|
||||
)
|
||||
self.assertEqual(len(undocumented), 0, f"Expected no undocumented classes, got {undocumented}")
|
||||
|
||||
def test_check_documentation_missing_class(self):
|
||||
# In this test, only FooProcessor is documented, but BarProcessor is missing from the docs
|
||||
doc_content_missing = """# Documentation
|
||||
## FooProcessor
|
||||
|
||||
[[autodoc]] module.FooProcessor
|
||||
"""
|
||||
with patch("builtins.open", mock_open(read_data=doc_content_missing)) as doc_file:
|
||||
doc_file.side_effect = [
|
||||
mock_open(read_data=doc_content_missing).return_value,
|
||||
mock_open(read_data=self.source_content).return_value,
|
||||
]
|
||||
|
||||
undocumented = check_documentation(
|
||||
doc_path="fake_doc.md",
|
||||
src_path="fake_source.py",
|
||||
doc_regex=r"\[\[autodoc\]\]\s([^\n]+)",
|
||||
src_regex=r"class\s+(\w+Processor)\(.*?nn\.Module.*?\):",
|
||||
)
|
||||
self.assertIn("BarProcessor", undocumented, f"BarProcessor should be undocumented, got {undocumented}")
|
||||
124
utils/check_support_list.py
Normal file
124
utils/check_support_list.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
"""
|
||||
Utility that checks that modules like attention processors are listed in the documentation file.
|
||||
|
||||
```bash
|
||||
python utils/check_support_list.py
|
||||
```
|
||||
|
||||
It has no auto-fix mode.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
# All paths are set with the intent that you run this script from the root of the repo
|
||||
REPO_PATH = "."
|
||||
|
||||
|
||||
def read_documented_classes(doc_path, autodoc_regex=r"\[\[autodoc\]\]\s([^\n]+)"):
|
||||
"""
|
||||
Reads documented classes from a doc file using a regex to find lines like [[autodoc]] my.module.Class.
|
||||
Returns a list of documented class names (just the class name portion).
|
||||
"""
|
||||
with open(os.path.join(REPO_PATH, doc_path), "r") as f:
|
||||
doctext = f.read()
|
||||
matches = re.findall(autodoc_regex, doctext)
|
||||
return [match.split(".")[-1] for match in matches]
|
||||
|
||||
|
||||
def read_source_classes(src_path, class_regex, exclude_conditions=None):
|
||||
"""
|
||||
Reads class names from a source file using a regex that captures class definitions.
|
||||
Optionally exclude classes based on a list of conditions (functions that take class name and return bool).
|
||||
"""
|
||||
if exclude_conditions is None:
|
||||
exclude_conditions = []
|
||||
with open(os.path.join(REPO_PATH, src_path), "r") as f:
|
||||
doctext = f.read()
|
||||
classes = re.findall(class_regex, doctext)
|
||||
# Filter out classes that meet any of the exclude conditions
|
||||
filtered_classes = [c for c in classes if not any(cond(c) for cond in exclude_conditions)]
|
||||
return filtered_classes
|
||||
|
||||
|
||||
def check_documentation(doc_path, src_path, doc_regex, src_regex, exclude_conditions=None):
|
||||
"""
|
||||
Generic function to check if all classes defined in `src_path` are documented in `doc_path`.
|
||||
Returns a set of undocumented class names.
|
||||
"""
|
||||
documented = set(read_documented_classes(doc_path, doc_regex))
|
||||
source_classes = set(read_source_classes(src_path, src_regex, exclude_conditions=exclude_conditions))
|
||||
|
||||
# Find which classes in source are not documented in a deterministic way.
|
||||
undocumented = sorted(source_classes - documented)
|
||||
return undocumented
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Define the checks we need to perform
|
||||
checks = {
|
||||
"Attention Processors": {
|
||||
"doc_path": "docs/source/en/api/attnprocessor.md",
|
||||
"src_path": "src/diffusers/models/attention_processor.py",
|
||||
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
|
||||
"src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]",
|
||||
"exclude_conditions": [lambda c: "LoRA" in c, lambda c: c == "Attention"],
|
||||
},
|
||||
"Image Processors": {
|
||||
"doc_path": "docs/source/en/api/image_processor.md",
|
||||
"src_path": "src/diffusers/image_processor.py",
|
||||
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
|
||||
"src_regex": r"class\s+(\w+Processor(?:\d*_?\d*))[:(]",
|
||||
},
|
||||
"Activations": {
|
||||
"doc_path": "docs/source/en/api/activations.md",
|
||||
"src_path": "src/diffusers/models/activations.py",
|
||||
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
|
||||
"src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):",
|
||||
},
|
||||
"Normalizations": {
|
||||
"doc_path": "docs/source/en/api/normalization.md",
|
||||
"src_path": "src/diffusers/models/normalization.py",
|
||||
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
|
||||
"src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):",
|
||||
"exclude_conditions": [
|
||||
# Exclude LayerNorm as it's an intentional exception
|
||||
lambda c: c == "LayerNorm"
|
||||
],
|
||||
},
|
||||
"LoRA Mixins": {
|
||||
"doc_path": "docs/source/en/api/loaders/lora.md",
|
||||
"src_path": "src/diffusers/loaders/lora_pipeline.py",
|
||||
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
|
||||
"src_regex": r"class\s+(\w+)\s*\(.*?nn\.Module.*?\):",
|
||||
},
|
||||
}
|
||||
|
||||
missing_items = {}
|
||||
for category, params in checks.items():
|
||||
undocumented = check_documentation(
|
||||
doc_path=params["doc_path"],
|
||||
src_path=params["src_path"],
|
||||
doc_regex=params["doc_regex"],
|
||||
src_regex=params["src_regex"],
|
||||
exclude_conditions=params.get("exclude_conditions"),
|
||||
)
|
||||
if undocumented:
|
||||
missing_items[category] = undocumented
|
||||
|
||||
# If we have any missing items, raise a single combined error
|
||||
if missing_items:
|
||||
error_msg = ["Some classes are not documented properly:\n"]
|
||||
for category, classes in missing_items.items():
|
||||
error_msg.append(f"- {category}: {', '.join(sorted(classes))}")
|
||||
raise ValueError("\n".join(error_msg))
|
||||
Reference in New Issue
Block a user