diff --git a/utils/generate_model_tests.py b/utils/generate_model_tests.py new file mode 100644 index 0000000000..ffd600dfdf --- /dev/null +++ b/utils/generate_model_tests.py @@ -0,0 +1,489 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utility script to generate test suites for diffusers model classes. + +Usage: + python utils/generate_model_tests.py src/diffusers/models/transformers/transformer_flux.py + +This will analyze the model file and generate a test file with appropriate +test classes based on the model's mixins and attributes. +""" + +import argparse +import ast +import sys +from pathlib import Path + + +MIXIN_TO_TESTER = { + "ModelMixin": "ModelTesterMixin", + "PeftAdapterMixin": "LoraTesterMixin", +} + +ATTRIBUTE_TO_TESTER = { + "_cp_plan": "ContextParallelTesterMixin", + "_supports_gradient_checkpointing": "TrainingTesterMixin", +} + +ALWAYS_INCLUDE_TESTERS = [ + "ModelTesterMixin", + "MemoryTesterMixin", + "AttentionTesterMixin", + "TorchCompileTesterMixin", +] + +OPTIONAL_TESTERS = [ + ("BitsAndBytesTesterMixin", "bnb"), + ("QuantoTesterMixin", "quanto"), + ("TorchAoTesterMixin", "torchao"), + ("GGUFTesterMixin", "gguf"), + ("ModelOptTesterMixin", "modelopt"), + ("SingleFileTesterMixin", "single_file"), + ("IPAdapterTesterMixin", "ip_adapter"), +] + + +class ModelAnalyzer(ast.NodeVisitor): + def __init__(self): + self.model_classes = [] + self.current_class = None + + def visit_ClassDef(self, node: ast.ClassDef): + base_names = [] + for base in node.bases: + if isinstance(base, ast.Name): + base_names.append(base.id) + elif isinstance(base, ast.Attribute): + base_names.append(base.attr) + + if "ModelMixin" in base_names: + class_info = { + "name": node.name, + "bases": base_names, + "attributes": {}, + "has_forward": False, + "init_params": [], + } + + for item in node.body: + if isinstance(item, ast.Assign): + for target in item.targets: + if isinstance(target, ast.Name): + attr_name = target.id + if attr_name.startswith("_"): + class_info["attributes"][attr_name] = self._get_value(item.value) + + elif isinstance(item, ast.FunctionDef): + if item.name == "forward": + class_info["has_forward"] = True + class_info["forward_params"] = self._extract_func_params(item) + elif item.name == "__init__": + class_info["init_params"] = self._extract_func_params(item) + + self.model_classes.append(class_info) + + self.generic_visit(node) + + def _extract_func_params(self, func_node: ast.FunctionDef) -> list[dict]: + params = [] + args = func_node.args + + num_defaults = len(args.defaults) + num_args = len(args.args) + first_default_idx = num_args - num_defaults + + for i, arg in enumerate(args.args): + if arg.arg == "self": + continue + + param_info = {"name": arg.arg, "type": None, "default": None} + + if arg.annotation: + param_info["type"] = self._get_annotation_str(arg.annotation) + + default_idx = i - first_default_idx + if default_idx >= 0 and default_idx < len(args.defaults): + param_info["default"] = self._get_value(args.defaults[default_idx]) + + params.append(param_info) + + return params + + def _get_annotation_str(self, node) -> str: + if isinstance(node, ast.Name): + return node.id + elif isinstance(node, ast.Constant): + return repr(node.value) + elif isinstance(node, ast.Subscript): + base = self._get_annotation_str(node.value) + if isinstance(node.slice, ast.Tuple): + args = ", ".join(self._get_annotation_str(el) for el in node.slice.elts) + else: + args = self._get_annotation_str(node.slice) + return f"{base}[{args}]" + elif isinstance(node, ast.Attribute): + return f"{self._get_annotation_str(node.value)}.{node.attr}" + elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr): + left = self._get_annotation_str(node.left) + right = self._get_annotation_str(node.right) + return f"{left} | {right}" + elif isinstance(node, ast.Tuple): + return ", ".join(self._get_annotation_str(el) for el in node.elts) + return "Any" + + def _get_value(self, node): + if isinstance(node, ast.Constant): + return node.value + elif isinstance(node, ast.Name): + if node.id == "None": + return None + elif node.id == "True": + return True + elif node.id == "False": + return False + return node.id + elif isinstance(node, ast.List): + return [self._get_value(el) for el in node.elts] + elif isinstance(node, ast.Dict): + return {self._get_value(k): self._get_value(v) for k, v in zip(node.keys, node.values)} + return "" + + +def analyze_model_file(filepath: str) -> list[dict]: + with open(filepath) as f: + source = f.read() + + tree = ast.parse(source) + analyzer = ModelAnalyzer() + analyzer.visit(tree) + + return analyzer.model_classes + + +def determine_testers(model_info: dict, include_optional: list[str]) -> list[str]: + testers = list(ALWAYS_INCLUDE_TESTERS) + + for base in model_info["bases"]: + if base in MIXIN_TO_TESTER: + tester = MIXIN_TO_TESTER[base] + if tester not in testers: + testers.append(tester) + + for attr, tester in ATTRIBUTE_TO_TESTER.items(): + if attr in model_info["attributes"]: + value = model_info["attributes"][attr] + if value is not None and value is not False: + if tester not in testers: + testers.append(tester) + + if "_cp_plan" in model_info["attributes"] and model_info["attributes"]["_cp_plan"] is not None: + if "ContextParallelTesterMixin" not in testers: + testers.append("ContextParallelTesterMixin") + + for tester, flag in OPTIONAL_TESTERS: + if flag in include_optional: + if tester not in testers: + testers.append(tester) + + return testers + + +def generate_config_class(model_info: dict, model_name: str) -> str: + class_name = f"{model_name}TesterConfig" + model_class = model_info["name"] + forward_params = model_info.get("forward_params", []) + init_params = model_info.get("init_params", []) + + lines = [ + f"class {class_name}:", + f" model_class = {model_class}", + ' pretrained_model_name_or_path = ""', + ' pretrained_model_kwargs = {"subfolder": "transformer"}', + "", + " @property", + " def generator(self):", + ' return torch.Generator("cpu").manual_seed(0)', + "", + " def get_init_dict(self) -> dict[str, int | list[int]]:", + ] + + if init_params: + lines.append(" # __init__ parameters:") + for param in init_params: + type_str = f": {param['type']}" if param["type"] else "" + default_str = f" = {param['default']}" if param["default"] is not None else "" + lines.append(f" # {param['name']}{type_str}{default_str}") + + lines.extend( + [ + " return {}", + "", + " def get_dummy_inputs(self) -> dict[str, torch.Tensor]:", + ] + ) + + if forward_params: + lines.append(" # forward() parameters:") + for param in forward_params: + type_str = f": {param['type']}" if param["type"] else "" + default_str = f" = {param['default']}" if param["default"] is not None else "" + lines.append(f" # {param['name']}{type_str}{default_str}") + + lines.extend( + [ + " # TODO: Fill in dummy inputs", + " return {}", + "", + " @property", + " def input_shape(self) -> tuple[int, ...]:", + " return (1, 1)", + "", + " @property", + " def output_shape(self) -> tuple[int, ...]:", + " return (1, 1)", + ] + ) + + return "\n".join(lines) + + +def generate_test_class(model_name: str, config_class: str, tester: str) -> str: + tester_short = tester.replace("TesterMixin", "") + class_name = f"Test{model_name}{tester_short}" + + lines = [f"class {class_name}({config_class}, {tester}):"] + + if tester == "TorchCompileTesterMixin": + lines.extend( + [ + " different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]", + "", + " def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:", + " # TODO: Implement dynamic input generation", + " return {}", + ] + ) + elif tester == "IPAdapterTesterMixin": + lines.extend( + [ + " ip_adapter_processor_cls = None # TODO: Set processor class", + "", + " def modify_inputs_for_ip_adapter(self, model, inputs_dict):", + " # TODO: Add IP adapter image embeds to inputs", + " return inputs_dict", + "", + " def create_ip_adapter_state_dict(self, model):", + " # TODO: Create IP adapter state dict", + " return {}", + ] + ) + elif tester == "SingleFileTesterMixin": + lines.extend( + [ + ' ckpt_path = "" # TODO: Set checkpoint path', + " alternate_keys_ckpt_paths = []", + ' pretrained_model_name_or_path = ""', + ' subfolder = "transformer"', + ] + ) + elif tester == "GGUFTesterMixin": + lines.extend( + [ + ' gguf_filename = "" # TODO: Set GGUF filename', + "", + " def get_dummy_inputs(self) -> dict[str, torch.Tensor]:", + " # TODO: Override with larger inputs for quantization tests", + " return {}", + ] + ) + elif tester in ["BitsAndBytesTesterMixin", "QuantoTesterMixin", "TorchAoTesterMixin", "ModelOptTesterMixin"]: + lines.extend( + [ + " def get_dummy_inputs(self) -> dict[str, torch.Tensor]:", + " # TODO: Override with larger inputs for quantization tests", + " return {}", + ] + ) + elif tester == "LoraHotSwappingForModelTesterMixin": + lines.extend( + [ + " different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]", + "", + " def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:", + " # TODO: Implement dynamic input generation", + " return {}", + ] + ) + else: + lines.append(" pass") + + return "\n".join(lines) + + +def generate_test_file(model_info: dict, model_filepath: str, include_optional: list[str]) -> str: + model_name = model_info["name"].replace("2DModel", "").replace("3DModel", "").replace("Model", "") + testers = determine_testers(model_info, include_optional) + tester_imports = sorted(set(testers) - {"LoraHotSwappingForModelTesterMixin"}) + + lines = [ + "# coding=utf-8", + "# Copyright 2025 HuggingFace Inc.", + "#", + '# Licensed under the Apache License, Version 2.0 (the "License");', + "# you may not use this file except in compliance with the License.", + "# You may obtain a copy of the License at", + "#", + "# http://www.apache.org/licenses/LICENSE-2.0", + "#", + "# Unless required by applicable law or agreed to in writing, software", + '# distributed under the License is distributed on an "AS IS" BASIS,', + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.", + "# See the License for the specific language governing permissions and", + "# limitations under the License.", + "", + "import torch", + "", + f"from diffusers import {model_info['name']}", + "from diffusers.utils.torch_utils import randn_tensor", + "", + "from ...testing_utils import enable_full_determinism, torch_device", + ] + + if "LoraTesterMixin" in testers: + lines.append("from ..test_modeling_common import LoraHotSwappingForModelTesterMixin") + + lines.extend( + [ + "from ..testing_utils import (", + *[f" {tester}," for tester in sorted(tester_imports)], + ")", + "", + "", + "enable_full_determinism()", + "", + "", + ] + ) + + config_class = f"{model_name}TesterConfig" + lines.append(generate_config_class(model_info, model_name)) + lines.append("") + lines.append("") + + for tester in testers: + lines.append(generate_test_class(model_name, config_class, tester)) + lines.append("") + lines.append("") + + if "LoraTesterMixin" in testers: + lines.append(generate_test_class(model_name, config_class, "LoraHotSwappingForModelTesterMixin")) + lines.append("") + lines.append("") + + return "\n".join(lines).rstrip() + "\n" + + +def get_test_output_path(model_filepath: str) -> str: + path = Path(model_filepath) + model_filename = path.stem + + if "transformers" in path.parts: + return f"tests/models/transformers/test_models_{model_filename}.py" + elif "unets" in path.parts: + return f"tests/models/unets/test_models_{model_filename}.py" + elif "autoencoders" in path.parts: + return f"tests/models/autoencoders/test_models_{model_filename}.py" + else: + return f"tests/models/test_models_{model_filename}.py" + + +def main(): + parser = argparse.ArgumentParser(description="Generate test suite for a diffusers model class") + parser.add_argument( + "model_filepath", + type=str, + help="Path to the model file (e.g., src/diffusers/models/transformers/transformer_flux.py)", + ) + parser.add_argument( + "--output", "-o", type=str, default=None, help="Output file path (default: auto-generated based on model path)" + ) + parser.add_argument( + "--include", + "-i", + type=str, + nargs="*", + default=[], + choices=["compile", "bnb", "quanto", "torchao", "gguf", "modelopt", "single_file", "ip_adapter", "all"], + help="Optional testers to include", + ) + parser.add_argument( + "--class-name", + "-c", + type=str, + default=None, + help="Specific model class to generate tests for (default: first model class found)", + ) + parser.add_argument("--dry-run", action="store_true", help="Print generated code without writing to file") + + args = parser.parse_args() + + if not Path(args.model_filepath).exists(): + print(f"Error: File not found: {args.model_filepath}", file=sys.stderr) + sys.exit(1) + + model_classes = analyze_model_file(args.model_filepath) + + if not model_classes: + print(f"Error: No model classes found in {args.model_filepath}", file=sys.stderr) + sys.exit(1) + + if args.class_name: + model_info = next((m for m in model_classes if m["name"] == args.class_name), None) + if not model_info: + available = [m["name"] for m in model_classes] + print(f"Error: Class '{args.class_name}' not found. Available: {available}", file=sys.stderr) + sys.exit(1) + else: + model_info = model_classes[0] + if len(model_classes) > 1: + print(f"Multiple model classes found, using: {model_info['name']}", file=sys.stderr) + print("Use --class-name to specify a different class", file=sys.stderr) + + include_optional = args.include + if "all" in include_optional: + include_optional = [flag for _, flag in OPTIONAL_TESTERS] + + generated_code = generate_test_file(model_info, args.model_filepath, include_optional) + + if args.dry_run: + print(generated_code) + else: + output_path = args.output or get_test_output_path(args.model_filepath) + output_dir = Path(output_path).parent + output_dir.mkdir(parents=True, exist_ok=True) + + with open(output_path, "w") as f: + f.write(generated_code) + + print(f"Generated test file: {output_path}") + print(f"Model class: {model_info['name']}") + print(f"Detected attributes: {list(model_info['attributes'].keys())}") + + +if __name__ == "__main__": + main()