[dynamo] Fix graph break registry loading in fbcode (#161550)

Summary: Add `torch/_dynamo/graph_break_registry.json` as an internal dependency. Minor related fixes.

Test Plan:
Test on OSS.

Rollback Plan:

Differential Revision: D81078973

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161550
Approved by: https://github.com/Lucaskabela, https://github.com/anijain2305
This commit is contained in:
William Wen
2025-08-27 19:25:15 +00:00
committed by PyTorch MergeBot
parent 443452ca2f
commit 2efcf9d081
2 changed files with 33 additions and 20 deletions

View File

@ -47,7 +47,7 @@ class GenericCtxMgr:
pass
class GraphBreakMessagesTest(LoggingTestCase):
class ErrorMessagesTest(LoggingTestCase):
def test_dynamic_shape_operator(self):
def fn():
return torch.nonzero(torch.rand([10, 10]))
@ -783,12 +783,12 @@ from user code:
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
"""\
Reconstruction failure
Explanation: Dynamo has no bytecode reconstruction implemented for sourceless variable UserMethodVariable(<function GraphBreakMessagesTest.test_reconstruction_failure.<locals>.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)).
Explanation: Dynamo has no bytecode reconstruction implemented for sourceless variable UserMethodVariable(<function ErrorMessagesTest.test_reconstruction_failure.<locals>.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)).
Hint: If Dynamo is attempting to trace a return statement and your code is attempting to return a variable that Dynamo cannot reconstruct, then remove it from the return statement.
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
Hint: Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have reconstruction rules may be fundamentally unreconstructable.
Developer debug context: UserMethodVariable(<function GraphBreakMessagesTest.test_reconstruction_failure.<locals>.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo))
Developer debug context: UserMethodVariable(<function ErrorMessagesTest.test_reconstruction_failure.<locals>.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo))
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0092.html
@ -839,12 +839,12 @@ User code traceback:
post_munge(munge_exc(records[1].exc_info[1], suppress_suffix=True, skip=0)),
"""\
Reconstruction failure
Explanation: Dynamo has no bytecode reconstruction implemented for sourceless variable UserMethodVariable(<function GraphBreakMessagesTest.test_reconstruction_failure_gb.<locals>.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)).
Explanation: Dynamo has no bytecode reconstruction implemented for sourceless variable UserMethodVariable(<function ErrorMessagesTest.test_reconstruction_failure_gb.<locals>.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)).
Hint: If Dynamo is attempting to trace a return statement and your code is attempting to return a variable that Dynamo cannot reconstruct, then remove it from the return statement.
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
Hint: Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have reconstruction rules may be fundamentally unreconstructable.
Developer debug context: UserMethodVariable(<function GraphBreakMessagesTest.test_reconstruction_failure_gb.<locals>.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo))
Developer debug context: UserMethodVariable(<function ErrorMessagesTest.test_reconstruction_failure_gb.<locals>.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo))
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0092.html
@ -1298,10 +1298,10 @@ call to a lru_cache wrapped function at: test_error_messages.py:N
lambda: outer(f, torch.randn(3)),
"""\
Skip calling `torch.compiler.disable()`d function
Explanation: Skip calling function `<function GraphBreakMessagesTest.test_disable_message.<locals>.f at 0xmem_addr>` since it was wrapped with `torch.compiler.disable` (reason: None)
Explanation: Skip calling function `<function ErrorMessagesTest.test_disable_message.<locals>.f at 0xmem_addr>` since it was wrapped with `torch.compiler.disable` (reason: None)
Hint: Remove the `torch.compiler.disable` call
Developer debug context: <function GraphBreakMessagesTest.test_disable_message.<locals>.f at 0xmem_addr>
Developer debug context: <function ErrorMessagesTest.test_disable_message.<locals>.f at 0xmem_addr>
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0098.html
@ -1320,10 +1320,10 @@ from user code:
lambda: outer(g, torch.randn(3)),
"""\
Skip calling `torch.compiler.disable()`d function
Explanation: Skip calling function `<function GraphBreakMessagesTest.test_disable_message.<locals>.g at 0xmem_addr>` since it was wrapped with `torch.compiler.disable` (reason: test message)
Explanation: Skip calling function `<function ErrorMessagesTest.test_disable_message.<locals>.g at 0xmem_addr>` since it was wrapped with `torch.compiler.disable` (reason: test message)
Hint: Remove the `torch.compiler.disable` call
Developer debug context: <function GraphBreakMessagesTest.test_disable_message.<locals>.g at 0xmem_addr>
Developer debug context: <function ErrorMessagesTest.test_disable_message.<locals>.g at 0xmem_addr>
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0098.html

View File

@ -39,6 +39,7 @@ from traceback import extract_stack, format_exc, format_list, StackSummary
from typing import Any, NoReturn, Optional, TYPE_CHECKING
import torch._guards
from torch._utils_internal import get_file_path_2
from . import config
from .utils import counters
@ -512,18 +513,29 @@ def format_graph_break_message(
@lru_cache(maxsize=1)
def _load_graph_break_registry() -> dict[str, Any]:
def _load_gb_type_to_gb_id_map() -> dict[str, Any]:
"""
Loads the graph break registry from JSON file with caching.
Loads the gb_type to gb_id map from the graph break registry from JSON file with caching.
Includes historical gb_type (mapping behavior of duplicate gb_types with different gb_ids is undefined).
"""
try:
script_dir = Path(__file__).resolve().parent
registry_path = script_dir / "graph_break_registry.json"
with registry_path.open() as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
registry_path = get_file_path_2(
"", str(script_dir), "graph_break_registry.json"
)
with open(registry_path) as f:
registry = json.load(f)
except Exception as e:
log.error("Error accessing the registry file: %s", e)
return {}
registry = {}
mapping = {}
for k, v in registry.items():
for entry in v:
mapping[entry["Gb_type"]] = k
return mapping
def get_gbid_documentation_link(gb_type: str) -> Optional[str]:
@ -540,11 +552,12 @@ def get_gbid_documentation_link(gb_type: str) -> Optional[str]:
"https://meta-pytorch.github.io/compile-graph-break-site/gb/" # @lint-ignore
)
registry = _load_graph_break_registry()
gb_type_to_gb_id_map = _load_gb_type_to_gb_id_map()
for k, v in registry.items():
if v and v[0].get("Gb_type") == gb_type:
return f"{GRAPH_BREAK_SITE_URL}gb{k.lstrip('GB')}.html"
if gb_type in gb_type_to_gb_id_map:
return (
f"{GRAPH_BREAK_SITE_URL}gb{gb_type_to_gb_id_map[gb_type].lstrip('GB')}.html"
)
return None