[Perf] Validate @config in pre-commit instead of dynamically (#20200)

Signed-off-by: Lionel Villard <villard@us.ibm.com>
This commit is contained in:
Lionel Villard
2025-07-01 05:10:28 -04:00
committed by GitHub
parent 787b13389e
commit c05596f1a3
6 changed files with 220 additions and 57 deletions

View File

@ -160,6 +160,13 @@ repos:
types: [python]
pass_filenames: false
additional_dependencies: [pathspec, regex]
- id: validate-config
name: Validate configuration has default values and that each field has a docstring
entry: python tools/validate_config.py
language: python
types: [python]
pass_filenames: true
files: vllm/config.py|tests/test_config.py
# Keep `suggestion` last
- id: suggestion
name: Suggestion

View File

@ -2,49 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import MISSING, Field, asdict, dataclass, field
from typing import Literal, Union
import pytest
from vllm.compilation.backends import VllmBackend
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
config, get_field)
get_field)
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
class _TestConfig1:
pass
@dataclass
class _TestConfig2:
a: int
"""docstring"""
@dataclass
class _TestConfig3:
a: int = 1
@dataclass
class _TestConfig4:
a: Union[Literal[1], Literal[2]] = 1
"""docstring"""
@pytest.mark.parametrize(("test_config", "expected_error"), [
(_TestConfig1, "must be a dataclass"),
(_TestConfig2, "must have a default"),
(_TestConfig3, "must have a docstring"),
(_TestConfig4, "must use a single Literal"),
])
def test_config(test_config, expected_error):
with pytest.raises(Exception, match=expected_error):
config(test_config)
def test_compile_config_repr_succeeds():
# setup: VllmBackend mutates the config object
config = VllmConfig()

0
tests/tools/__init__.py Normal file
View File

View File

@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import pytest
from tools.validate_config import validate_ast
_TestConfig1 = '''
@config
class _TestConfig1:
pass
'''
_TestConfig2 = '''
@config
@dataclass
class _TestConfig2:
a: int
"""docstring"""
'''
_TestConfig3 = '''
@config
@dataclass
class _TestConfig3:
a: int = 1
'''
_TestConfig4 = '''
@config
@dataclass
class _TestConfig4:
a: Union[Literal[1], Literal[2]] = 1
"""docstring"""
'''
@pytest.mark.parametrize(("test_config", "expected_error"), [
(_TestConfig1, "must be a dataclass"),
(_TestConfig2, "must have a default"),
(_TestConfig3, "must have a docstring"),
(_TestConfig4, "must use a single Literal"),
])
def test_config(test_config, expected_error):
tree = ast.parse(test_config)
with pytest.raises(Exception, match=expected_error):
validate_ast(tree)

158
tools/validate_config.py Normal file
View File

