[aoti-fx] Add meta["val"] metadata (#161019)

Summary: Added a `_set_node_metadata_hook` which automatically adds node.meta["val"] to every new node that gets created under this context.

Test Plan:
` buck2 test //mtia/host_runtime/afg/tests:test_dynamic_shapes_advanced_ops`
https://www.internalfb.com/buck2/866439a2-2ba6-42d1-8e43-508d60456e2e

`buck2 test //mtia/host_runtime/afg/tests:test_dynamic_shapes_basic_ops`
https://www.internalfb.com/intern/testinfra/testrun/11540474149662857

Rollback Plan:

Differential Revision: D80579336

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161019
Approved by: https://github.com/blaine-rister
This commit is contained in:
Angela Yi
2025-08-21 16:45:41 +00:00
committed by PyTorch MergeBot
parent a6401cb5aa
commit 3dacaf0e1e
3 changed files with 120 additions and 97 deletions

View File

@ -557,6 +557,13 @@ class AOTFxirTestCase(InductorTestCase):
)
self.assertTrue(torch.allclose(model(*inp), gm(*inp)))
for node in gm.graph.nodes:
if (
node.op == "call_function"
and node.target != triton_kernel_wrapper_mutation
):
self.assertTrue(node.meta.get("val", None) is not None)
def test_aoti_fx_add(self):
class M(torch.nn.Module):
def forward(self, x, y):

View File

