mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
restore CSE'd node metadata in runtime asserts pass (#134516)
Adds val, and optionally stack_trace & nn_module_stack metadata back to SymInt compute nodes that we CSE, with a hook on `graph.create_node()`. Not sure if there's other metadata we want to populate here? Pull Request resolved: https://github.com/pytorch/pytorch/pull/134516 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
9f00317997
commit
1dfb105239
@ -423,7 +423,6 @@ class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
|
||||
opt_model = torch.compile(dynamic=True)(model)
|
||||
opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
|
||||
|
||||
@unittest.expectedFailure # https://github.com/pytorch/pytorch/issues/130534"
|
||||
@config.patch(optimize_ddp=True, capture_dynamic_output_shape_ops=True)
|
||||
def test_unbacked_symbol_splitting_no_binding(self):
|
||||
class Model(nn.Module):
|
||||
|
@ -1,4 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import logging
|
||||
import operator
|
||||
import sys
|
||||
@ -43,6 +44,18 @@ def _get_example_value(node: fx.Node) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _get_example_value_key(node: fx.Node) -> Optional[str]:
|
||||
"""
|
||||
actually just run this once at start of pass, based on first node, and constantly use that.
|
||||
"""
|
||||
if "example_value" in node.meta:
|
||||
return "example_value"
|
||||
elif "val" in node.meta:
|
||||
return "val"
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _get_sym_val(node: fx.Node) -> Optional["sympy.Expr"]:
|
||||
val = _get_example_value(node)
|
||||
if isinstance(val, py_sym_types):
|
||||
@ -92,6 +105,7 @@ def insert_deferred_runtime_asserts(
|
||||
# Import sympy locally
|
||||
import sympy
|
||||
|
||||
from torch._export.passes._node_metadata_hook import _set_node_metadata_hook
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
_has_uninterpretable_sympy_function,
|
||||
CallMethodKey,
|
||||
@ -145,6 +159,30 @@ def insert_deferred_runtime_asserts(
|
||||
)
|
||||
)
|
||||
|
||||
# Figure out what key to use, val or example_value
|
||||
val_key = "val"
|
||||
for node in graph.nodes:
|
||||
if "example_value" in node.meta:
|
||||
val_key = "example_value"
|
||||
break
|
||||
elif "val" in node.meta:
|
||||
break
|
||||
|
||||
def _node_metadata_hook(
|
||||
node: torch.fx.Node,
|
||||
stack_trace: Optional[str] = None,
|
||||
nn_module_stack: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
fake_args = [
|
||||
_get_example_value(arg) if isinstance(arg, torch.fx.Node) else arg
|
||||
for arg in node.args
|
||||
]
|
||||
node.meta[val_key] = node.target(*fake_args) # type: ignore[operator]
|
||||
if stack_trace is not None:
|
||||
node.meta["stack_trace"] = stack_trace
|
||||
if nn_module_stack is not None:
|
||||
node.meta["nn_module_stack"] = nn_module_stack
|
||||
|
||||
# Track asserts/checks we've added
|
||||
added_asserts: Set[sympy.Expr] = set()
|
||||
constrained_unbacked_symbols: Set[sympy.Symbol] = set()
|
||||
@ -247,7 +285,8 @@ def insert_deferred_runtime_asserts(
|
||||
and isinstance(s := symint.node.expr, sympy.Symbol)
|
||||
and s not in expr_to_proxy
|
||||
):
|
||||
expr_to_proxy[s] = fx.Proxy(cb())
|
||||
with _set_node_metadata_hook(gm, _node_metadata_hook):
|
||||
expr_to_proxy[s] = fx.Proxy(cb())
|
||||
log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
|
||||
|
||||
match_symbol(example_value, lambda: node)
|
||||
@ -331,7 +370,15 @@ def insert_deferred_runtime_asserts(
|
||||
if _is_intermediate_tensor_sym_call(
|
||||
node
|
||||
): # reify from input shapes
|
||||
expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr) # type: ignore[arg-type]
|
||||
with _set_node_metadata_hook(
|
||||
gm,
|
||||
functools.partial(
|
||||
_node_metadata_hook,
|
||||
stack_trace=node.meta.get("stack_trace"),
|
||||
nn_module_stack=node.meta.get("nn_module_stack"),
|
||||
),
|
||||
):
|
||||
expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr) # type: ignore[arg-type]
|
||||
# won't try DCE-ing tensor compute here
|
||||
hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type]
|
||||
node.replace_all_uses_with(hash_node)
|
||||
@ -436,7 +483,8 @@ def insert_deferred_runtime_asserts(
|
||||
raise AssertionError(f"unrecognized keypath {keypath}")
|
||||
|
||||
if s not in expr_to_proxy:
|
||||
expr_to_proxy[s] = fx.Proxy(go(node, keypath))
|
||||
with _set_node_metadata_hook(gm, _node_metadata_hook):
|
||||
expr_to_proxy[s] = fx.Proxy(go(node, keypath))
|
||||
log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
|
||||
|
||||
for i0 in defs:
|
||||
@ -519,26 +567,34 @@ def insert_deferred_runtime_asserts(
|
||||
# TODO(pianpwk): calling sym_constrain_range_for_size or adding bound asserts
|
||||
# raises AOTAutograd errors on cast_symbool_to_symint_guardless
|
||||
|
||||
if (min_val := convert(vr.lower)) is not None:
|
||||
ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node
|
||||
graph.call_function(
|
||||
torch.ops.aten._assert_scalar.default,
|
||||
(
|
||||
ge,
|
||||
f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'",
|
||||
),
|
||||
)
|
||||
added_asserts.add(i0 >= min_val)
|
||||
if (max_val := convert(vr.upper)) is not None:
|
||||
le = _sympy_interp(expr_to_proxy, i0 <= max_val).node
|
||||
graph.call_function(
|
||||
torch.ops.aten._assert_scalar.default,
|
||||
(
|
||||
le,
|
||||
f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'",
|
||||
),
|
||||
)
|
||||
added_asserts.add(i0 <= max_val)
|
||||
with _set_node_metadata_hook(
|
||||
gm,
|
||||
functools.partial(
|
||||
_node_metadata_hook,
|
||||
stack_trace=node.meta.get("stack_trace"),
|
||||
nn_module_stack=node.meta.get("nn_module_stack"),
|
||||
),
|
||||
):
|
||||
if (min_val := convert(vr.lower)) is not None:
|
||||
ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node
|
||||
graph.call_function(
|
||||
torch.ops.aten._assert_scalar.default,
|
||||
(
|
||||
ge,
|
||||
f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'",
|
||||
),
|
||||
)
|
||||
added_asserts.add(i0 >= min_val)
|
||||
if (max_val := convert(vr.upper)) is not None:
|
||||
le = _sympy_interp(expr_to_proxy, i0 <= max_val).node
|
||||
graph.call_function(
|
||||
torch.ops.aten._assert_scalar.default,
|
||||
(
|
||||
le,
|
||||
f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'",
|
||||
),
|
||||
)
|
||||
added_asserts.add(i0 <= max_val)
|
||||
|
||||
constrained_unbacked_symbols.add(i0)
|
||||
add_runtime_asserts(ras)
|
||||
|
Reference in New Issue
Block a user