mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* up * up up * update outputs * style * add modular_auto_docstring! * more auto docstring * style * up up up * more more * up * address feedbacks * add TODO in the description for empty docstring * refactor based on dhruv's feedback: remove the class method * add template method * up * up up up * apply auto docstring * make style * rmove space in make docstring * Apply suggestions from code review * revert change in z * fix * Apply style fixes * include auto-docstring check in the modular ci. (#13004) * Run ruff format after auto docstring generation * up * upup * upup * style --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
353 lines
12 KiB
Python
353 lines
12 KiB
Python
# coding=utf-8
|
|
# Copyright 2025 The HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Auto Docstring Generator for Modular Pipeline Blocks
|
|
|
|
This script scans Python files for classes that have `# auto_docstring` comment above them
|
|
and inserts/updates the docstring from the class's `doc` property.
|
|
|
|
Run from the root of the repo:
|
|
python utils/modular_auto_docstring.py [path] [--fix_and_overwrite]
|
|
|
|
Examples:
|
|
# Check for auto_docstring markers (will error if found without proper docstring)
|
|
python utils/modular_auto_docstring.py
|
|
|
|
# Check specific directory
|
|
python utils/modular_auto_docstring.py src/diffusers/modular_pipelines/
|
|
|
|
# Fix and overwrite the docstrings
|
|
python utils/modular_auto_docstring.py --fix_and_overwrite
|
|
|
|
Usage in code:
|
|
# auto_docstring
|
|
class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
|
|
# docstring will be automatically inserted here
|
|
|
|
@property
|
|
def doc(self):
|
|
return "Your docstring content..."
|
|
"""
|
|
|
|
import argparse
|
|
import ast
|
|
import glob
|
|
import importlib
|
|
import os
|
|
import re
|
|
import subprocess
|
|
import sys
|
|
|
|
|
|
# All paths are set with the intent you should run this script from the root of the repo
|
|
DIFFUSERS_PATH = "src/diffusers"
|
|
REPO_PATH = "."
|
|
|
|
# Pattern to match the auto_docstring comment
|
|
AUTO_DOCSTRING_PATTERN = re.compile(r"^\s*#\s*auto_docstring\s*$")
|
|
|
|
|
|
def setup_diffusers_import():
|
|
"""Setup import path to use the local diffusers module."""
|
|
src_path = os.path.join(REPO_PATH, "src")
|
|
if src_path not in sys.path:
|
|
sys.path.insert(0, src_path)
|
|
|
|
|
|
def get_module_from_filepath(filepath: str) -> str:
|
|
"""Convert a filepath to a module name."""
|
|
filepath = os.path.normpath(filepath)
|
|
|
|
if filepath.startswith("src" + os.sep):
|
|
filepath = filepath[4:]
|
|
|
|
if filepath.endswith(".py"):
|
|
filepath = filepath[:-3]
|
|
|
|
module_name = filepath.replace(os.sep, ".")
|
|
return module_name
|
|
|
|
|
|
def load_module(filepath: str):
|
|
"""Load a module from filepath."""
|
|
setup_diffusers_import()
|
|
module_name = get_module_from_filepath(filepath)
|
|
|
|
try:
|
|
module = importlib.import_module(module_name)
|
|
return module
|
|
except Exception as e:
|
|
print(f"Warning: Could not import module {module_name}: {e}")
|
|
return None
|
|
|
|
|
|
def get_doc_from_class(module, class_name: str) -> str:
|
|
"""Get the doc property from an instantiated class."""
|
|
if module is None:
|
|
return None
|
|
|
|
cls = getattr(module, class_name, None)
|
|
if cls is None:
|
|
return None
|
|
|
|
try:
|
|
instance = cls()
|
|
if hasattr(instance, "doc"):
|
|
return instance.doc
|
|
except Exception as e:
|
|
print(f"Warning: Could not instantiate {class_name}: {e}")
|
|
|
|
return None
|
|
|
|
|
|
def find_auto_docstring_classes(filepath: str) -> list:
|
|
"""
|
|
Find all classes in a file that have # auto_docstring comment above them.
|
|
|
|
Returns list of (class_name, class_line_number, has_existing_docstring, docstring_end_line)
|
|
"""
|
|
with open(filepath, "r", encoding="utf-8", newline="\n") as f:
|
|
lines = f.readlines()
|
|
|
|
# Parse AST to find class locations and their docstrings
|
|
content = "".join(lines)
|
|
try:
|
|
tree = ast.parse(content)
|
|
except SyntaxError as e:
|
|
print(f"Syntax error in {filepath}: {e}")
|
|
return []
|
|
|
|
# Build a map of class_name -> (class_line, has_docstring, docstring_end_line)
|
|
class_info = {}
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, ast.ClassDef):
|
|
has_docstring = False
|
|
docstring_end_line = node.lineno # default to class line
|
|
|
|
if node.body and isinstance(node.body[0], ast.Expr):
|
|
first_stmt = node.body[0]
|
|
if isinstance(first_stmt.value, ast.Constant) and isinstance(first_stmt.value.value, str):
|
|
has_docstring = True
|
|
docstring_end_line = first_stmt.end_lineno or first_stmt.lineno
|
|
|
|
class_info[node.name] = (node.lineno, has_docstring, docstring_end_line)
|
|
|
|
# Now scan for # auto_docstring comments
|
|
classes_to_update = []
|
|
|
|
for i, line in enumerate(lines):
|
|
if AUTO_DOCSTRING_PATTERN.match(line):
|
|
# Found the marker, look for class definition on next non-empty, non-comment line
|
|
j = i + 1
|
|
while j < len(lines):
|
|
next_line = lines[j].strip()
|
|
if next_line and not next_line.startswith("#"):
|
|
break
|
|
j += 1
|
|
|
|
if j < len(lines) and lines[j].strip().startswith("class "):
|
|
# Extract class name
|
|
match = re.match(r"class\s+(\w+)", lines[j].strip())
|
|
if match:
|
|
class_name = match.group(1)
|
|
if class_name in class_info:
|
|
class_line, has_docstring, docstring_end_line = class_info[class_name]
|
|
classes_to_update.append((class_name, class_line, has_docstring, docstring_end_line))
|
|
|
|
return classes_to_update
|
|
|
|
|
|
def strip_class_name_line(doc: str, class_name: str) -> str:
|
|
"""Remove the 'class ClassName' line from the doc if present."""
|
|
lines = doc.strip().split("\n")
|
|
if lines and lines[0].strip() == f"class {class_name}":
|
|
# Remove the class line and any blank line following it
|
|
lines = lines[1:]
|
|
while lines and not lines[0].strip():
|
|
lines = lines[1:]
|
|
return "\n".join(lines)
|
|
|
|
|
|
def format_docstring(doc: str, indent: str = " ") -> str:
|
|
"""Format a doc string as a properly indented docstring."""
|
|
lines = doc.strip().split("\n")
|
|
|
|
if len(lines) == 1:
|
|
return f'{indent}"""{lines[0]}"""\n'
|
|
else:
|
|
result = [f'{indent}"""\n']
|
|
for line in lines:
|
|
if line.strip():
|
|
result.append(f"{indent}{line}\n")
|
|
else:
|
|
result.append("\n")
|
|
result.append(f'{indent}"""\n')
|
|
return "".join(result)
|
|
|
|
|
|
def run_ruff_format(filepath: str):
|
|
"""Run ruff check --fix, ruff format, and doc-builder style on a file to ensure consistent formatting."""
|
|
try:
|
|
# First run ruff check --fix to fix any linting issues (including line length)
|
|
subprocess.run(
|
|
["ruff", "check", "--fix", filepath],
|
|
check=False, # Don't fail if there are unfixable issues
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
# Then run ruff format for code formatting
|
|
subprocess.run(
|
|
["ruff", "format", filepath],
|
|
check=True,
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
# Finally run doc-builder style for docstring formatting
|
|
subprocess.run(
|
|
["doc-builder", "style", filepath, "--max_len", "119"],
|
|
check=False, # Don't fail if doc-builder has issues
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
print(f"Formatted {filepath}")
|
|
except subprocess.CalledProcessError as e:
|
|
print(f"Warning: formatting failed for {filepath}: {e.stderr}")
|
|
except FileNotFoundError as e:
|
|
print(f"Warning: tool not found ({e}). Skipping formatting.")
|
|
except Exception as e:
|
|
print(f"Warning: unexpected error formatting {filepath}: {e}")
|
|
|
|
|
|
def get_existing_docstring(lines: list, class_line: int, docstring_end_line: int) -> str:
|
|
"""Extract the existing docstring content from lines."""
|
|
# class_line is 1-indexed, docstring starts at class_line (0-indexed: class_line)
|
|
# and ends at docstring_end_line (1-indexed, inclusive)
|
|
docstring_lines = lines[class_line:docstring_end_line]
|
|
return "".join(docstring_lines)
|
|
|
|
|
|
def process_file(filepath: str, overwrite: bool = False) -> list:
|
|
"""
|
|
Process a file and find/insert docstrings for # auto_docstring marked classes.
|
|
|
|
Returns list of classes that need updating.
|
|
"""
|
|
classes_to_update = find_auto_docstring_classes(filepath)
|
|
|
|
if not classes_to_update:
|
|
return []
|
|
|
|
if not overwrite:
|
|
# Check mode: only verify that docstrings exist
|
|
# Content comparison is not reliable due to formatting differences
|
|
classes_needing_update = []
|
|
for class_name, class_line, has_docstring, docstring_end_line in classes_to_update:
|
|
if not has_docstring:
|
|
# No docstring exists, needs update
|
|
classes_needing_update.append((filepath, class_name, class_line))
|
|
return classes_needing_update
|
|
|
|
# Load the module to get doc properties
|
|
module = load_module(filepath)
|
|
|
|
with open(filepath, "r", encoding="utf-8", newline="\n") as f:
|
|
lines = f.readlines()
|
|
|
|
# Process in reverse order to maintain line numbers
|
|
updated = False
|
|
for class_name, class_line, has_docstring, docstring_end_line in reversed(classes_to_update):
|
|
doc = get_doc_from_class(module, class_name)
|
|
|
|
if doc is None:
|
|
print(f"Warning: Could not get doc for {class_name} in {filepath}")
|
|
continue
|
|
|
|
# Remove the "class ClassName" line since it's redundant in a docstring
|
|
doc = strip_class_name_line(doc, class_name)
|
|
|
|
# Format the new docstring with 4-space indent
|
|
new_docstring = format_docstring(doc, " ")
|
|
|
|
if has_docstring:
|
|
# Replace existing docstring (line after class definition to docstring_end_line)
|
|
# class_line is 1-indexed, we want to replace from class_line+1 to docstring_end_line
|
|
lines = lines[:class_line] + [new_docstring] + lines[docstring_end_line:]
|
|
else:
|
|
# Insert new docstring right after class definition line
|
|
# class_line is 1-indexed, so lines[class_line-1] is the class line
|
|
# Insert at position class_line (which is right after the class line)
|
|
lines = lines[:class_line] + [new_docstring] + lines[class_line:]
|
|
|
|
updated = True
|
|
print(f"Updated docstring for {class_name} in {filepath}")
|
|
|
|
if updated:
|
|
with open(filepath, "w", encoding="utf-8", newline="\n") as f:
|
|
f.writelines(lines)
|
|
# Run ruff format to ensure consistent line wrapping
|
|
run_ruff_format(filepath)
|
|
|
|
return [(filepath, cls_name, line) for cls_name, line, _, _ in classes_to_update]
|
|
|
|
|
|
def check_auto_docstrings(path: str = None, overwrite: bool = False):
|
|
"""
|
|
Check all files for # auto_docstring markers and optionally fix them.
|
|
"""
|
|
if path is None:
|
|
path = DIFFUSERS_PATH
|
|
|
|
if os.path.isfile(path):
|
|
all_files = [path]
|
|
else:
|
|
all_files = glob.glob(os.path.join(path, "**/*.py"), recursive=True)
|
|
|
|
all_markers = []
|
|
|
|
for filepath in all_files:
|
|
markers = process_file(filepath, overwrite)
|
|
all_markers.extend(markers)
|
|
|
|
if not overwrite and len(all_markers) > 0:
|
|
message = "\n".join([f"- {f}: {cls} at line {line}" for f, cls, line in all_markers])
|
|
raise ValueError(
|
|
f"Found the following # auto_docstring markers that need docstrings:\n{message}\n\n"
|
|
f"Run `python utils/modular_auto_docstring.py --fix_and_overwrite` to fix them."
|
|
)
|
|
|
|
if overwrite and len(all_markers) > 0:
|
|
print(f"\nProcessed {len(all_markers)} docstring(s).")
|
|
elif not overwrite and len(all_markers) == 0:
|
|
print("All # auto_docstring markers have valid docstrings.")
|
|
elif len(all_markers) == 0:
|
|
print("No # auto_docstring markers found.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="Check and fix # auto_docstring markers in modular pipeline blocks",
|
|
)
|
|
parser.add_argument("path", nargs="?", default=None, help="File or directory to process (default: src/diffusers)")
|
|
parser.add_argument(
|
|
"--fix_and_overwrite",
|
|
action="store_true",
|
|
help="Whether to fix the docstrings by inserting them from doc property.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
check_auto_docstrings(args.path, args.fix_and_overwrite)
|