restore CSE'd node metadata in runtime asserts pass (#134516)

Adds val, and optionally stack_trace & nn_module_stack metadata back to SymInt compute nodes that we CSE, with a hook on `graph.create_node()`. Not sure if there's other metadata we want to populate here?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134516
Approved by: https://github.com/ezyang
This commit is contained in:
Pian Pawakapan
2024-09-04 05:56:28 +00:00
committed by PyTorch MergeBot
parent 9f00317997
commit 1dfb105239
2 changed files with 79 additions and 24 deletions

View File

@ -423,7 +423,6 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
opt_model = torch.compile(dynamic=True)(model)
opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
@unittest.expectedFailure # https://github.com/pytorch/pytorch/issues/130534"
@config.patch(optimize_ddp=True, capture_dynamic_output_shape_ops=True)
def test_unbacked_symbol_splitting_no_binding(self):
class Model(nn.Module):

View File

@ -1,4 +1,5 @@
# mypy: allow-untyped-defs
import functools
import logging
import operator
import sys
@ -43,6 +44,18 @@ def _get_example_value(node: fx.Node) -> Optional[str]:
return None
def _get_example_value_key(node: fx.Node) -> Optional[str]:
"""
actually just run this once at start of pass, based on first node, and constantly use that.
"""
if "example_value" in node.meta:
return "example_value"
elif "val" in node.meta:
return "val"
else:
return None
def _get_sym_val(node: fx.Node) -> Optional["sympy.Expr"]:
val = _get_example_value(node)
if isinstance(val, py_sym_types):
@ -92,6 +105,7 @@ def insert_deferred_runtime_asserts(
# Import sympy locally
import sympy
from torch._export.passes._node_metadata_hook import _set_node_metadata_hook
from torch.fx.experimental.symbolic_shapes import (
_has_uninterpretable_sympy_function,
CallMethodKey,
@ -145,6 +159,30 @@ def insert_deferred_runtime_asserts(
)
)
# Figure out what key to use, val or example_value
val_key = "val"
for node in graph.nodes:
if "example_value" in node.meta:
val_key = "example_value"
break
elif "val" in node.meta:
break
def _node_metadata_hook(
node: torch.fx.Node,
stack_trace: Optional[str] = None,
nn_module_stack: Optional[Dict[str, Any]] = None,
) -> None:
fake_args = [
_get_example_value(arg) if isinstance(arg, torch.fx.Node) else arg
for arg in node.args
]
node.meta[val_key] = node.target(*fake_args) # type: ignore[operator]
if stack_trace is not None:
node.meta["stack_trace"] = stack_trace
if nn_module_stack is not None:
node.meta["nn_module_stack"] = nn_module_stack
# Track asserts/checks we've added
added_asserts: Set[sympy.Expr] = set()
constrained_unbacked_symbols: Set[sympy.Symbol] = set()
@ -247,7 +285,8 @@ def insert_deferred_runtime_asserts(
and isinstance(s := symint.node.expr, sympy.Symbol)
and s not in expr_to_proxy
):
expr_to_proxy[s] = fx.Proxy(cb())
with _set_node_metadata_hook(gm, _node_metadata_hook):
expr_to_proxy[s] = fx.Proxy(cb())
log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
match_symbol(example_value, lambda: node)
@ -331,7 +370,15 @@ def insert_deferred_runtime_asserts(
if _is_intermediate_tensor_sym_call(
node
): # reify from input shapes
expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr) # type: ignore[arg-type]
with _set_node_metadata_hook(
gm,
functools.partial(
_node_metadata_hook,
stack_trace=node.meta.get("stack_trace"),
nn_module_stack=node.meta.get("nn_module_stack"),
),
):
expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr) # type: ignore[arg-type]
# won't try DCE-ing tensor compute here
hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type]
node.replace_all_uses_with(hash_node)
@ -436,7 +483,8 @@ def insert_deferred_runtime_asserts(
raise AssertionError(f"unrecognized keypath {keypath}")
if s not in expr_to_proxy:
expr_to_proxy[s] = fx.Proxy(go(node, keypath))
with _set_node_metadata_hook(gm, _node_metadata_hook):
expr_to_proxy[s] = fx.Proxy(go(node, keypath))
log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
for i0 in defs:
@ -519,26 +567,34 @@ def insert_deferred_runtime_asserts(
# TODO(pianpwk): calling sym_constrain_range_for_size or adding bound asserts
# raises AOTAutograd errors on cast_symbool_to_symint_guardless
if (min_val := convert(vr.lower)) is not None:
ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node
graph.call_function(
torch.ops.aten._assert_scalar.default,
(
ge,
f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'",
),
)
added_asserts.add(i0 >= min_val)
if (max_val := convert(vr.upper)) is not None:
le = _sympy_interp(expr_to_proxy, i0 <= max_val).node
graph.call_function(
torch.ops.aten._assert_scalar.default,
(
le,
f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'",
),
)
added_asserts.add(i0 <= max_val)
with _set_node_metadata_hook(
gm,
functools.partial(
_node_metadata_hook,
stack_trace=node.meta.get("stack_trace"),
nn_module_stack=node.meta.get("nn_module_stack"),
),
):
if (min_val := convert(vr.lower)) is not None:
ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node
graph.call_function(
torch.ops.aten._assert_scalar.default,
(
ge,
f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'",
),
)
added_asserts.add(i0 >= min_val)
if (max_val := convert(vr.upper)) is not None:
le = _sympy_interp(expr_to_proxy, i0 <= max_val).node
graph.call_function(
torch.ops.aten._assert_scalar.default,
(
le,
f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'",
),
)
added_asserts.add(i0 <= max_val)
constrained_unbacked_symbols.add(i0)
add_runtime_asserts(ras)