[dynamo] lintrunner for gb_registry adds/updates (#158460)

This PR adds automation to adding/updating the JSON registry through the lintrunner.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158460
Approved by: https://github.com/williamwen42
This commit is contained in:
Sidharth
2025-07-23 12:05:49 -07:00
committed by PyTorch MergeBot
parent 64e8d7d66b
commit 4d5d56a30e
7 changed files with 702 additions and 254 deletions

View File

@ -365,7 +365,6 @@ test_dynamo_wrapped_shard() {
exit 1
fi
python tools/dynamo/verify_dynamo.py
python tools/dynamo/gb_id_mapping.py verify
# PLEASE DO NOT ADD ADDITIONAL EXCLUDES HERE.
# Instead, use @skipIfTorchDynamo on your tests.
time python test/run_test.py --dynamo \

View File

@ -1794,3 +1794,12 @@ include_patterns = [
'torch/header_only_apis.txt',
]
is_formatter = false
[[linter]]
code = "GB_REGISTRY"
include_patterns = ["torch/_dynamo/**/*.py"]
command = [
"python3",
"tools/linter/adapters/gb_registry_linter.py",
]

View File

@ -1325,6 +1325,7 @@ def main() -> None:
"utils/model_dump/code.js",
"utils/model_dump/*.mjs",
"_dynamo/graph_break_registry.json",
"tools/dynamo/gb_id_mapping.py",
]
if not BUILD_LIBTORCH_WHL:

View File

@ -4,7 +4,6 @@ import argparse
import ast
import json
import re
import sys
from pathlib import Path
@ -26,7 +25,7 @@ def save_registry(reg, path):
def next_gb_id(reg):
ids = [int(x[2:]) for x in reg if x.startswith("GB") and x[2:].isdigit()]
return f"GB{(max(ids, default=0) + 1):04d}"
return f"GB{(max(ids, default=-1) + 1):04d}"
def clean_string(s):
@ -176,202 +175,6 @@ def find_unimplemented_v2_calls(path, dynamo_dir=None):
return results
def cmd_add_new_gb_type(gb_type, file_path, registry_path, additional_info=None):
"""
Add a new graph break type to the registry.
Args:
gb_type: The graph break type to add
file_path: Path to the file containing the unimplemented_v2 call
registry_path: Path to the registry JSON file
"""
registry_path = Path(registry_path)
reg = load_registry(registry_path)
existing_gb_types = {entry[0]["Gb_type"] for entry in reg.values()}
if gb_type in existing_gb_types:
print(
f"Error: gb_type '{gb_type}' already exists in registry. Please rename the gb_type so it can be unique."
)
return False
dynamo_dir = Path(registry_path).parent
calls = find_unimplemented_v2_calls(Path(file_path), dynamo_dir)
matching_call = next((call for call in calls if call["gb_type"] == gb_type), None)
if not matching_call:
print(
f"Error: Could not find unimplemented_v2 call with gb_type '{gb_type}' in {file_path}"
)
return False
gb_id = next_gb_id(reg)
reg[gb_id] = [
{
"Gb_type": gb_type,
"Context": matching_call["context"],
"Explanation": matching_call["explanation"],
"Hints": matching_call["hints"] or [],
**({"Additional_Info": [additional_info]} if additional_info else {}),
}
]
save_registry(reg, registry_path)
print(f"Added {gb_type} to registry with ID {gb_id}")
return True
def cmd_update_gb_type(
old_gb_type,
file_path,
registry_path,
new_gb_type=None,
additional_info=None,
):
"""
Update an existing graph break type in the registry by adding a new version
to the version history list.
Args:
old_gb_type: The current graph break type to update
file_path: Path to the file containing the updated unimplemented_v2 call
registry_path: Path to the registry JSON file
new_gb_type: Optional new gb_type name to replace the old one
"""
registry_path = Path(registry_path)
reg = load_registry(registry_path)
gb_id_map = {entry[0]["Gb_type"]: id for id, entry in reg.items()}
gb_id = gb_id_map.get(old_gb_type)
if gb_id is None:
print(f"Error: gb_type '{old_gb_type}' not found in registry.")
return False
search_gb_type = new_gb_type if new_gb_type else old_gb_type
dynamo_dir = Path(registry_path).parent
calls = find_unimplemented_v2_calls(Path(file_path), dynamo_dir)
matching_call = next(
(call for call in calls if call["gb_type"] == search_gb_type), None
)
if not matching_call:
print(
f"Error: Could not find unimplemented_v2 call with gb_type '{search_gb_type}' in {file_path}"
)
return False
if (
matching_call["gb_type"] != old_gb_type
and matching_call["gb_type"] in gb_id_map
):
print(
f"Error: New gb_type '{matching_call['gb_type']}' already exists in registry. Please use a unique gb_type."
)
return False
new_entry = {
"Gb_type": matching_call["gb_type"],
"Context": matching_call["context"],
"Explanation": matching_call["explanation"],
"Hints": matching_call["hints"] or [],
}
if additional_info:
additional_info_list = reg[gb_id][0].get("Additional_Info", [])
new_entry["Additional_Info"] = (
additional_info_list + [additional_info]
if additional_info_list
else [additional_info]
)
elif "Additional_Info" in reg[gb_id][0]:
new_entry["Additional_Info"] = reg[gb_id][0]["Additional_Info"]
reg[gb_id].insert(0, new_entry)
save_registry(reg, registry_path)
print(
f"Updated {old_gb_type} to {matching_call['gb_type']} in registry with ID {gb_id}"
)
return True
def test_verify_gb_id_mapping(dynamo_dir, registry_path):
"""
Verifies that all unimplemented_v2 calls in torch/_dynamo match entries in the registry.
"""
script_dir = Path(__file__).resolve().parent
dynamo_dir = script_dir.parent.parent / "torch" / "_dynamo"
registry_path = (
script_dir.parent.parent / "torch" / "_dynamo" / "graph_break_registry.json"
)
python_files = list(dynamo_dir.glob("**/*.py"))
reg = load_registry(registry_path)
gb_type_to_entry = {entries[0]["Gb_type"]: entries[0] for _, entries in reg.items()}
mismatches = []
for file_path in python_files:
calls = find_unimplemented_v2_calls(file_path, dynamo_dir)
for call in calls:
gb_type = call["gb_type"]
if gb_type not in gb_type_to_entry:
mismatches.append((gb_type, file_path, "Not found in registry"))
continue
entry = gb_type_to_entry[gb_type]
if call["context"] != entry["Context"]:
mismatches.append((gb_type, file_path, "Context mismatch"))
elif call["explanation"] != entry["Explanation"]:
mismatches.append((gb_type, file_path, "Explanation mismatch"))
elif sorted(call["hints"]) != sorted(entry["Hints"]):
mismatches.append((gb_type, file_path, "Hints mismatch"))
if mismatches:
print(
"Found the unimplemented_v2 or unimplemented_v2_with_warning calls below that "
"don't match the registry in graph_break_registry.json."
)
for gb_type, file_path, reason in mismatches:
print(f" - {gb_type} in {file_path}: {reason}")
print("Please update the registry using one of these commands:")
print(
"- If you added a new callsite: python tools/dynamo/gb_id_mapping.py add "
'"GB_TYPE" PATH_TO_FILE --additional-info "INFO"'
)
print(
" • GB_TYPE: The graph break type string used in your unimplemented_v2 call"
" • PATH_TO_FILE: Path to the file containing your new unimplemented_v2 call"
" • --additional-info: Optional extra information to include in the registry entry"
)
print(
'- If you updated an existing callsite: python tools/dynamo/gb_id_mapping.py update "GB_TYPE" PATH_TO_FILE '
'--new_gb_type "NEW_NAME" --additional-info "INFO"'
)
print(" • GB_TYPE: The original graph break type to update")
print(" • PATH_TO_FILE: Path to the file containing the updated call")
print(" • --new_gb_type: New name if you changed the graph break type")
print(" • --additional-info: Optional extra information to add")
print(
"- Recreate registry (Only do this if a complete reset is needed): python tools/dynamo/gb_id_mapping.py create"
)
print(
"If you have also wrote a test for the new graph break, please update the test as well "
"using EXPECTTEST_ACCEPT=1 so the message includes the respective webpage "
)
print(
"Note: If you've reset the entire registry file, you can force push to bypass this check."
)
return False
print("All unimplemented_v2 calls match the registry.")
return True
def create_registry(dynamo_dir, registry_path):
calls = find_unimplemented_v2_calls(dynamo_dir)
registry = {}
@ -420,40 +223,6 @@ def main():
help="Directory to search for unimplemented_v2 calls.",
)
add_parser = subparsers.add_parser("add", help="Add a gb_type to registry")
add_parser.add_argument("gb_type", help="The gb_type to add")
add_parser.add_argument(
"file_path", help="Path to the file containing the unimplemented_v2 call"
)
add_parser.add_argument(
"--additional-info", help="Optional additional information to include"
)
update_parser = subparsers.add_parser(
"update", help="Update an existing gb_type in registry"
)
update_parser.add_argument("gb_type", help="The gb_type to update")
update_parser.add_argument(
"file_path",
help="Path to the file containing the updated unimplemented_v2 call",
)
update_parser.add_argument(
"--new_gb_type", help="New gb_type name if it has changed", default=None
)
update_parser.add_argument(
"--additional-info", help="Optional additional information to include"
)
verify_parser = subparsers.add_parser(
"verify", help="Verify all unimplemented_v2 calls match registry entries"
)
verify_parser.add_argument(
"--dynamo_dir",
type=str,
default=default_dynamo_dir,
help="Directory to search for unimplemented_v2 calls.",
)
parser.add_argument(
"--registry-path",
type=str,
@ -465,26 +234,6 @@ def main():
if args.command == "create":
create_registry(args.dynamo_dir, args.registry_path)
elif args.command == "add":
success = cmd_add_new_gb_type(
args.gb_type, args.file_path, args.registry_path, args.additional_info
)
if not success:
sys.exit(1)
elif args.command == "update":
success = cmd_update_gb_type(
args.gb_type,
args.file_path,
args.registry_path,
args.new_gb_type,
args.additional_info,
)
if not success:
sys.exit(1)
elif args.command == "verify":
success = test_verify_gb_id_mapping(args.dynamo_dir, args.registry_path)
if not success:
sys.exit(1)
else:
parser.print_help()

View File

@ -0,0 +1,293 @@
# mypy: ignore-errors
from __future__ import annotations
import argparse
import json
import sys
from enum import Enum
from pathlib import Path
from typing import Any, NamedTuple
REPO_ROOT = Path(__file__).resolve().parents[3]
sys.path.insert(0, str(REPO_ROOT))
from tools.dynamo.gb_id_mapping import (
find_unimplemented_v2_calls,
load_registry,
next_gb_id,
)
LINTER_CODE = "GB_REGISTRY"
class LintSeverity(str, Enum):
ERROR = "error"
WARNING = "warning"
ADVICE = "advice"
DISABLED = "disabled"
class LintMessage(NamedTuple):
path: str | None
line: int | None
char: int | None
code: str
severity: LintSeverity
name: str
original: str | None
replacement: str | None
description: str | None
def _collect_all_calls(
dynamo_dir: Path,
) -> dict[str, list[tuple[dict[str, Any], Path]]]:
"""Return mapping *gb_type → list[(call_info, file_path)]* for all occurrences."""
gb_type_calls: dict[str, list[tuple[dict[str, Any], Path]]] = {}
for py_file in dynamo_dir.rglob("*.py"):
for call in find_unimplemented_v2_calls(py_file, dynamo_dir):
gb_type = call["gb_type"]
if gb_type not in gb_type_calls:
gb_type_calls[gb_type] = []
gb_type_calls[gb_type].append((call, py_file))
return gb_type_calls
def _create_registry_entry(
gb_type: str, context: str, explanation: str, hints: list[str]
) -> dict[str, Any]:
"""Create a registry entry with consistent format."""
return {
"Gb_type": gb_type,
"Context": context,
"Explanation": explanation,
"Hints": hints or [],
}
def _update_registry_with_changes(
registry: dict,
calls: dict[str, tuple[dict[str, Any], Path]],
renames: dict[str, str] | None = None,
) -> dict:
"""Calculate what the updated registry should look like."""
renames = renames or {}
updated_registry = dict(registry)
latest_entry: dict[str, Any] = {
entries[0]["Gb_type"]: entries[0] for entries in registry.values()
}
gb_type_to_key: dict[str, str] = {
entries[0]["Gb_type"]: key for key, entries in registry.items()
}
# Method for determining add vs. update:
# - If gb_type exists in registry but content differs: UPDATE (append new entry to preserve history)
# - If gb_type is new but content matches existing entry: RENAME (append new entry with new gb_type)
# - If gb_type is completely new: ADD (create new registry entry with a new GBID)
for old_gb_type, new_gb_type in renames.items():
registry_key = gb_type_to_key[old_gb_type]
old_entry = updated_registry[registry_key][0]
new_entry = _create_registry_entry(
new_gb_type,
old_entry["Context"],
old_entry["Explanation"],
old_entry["Hints"],
)
updated_registry[registry_key] = [new_entry] + updated_registry[registry_key]
latest_entry[new_gb_type] = new_entry
gb_type_to_key[new_gb_type] = registry_key
del latest_entry[old_gb_type]
del gb_type_to_key[old_gb_type]
for gb_type, (call, file_path) in calls.items():
if gb_type in latest_entry:
existing_entry = latest_entry[gb_type]
if not (
call["context"] == existing_entry["Context"]
and call["explanation"] == existing_entry["Explanation"]
and sorted(call["hints"]) == sorted(existing_entry["Hints"])
):
registry_key = gb_type_to_key[gb_type]
new_entry = _create_registry_entry(
gb_type, call["context"], call["explanation"], call["hints"]
)
updated_registry[registry_key] = [new_entry] + updated_registry[
registry_key
]
else:
new_key = next_gb_id(updated_registry)
new_entry = _create_registry_entry(
gb_type, call["context"], call["explanation"], call["hints"]
)
updated_registry[new_key] = [new_entry]
return updated_registry
def check_registry_sync(dynamo_dir: Path, registry_path: Path) -> list[LintMessage]:
"""Check registry sync and return lint messages."""
lint_messages = []
all_calls = _collect_all_calls(dynamo_dir)
duplicates = []
for gb_type, call_list in all_calls.items():
if len(call_list) > 1:
first_call = call_list[0][0]
for call, file_path in call_list[1:]:
if (
call["context"] != first_call["context"]
or call["explanation"] != first_call["explanation"]
or sorted(call["hints"]) != sorted(first_call["hints"])
):
duplicates.append({"gb_type": gb_type, "calls": call_list})
break
for dup in duplicates:
gb_type = dup["gb_type"]
calls = dup["calls"]
description = f"The gb_type '{gb_type}' is used {len(calls)} times with different content. "
description += "Each gb_type must be unique across your entire codebase."
lint_messages.append(
LintMessage(
path=str(calls[0][1]),
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.ERROR,
name="Duplicate gb_type",
original=None,
replacement=None,
description=description,
)
)
if duplicates:
return lint_messages
calls = {gb_type: calls[0] for gb_type, calls in all_calls.items()}
registry = load_registry(registry_path)
latest_entry: dict[str, Any] = {
entries[0]["Gb_type"]: entries[0] for entries in registry.values()
}
renames: dict[str, str] = {}
remaining_calls = dict(calls)
for gb_type, (call, file_path) in calls.items():
if gb_type not in latest_entry:
for existing_gb_type, existing_entry in latest_entry.items():
if (
call["context"] == existing_entry["Context"]
and call["explanation"] == existing_entry["Explanation"]
and sorted(call["hints"]) == sorted(existing_entry["Hints"])
):
renames[existing_gb_type] = gb_type
del remaining_calls[gb_type]
break
needs_update = bool(renames)
for gb_type, (call, file_path) in remaining_calls.items():
if gb_type in latest_entry:
existing_entry = latest_entry[gb_type]
if not (
call["context"] == existing_entry["Context"]
and call["explanation"] == existing_entry["Explanation"]
and sorted(call["hints"] or []) == sorted(existing_entry["Hints"] or [])
):
needs_update = True
break
else:
needs_update = True
break
if needs_update:
updated_registry = _update_registry_with_changes(
registry, remaining_calls, renames
)
original_content = registry_path.read_text(encoding="utf-8")
replacement_content = (
json.dumps(updated_registry, indent=2, ensure_ascii=False) + "\n"
)
changes = []
if renames:
for old, new in renames.items():
changes.append(f"renamed '{old}''{new}'")
if remaining_calls:
new_count = sum(
1 for gb_type in remaining_calls if gb_type not in latest_entry
)
if new_count:
changes.append(f"added {new_count} new gb_types")
description = f"Registry sync needed ({', '.join(changes)}). Run `lintrunner -a` to apply changes."
lint_messages.append(
LintMessage(
path=str(registry_path),
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.WARNING,
name="Registry sync needed",
original=original_content,
replacement=replacement_content,
description=description,
)
)
return lint_messages
if __name__ == "__main__":
script_dir = Path(__file__).resolve()
repo_root = script_dir.parents[3]
default_registry_path = (
repo_root / "torch" / "_dynamo" / "graph_break_registry.json"
)
default_dynamo_dir = repo_root / "torch" / "_dynamo"
parser = argparse.ArgumentParser(
description="Auto-sync graph break registry with source code"
)
parser.add_argument(
"--dynamo-dir",
type=Path,
default=default_dynamo_dir,
help=f"Path to the dynamo directory (default: {default_dynamo_dir})",
)
parser.add_argument(
"--registry-path",
type=Path,
default=default_registry_path,
help=f"Path to the registry file (default: {default_registry_path})",
)
args = parser.parse_args()
lint_messages = check_registry_sync(
dynamo_dir=args.dynamo_dir, registry_path=args.registry_path
)
for lint_message in lint_messages:
print(json.dumps(lint_message._asdict()), flush=True)

View File

@ -0,0 +1,397 @@
# mypy: ignore-errors
import json
import shutil
import unittest
from pathlib import Path
from tools.linter.adapters.gb_registry_linter import (
check_registry_sync,
LINTER_CODE,
LintMessage,
LintSeverity,
)
class TestGraphBreakRegistryLinter(unittest.TestCase):
"""
Test the graph break registry linter functionality
"""
def setUp(self):
script_dir = Path(__file__).resolve()
self.test_data_dir = script_dir.parent / "graph_break_registry_linter_testdata"
self.test_data_dir.mkdir(parents=True, exist_ok=True)
self.registry_path = self.test_data_dir / "graph_break_test_registry.json"
with open(self.registry_path, "w") as f:
json.dump({}, f)
self.callsite_file = self.test_data_dir / "callsite_test.py"
callsite_content = """from torch._dynamo.exc import unimplemented_v2
def test(self):
unimplemented_v2(
gb_type="testing",
context="testing",
explanation="testing",
hints=["testing"],
)
"""
with open(self.callsite_file, "w") as f:
f.write(callsite_content)
def tearDown(self):
if self.test_data_dir.exists():
shutil.rmtree(self.test_data_dir)
def test_case1_new_gb_type(self):
"""Test Case 1: Adding a completely new gb_type to an empty registry."""
with open(self.registry_path) as f:
original_content = f.read()
messages = check_registry_sync(self.test_data_dir, self.registry_path)
expected_registry = {
"GB0000": [
{
"Gb_type": "testing",
"Context": "testing",
"Explanation": "testing",
"Hints": ["testing"],
}
]
}
expected_replacement = (
json.dumps(expected_registry, indent=2, ensure_ascii=False) + "\n"
)
expected_msg = LintMessage(
path=str(self.registry_path),
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.WARNING,
name="Registry sync needed",
original=original_content,
replacement=expected_replacement,
description="Registry sync needed (added 1 new gb_types). Run `lintrunner -a` to apply changes.",
)
self.assertEqual(messages, [expected_msg])
if messages and messages[0].replacement:
with open(self.registry_path, "w") as f:
f.write(messages[0].replacement)
messages_after_fix = check_registry_sync(self.test_data_dir, self.registry_path)
self.assertEqual(
len(messages_after_fix), 0, "Should have no messages after applying the fix"
)
def test_case2_rename_gb_type(self):
"""Test Case 2: Renaming a gb_type while keeping other content the same."""
registry_data = {
"GB0000": [
{
"Gb_type": "testing",
"Context": "testing",
"Explanation": "testing",
"Hints": ["testing"],
}
]
}
with open(self.registry_path, "w") as f:
json.dump(registry_data, f, indent=2)
renamed_callsite_content = """from torch._dynamo.exc import unimplemented_v2
def test(self):
unimplemented_v2(gb_type="renamed_testing", context="testing", explanation="testing", hints=["testing"])
"""
with open(self.callsite_file, "w") as f:
f.write(renamed_callsite_content)
with open(self.registry_path) as f:
original_content = f.read()
messages = check_registry_sync(self.test_data_dir, self.registry_path)
expected_registry = {
"GB0000": [
{
"Gb_type": "renamed_testing",
"Context": "testing",
"Explanation": "testing",
"Hints": ["testing"],
},
{
"Gb_type": "testing",
"Context": "testing",
"Explanation": "testing",
"Hints": ["testing"],
},
]
}
expected_replacement = (
json.dumps(expected_registry, indent=2, ensure_ascii=False) + "\n"
)
expected_msg = LintMessage(
path=str(self.registry_path),
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.WARNING,
name="Registry sync needed",
original=original_content,
replacement=expected_replacement,
description="Registry sync needed (renamed 'testing''renamed_testing'). Run `lintrunner -a` to apply changes.",
)
self.assertEqual(messages, [expected_msg])
if messages and messages[0].replacement:
with open(self.registry_path, "w") as f:
f.write(messages[0].replacement)
messages_after_fix = check_registry_sync(self.test_data_dir, self.registry_path)
self.assertEqual(
len(messages_after_fix), 0, "Should have no messages after applying the fix"
)
def test_case3_content_change(self):
"""Test Case 3: Changing the content of an existing gb_type."""
registry_data = {
"GB0000": [
{
"Gb_type": "testing",
"Context": "old_context",
"Explanation": "old_explanation",
"Hints": ["old_hint"],
}
]
}
with open(self.registry_path, "w") as f:
json.dump(registry_data, f, indent=2)
updated_callsite_content = """from torch._dynamo.exc import unimplemented_v2
def test(self):
unimplemented_v2(gb_type="testing", context="new_context", explanation="new_explanation", hints=["new_hint"])
"""
with open(self.callsite_file, "w") as f:
f.write(updated_callsite_content)
with open(self.registry_path) as f:
original_content = f.read()
messages = check_registry_sync(self.test_data_dir, self.registry_path)
expected_registry = {
"GB0000": [
{
"Gb_type": "testing",
"Context": "new_context",
"Explanation": "new_explanation",
"Hints": ["new_hint"],
},
{
"Gb_type": "testing",
"Context": "old_context",
"Explanation": "old_explanation",
"Hints": ["old_hint"],
},
]
}
expected_replacement = (
json.dumps(expected_registry, indent=2, ensure_ascii=False) + "\n"
)
expected_msg = LintMessage(
path=str(self.registry_path),
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.WARNING,
name="Registry sync needed",
original=original_content,
replacement=expected_replacement,
description="Registry sync needed (). Run `lintrunner -a` to apply changes.",
)
self.assertEqual(messages, [expected_msg])
if messages and messages[0].replacement:
with open(self.registry_path, "w") as f:
f.write(messages[0].replacement)
messages_after_fix = check_registry_sync(self.test_data_dir, self.registry_path)
self.assertEqual(
len(messages_after_fix), 0, "Should have no messages after applying the fix"
)
def test_case4_no_changes(self):
"""Test Case 4: Ensuring no message is produced when the registry is in sync."""
registry_data = {
"GB0000": [
{
"Gb_type": "testing",
"Context": "testing",
"Explanation": "testing",
"Hints": ["testing"],
}
]
}
with open(self.registry_path, "w") as f:
json.dump(registry_data, f, indent=2)
messages = check_registry_sync(self.test_data_dir, self.registry_path)
self.assertEqual(
len(messages), 0, "Should have no messages when registry is already in sync"
)
def test_case5_new_gbid_on_full_change(self):
"""Test Case 5: A completely new entry should get a new GB ID."""
registry_data = {
"GB0000": [
{
"Gb_type": "original_testing",
"Context": "original_context",
"Explanation": "original_explanation",
"Hints": ["original_hint"],
}
]
}
with open(self.registry_path, "w") as f:
json.dump(registry_data, f, indent=2)
new_callsite_content = """from torch._dynamo.exc import unimplemented_v2
def test(self):
unimplemented_v2(
gb_type="completely_new_testing",
context="completely_new_context",
explanation="completely_new_explanation",
hints=["completely_new_hint"],
)
"""
with open(self.callsite_file, "w") as f:
f.write(new_callsite_content)
with open(self.registry_path) as f:
original_content = f.read()
messages = check_registry_sync(self.test_data_dir, self.registry_path)
expected_registry = {
"GB0000": [
{
"Gb_type": "original_testing",
"Context": "original_context",
"Explanation": "original_explanation",
"Hints": ["original_hint"],
}
],
"GB0001": [
{
"Gb_type": "completely_new_testing",
"Context": "completely_new_context",
"Explanation": "completely_new_explanation",
"Hints": ["completely_new_hint"],
}
],
}
expected_replacement = (
json.dumps(expected_registry, indent=2, ensure_ascii=False) + "\n"
)
expected_msg = LintMessage(
path=str(self.registry_path),
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.WARNING,
name="Registry sync needed",
original=original_content,
replacement=expected_replacement,
description="Registry sync needed (added 1 new gb_types). Run `lintrunner -a` to apply changes.",
)
self.assertEqual(messages, [expected_msg])
# Apply the fix and verify the file's final state
if messages and messages[0].replacement:
with open(self.registry_path, "w") as f:
f.write(messages[0].replacement)
messages_after_fix = check_registry_sync(self.test_data_dir, self.registry_path)
self.assertEqual(
len(messages_after_fix), 0, "Should have no messages after applying the fix"
)
def test_case6_dynamic_hints_from_variable(self):
"""Test Case 6: Verifies hints can be unpacked from an imported variable."""
mock_hints_file = self.test_data_dir / "graph_break_hints.py"
init_py = self.test_data_dir / "__init__.py"
try:
supportable_string = (
"It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you "
"encounter this graph break often and it is causing performance issues."
)
mock_hints_content = f'SUPPORTABLE = ["{supportable_string}"]'
with open(mock_hints_file, "w") as f:
f.write(mock_hints_content)
init_py.touch()
dynamic_hints_callsite = """from torch._dynamo.exc import unimplemented_v2
from torch._dynamo import graph_break_hints
def test(self):
unimplemented_v2(
gb_type="testing_with_graph_break_hints",
context="testing_with_graph_break_hints",
explanation="testing_with_graph_break_hints",
hints=[*graph_break_hints.SUPPORTABLE],
)
"""
with open(self.callsite_file, "w") as f:
f.write(dynamic_hints_callsite)
with open(self.registry_path) as f:
original_content = f.read()
messages = check_registry_sync(self.test_data_dir, self.registry_path)
expected_registry = {
"GB0000": [
{
"Gb_type": "testing_with_graph_break_hints",
"Context": "testing_with_graph_break_hints",
"Explanation": "testing_with_graph_break_hints",
"Hints": [supportable_string],
}
]
}
expected_replacement = (
json.dumps(expected_registry, indent=2, ensure_ascii=False) + "\n"
)
expected_msg = LintMessage(
path=str(self.registry_path),
line=None,
char=None,
code=LINTER_CODE,
severity=LintSeverity.WARNING,
name="Registry sync needed",
original=original_content,
replacement=expected_replacement,
description="Registry sync needed (added 1 new gb_types). Run `lintrunner -a` to apply changes.",
)
self.assertEqual(messages, [expected_msg])
if messages and messages[0].replacement:
with open(self.registry_path, "w") as f:
f.write(messages[0].replacement)
messages_after_fix = check_registry_sync(
self.test_data_dir, self.registry_path
)
self.assertEqual(
len(messages_after_fix),
0,
"Should have no messages after applying the fix",
)
finally:
mock_hints_file.unlink()
init_py.unlink()
if __name__ == "__main__":
unittest.main()

View File

@ -2510,4 +2510,4 @@
]
}
]
}
}