[dynamo] updated version of detecting any differences between PRs unimplemented_v2() callsites and graph_break_registry json file (#156237)

This PR runs an automatic check as part of dynamo_wrapped to make sure that all unimplemented_v2() callsites are mapped to the JSON file. It also fixes the issue of the CI not able to expand the hints, which was the root cause of the previous workflow failure. If not, the dev gets a message giving them instructions on how to update the JSON file. I also updated a dynamic gb_type to static and updated its test_error_message to include the GBID link for the graph break (before the link would not be produced).

Testing:
I ran the file with the argument to ensure all cases were covered, and also tested the test in CI.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156237
Approved by: https://github.com/williamwen42
This commit is contained in:
Sidharth
2025-06-23 23:38:46 -07:00
committed by PyTorch MergeBot
parent 2d7e6c6241
commit a00a697c17
5 changed files with 135 additions and 11 deletions

View File

@ -347,6 +347,7 @@ test_dynamo_wrapped_shard() {
exit 1
fi
python tools/dynamo/verify_dynamo.py
python tools/dynamo/gb_id_mapping.py verify
# PLEASE DO NOT ADD ADDITIONAL EXCLUDES HERE.
# Instead, use @skipIfTorchDynamo on your tests.
time python test/run_test.py --dynamo \
@ -361,6 +362,7 @@ test_dynamo_wrapped_shard() {
assert_git_not_dirty
}
test_inductor_distributed() {
# Smuggle a few multi-gpu tests here so that we don't have to request another large node
echo "Testing multi_gpu tests in test_torchinductor"

View File

@ -290,7 +290,7 @@ Backend compiler exception
return x + 1
For more details about this graph break, please visit: None""",
For more details about this graph break, please visit: https://compile-graph-break-site.vercel.app/gb/GB0219""",
)
def test_unsupported_builtin(self):
@ -896,7 +896,7 @@ Data-dependent branching
Developer debug context: attempted to jump with TensorVariable()
For more details about this graph break, please visit: None
For more details about this graph break, please visit: https://compile-graph-break-site.vercel.app/gb/GB0170
from user code:
File "test_error_messages.py", line N, in fn

View File

@ -112,12 +112,16 @@ def find_unimplemented_v2_calls(path):
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
if node.name == "unimplemented_v2":
if node.name in (
"unimplemented_v2",
"unimplemented_v2_with_warning",
):
continue
if (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Name)
and node.func.id == "unimplemented_v2"
and node.func.id
in ("unimplemented_v2", "unimplemented_v2_with_warning")
):
info = {
"gb_type": None,
@ -266,6 +270,83 @@ def cmd_update_gb_type(
return True
def test_verify_gb_id_mapping(dynamo_dir, registry_path):
"""
Verifies that all unimplemented_v2 calls in torch/_dynamo match entries in the registry.
"""
script_dir = Path(__file__).resolve().parent
dynamo_dir = script_dir.parent.parent / "torch" / "_dynamo"
registry_path = (
script_dir.parent.parent / "torch" / "_dynamo" / "graph_break_registry.json"
)
python_files = list(dynamo_dir.glob("**/*.py"))
reg = load_registry(registry_path)
gb_type_to_entry = {entries[0]["Gb_type"]: entries[0] for _, entries in reg.items()}
mismatches = []
for file_path in python_files:
calls = find_unimplemented_v2_calls(file_path)
for call in calls:
gb_type = call["gb_type"]
if gb_type not in gb_type_to_entry:
mismatches.append((gb_type, file_path, "Not found in registry"))
continue
entry = gb_type_to_entry[gb_type]
if call["context"] != entry["Context"]:
mismatches.append((gb_type, file_path, "Context mismatch"))
elif call["explanation"] != entry["Explanation"]:
mismatches.append((gb_type, file_path, "Explanation mismatch"))
elif sorted(call["hints"]) != sorted(entry["Hints"]):
mismatches.append((gb_type, file_path, "Hints mismatch"))
if mismatches:
print(
"Found the unimplemented_v2 or unimplemented_v2_with_warning calls below that "
"don't match the registry in graph_break_registry.json."
)
for gb_type, file_path, reason in mismatches:
print(f" - {gb_type} in {file_path}: {reason}")
print("Please update the registry using one of these commands:")
print(
"- If you added a new callsite: python tools/dynamo/gb_id_mapping.py add "
'"GB_TYPE" PATH_TO_FILE --additional-info "INFO"'
)
print(
" • GB_TYPE: The graph break type string used in your unimplemented_v2 call"
" • PATH_TO_FILE: Path to the file containing your new unimplemented_v2 call"
" • --additional-info: Optional extra information to include in the registry entry"
)
print(
'- If you updated an existing callsite: python tools/dynamo/gb_id_mapping.py update "GB_TYPE" PATH_TO_FILE '
'--new_gb_type "NEW_NAME" --additional-info "INFO"'
)
print(" • GB_TYPE: The original graph break type to update")
print(" • PATH_TO_FILE: Path to the file containing the updated call")
print(" • --new_gb_type: New name if you changed the graph break type")
print(" • --additional-info: Optional extra information to add")
print(
"- Recreate registry (Only do this if a complete reset is needed): python tools/dynamo/gb_id_mapping.py create"
)
print(
"If you have also wrote a test for the new graph break, please update the test as well "
"using EXPECTTEST_ACCEPT=1 so the message includes the respective webpage "
)
print(
"Note: If you've reset the entire registry file, you can force push to bypass this check."
)
return False
print("All unimplemented_v2 calls match the registry.")
return True
def create_registry(dynamo_dir, registry_path):
calls = find_unimplemented_v2_calls(dynamo_dir)
registry = {}
@ -293,9 +374,8 @@ def create_registry(dynamo_dir, registry_path):
def main():
script_dir = Path(__file__).resolve().parent
repo_root = script_dir.parent.parent
registry_path = script_dir / "graph_break_registry.json"
repo_root = Path(__file__).resolve().parent.parent.parent
registry_path = repo_root / "torch" / "_dynamo" / "graph_break_registry.json"
try:
import torch._dynamo
@ -339,6 +419,16 @@ def main():
"--additional-info", help="Optional additional information to include"
)
verify_parser = subparsers.add_parser(
"verify", help="Verify all unimplemented_v2 calls match registry entries"
)
verify_parser.add_argument(
"--dynamo_dir",
type=str,
default=default_dynamo_dir,
help="Directory to search for unimplemented_v2 calls.",
)
parser.add_argument(
"--registry-path",
type=str,
@ -366,6 +456,10 @@ def main():
)
if not success:
sys.exit(1)
elif args.command == "verify":
success = test_verify_gb_id_mapping(args.dynamo_dir, args.registry_path)
if not success:
sys.exit(1)
else:
parser.print_help()

View File

@ -168,7 +168,7 @@
{
"Gb_type": "Attempted to wrap torch._higher_order_ops.invoke_subgraph",
"Context": "",
"Explanation": "Directly using invoke_subgraph is not supported. Use mark_compile_region",
"Explanation": "Directly using invoke_subgraph is not supported. Use nested_compile_region",
"Hints": []
}
],
@ -222,7 +222,7 @@
{
"Gb_type": "Builtin `operator.*` comparison with constant `self` failed",
"Context": "call_method {self} {name} {args} {kwargs}",
"Explanation": "\"Failed to compare {self} with {other}, because {other} is not a Python constant or its mutation check fails.\"",
"Explanation": "\"Failed to compare {self} with {other}, \" + f\"because {other} is not a Python constant or its mutation check fails.\"",
"Hints": []
}
],
@ -1658,6 +1658,21 @@
}
],
"GB0170": [
{
"Gb_type": "Data-dependent branching",
"Context": "attempted to jump with {value}",
"Explanation": "_explanation",
"Hints": [
"Use `torch.cond` to express dynamic control flow.",
"This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround."
]
},
{
"Gb_type": "Data-dependent branching",
"Context": "attempted to jump with {value}",
"Explanation": "_explanation",
"Hints": []
},
{
"Gb_type": "_gb_type",
"Context": "attempted to jump with {value}",
@ -2137,5 +2152,15 @@
"Explanation": "Dynamo does not support tracing builtin index() on a Tensor",
"Hints": []
}
],
"GB0219": [
{
"Gb_type": "Backend compiler exception",
"Context": "Backend: {name}\nException:{str(e)}\nTraceback:\n{self.root_tx.format_frame_summary()}",
"Explanation": "Backend compiler `{name}` failed with {str(e)}. Adding a graph break.",
"Hints": [
"Report an issue to the backend compiler repo."
]
}
]
}

View File

@ -813,10 +813,13 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
self.jump(inst)
else:
unimplemented_v2(
gb_type=_gb_type,
gb_type="Data-dependent branching",
context=f"attempted to jump with {value}",
explanation=_explanation,
hints=_hints,
hints=[
*graph_break_hints.FUNDAMENTAL,
"Use `torch.cond` to express dynamic control flow.",
],
)
return inner