@ -0,0 +1,158 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Ensures all fields in a config dataclass have default values
and that each field has a docstring.
"""
import ast
import inspect
import sys
def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]:
"""
Get any docstrings placed after attribute assignments in a class body.
Adapted from https://davidism.com/attribute-docstrings/
https://davidism.com/mit-license/
"""
def pairwise(iterable):
"""
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise
Can be removed when Python 3.9 support is dropped.
"""
iterator = iter(iterable)
a = next(iterator, None)
for b in iterator:
yield a, b
a = b
out = {}
# Consider each pair of nodes.
for a, b in pairwise(cls_node.body):
# Must be an assignment then a constant string.
if (not isinstance(a, (ast.Assign, ast.AnnAssign))
or not isinstance(b, ast.Expr)
or not isinstance(b.value, ast.Constant)
or not isinstance(b.value.value, str)):
continue
doc = inspect.cleandoc(b.value.value)
# An assignment can have multiple targets (a = b = v), but an
# annotated assignment only has one target.
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
for target in targets:
# Must be assigning to a plain name.
if not isinstance(target, ast.Name):
continue
out[target.id] = doc
return out
class ConfigValidator(ast.NodeVisitor):
def __init__(self):
...
def visit_ClassDef(self, node):
# Validate class with both @config and @dataclass decorators
decorators = [
id for d in node.decorator_list if (isinstance(d, ast.Name) and (
(id := d.id) == 'config' or id == 'dataclass')) or
(isinstance(d, ast.Call) and (isinstance(d.func, ast.Name) and
(id := d.func.id) == 'dataclass'))
]
if set(decorators) == {'config', 'dataclass'}:
validate_class(node)
elif set(decorators) == {'config'}:
fail(
f"Class {node.name} with config decorator must be a dataclass.",
node)
self.generic_visit(node)
def validate_class(class_node: ast.ClassDef):
attr_docs = get_attr_docs(class_node)
for stmt in class_node.body:
# A field is defined as a class variable that has a type annotation.
if isinstance(stmt, ast.AnnAssign):
# Skip ClassVar
# see https://docs.python.org/3/library/dataclasses.html#class-variables
if isinstance(stmt.annotation, ast.Subscript) and isinstance(
stmt.annotation.value,
ast.Name) and stmt.annotation.value.id == "ClassVar":
continue
if isinstance(stmt.target, ast.Name):
field_name = stmt.target.id
if stmt.value is None:
fail(
f"Field '{field_name}' in {class_node.name} must have "
"a default value.", stmt)
if field_name not in attr_docs:
fail(
f"Field '{field_name}' in {class_node.name} must have "
"a docstring.", stmt)
if isinstance(stmt.annotation, ast.Subscript) and \
isinstance(stmt.annotation.value, ast.Name) \
and stmt.annotation.value.id == "Union" and \
isinstance(stmt.annotation.slice, ast.Tuple):
args = stmt.annotation.slice.elts
literal_args = [
arg for arg in args
if isinstance(arg, ast.Subscript) and isinstance(
arg.value, ast.Name) and arg.value.id == "Literal"
]
if len(literal_args) > 1:
fail(
f"Field '{field_name}' in {class_node.name} must "
"use a single "
"Literal type. Please use 'Literal[Literal1, "
"Literal2]' instead of 'Union[Literal1, Literal2]'"
".", stmt)
def validate_ast(tree: ast.stmt):
ConfigValidator().visit(tree)
def validate_file(file_path: str):
try:
print(f"validating {file_path} config dataclasses ", end="")
with open(file_path, encoding="utf-8") as f:
source = f.read()
tree = ast.parse(source, filename=file_path)
validate_ast(tree)
except ValueError as e:
print(e)
SystemExit(2)
else:
print("")
def fail(message: str, node: ast.stmt):
raise ValueError(f"❌ line({node.lineno}): {message}")
def main():
for filename in sys.argv[1:]:
validate_file(filename)
if __name__ == "__main__":
main()

View File

@ -18,7 +18,7 @@ from functools import cached_property
from importlib.util import find_spec
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
Protocol, TypeVar, Union, cast, get_args, get_origin)
Protocol, TypeVar, Union, cast, get_args)
import regex as re
import torch
@ -193,28 +193,10 @@ def config(cls: ConfigT) -> ConfigT:
(i.e. `ConfigT(**json.loads(cli_arg))`). However, if a particular `ConfigT`
requires custom construction from CLI (i.e. `CompilationConfig`), it can
have a `from_cli` method, which will be called instead.
Config validation is performed by the tools/validate_config.py
script, which is invoked during the pre-commit checks.
"""
if not is_dataclass(cls):
raise TypeError("The decorated class must be a dataclass.")
attr_docs = get_attr_docs(cls)
for f in fields(cls):
if f.init and f.default is MISSING and f.default_factory is MISSING:
raise ValueError(
f"Field '{f.name}' in {cls.__name__} must have a default value."
)
if f.name not in attr_docs:
raise ValueError(
f"Field '{f.name}' in {cls.__name__} must have a docstring.")
if get_origin(f.type) is Union:
args = get_args(f.type)
literal_args = [arg for arg in args if get_origin(arg) is Literal]
if len(literal_args) > 1:
raise ValueError(
f"Field '{f.name}' in {cls.__name__} must use a single "
"Literal type. Please use 'Literal[Literal1, Literal2]' "
"instead of 'Union[Literal1, Literal2]'.")
return cls
@ -1798,7 +1780,7 @@ class ParallelConfig:
eplb_step_interval: int = 3000
"""
Interval for rearranging experts in expert parallelism.
Note that if this is greater than the EPLB window size, only the metrics
of the last `eplb_window_size` steps will be used for rearranging experts.
"""