mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[dynamo] lintrunner for gb_registry adds/updates (#158460)
This PR adds automation to adding/updating the JSON registry through the lintrunner. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158460 Approved by: https://github.com/williamwen42
This commit is contained in:
committed by
PyTorch MergeBot
parent
64e8d7d66b
commit
4d5d56a30e
@ -365,7 +365,6 @@ test_dynamo_wrapped_shard() {
|
||||
exit 1
|
||||
fi
|
||||
python tools/dynamo/verify_dynamo.py
|
||||
python tools/dynamo/gb_id_mapping.py verify
|
||||
# PLEASE DO NOT ADD ADDITIONAL EXCLUDES HERE.
|
||||
# Instead, use @skipIfTorchDynamo on your tests.
|
||||
time python test/run_test.py --dynamo \
|
||||
|
@ -1794,3 +1794,12 @@ include_patterns = [
|
||||
'torch/header_only_apis.txt',
|
||||
]
|
||||
is_formatter = false
|
||||
|
||||
|
||||
[[linter]]
|
||||
code = "GB_REGISTRY"
|
||||
include_patterns = ["torch/_dynamo/**/*.py"]
|
||||
command = [
|
||||
"python3",
|
||||
"tools/linter/adapters/gb_registry_linter.py",
|
||||
]
|
||||
|
1
setup.py
1
setup.py
@ -1325,6 +1325,7 @@ def main() -> None:
|
||||
"utils/model_dump/code.js",
|
||||
"utils/model_dump/*.mjs",
|
||||
"_dynamo/graph_break_registry.json",
|
||||
"tools/dynamo/gb_id_mapping.py",
|
||||
]
|
||||
|
||||
if not BUILD_LIBTORCH_WHL:
|
||||
|
@ -4,7 +4,6 @@ import argparse
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@ -26,7 +25,7 @@ def save_registry(reg, path):
|
||||
|
||||
def next_gb_id(reg):
|
||||
ids = [int(x[2:]) for x in reg if x.startswith("GB") and x[2:].isdigit()]
|
||||
return f"GB{(max(ids, default=0) + 1):04d}"
|
||||
return f"GB{(max(ids, default=-1) + 1):04d}"
|
||||
|
||||
|
||||
def clean_string(s):
|
||||
@ -176,202 +175,6 @@ def find_unimplemented_v2_calls(path, dynamo_dir=None):
|
||||
return results
|
||||
|
||||
|
||||
def cmd_add_new_gb_type(gb_type, file_path, registry_path, additional_info=None):
|
||||
"""
|
||||
Add a new graph break type to the registry.
|
||||
|
||||
Args:
|
||||
gb_type: The graph break type to add
|
||||
file_path: Path to the file containing the unimplemented_v2 call
|
||||
registry_path: Path to the registry JSON file
|
||||
"""
|
||||
registry_path = Path(registry_path)
|
||||
reg = load_registry(registry_path)
|
||||
|
||||
existing_gb_types = {entry[0]["Gb_type"] for entry in reg.values()}
|
||||
if gb_type in existing_gb_types:
|
||||
print(
|
||||
f"Error: gb_type '{gb_type}' already exists in registry. Please rename the gb_type so it can be unique."
|
||||
)
|
||||
return False
|
||||
dynamo_dir = Path(registry_path).parent
|
||||
calls = find_unimplemented_v2_calls(Path(file_path), dynamo_dir)
|
||||
matching_call = next((call for call in calls if call["gb_type"] == gb_type), None)
|
||||
|
||||
if not matching_call:
|
||||
print(
|
||||
f"Error: Could not find unimplemented_v2 call with gb_type '{gb_type}' in {file_path}"
|
||||
)
|
||||
return False
|
||||
|
||||
gb_id = next_gb_id(reg)
|
||||
reg[gb_id] = [
|
||||
{
|
||||
"Gb_type": gb_type,
|
||||
"Context": matching_call["context"],
|
||||
"Explanation": matching_call["explanation"],
|
||||
"Hints": matching_call["hints"] or [],
|
||||
**({"Additional_Info": [additional_info]} if additional_info else {}),
|
||||
}
|
||||
]
|
||||
|
||||
save_registry(reg, registry_path)
|
||||
print(f"Added {gb_type} to registry with ID {gb_id}")
|
||||
return True
|
||||
|
||||
|
||||
def cmd_update_gb_type(
|
||||
old_gb_type,
|
||||
file_path,
|
||||
registry_path,
|
||||
new_gb_type=None,
|
||||
additional_info=None,
|
||||
):
|
||||
"""
|
||||
Update an existing graph break type in the registry by adding a new version
|
||||
to the version history list.
|
||||
|
||||
Args:
|
||||
old_gb_type: The current graph break type to update
|
||||
file_path: Path to the file containing the updated unimplemented_v2 call
|
||||
registry_path: Path to the registry JSON file
|
||||
new_gb_type: Optional new gb_type name to replace the old one
|
||||
"""
|
||||
registry_path = Path(registry_path)
|
||||
reg = load_registry(registry_path)
|
||||
|
||||
gb_id_map = {entry[0]["Gb_type"]: id for id, entry in reg.items()}
|
||||
gb_id = gb_id_map.get(old_gb_type)
|
||||
|
||||
if gb_id is None:
|
||||
print(f"Error: gb_type '{old_gb_type}' not found in registry.")
|
||||
return False
|
||||
|
||||
search_gb_type = new_gb_type if new_gb_type else old_gb_type
|
||||
dynamo_dir = Path(registry_path).parent
|
||||
calls = find_unimplemented_v2_calls(Path(file_path), dynamo_dir)
|
||||
matching_call = next(
|
||||
(call for call in calls if call["gb_type"] == search_gb_type), None
|
||||
)
|
||||
|
||||
if not matching_call:
|
||||
print(
|
||||
f"Error: Could not find unimplemented_v2 call with gb_type '{search_gb_type}' in {file_path}"
|
||||
)
|
||||
return False
|
||||
|
||||
if (
|
||||
matching_call["gb_type"] != old_gb_type
|
||||
and matching_call["gb_type"] in gb_id_map
|
||||
):
|
||||
print(
|
||||
f"Error: New gb_type '{matching_call['gb_type']}' already exists in registry. Please use a unique gb_type."
|
||||
)
|
||||
return False
|
||||
|
||||
new_entry = {
|
||||
"Gb_type": matching_call["gb_type"],
|
||||
"Context": matching_call["context"],
|
||||
"Explanation": matching_call["explanation"],
|
||||
"Hints": matching_call["hints"] or [],
|
||||
}
|
||||
|
||||
if additional_info:
|
||||
additional_info_list = reg[gb_id][0].get("Additional_Info", [])
|
||||
new_entry["Additional_Info"] = (
|
||||
additional_info_list + [additional_info]
|
||||
if additional_info_list
|
||||
else [additional_info]
|
||||
)
|
||||
elif "Additional_Info" in reg[gb_id][0]:
|
||||
new_entry["Additional_Info"] = reg[gb_id][0]["Additional_Info"]
|
||||
|
||||
reg[gb_id].insert(0, new_entry)
|
||||
|
||||
save_registry(reg, registry_path)
|
||||
print(
|
||||
f"Updated {old_gb_type} to {matching_call['gb_type']} in registry with ID {gb_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def test_verify_gb_id_mapping(dynamo_dir, registry_path):
|
||||
"""
|
||||
Verifies that all unimplemented_v2 calls in torch/_dynamo match entries in the registry.
|
||||
"""
|
||||
script_dir = Path(__file__).resolve().parent
|
||||
dynamo_dir = script_dir.parent.parent / "torch" / "_dynamo"
|
||||
registry_path = (
|
||||
script_dir.parent.parent / "torch" / "_dynamo" / "graph_break_registry.json"
|
||||
)
|
||||
|
||||
python_files = list(dynamo_dir.glob("**/*.py"))
|
||||
|
||||
reg = load_registry(registry_path)
|
||||
gb_type_to_entry = {entries[0]["Gb_type"]: entries[0] for _, entries in reg.items()}
|
||||
|
||||
mismatches = []
|
||||
for file_path in python_files:
|
||||
calls = find_unimplemented_v2_calls(file_path, dynamo_dir)
|
||||
for call in calls:
|
||||
gb_type = call["gb_type"]
|
||||
if gb_type not in gb_type_to_entry:
|
||||
mismatches.append((gb_type, file_path, "Not found in registry"))
|
||||
continue
|
||||
|
||||
entry = gb_type_to_entry[gb_type]
|
||||
if call["context"] != entry["Context"]:
|
||||
mismatches.append((gb_type, file_path, "Context mismatch"))
|
||||
elif call["explanation"] != entry["Explanation"]:
|
||||
mismatches.append((gb_type, file_path, "Explanation mismatch"))
|
||||
elif sorted(call["hints"]) != sorted(entry["Hints"]):
|
||||
mismatches.append((gb_type, file_path, "Hints mismatch"))
|
||||
|
||||
if mismatches:
|
||||
print(
|
||||
"Found the unimplemented_v2 or unimplemented_v2_with_warning calls below that "
|
||||
"don't match the registry in graph_break_registry.json."
|
||||
)
|
||||
for gb_type, file_path, reason in mismatches:
|
||||
print(f" - {gb_type} in {file_path}: {reason}")
|
||||
|
||||
print("Please update the registry using one of these commands:")
|
||||
|
||||
print(
|
||||
"- If you added a new callsite: python tools/dynamo/gb_id_mapping.py add "
|
||||
'"GB_TYPE" PATH_TO_FILE --additional-info "INFO"'
|
||||
)
|
||||
|
||||
print(
|
||||
" • GB_TYPE: The graph break type string used in your unimplemented_v2 call"
|
||||
" • PATH_TO_FILE: Path to the file containing your new unimplemented_v2 call"
|
||||
" • --additional-info: Optional extra information to include in the registry entry"
|
||||
)
|
||||
|
||||
print(
|
||||
'- If you updated an existing callsite: python tools/dynamo/gb_id_mapping.py update "GB_TYPE" PATH_TO_FILE '
|
||||
'--new_gb_type "NEW_NAME" --additional-info "INFO"'
|
||||
)
|
||||
print(" • GB_TYPE: The original graph break type to update")
|
||||
print(" • PATH_TO_FILE: Path to the file containing the updated call")
|
||||
print(" • --new_gb_type: New name if you changed the graph break type")
|
||||
print(" • --additional-info: Optional extra information to add")
|
||||
print(
|
||||
"- Recreate registry (Only do this if a complete reset is needed): python tools/dynamo/gb_id_mapping.py create"
|
||||
)
|
||||
print(
|
||||
"If you have also wrote a test for the new graph break, please update the test as well "
|
||||
"using EXPECTTEST_ACCEPT=1 so the message includes the respective webpage "
|
||||
)
|
||||
print(
|
||||
"Note: If you've reset the entire registry file, you can force push to bypass this check."
|
||||
)
|
||||
return False
|
||||
|
||||
print("All unimplemented_v2 calls match the registry.")
|
||||
return True
|
||||
|
||||
|
||||
def create_registry(dynamo_dir, registry_path):
|
||||
calls = find_unimplemented_v2_calls(dynamo_dir)
|
||||
registry = {}
|
||||
@ -420,40 +223,6 @@ def main():
|
||||
help="Directory to search for unimplemented_v2 calls.",
|
||||
)
|
||||
|
||||
add_parser = subparsers.add_parser("add", help="Add a gb_type to registry")
|
||||
add_parser.add_argument("gb_type", help="The gb_type to add")
|
||||
add_parser.add_argument(
|
||||
"file_path", help="Path to the file containing the unimplemented_v2 call"
|
||||
)
|
||||
add_parser.add_argument(
|
||||
"--additional-info", help="Optional additional information to include"
|
||||
)
|
||||
|
||||
update_parser = subparsers.add_parser(
|
||||
"update", help="Update an existing gb_type in registry"
|
||||
)
|
||||
update_parser.add_argument("gb_type", help="The gb_type to update")
|
||||
update_parser.add_argument(
|
||||
"file_path",
|
||||
help="Path to the file containing the updated unimplemented_v2 call",
|
||||
)
|
||||
update_parser.add_argument(
|
||||
"--new_gb_type", help="New gb_type name if it has changed", default=None
|
||||
)
|
||||
update_parser.add_argument(
|
||||
"--additional-info", help="Optional additional information to include"
|
||||
)
|
||||
|
||||
verify_parser = subparsers.add_parser(
|
||||
"verify", help="Verify all unimplemented_v2 calls match registry entries"
|
||||
)
|
||||
verify_parser.add_argument(
|
||||
"--dynamo_dir",
|
||||
type=str,
|
||||
default=default_dynamo_dir,
|
||||
help="Directory to search for unimplemented_v2 calls.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--registry-path",
|
||||
type=str,
|
||||
@ -465,26 +234,6 @@ def main():
|
||||
|
||||
if args.command == "create":
|
||||
create_registry(args.dynamo_dir, args.registry_path)
|
||||
elif args.command == "add":
|
||||
success = cmd_add_new_gb_type(
|
||||
args.gb_type, args.file_path, args.registry_path, args.additional_info
|
||||
)
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
elif args.command == "update":
|
||||
success = cmd_update_gb_type(
|
||||
args.gb_type,
|
||||
args.file_path,
|
||||
args.registry_path,
|
||||
args.new_gb_type,
|
||||
args.additional_info,
|
||||
)
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
elif args.command == "verify":
|
||||
success = test_verify_gb_id_mapping(args.dynamo_dir, args.registry_path)
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
293
tools/linter/adapters/gb_registry_linter.py
Normal file
293
tools/linter/adapters/gb_registry_linter.py
Normal file
@ -0,0 +1,293 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[3]
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
|
||||
from tools.dynamo.gb_id_mapping import (
|
||||
find_unimplemented_v2_calls,
|
||||
load_registry,
|
||||
next_gb_id,
|
||||
)
|
||||
|
||||
|
||||
LINTER_CODE = "GB_REGISTRY"
|
||||
|
||||
|
||||
class LintSeverity(str, Enum):
|
||||
ERROR = "error"
|
||||
WARNING = "warning"
|
||||
ADVICE = "advice"
|
||||
DISABLED = "disabled"
|
||||
|
||||
|
||||
class LintMessage(NamedTuple):
|
||||
path: str | None
|
||||
line: int | None
|
||||
char: int | None
|
||||
code: str
|
||||
severity: LintSeverity
|
||||
name: str
|
||||
original: str | None
|
||||
replacement: str | None
|
||||
description: str | None
|
||||
|
||||
|
||||
def _collect_all_calls(
|
||||
dynamo_dir: Path,
|
||||
) -> dict[str, list[tuple[dict[str, Any], Path]]]:
|
||||
"""Return mapping *gb_type → list[(call_info, file_path)]* for all occurrences."""
|
||||
gb_type_calls: dict[str, list[tuple[dict[str, Any], Path]]] = {}
|
||||
|
||||
for py_file in dynamo_dir.rglob("*.py"):
|
||||
for call in find_unimplemented_v2_calls(py_file, dynamo_dir):
|
||||
gb_type = call["gb_type"]
|
||||
if gb_type not in gb_type_calls:
|
||||
gb_type_calls[gb_type] = []
|
||||
gb_type_calls[gb_type].append((call, py_file))
|
||||
|
||||
return gb_type_calls
|
||||
|
||||
|
||||
def _create_registry_entry(
|
||||
gb_type: str, context: str, explanation: str, hints: list[str]
|
||||
) -> dict[str, Any]:
|
||||
"""Create a registry entry with consistent format."""
|
||||
return {
|
||||
"Gb_type": gb_type,
|
||||
"Context": context,
|
||||
"Explanation": explanation,
|
||||
"Hints": hints or [],
|
||||
}
|
||||
|
||||
|
||||
def _update_registry_with_changes(
|
||||
registry: dict,
|
||||
calls: dict[str, tuple[dict[str, Any], Path]],
|
||||
renames: dict[str, str] | None = None,
|
||||
) -> dict:
|
||||
"""Calculate what the updated registry should look like."""
|
||||
renames = renames or {}
|
||||
updated_registry = dict(registry)
|
||||
|
||||
latest_entry: dict[str, Any] = {
|
||||
entries[0]["Gb_type"]: entries[0] for entries in registry.values()
|
||||
}
|
||||
gb_type_to_key: dict[str, str] = {
|
||||
entries[0]["Gb_type"]: key for key, entries in registry.items()
|
||||
}
|
||||
|
||||
# Method for determining add vs. update:
|
||||
# - If gb_type exists in registry but content differs: UPDATE (append new entry to preserve history)
|
||||
# - If gb_type is new but content matches existing entry: RENAME (append new entry with new gb_type)
|
||||
# - If gb_type is completely new: ADD (create new registry entry with a new GBID)
|
||||
|
||||
for old_gb_type, new_gb_type in renames.items():
|
||||
registry_key = gb_type_to_key[old_gb_type]
|
||||
old_entry = updated_registry[registry_key][0]
|
||||
|
||||
new_entry = _create_registry_entry(
|
||||
new_gb_type,
|
||||
old_entry["Context"],
|
||||
old_entry["Explanation"],
|
||||
old_entry["Hints"],
|
||||
)
|
||||
updated_registry[registry_key] = [new_entry] + updated_registry[registry_key]
|
||||
|
||||
latest_entry[new_gb_type] = new_entry
|
||||
gb_type_to_key[new_gb_type] = registry_key
|
||||
del latest_entry[old_gb_type]
|
||||
del gb_type_to_key[old_gb_type]
|
||||
|
||||
for gb_type, (call, file_path) in calls.items():
|
||||
if gb_type in latest_entry:
|
||||
existing_entry = latest_entry[gb_type]
|
||||
|
||||
if not (
|
||||
call["context"] == existing_entry["Context"]
|
||||
and call["explanation"] == existing_entry["Explanation"]
|
||||
and sorted(call["hints"]) == sorted(existing_entry["Hints"])
|
||||
):
|
||||
registry_key = gb_type_to_key[gb_type]
|
||||
new_entry = _create_registry_entry(
|
||||
gb_type, call["context"], call["explanation"], call["hints"]
|
||||
)
|
||||
updated_registry[registry_key] = [new_entry] + updated_registry[
|
||||
registry_key
|
||||
]
|
||||
else:
|
||||
new_key = next_gb_id(updated_registry)
|
||||
new_entry = _create_registry_entry(
|
||||
gb_type, call["context"], call["explanation"], call["hints"]
|
||||
)
|
||||
updated_registry[new_key] = [new_entry]
|
||||
|
||||
return updated_registry
|
||||
|
||||
|
||||
def check_registry_sync(dynamo_dir: Path, registry_path: Path) -> list[LintMessage]:
|
||||
"""Check registry sync and return lint messages."""
|
||||
lint_messages = []
|
||||
|
||||
all_calls = _collect_all_calls(dynamo_dir)
|
||||
|
||||
duplicates = []
|
||||
for gb_type, call_list in all_calls.items():
|
||||
if len(call_list) > 1:
|
||||
first_call = call_list[0][0]
|
||||
for call, file_path in call_list[1:]:
|
||||
if (
|
||||
call["context"] != first_call["context"]
|
||||
or call["explanation"] != first_call["explanation"]
|
||||
or sorted(call["hints"]) != sorted(first_call["hints"])
|
||||
):
|
||||
duplicates.append({"gb_type": gb_type, "calls": call_list})
|
||||
break
|
||||
|
||||
for dup in duplicates:
|
||||
gb_type = dup["gb_type"]
|
||||
calls = dup["calls"]
|
||||
|
||||
description = f"The gb_type '{gb_type}' is used {len(calls)} times with different content. "
|
||||
description += "Each gb_type must be unique across your entire codebase."
|
||||
|
||||
lint_messages.append(
|
||||
LintMessage(
|
||||
path=str(calls[0][1]),
|
||||
line=None,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.ERROR,
|
||||
name="Duplicate gb_type",
|
||||
original=None,
|
||||
replacement=None,
|
||||
description=description,
|
||||
)
|
||||
)
|
||||
|
||||
if duplicates:
|
||||
return lint_messages
|
||||
|
||||
calls = {gb_type: calls[0] for gb_type, calls in all_calls.items()}
|
||||
|
||||
registry = load_registry(registry_path)
|
||||
latest_entry: dict[str, Any] = {
|
||||
entries[0]["Gb_type"]: entries[0] for entries in registry.values()
|
||||
}
|
||||
|
||||
renames: dict[str, str] = {}
|
||||
remaining_calls = dict(calls)
|
||||
|
||||
for gb_type, (call, file_path) in calls.items():
|
||||
if gb_type not in latest_entry:
|
||||
for existing_gb_type, existing_entry in latest_entry.items():
|
||||
if (
|
||||
call["context"] == existing_entry["Context"]
|
||||
and call["explanation"] == existing_entry["Explanation"]
|
||||
and sorted(call["hints"]) == sorted(existing_entry["Hints"])
|
||||
):
|
||||
renames[existing_gb_type] = gb_type
|
||||
del remaining_calls[gb_type]
|
||||
break
|
||||
|
||||
needs_update = bool(renames)
|
||||
|
||||
for gb_type, (call, file_path) in remaining_calls.items():
|
||||
if gb_type in latest_entry:
|
||||
existing_entry = latest_entry[gb_type]
|
||||
|
||||
if not (
|
||||
call["context"] == existing_entry["Context"]
|
||||
and call["explanation"] == existing_entry["Explanation"]
|
||||
and sorted(call["hints"] or []) == sorted(existing_entry["Hints"] or [])
|
||||
):
|
||||
needs_update = True
|
||||
break
|
||||
else:
|
||||
needs_update = True
|
||||
break
|
||||
|
||||
if needs_update:
|
||||
updated_registry = _update_registry_with_changes(
|
||||
registry, remaining_calls, renames
|
||||
)
|
||||
|
||||
original_content = registry_path.read_text(encoding="utf-8")
|
||||
|
||||
replacement_content = (
|
||||
json.dumps(updated_registry, indent=2, ensure_ascii=False) + "\n"
|
||||
)
|
||||
|
||||
changes = []
|
||||
if renames:
|
||||
for old, new in renames.items():
|
||||
changes.append(f"renamed '{old}' → '{new}'")
|
||||
if remaining_calls:
|
||||
new_count = sum(
|
||||
1 for gb_type in remaining_calls if gb_type not in latest_entry
|
||||
)
|
||||
if new_count:
|
||||
changes.append(f"added {new_count} new gb_types")
|
||||
|
||||
description = f"Registry sync needed ({', '.join(changes)}). Run `lintrunner -a` to apply changes."
|
||||
|
||||
lint_messages.append(
|
||||
LintMessage(
|
||||
path=str(registry_path),
|
||||
line=None,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.WARNING,
|
||||
name="Registry sync needed",
|
||||
original=original_content,
|
||||
replacement=replacement_content,
|
||||
description=description,
|
||||
)
|
||||
)
|
||||
|
||||
return lint_messages
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
script_dir = Path(__file__).resolve()
|
||||
repo_root = script_dir.parents[3]
|
||||
default_registry_path = (
|
||||
repo_root / "torch" / "_dynamo" / "graph_break_registry.json"
|
||||
)
|
||||
|
||||
default_dynamo_dir = repo_root / "torch" / "_dynamo"
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Auto-sync graph break registry with source code"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dynamo-dir",
|
||||
type=Path,
|
||||
default=default_dynamo_dir,
|
||||
help=f"Path to the dynamo directory (default: {default_dynamo_dir})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--registry-path",
|
||||
type=Path,
|
||||
default=default_registry_path,
|
||||
help=f"Path to the registry file (default: {default_registry_path})",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
lint_messages = check_registry_sync(
|
||||
dynamo_dir=args.dynamo_dir, registry_path=args.registry_path
|
||||
)
|
||||
|
||||
for lint_message in lint_messages:
|
||||
print(json.dumps(lint_message._asdict()), flush=True)
|
397
tools/test/test_gb_registry_linter.py
Normal file
397
tools/test/test_gb_registry_linter.py
Normal file
@ -0,0 +1,397 @@
|
||||
# mypy: ignore-errors
|
||||
import json
|
||||
import shutil
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from tools.linter.adapters.gb_registry_linter import (
|
||||
check_registry_sync,
|
||||
LINTER_CODE,
|
||||
LintMessage,
|
||||
LintSeverity,
|
||||
)
|
||||
|
||||
|
||||
class TestGraphBreakRegistryLinter(unittest.TestCase):
|
||||
"""
|
||||
Test the graph break registry linter functionality
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
script_dir = Path(__file__).resolve()
|
||||
self.test_data_dir = script_dir.parent / "graph_break_registry_linter_testdata"
|
||||
self.test_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.registry_path = self.test_data_dir / "graph_break_test_registry.json"
|
||||
with open(self.registry_path, "w") as f:
|
||||
json.dump({}, f)
|
||||
|
||||
self.callsite_file = self.test_data_dir / "callsite_test.py"
|
||||
callsite_content = """from torch._dynamo.exc import unimplemented_v2
|
||||
|
||||
def test(self):
|
||||
unimplemented_v2(
|
||||
gb_type="testing",
|
||||
context="testing",
|
||||
explanation="testing",
|
||||
hints=["testing"],
|
||||
)
|
||||
"""
|
||||
with open(self.callsite_file, "w") as f:
|
||||
f.write(callsite_content)
|
||||
|
||||
def tearDown(self):
|
||||
if self.test_data_dir.exists():
|
||||
shutil.rmtree(self.test_data_dir)
|
||||
|
||||
def test_case1_new_gb_type(self):
|
||||
"""Test Case 1: Adding a completely new gb_type to an empty registry."""
|
||||
with open(self.registry_path) as f:
|
||||
original_content = f.read()
|
||||
|
||||
messages = check_registry_sync(self.test_data_dir, self.registry_path)
|
||||
|
||||
expected_registry = {
|
||||
"GB0000": [
|
||||
{
|
||||
"Gb_type": "testing",
|
||||
"Context": "testing",
|
||||
"Explanation": "testing",
|
||||
"Hints": ["testing"],
|
||||
}
|
||||
]
|
||||
}
|
||||
expected_replacement = (
|
||||
json.dumps(expected_registry, indent=2, ensure_ascii=False) + "\n"
|
||||
)
|
||||
expected_msg = LintMessage(
|
||||
path=str(self.registry_path),
|
||||
line=None,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.WARNING,
|
||||
name="Registry sync needed",
|
||||
original=original_content,
|
||||
replacement=expected_replacement,
|
||||
description="Registry sync needed (added 1 new gb_types). Run `lintrunner -a` to apply changes.",
|
||||
)
|
||||
self.assertEqual(messages, [expected_msg])
|
||||
|
||||
if messages and messages[0].replacement:
|
||||
with open(self.registry_path, "w") as f:
|
||||
f.write(messages[0].replacement)
|
||||
|
||||
messages_after_fix = check_registry_sync(self.test_data_dir, self.registry_path)
|
||||
self.assertEqual(
|
||||
len(messages_after_fix), 0, "Should have no messages after applying the fix"
|
||||
)
|
||||
|
||||
def test_case2_rename_gb_type(self):
|
||||
"""Test Case 2: Renaming a gb_type while keeping other content the same."""
|
||||
registry_data = {
|
||||
"GB0000": [
|
||||
{
|
||||
"Gb_type": "testing",
|
||||
"Context": "testing",
|
||||
"Explanation": "testing",
|
||||
"Hints": ["testing"],
|
||||
}
|
||||
]
|
||||
}
|
||||
with open(self.registry_path, "w") as f:
|
||||
json.dump(registry_data, f, indent=2)
|
||||
|
||||
renamed_callsite_content = """from torch._dynamo.exc import unimplemented_v2
|
||||
def test(self):
|
||||
unimplemented_v2(gb_type="renamed_testing", context="testing", explanation="testing", hints=["testing"])
|
||||
"""
|
||||
with open(self.callsite_file, "w") as f:
|
||||
f.write(renamed_callsite_content)
|
||||
|
||||
with open(self.registry_path) as f:
|
||||
original_content = f.read()
|
||||
|
||||
messages = check_registry_sync(self.test_data_dir, self.registry_path)
|
||||
expected_registry = {
|
||||
"GB0000": [
|
||||
{
|
||||
"Gb_type": "renamed_testing",
|
||||
"Context": "testing",
|
||||
"Explanation": "testing",
|
||||
"Hints": ["testing"],
|
||||
},
|
||||
{
|
||||
"Gb_type": "testing",
|
||||
"Context": "testing",
|
||||
"Explanation": "testing",
|
||||
"Hints": ["testing"],
|
||||
},
|
||||
]
|
||||
}
|
||||
expected_replacement = (
|
||||
json.dumps(expected_registry, indent=2, ensure_ascii=False) + "\n"
|
||||
)
|
||||
expected_msg = LintMessage(
|
||||
path=str(self.registry_path),
|
||||
line=None,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.WARNING,
|
||||
name="Registry sync needed",
|
||||
original=original_content,
|
||||
replacement=expected_replacement,
|
||||
description="Registry sync needed (renamed 'testing' → 'renamed_testing'). Run `lintrunner -a` to apply changes.",
|
||||
)
|
||||
self.assertEqual(messages, [expected_msg])
|
||||
|
||||
if messages and messages[0].replacement:
|
||||
with open(self.registry_path, "w") as f:
|
||||
f.write(messages[0].replacement)
|
||||
|
||||
messages_after_fix = check_registry_sync(self.test_data_dir, self.registry_path)
|
||||
self.assertEqual(
|
||||
len(messages_after_fix), 0, "Should have no messages after applying the fix"
|
||||
)
|
||||
|
||||
def test_case3_content_change(self):
|
||||
"""Test Case 3: Changing the content of an existing gb_type."""
|
||||
registry_data = {
|
||||
"GB0000": [
|
||||
{
|
||||
"Gb_type": "testing",
|
||||
"Context": "old_context",
|
||||
"Explanation": "old_explanation",
|
||||
"Hints": ["old_hint"],
|
||||
}
|
||||
]
|
||||
}
|
||||
with open(self.registry_path, "w") as f:
|
||||
json.dump(registry_data, f, indent=2)
|
||||
|
||||
updated_callsite_content = """from torch._dynamo.exc import unimplemented_v2
|
||||
def test(self):
|
||||
unimplemented_v2(gb_type="testing", context="new_context", explanation="new_explanation", hints=["new_hint"])
|
||||
"""
|
||||
with open(self.callsite_file, "w") as f:
|
||||
f.write(updated_callsite_content)
|
||||
|
||||
with open(self.registry_path) as f:
|
||||
original_content = f.read()
|
||||
|
||||
messages = check_registry_sync(self.test_data_dir, self.registry_path)
|
||||
expected_registry = {
|
||||
"GB0000": [
|
||||
{
|
||||
"Gb_type": "testing",
|
||||
"Context": "new_context",
|
||||
"Explanation": "new_explanation",
|
||||
"Hints": ["new_hint"],
|
||||
},
|
||||
{
|
||||
"Gb_type": "testing",
|
||||
"Context": "old_context",
|
||||
"Explanation": "old_explanation",
|
||||
"Hints": ["old_hint"],
|
||||
},
|
||||
]
|
||||
}
|
||||
expected_replacement = (
|
||||
json.dumps(expected_registry, indent=2, ensure_ascii=False) + "\n"
|
||||
)
|
||||
expected_msg = LintMessage(
|
||||
path=str(self.registry_path),
|
||||
line=None,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.WARNING,
|
||||
name="Registry sync needed",
|
||||
original=original_content,
|
||||
replacement=expected_replacement,
|
||||
description="Registry sync needed (). Run `lintrunner -a` to apply changes.",
|
||||
)
|
||||
self.assertEqual(messages, [expected_msg])
|
||||
|
||||
if messages and messages[0].replacement:
|
||||
with open(self.registry_path, "w") as f:
|
||||
f.write(messages[0].replacement)
|
||||
|
||||
messages_after_fix = check_registry_sync(self.test_data_dir, self.registry_path)
|
||||
self.assertEqual(
|
||||
len(messages_after_fix), 0, "Should have no messages after applying the fix"
|
||||
)
|
||||
|
||||
def test_case4_no_changes(self):
|
||||
"""Test Case 4: Ensuring no message is produced when the registry is in sync."""
|
||||
registry_data = {
|
||||
"GB0000": [
|
||||
{
|
||||
"Gb_type": "testing",
|
||||
"Context": "testing",
|
||||
"Explanation": "testing",
|
||||
"Hints": ["testing"],
|
||||
}
|
||||
]
|
||||
}
|
||||
with open(self.registry_path, "w") as f:
|
||||
json.dump(registry_data, f, indent=2)
|
||||
|
||||
messages = check_registry_sync(self.test_data_dir, self.registry_path)
|
||||
self.assertEqual(
|
||||
len(messages), 0, "Should have no messages when registry is already in sync"
|
||||
)
|
||||
|
||||
def test_case5_new_gbid_on_full_change(self):
|
||||
"""Test Case 5: A completely new entry should get a new GB ID."""
|
||||
registry_data = {
|
||||
"GB0000": [
|
||||
{
|
||||
"Gb_type": "original_testing",
|
||||
"Context": "original_context",
|
||||
"Explanation": "original_explanation",
|
||||
"Hints": ["original_hint"],
|
||||
}
|
||||
]
|
||||
}
|
||||
with open(self.registry_path, "w") as f:
|
||||
json.dump(registry_data, f, indent=2)
|
||||
|
||||
new_callsite_content = """from torch._dynamo.exc import unimplemented_v2
|
||||
def test(self):
|
||||
unimplemented_v2(
|
||||
gb_type="completely_new_testing",
|
||||
context="completely_new_context",
|
||||
explanation="completely_new_explanation",
|
||||
hints=["completely_new_hint"],
|
||||
)
|
||||
"""
|
||||
with open(self.callsite_file, "w") as f:
|
||||
f.write(new_callsite_content)
|
||||
|
||||
with open(self.registry_path) as f:
|
||||
original_content = f.read()
|
||||
|
||||
messages = check_registry_sync(self.test_data_dir, self.registry_path)
|
||||
expected_registry = {
|
||||
"GB0000": [
|
||||
{
|
||||
"Gb_type": "original_testing",
|
||||
"Context": "original_context",
|
||||
"Explanation": "original_explanation",
|
||||
"Hints": ["original_hint"],
|
||||
}
|
||||
],
|
||||
"GB0001": [
|
||||
{
|
||||
"Gb_type": "completely_new_testing",
|
||||
"Context": "completely_new_context",
|
||||
"Explanation": "completely_new_explanation",
|
||||
"Hints": ["completely_new_hint"],
|
||||
}
|
||||
],
|
||||
}
|
||||
expected_replacement = (
|
||||
json.dumps(expected_registry, indent=2, ensure_ascii=False) + "\n"
|
||||
)
|
||||
expected_msg = LintMessage(
|
||||
path=str(self.registry_path),
|
||||
line=None,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.WARNING,
|
||||
name="Registry sync needed",
|
||||
original=original_content,
|
||||
replacement=expected_replacement,
|
||||
description="Registry sync needed (added 1 new gb_types). Run `lintrunner -a` to apply changes.",
|
||||
)
|
||||
self.assertEqual(messages, [expected_msg])
|
||||
|
||||
# Apply the fix and verify the file's final state
|
||||
if messages and messages[0].replacement:
|
||||
with open(self.registry_path, "w") as f:
|
||||
f.write(messages[0].replacement)
|
||||
|
||||
messages_after_fix = check_registry_sync(self.test_data_dir, self.registry_path)
|
||||
self.assertEqual(
|
||||
len(messages_after_fix), 0, "Should have no messages after applying the fix"
|
||||
)
|
||||
|
||||
def test_case6_dynamic_hints_from_variable(self):
|
||||
"""Test Case 6: Verifies hints can be unpacked from an imported variable."""
|
||||
mock_hints_file = self.test_data_dir / "graph_break_hints.py"
|
||||
init_py = self.test_data_dir / "__init__.py"
|
||||
try:
|
||||
supportable_string = (
|
||||
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you "
|
||||
"encounter this graph break often and it is causing performance issues."
|
||||
)
|
||||
mock_hints_content = f'SUPPORTABLE = ["{supportable_string}"]'
|
||||
with open(mock_hints_file, "w") as f:
|
||||
f.write(mock_hints_content)
|
||||
|
||||
init_py.touch()
|
||||
|
||||
dynamic_hints_callsite = """from torch._dynamo.exc import unimplemented_v2
|
||||
from torch._dynamo import graph_break_hints
|
||||
|
||||
def test(self):
|
||||
unimplemented_v2(
|
||||
gb_type="testing_with_graph_break_hints",
|
||||
context="testing_with_graph_break_hints",
|
||||
explanation="testing_with_graph_break_hints",
|
||||
hints=[*graph_break_hints.SUPPORTABLE],
|
||||
)
|
||||
"""
|
||||
with open(self.callsite_file, "w") as f:
|
||||
f.write(dynamic_hints_callsite)
|
||||
|
||||
with open(self.registry_path) as f:
|
||||
original_content = f.read()
|
||||
|
||||
messages = check_registry_sync(self.test_data_dir, self.registry_path)
|
||||
|
||||
expected_registry = {
|
||||
"GB0000": [
|
||||
{
|
||||
"Gb_type": "testing_with_graph_break_hints",
|
||||
"Context": "testing_with_graph_break_hints",
|
||||
"Explanation": "testing_with_graph_break_hints",
|
||||
"Hints": [supportable_string],
|
||||
}
|
||||
]
|
||||
}
|
||||
expected_replacement = (
|
||||
json.dumps(expected_registry, indent=2, ensure_ascii=False) + "\n"
|
||||
)
|
||||
expected_msg = LintMessage(
|
||||
path=str(self.registry_path),
|
||||
line=None,
|
||||
char=None,
|
||||
code=LINTER_CODE,
|
||||
severity=LintSeverity.WARNING,
|
||||
name="Registry sync needed",
|
||||
original=original_content,
|
||||
replacement=expected_replacement,
|
||||
description="Registry sync needed (added 1 new gb_types). Run `lintrunner -a` to apply changes.",
|
||||
)
|
||||
|
||||
self.assertEqual(messages, [expected_msg])
|
||||
|
||||
if messages and messages[0].replacement:
|
||||
with open(self.registry_path, "w") as f:
|
||||
f.write(messages[0].replacement)
|
||||
|
||||
messages_after_fix = check_registry_sync(
|
||||
self.test_data_dir, self.registry_path
|
||||
)
|
||||
self.assertEqual(
|
||||
len(messages_after_fix),
|
||||
0,
|
||||
"Should have no messages after applying the fix",
|
||||
)
|
||||
finally:
|
||||
mock_hints_file.unlink()
|
||||
init_py.unlink()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -2510,4 +2510,4 @@
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user