# coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Auto Docstring Generator for Modular Pipeline Blocks This script scans Python files for classes that have `# auto_docstring` comment above them and inserts/updates the docstring from the class's `doc` property. Run from the root of the repo: python utils/modular_auto_docstring.py [path] [--fix_and_overwrite] Examples: # Check for auto_docstring markers (will error if found without proper docstring) python utils/modular_auto_docstring.py # Check specific directory python utils/modular_auto_docstring.py src/diffusers/modular_pipelines/ # Fix and overwrite the docstrings python utils/modular_auto_docstring.py --fix_and_overwrite Usage in code: # auto_docstring class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks): # docstring will be automatically inserted here @property def doc(self): return "Your docstring content..." """ import argparse import ast import glob import importlib import os import re import sys # All paths are set with the intent you should run this script from the root of the repo DIFFUSERS_PATH = "src/diffusers" REPO_PATH = "." # Pattern to match the auto_docstring comment AUTO_DOCSTRING_PATTERN = re.compile(r"^\s*#\s*auto_docstring\s*$") def setup_diffusers_import(): """Setup import path to use the local diffusers module.""" src_path = os.path.join(REPO_PATH, "src") if src_path not in sys.path: sys.path.insert(0, src_path) def get_module_from_filepath(filepath: str) -> str: """Convert a filepath to a module name.""" filepath = os.path.normpath(filepath) if filepath.startswith("src" + os.sep): filepath = filepath[4:] if filepath.endswith(".py"): filepath = filepath[:-3] module_name = filepath.replace(os.sep, ".") return module_name def load_module(filepath: str): """Load a module from filepath.""" setup_diffusers_import() module_name = get_module_from_filepath(filepath) try: module = importlib.import_module(module_name) return module except Exception as e: print(f"Warning: Could not import module {module_name}: {e}") return None def get_doc_from_class(module, class_name: str) -> str: """Get the doc property from an instantiated class.""" if module is None: return None cls = getattr(module, class_name, None) if cls is None: return None try: instance = cls() if hasattr(instance, "doc"): return instance.doc except Exception as e: print(f"Warning: Could not instantiate {class_name}: {e}") return None def find_auto_docstring_classes(filepath: str) -> list: """ Find all classes in a file that have # auto_docstring comment above them. Returns list of (class_name, class_line_number, has_existing_docstring, docstring_end_line) """ with open(filepath, "r", encoding="utf-8", newline="\n") as f: lines = f.readlines() # Parse AST to find class locations and their docstrings content = "".join(lines) try: tree = ast.parse(content) except SyntaxError as e: print(f"Syntax error in {filepath}: {e}") return [] # Build a map of class_name -> (class_line, has_docstring, docstring_end_line) class_info = {} for node in ast.walk(tree): if isinstance(node, ast.ClassDef): has_docstring = False docstring_end_line = node.lineno # default to class line if node.body and isinstance(node.body[0], ast.Expr): first_stmt = node.body[0] if isinstance(first_stmt.value, ast.Constant) and isinstance(first_stmt.value.value, str): has_docstring = True docstring_end_line = first_stmt.end_lineno or first_stmt.lineno class_info[node.name] = (node.lineno, has_docstring, docstring_end_line) # Now scan for # auto_docstring comments classes_to_update = [] for i, line in enumerate(lines): if AUTO_DOCSTRING_PATTERN.match(line): # Found the marker, look for class definition on next non-empty, non-comment line j = i + 1 while j < len(lines): next_line = lines[j].strip() if next_line and not next_line.startswith("#"): break j += 1 if j < len(lines) and lines[j].strip().startswith("class "): # Extract class name match = re.match(r"class\s+(\w+)", lines[j].strip()) if match: class_name = match.group(1) if class_name in class_info: class_line, has_docstring, docstring_end_line = class_info[class_name] classes_to_update.append((class_name, class_line, has_docstring, docstring_end_line)) return classes_to_update def strip_class_name_line(doc: str, class_name: str) -> str: """Remove the 'class ClassName' line from the doc if present.""" lines = doc.strip().split("\n") if lines and lines[0].strip() == f"class {class_name}": # Remove the class line and any blank line following it lines = lines[1:] while lines and not lines[0].strip(): lines = lines[1:] return "\n".join(lines) def format_docstring(doc: str, indent: str = " ") -> str: """Format a doc string as a properly indented docstring.""" lines = doc.strip().split("\n") if len(lines) == 1: return f'{indent}"""{lines[0]}"""\n' else: result = [f'{indent}"""\n'] for line in lines: if line.strip(): result.append(f"{indent}{line}\n") else: result.append("\n") result.append(f'{indent}"""\n') return "".join(result) def process_file(filepath: str, overwrite: bool = False) -> list: """ Process a file and find/insert docstrings for # auto_docstring marked classes. Returns list of classes that need updating. """ classes_to_update = find_auto_docstring_classes(filepath) if not classes_to_update: return [] if not overwrite: # Just return the list of classes that need updating return [(filepath, cls_name, line) for cls_name, line, _, _ in classes_to_update] # Load the module to get doc properties module = load_module(filepath) with open(filepath, "r", encoding="utf-8", newline="\n") as f: lines = f.readlines() # Process in reverse order to maintain line numbers updated = False for class_name, class_line, has_docstring, docstring_end_line in reversed(classes_to_update): doc = get_doc_from_class(module, class_name) if doc is None: print(f"Warning: Could not get doc for {class_name} in {filepath}") continue # Remove the "class ClassName" line since it's redundant in a docstring doc = strip_class_name_line(doc, class_name) # Format the new docstring with 4-space indent new_docstring = format_docstring(doc, " ") if has_docstring: # Replace existing docstring (line after class definition to docstring_end_line) # class_line is 1-indexed, we want to replace from class_line+1 to docstring_end_line lines = lines[:class_line] + [new_docstring] + lines[docstring_end_line:] else: # Insert new docstring right after class definition line # class_line is 1-indexed, so lines[class_line-1] is the class line # Insert at position class_line (which is right after the class line) lines = lines[:class_line] + [new_docstring] + lines[class_line:] updated = True print(f"Updated docstring for {class_name} in {filepath}") if updated: with open(filepath, "w", encoding="utf-8", newline="\n") as f: f.writelines(lines) return [(filepath, cls_name, line) for cls_name, line, _, _ in classes_to_update] def check_auto_docstrings(path: str = None, overwrite: bool = False): """ Check all files for # auto_docstring markers and optionally fix them. """ if path is None: path = DIFFUSERS_PATH if os.path.isfile(path): all_files = [path] else: all_files = glob.glob(os.path.join(path, "**/*.py"), recursive=True) all_markers = [] for filepath in all_files: markers = process_file(filepath, overwrite) all_markers.extend(markers) if not overwrite and len(all_markers) > 0: message = "\n".join([f"- {f}: {cls} at line {line}" for f, cls, line in all_markers]) raise ValueError( f"Found the following # auto_docstring markers that need docstrings:\n{message}\n\n" f"Run `python utils/modular_auto_docstring.py --fix_and_overwrite` to fix them." ) if overwrite and len(all_markers) > 0: print(f"\nUpdated {len(all_markers)} docstring(s).") elif len(all_markers) == 0: print("No # auto_docstring markers found.") if __name__ == "__main__": parser = argparse.ArgumentParser( description="Check and fix # auto_docstring markers in modular pipeline blocks", ) parser.add_argument("path", nargs="?", default=None, help="File or directory to process (default: src/diffusers)") parser.add_argument( "--fix_and_overwrite", action="store_true", help="Whether to fix the docstrings by inserting them from doc property.", ) args = parser.parse_args() check_auto_docstrings(args.path, args.fix_and_overwrite)