@ -3,6 +3,9 @@ import contextlib
from typing import Any, Optional
import torch
import torch.utils._pytree as pytree
from torch._dispatch.python import enable_python_dispatcher
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.graph_module import GraphModule
@ -10,7 +13,9 @@ _EMPTY_NN_MODULE_STACK_KEY = "_empty_nn_module_stack_from_metadata_hook"
def _node_metadata_hook(
node: torch.fx.Node, metadata: Optional[dict[str, Any]] = None
node: torch.fx.Node,
metadata: Optional[dict[str, Any]] = None,
fake_mode: Optional[FakeTensorMode] = None,
) -> None:
"""
Hook for adding the appropriate metadata to nodes that are created during a
@ -27,11 +32,11 @@ def _node_metadata_hook(
that nodes being added are only call_function nodes, and copies over the
first argument node's nn_module_stack.
"""
assert node.op == "call_function" and callable(node.target)
fake_mode = fake_mode or contextlib.nullcontext()
arg_meta = [arg.meta for arg in node.args if isinstance(arg, torch.fx.Node)]
assert len(arg_meta) >= 1
arg_meta = arg_meta[0]
assert node.op == "call_function" and callable(node.target), (
f"node: {node}, target: {node.target}"
)
if (
isinstance(node.target, torch._ops.OpOverload)
@ -39,34 +44,48 @@ def _node_metadata_hook(
):
node.meta["val"] = None
else:
fake_args = [
arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg
for arg in node.args
]
fake_res = node.target(*fake_args)
fake_args, fake_kwargs = pytree.tree_map_only(
torch.fx.Node, lambda arg: arg.meta["val"], (node.args, node.kwargs)
)
with fake_mode, enable_python_dispatcher():
fake_res = node.target(*fake_args, **fake_kwargs)
node.meta["val"] = fake_res
node.meta["nn_module_stack"] = arg_meta.get(
"nn_module_stack",
{
_EMPTY_NN_MODULE_STACK_KEY: (
_EMPTY_NN_MODULE_STACK_KEY,
_EMPTY_NN_MODULE_STACK_KEY,
)
},
)
node.meta["torch_fn"] = (
f"{node.target.__name__}_0",
f"{node.target.__class__.__name__}.{node.target.__name__}",
)
# Hook specified metadata takes precedence over all previously set
# metadata, so this goes last
if metadata is not None:
for k, v in metadata.items():
node.meta[k] = v
# Copy over metadata from argument nodes
arg_meta = [
arg.meta
for arg in pytree.tree_flatten((node.args, node.kwargs))[0]
if isinstance(arg, torch.fx.Node)
]
if len(arg_meta) == 0:
return
arg_meta = arg_meta[0]
node.meta["nn_module_stack"] = node.meta.get(
"nn_module_stack",
arg_meta.get(
"nn_module_stack",
{
_EMPTY_NN_MODULE_STACK_KEY: (
_EMPTY_NN_MODULE_STACK_KEY,
_EMPTY_NN_MODULE_STACK_KEY,
)
},
),
)
node.meta["torch_fn"] = node.meta.get(
"torch_fn",
(
f"{node.target.__name__}_0",
f"{node.target.__class__.__name__}.{node.target.__name__}",
),
)
@contextlib.contextmanager
def _set_node_metadata_hook(gm: torch.fx.GraphModule, f):

View File

@ -10,6 +10,11 @@ from typing import Any, Callable, Optional, Union
import sympy
import torch
from torch._export.passes._node_metadata_hook import (
_node_metadata_hook,
_set_node_metadata_hook,
)
from torch._export.utils import _detect_fake_mode_from_gm
from torch._higher_order_ops.triton_kernel_wrap import (
TraceableTritonKernelWrapper,
tracing_triton_hopifier_singleton,
@ -198,11 +203,6 @@ class FxConverter:
device=device,
)
def _create_meta_from_buffer(
self, node: torch.fx.Node, buffer: CodegenBuffer
) -> None:
node.meta["val"] = buffer.get_example()
def _create_as_strided(
self,
input_node: torch.fx.Node,
@ -266,31 +266,6 @@ class FxConverter:
Converts graph inputs to FX placeholders.
"""
def _codegen_symbol(
sym_or_exp: Union[sympy.Symbol, sympy.Expr],
base_node: torch.fx.Node,
target: torch._ops.OpOverload,
dim: int,
) -> None:
if isinstance(sym_or_exp, sympy.Symbol):
buffer = SymbolBuffer(sym_or_exp)
if buffer.get_name() in self.buffer_to_node:
return
size_node = self.gm.graph.call_function(target, (base_node, dim))
size_proxy = torch.fx.Proxy(size_node, tracer=self.tracer)
self._create_meta_from_buffer(size_node, buffer)
self._record_allocation(buffer, size_node)
self.expr_to_proxy[sym_or_exp] = size_proxy
elif isinstance(sym_or_exp, sympy.Integer):
return
elif isinstance(sym_or_exp, sympy.Expr):
self._sympy_interp(sym_or_exp)
for node in V.graph.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr]
name = node.name
if name in V.graph.graph_inputs:
@ -303,18 +278,48 @@ class FxConverter:
else self._get_buffer(ir_node)
)
placeholder_node = self.gm.graph.placeholder(buffer.get_name())
self._create_meta_from_buffer(placeholder_node, buffer)
placeholder_node.meta["val"] = buffer.get_example()
self._record_allocation(buffer, placeholder_node)
# not sure if this is needed...
if isinstance(ir_node, (sympy.Symbol)):
placeholder_proxy = torch.fx.Proxy(
placeholder_node, tracer=self.tracer
)
self.expr_to_proxy[ir_node] = placeholder_proxy
elif V.aot_compilation:
# Create dummy input nodes to match the input signature
self.gm.graph.placeholder(name)
# Generate nodes for dynamic input sizes/strides.
def _generate_graph_input_shapes(self) -> None:
"""
Generate nodes creating symints that are part of graph input
shape/strides.
"""
def _codegen_symbol(
sym_or_exp: Union[sympy.Symbol, sympy.Expr],
base_node: torch.fx.Node,
target: torch._ops.OpOverload,
dim: int,
) -> None:
if isinstance(sym_or_exp, sympy.Symbol):
if sym_or_exp in self.expr_to_proxy:
return
size_node = self.gm.graph.call_function(target, (base_node, dim))
size_proxy = torch.fx.Proxy(size_node, tracer=self.tracer)
self.expr_to_proxy[sym_or_exp] = size_proxy
elif isinstance(sym_or_exp, sympy.Integer):
return
elif isinstance(sym_or_exp, sympy.Expr):
self._sympy_interp(sym_or_exp)
for node in V.graph.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr]
name = node.name
if name in V.graph.graph_inputs:
ir_node = V.graph.graph_inputs[name]
if isinstance(ir_node, ir.TensorBox):
buffer = self._get_buffer(ir_node)
placeholder_node = self.buffer_to_node[buffer.get_name()]
for dim, size in enumerate(ir_node.get_size()):
_codegen_symbol(
size, placeholder_node, torch.ops.aten.sym_size.int, dim
@ -324,10 +329,6 @@ class FxConverter:
stride, placeholder_node, torch.ops.aten.sym_stride.int, dim
)
elif V.aot_compilation:
# Create dummy input nodes to match the input signature
self.gm.graph.placeholder(name)
def _generate_graph_constants(self) -> None:
for name, value in V.graph.constants.items():
node = self.gm.graph.get_attr(name)
@ -397,24 +398,32 @@ class FxConverter:
self._generate_graph_inputs()
self._generate_graph_constants()
# Generate FX IR from Wrapper IR lines.
for line in self.lines:
if isinstance(line, WrapperLine):
line.codegen_fx(self)(line)
elif isinstance(line, LineContext):
# Ignore line context in FX IR.
pass
else:
raise NotImplementedError(
textwrap.dedent(
f"""
Found line of unrecognized type '{type(line)}':
'{line}'
fake_mode = _detect_fake_mode_from_gm(self.gm)
FX conversion only supports Wrapper IR lines.
"""
with _set_node_metadata_hook(
self.gm,
functools.partial(_node_metadata_hook, fake_mode=fake_mode),
):
self._generate_graph_input_shapes()
# Generate FX IR from Wrapper IR lines.
for line in self.lines:
if isinstance(line, WrapperLine):
line.codegen_fx(self)(line)
elif isinstance(line, LineContext):
# Ignore line context in FX IR.
pass
else:
raise NotImplementedError(
textwrap.dedent(
f"""
Found line of unrecognized type '{type(line)}':
'{line}'
FX conversion only supports Wrapper IR lines.
"""
)
)
)
self._generate_output()
self.gm.recompile()
@ -512,7 +521,6 @@ class FxConverter:
)
assert name
node.name = name
self._create_meta_from_buffer(node, buffer)
self._record_allocation(buffer, node)
def _generate_comment(self, line: WrapperLine) -> None:
@ -583,7 +591,6 @@ class FxConverter:
# Map ReinterpretView to as_strided.
result_node = self._create_as_strided(input_node, size, stride, offset)
result_node.name = name
result_node.meta["val"] = layout.get_example()
self._record_allocation(result_buffer, result_node)
def _generate_reuse(self, line: WrapperLine) -> None:
@ -606,7 +613,6 @@ class FxConverter:
or old.get_offset() != offset
):
result_node = self._create_as_strided(old_node, size, stride, offset)
self._create_meta_from_buffer(result_node, new)
self._record_allocation(new, result_node)
@ -635,7 +641,6 @@ class FxConverter:
idx = inds[0]
node = self.gm.graph.call_function(operator.getitem, args=(arg_node, idx))
node.meta["val"] = arg_node.meta["val"][idx]
node.name = line.result_name
self.buffer_to_node[line.result_name] = node
@ -778,14 +783,6 @@ class FxConverter:
fx_node.name = result_buffer
self.buffer_to_node[result_buffer] = fx_node
arg_tensors = [
arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg
for arg in args
]
# Run the operation to propagate metadata.
fx_node.meta["val"] = op(*arg_tensors, **kwargs)
def _generate_kernel_call(self, line: WrapperLine) -> None:
assert isinstance(line, KernelCallLine)
if not line.triton: