mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[aoti-fx] Add meta["val"] metadata (#161019)
Summary: Added a `_set_node_metadata_hook` which automatically adds node.meta["val"] to every new node that gets created under this context. Test Plan: ` buck2 test //mtia/host_runtime/afg/tests:test_dynamic_shapes_advanced_ops` https://www.internalfb.com/buck2/866439a2-2ba6-42d1-8e43-508d60456e2e `buck2 test //mtia/host_runtime/afg/tests:test_dynamic_shapes_basic_ops` https://www.internalfb.com/intern/testinfra/testrun/11540474149662857 Rollback Plan: Differential Revision: D80579336 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161019 Approved by: https://github.com/blaine-rister
This commit is contained in:
committed by
PyTorch MergeBot
parent
a6401cb5aa
commit
3dacaf0e1e
@ -557,6 +557,13 @@ class AOTFxirTestCase(InductorTestCase):
|
||||
)
|
||||
self.assertTrue(torch.allclose(model(*inp), gm(*inp)))
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target != triton_kernel_wrapper_mutation
|
||||
):
|
||||
self.assertTrue(node.meta.get("val", None) is not None)
|
||||
|
||||
def test_aoti_fx_add(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
|
@ -3,6 +3,9 @@ import contextlib
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.graph_module import GraphModule
|
||||
|
||||
|
||||
@ -10,7 +13,9 @@ _EMPTY_NN_MODULE_STACK_KEY = "_empty_nn_module_stack_from_metadata_hook"
|
||||
|
||||
|
||||
def _node_metadata_hook(
|
||||
node: torch.fx.Node, metadata: Optional[dict[str, Any]] = None
|
||||
node: torch.fx.Node,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
fake_mode: Optional[FakeTensorMode] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Hook for adding the appropriate metadata to nodes that are created during a
|
||||
@ -27,11 +32,11 @@ def _node_metadata_hook(
|
||||
that nodes being added are only call_function nodes, and copies over the
|
||||
first argument node's nn_module_stack.
|
||||
"""
|
||||
assert node.op == "call_function" and callable(node.target)
|
||||
fake_mode = fake_mode or contextlib.nullcontext()
|
||||
|
||||
arg_meta = [arg.meta for arg in node.args if isinstance(arg, torch.fx.Node)]
|
||||
assert len(arg_meta) >= 1
|
||||
arg_meta = arg_meta[0]
|
||||
assert node.op == "call_function" and callable(node.target), (
|
||||
f"node: {node}, target: {node.target}"
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(node.target, torch._ops.OpOverload)
|
||||
@ -39,34 +44,48 @@ def _node_metadata_hook(
|
||||
):
|
||||
node.meta["val"] = None
|
||||
else:
|
||||
fake_args = [
|
||||
arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg
|
||||
for arg in node.args
|
||||
]
|
||||
fake_res = node.target(*fake_args)
|
||||
fake_args, fake_kwargs = pytree.tree_map_only(
|
||||
torch.fx.Node, lambda arg: arg.meta["val"], (node.args, node.kwargs)
|
||||
)
|
||||
with fake_mode, enable_python_dispatcher():
|
||||
fake_res = node.target(*fake_args, **fake_kwargs)
|
||||
node.meta["val"] = fake_res
|
||||
|
||||
node.meta["nn_module_stack"] = arg_meta.get(
|
||||
"nn_module_stack",
|
||||
{
|
||||
_EMPTY_NN_MODULE_STACK_KEY: (
|
||||
_EMPTY_NN_MODULE_STACK_KEY,
|
||||
_EMPTY_NN_MODULE_STACK_KEY,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
node.meta["torch_fn"] = (
|
||||
f"{node.target.__name__}_0",
|
||||
f"{node.target.__class__.__name__}.{node.target.__name__}",
|
||||
)
|
||||
|
||||
# Hook specified metadata takes precedence over all previously set
|
||||
# metadata, so this goes last
|
||||
if metadata is not None:
|
||||
for k, v in metadata.items():
|
||||
node.meta[k] = v
|
||||
|
||||
# Copy over metadata from argument nodes
|
||||
arg_meta = [
|
||||
arg.meta
|
||||
for arg in pytree.tree_flatten((node.args, node.kwargs))[0]
|
||||
if isinstance(arg, torch.fx.Node)
|
||||
]
|
||||
if len(arg_meta) == 0:
|
||||
return
|
||||
arg_meta = arg_meta[0]
|
||||
|
||||
node.meta["nn_module_stack"] = node.meta.get(
|
||||
"nn_module_stack",
|
||||
arg_meta.get(
|
||||
"nn_module_stack",
|
||||
{
|
||||
_EMPTY_NN_MODULE_STACK_KEY: (
|
||||
_EMPTY_NN_MODULE_STACK_KEY,
|
||||
_EMPTY_NN_MODULE_STACK_KEY,
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
node.meta["torch_fn"] = node.meta.get(
|
||||
"torch_fn",
|
||||
(
|
||||
f"{node.target.__name__}_0",
|
||||
f"{node.target.__class__.__name__}.{node.target.__name__}",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _set_node_metadata_hook(gm: torch.fx.GraphModule, f):
|
||||
|
@ -10,6 +10,11 @@ from typing import Any, Callable, Optional, Union
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._export.passes._node_metadata_hook import (
|
||||
_node_metadata_hook,
|
||||
_set_node_metadata_hook,
|
||||
)
|
||||
from torch._export.utils import _detect_fake_mode_from_gm
|
||||
from torch._higher_order_ops.triton_kernel_wrap import (
|
||||
TraceableTritonKernelWrapper,
|
||||
tracing_triton_hopifier_singleton,
|
||||
@ -198,11 +203,6 @@ class FxConverter:
|
||||
device=device,
|
||||
)
|
||||
|
||||
def _create_meta_from_buffer(
|
||||
self, node: torch.fx.Node, buffer: CodegenBuffer
|
||||
) -> None:
|
||||
node.meta["val"] = buffer.get_example()
|
||||
|
||||
def _create_as_strided(
|
||||
self,
|
||||
input_node: torch.fx.Node,
|
||||
@ -266,31 +266,6 @@ class FxConverter:
|
||||
Converts graph inputs to FX placeholders.
|
||||
"""
|
||||
|
||||
def _codegen_symbol(
|
||||
sym_or_exp: Union[sympy.Symbol, sympy.Expr],
|
||||
base_node: torch.fx.Node,
|
||||
target: torch._ops.OpOverload,
|
||||
dim: int,
|
||||
) -> None:
|
||||
if isinstance(sym_or_exp, sympy.Symbol):
|
||||
buffer = SymbolBuffer(sym_or_exp)
|
||||
|
||||
if buffer.get_name() in self.buffer_to_node:
|
||||
return
|
||||
|
||||
size_node = self.gm.graph.call_function(target, (base_node, dim))
|
||||
size_proxy = torch.fx.Proxy(size_node, tracer=self.tracer)
|
||||
|
||||
self._create_meta_from_buffer(size_node, buffer)
|
||||
self._record_allocation(buffer, size_node)
|
||||
self.expr_to_proxy[sym_or_exp] = size_proxy
|
||||
|
||||
elif isinstance(sym_or_exp, sympy.Integer):
|
||||
return
|
||||
|
||||
elif isinstance(sym_or_exp, sympy.Expr):
|
||||
self._sympy_interp(sym_or_exp)
|
||||
|
||||
for node in V.graph.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr]
|
||||
name = node.name
|
||||
if name in V.graph.graph_inputs:
|
||||
@ -303,18 +278,48 @@ class FxConverter:
|
||||
else self._get_buffer(ir_node)
|
||||
)
|
||||
placeholder_node = self.gm.graph.placeholder(buffer.get_name())
|
||||
self._create_meta_from_buffer(placeholder_node, buffer)
|
||||
placeholder_node.meta["val"] = buffer.get_example()
|
||||
self._record_allocation(buffer, placeholder_node)
|
||||
|
||||
# not sure if this is needed...
|
||||
if isinstance(ir_node, (sympy.Symbol)):
|
||||
placeholder_proxy = torch.fx.Proxy(
|
||||
placeholder_node, tracer=self.tracer
|
||||
)
|
||||
self.expr_to_proxy[ir_node] = placeholder_proxy
|
||||
elif V.aot_compilation:
|
||||
# Create dummy input nodes to match the input signature
|
||||
self.gm.graph.placeholder(name)
|
||||
|
||||
# Generate nodes for dynamic input sizes/strides.
|
||||
def _generate_graph_input_shapes(self) -> None:
|
||||
"""
|
||||
Generate nodes creating symints that are part of graph input
|
||||
shape/strides.
|
||||
"""
|
||||
|
||||
def _codegen_symbol(
|
||||
sym_or_exp: Union[sympy.Symbol, sympy.Expr],
|
||||
base_node: torch.fx.Node,
|
||||
target: torch._ops.OpOverload,
|
||||
dim: int,
|
||||
) -> None:
|
||||
if isinstance(sym_or_exp, sympy.Symbol):
|
||||
if sym_or_exp in self.expr_to_proxy:
|
||||
return
|
||||
|
||||
size_node = self.gm.graph.call_function(target, (base_node, dim))
|
||||
size_proxy = torch.fx.Proxy(size_node, tracer=self.tracer)
|
||||
|
||||
self.expr_to_proxy[sym_or_exp] = size_proxy
|
||||
|
||||
elif isinstance(sym_or_exp, sympy.Integer):
|
||||
return
|
||||
|
||||
elif isinstance(sym_or_exp, sympy.Expr):
|
||||
self._sympy_interp(sym_or_exp)
|
||||
|
||||
for node in V.graph.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr]
|
||||
name = node.name
|
||||
if name in V.graph.graph_inputs:
|
||||
ir_node = V.graph.graph_inputs[name]
|
||||
if isinstance(ir_node, ir.TensorBox):
|
||||
buffer = self._get_buffer(ir_node)
|
||||
placeholder_node = self.buffer_to_node[buffer.get_name()]
|
||||
|
||||
for dim, size in enumerate(ir_node.get_size()):
|
||||
_codegen_symbol(
|
||||
size, placeholder_node, torch.ops.aten.sym_size.int, dim
|
||||
@ -324,10 +329,6 @@ class FxConverter:
|
||||
stride, placeholder_node, torch.ops.aten.sym_stride.int, dim
|
||||
)
|
||||
|
||||
elif V.aot_compilation:
|
||||
# Create dummy input nodes to match the input signature
|
||||
self.gm.graph.placeholder(name)
|
||||
|
||||
def _generate_graph_constants(self) -> None:
|
||||
for name, value in V.graph.constants.items():
|
||||
node = self.gm.graph.get_attr(name)
|
||||
@ -397,24 +398,32 @@ class FxConverter:
|
||||
self._generate_graph_inputs()
|
||||
self._generate_graph_constants()
|
||||
|
||||
# Generate FX IR from Wrapper IR lines.
|
||||
for line in self.lines:
|
||||
if isinstance(line, WrapperLine):
|
||||
line.codegen_fx(self)(line)
|
||||
elif isinstance(line, LineContext):
|
||||
# Ignore line context in FX IR.
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
textwrap.dedent(
|
||||
f"""
|
||||
Found line of unrecognized type '{type(line)}':
|
||||
'{line}'
|
||||
fake_mode = _detect_fake_mode_from_gm(self.gm)
|
||||
|
||||
FX conversion only supports Wrapper IR lines.
|
||||
"""
|
||||
with _set_node_metadata_hook(
|
||||
self.gm,
|
||||
functools.partial(_node_metadata_hook, fake_mode=fake_mode),
|
||||
):
|
||||
self._generate_graph_input_shapes()
|
||||
|
||||
# Generate FX IR from Wrapper IR lines.
|
||||
for line in self.lines:
|
||||
if isinstance(line, WrapperLine):
|
||||
line.codegen_fx(self)(line)
|
||||
elif isinstance(line, LineContext):
|
||||
# Ignore line context in FX IR.
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
textwrap.dedent(
|
||||
f"""
|
||||
Found line of unrecognized type '{type(line)}':
|
||||
'{line}'
|
||||
|
||||
FX conversion only supports Wrapper IR lines.
|
||||
"""
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self._generate_output()
|
||||
self.gm.recompile()
|
||||
@ -512,7 +521,6 @@ class FxConverter:
|
||||
)
|
||||
assert name
|
||||
node.name = name
|
||||
self._create_meta_from_buffer(node, buffer)
|
||||
self._record_allocation(buffer, node)
|
||||
|
||||
def _generate_comment(self, line: WrapperLine) -> None:
|
||||
@ -583,7 +591,6 @@ class FxConverter:
|
||||
# Map ReinterpretView to as_strided.
|
||||
result_node = self._create_as_strided(input_node, size, stride, offset)
|
||||
result_node.name = name
|
||||
result_node.meta["val"] = layout.get_example()
|
||||
self._record_allocation(result_buffer, result_node)
|
||||
|
||||
def _generate_reuse(self, line: WrapperLine) -> None:
|
||||
@ -606,7 +613,6 @@ class FxConverter:
|
||||
or old.get_offset() != offset
|
||||
):
|
||||
result_node = self._create_as_strided(old_node, size, stride, offset)
|
||||
self._create_meta_from_buffer(result_node, new)
|
||||
|
||||
self._record_allocation(new, result_node)
|
||||
|
||||
@ -635,7 +641,6 @@ class FxConverter:
|
||||
idx = inds[0]
|
||||
|
||||
node = self.gm.graph.call_function(operator.getitem, args=(arg_node, idx))
|
||||
node.meta["val"] = arg_node.meta["val"][idx]
|
||||
node.name = line.result_name
|
||||
self.buffer_to_node[line.result_name] = node
|
||||
|
||||
@ -778,14 +783,6 @@ class FxConverter:
|
||||
fx_node.name = result_buffer
|
||||
self.buffer_to_node[result_buffer] = fx_node
|
||||
|
||||
arg_tensors = [
|
||||
arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg
|
||||
for arg in args
|
||||
]
|
||||
|
||||
# Run the operation to propagate metadata.
|
||||
fx_node.meta["val"] = op(*arg_tensors, **kwargs)
|
||||
|
||||
def _generate_kernel_call(self, line: WrapperLine) -> None:
|
||||
assert isinstance(line, KernelCallLine)
|
||||
if not line.triton:
|
||||
|
Reference in New Issue
Block a user