Change arg_kwarg_vals propagation strategy (#148046)

Instead of always propagating arg_kwarg_vals in _COPY_META_FIELDS, we
special-case the pattern matcher to propagate arg_kwarg_vals when
it sees triton_kernel_wrapper_functional.

The strategy is:
1) trace out the replacement graph with arg_kwarg_vals (which have accurate eager-mode metadata)
2) trace out the replacement graph with vals (which have the accurate Inductor metadata)
3) Propagate the arg_kwarg_vals from the first graph to the second.
4) Use the second graph as the replacement graph.

The strategy is this because we want to extend this to handle
auto_functionalized later up in the stack.

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148046
Approved by: https://github.com/eellison
This commit is contained in:
rzou
2025-04-01 21:07:25 -07:00
committed by PyTorch MergeBot
parent 03138733ba
commit c41fbb4f78
3 changed files with 76 additions and 5 deletions

View File

@ -21,7 +21,7 @@ from torch._inductor.lowering import (
)
from torch._inductor.virtualized import V
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
from torch.fx.immutable_collections import immutable_dict
from torch.fx.immutable_collections import immutable_dict, immutable_list
from torch.fx.passes.reinplace import _is_view_op
from torch.utils import _pytree as pytree
from torch.utils._ordered_set import OrderedSet
@ -720,6 +720,14 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
kwargs = dict(node.kwargs)
kwargs["tensors_to_clone"] = tensors_to_clone
node.kwargs = immutable_dict(kwargs)
if "arg_kwarg_vals" in node.meta:
# We changed the kwargs, so we need to update arg_kwarg_vals
# to something sane.
args, kwargs = node.meta["arg_kwarg_vals"]
new_kwargs = {**kwargs}
new_kwargs["tensors_to_clone"] = immutable_list(tensors_to_clone)
new_kwargs = immutable_dict(new_kwargs)
node.meta["arg_kwarg_vals"] = (args, new_kwargs)
elif (
inplaceable_op := inplaceable_foreach_ops.get(node.target, None)
) is not None:

View File

@ -251,14 +251,73 @@ class Match:
else contextlib.nullcontext()
)
def should_propagate_arg_kwarg_vals(nodes: list[torch.fx.Node]) -> bool:
if len(nodes) != 1:
return False
node = nodes[0]
if "arg_kwarg_vals" not in node.meta:
return False
return node.target in OrderedSet(
[
torch.ops.higher_order.triton_kernel_wrapper_functional,
]
)
with context:
if trace_fn is None:
trace_fn = functools.partial(
fwd_only, run_functional_passes=run_functional_passes
)
replacement = trace_fn(
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"])
if should_propagate_arg_kwarg_vals(self.nodes):
# Our strategy is:
# 1) trace out the graph with arg_kwarg_vals (which have accurate eager-mode metadata)
# 2) trace out the graph with vals (which have the accurate Inductor metadata)
# 3) Propagate the arg_kwarg_vals from the first graph to the second.
# 4) Use the second graph as the replacement graph.
# Construct a map of node -> FakeTensor val in arg_kwarg_vals
node_to_val = {}
fake_args, fake_kwargs = self.nodes[0].meta["arg_kwarg_vals"]
fake_kwargs = {**fake_kwargs}
match_args, match_kwargs = tuple(self.args), self.kwargs
def record(node: torch.fx.Node, val: Any) -> None:
if isinstance(node, torch.fx.Node):
node_to_val[node] = val
torch.utils._pytree.tree_map(
record, (match_args, match_kwargs), (fake_args, fake_kwargs)
)
# map args to their FakeTensor val in arg_kwarg_vals
example_vals = torch.fx.map_arg(args, lambda arg: node_to_val[arg])
# first graph
graph_with_eager_vals = trace_fn(replacement_fn, example_vals)
# second graph
example_vals = torch.fx.map_arg(args, lambda arg: arg.meta["val"])
replacement = trace_fn(graph_with_eager_vals, example_vals)
# propagate metadata from first graph to second
# NB: This assertion might not be true in general, but it is true for
# the two use cases we have
# (triton_kernel_wrapper_functional, auto_functionalized)
assert len(graph_with_eager_vals.graph.nodes) == len(
replacement.graph.nodes
)
for old_node, new_node in zip(
graph_with_eager_vals.graph.nodes, replacement.graph.nodes
):
if "arg_kwarg_vals" in old_node.meta:
new_node.meta["arg_kwarg_vals"] = old_node.meta[
"arg_kwarg_vals"
]
else:
example_vals = torch.fx.map_arg(args, lambda arg: arg.meta["val"])
replacement = trace_fn(replacement_fn, example_vals)
if len(self.nodes) == 1:
for n in replacement.graph.nodes:
_transfer_meta(
@ -1083,6 +1142,11 @@ class ReplacementPatternEntry(PatternEntry):
old_node=node,
pass_name="Interpreter_Replacer",
)
# This function copy-pastes the replacement graph into
# the graph. If the replacement graph had any arg_kwarg_vals,
# or val/tensor_meta, we propagate those over.
if "arg_kwarg_vals" in node.meta:
result.meta["arg_kwarg_vals"] = node.meta["arg_kwarg_vals"]
if "val" in node.meta and "val" not in result.meta:
result.meta["val"] = node.meta["val"]
if isinstance(node.meta["val"], torch.Tensor):

View File

@ -116,7 +116,6 @@ _COPY_META_FIELDS = [
"_numeric_debug_handle", # TODO deprecated
"custom",
"partitioner_tag",
"arg_kwarg_vals",
]