mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
Change arg_kwarg_vals propagation strategy (#148046)
Instead of always propagating arg_kwarg_vals in _COPY_META_FIELDS, we special-case the pattern matcher to propagate arg_kwarg_vals when it sees triton_kernel_wrapper_functional. The strategy is: 1) trace out the replacement graph with arg_kwarg_vals (which have accurate eager-mode metadata) 2) trace out the replacement graph with vals (which have the accurate Inductor metadata) 3) Propagate the arg_kwarg_vals from the first graph to the second. 4) Use the second graph as the replacement graph. The strategy is this because we want to extend this to handle auto_functionalized later up in the stack. Test Plan: - existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/148046 Approved by: https://github.com/eellison
This commit is contained in:
@ -21,7 +21,7 @@ from torch._inductor.lowering import (
|
||||
)
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
|
||||
from torch.fx.immutable_collections import immutable_dict
|
||||
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
||||
from torch.fx.passes.reinplace import _is_view_op
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
@ -720,6 +720,14 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
|
||||
kwargs = dict(node.kwargs)
|
||||
kwargs["tensors_to_clone"] = tensors_to_clone
|
||||
node.kwargs = immutable_dict(kwargs)
|
||||
if "arg_kwarg_vals" in node.meta:
|
||||
# We changed the kwargs, so we need to update arg_kwarg_vals
|
||||
# to something sane.
|
||||
args, kwargs = node.meta["arg_kwarg_vals"]
|
||||
new_kwargs = {**kwargs}
|
||||
new_kwargs["tensors_to_clone"] = immutable_list(tensors_to_clone)
|
||||
new_kwargs = immutable_dict(new_kwargs)
|
||||
node.meta["arg_kwarg_vals"] = (args, new_kwargs)
|
||||
elif (
|
||||
inplaceable_op := inplaceable_foreach_ops.get(node.target, None)
|
||||
) is not None:
|
||||
|
@ -251,14 +251,73 @@ class Match:
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
|
||||
def should_propagate_arg_kwarg_vals(nodes: list[torch.fx.Node]) -> bool:
|
||||
if len(nodes) != 1:
|
||||
return False
|
||||
node = nodes[0]
|
||||
if "arg_kwarg_vals" not in node.meta:
|
||||
return False
|
||||
return node.target in OrderedSet(
|
||||
[
|
||||
torch.ops.higher_order.triton_kernel_wrapper_functional,
|
||||
]
|
||||
)
|
||||
|
||||
with context:
|
||||
if trace_fn is None:
|
||||
trace_fn = functools.partial(
|
||||
fwd_only, run_functional_passes=run_functional_passes
|
||||
)
|
||||
replacement = trace_fn(
|
||||
replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"])
|
||||
|
||||
if should_propagate_arg_kwarg_vals(self.nodes):
|
||||
# Our strategy is:
|
||||
# 1) trace out the graph with arg_kwarg_vals (which have accurate eager-mode metadata)
|
||||
# 2) trace out the graph with vals (which have the accurate Inductor metadata)
|
||||
# 3) Propagate the arg_kwarg_vals from the first graph to the second.
|
||||
# 4) Use the second graph as the replacement graph.
|
||||
|
||||
# Construct a map of node -> FakeTensor val in arg_kwarg_vals
|
||||
node_to_val = {}
|
||||
|
||||
fake_args, fake_kwargs = self.nodes[0].meta["arg_kwarg_vals"]
|
||||
fake_kwargs = {**fake_kwargs}
|
||||
match_args, match_kwargs = tuple(self.args), self.kwargs
|
||||
|
||||
def record(node: torch.fx.Node, val: Any) -> None:
|
||||
if isinstance(node, torch.fx.Node):
|
||||
node_to_val[node] = val
|
||||
|
||||
torch.utils._pytree.tree_map(
|
||||
record, (match_args, match_kwargs), (fake_args, fake_kwargs)
|
||||
)
|
||||
# map args to their FakeTensor val in arg_kwarg_vals
|
||||
example_vals = torch.fx.map_arg(args, lambda arg: node_to_val[arg])
|
||||
|
||||
# first graph
|
||||
graph_with_eager_vals = trace_fn(replacement_fn, example_vals)
|
||||
|
||||
# second graph
|
||||
example_vals = torch.fx.map_arg(args, lambda arg: arg.meta["val"])
|
||||
replacement = trace_fn(graph_with_eager_vals, example_vals)
|
||||
|
||||
# propagate metadata from first graph to second
|
||||
# NB: This assertion might not be true in general, but it is true for
|
||||
# the two use cases we have
|
||||
# (triton_kernel_wrapper_functional, auto_functionalized)
|
||||
assert len(graph_with_eager_vals.graph.nodes) == len(
|
||||
replacement.graph.nodes
|
||||
)
|
||||
for old_node, new_node in zip(
|
||||
graph_with_eager_vals.graph.nodes, replacement.graph.nodes
|
||||
):
|
||||
if "arg_kwarg_vals" in old_node.meta:
|
||||
new_node.meta["arg_kwarg_vals"] = old_node.meta[
|
||||
"arg_kwarg_vals"
|
||||
]
|
||||
|
||||
else:
|
||||
example_vals = torch.fx.map_arg(args, lambda arg: arg.meta["val"])
|
||||
replacement = trace_fn(replacement_fn, example_vals)
|
||||
if len(self.nodes) == 1:
|
||||
for n in replacement.graph.nodes:
|
||||
_transfer_meta(
|
||||
@ -1083,6 +1142,11 @@ class ReplacementPatternEntry(PatternEntry):
|
||||
old_node=node,
|
||||
pass_name="Interpreter_Replacer",
|
||||
)
|
||||
# This function copy-pastes the replacement graph into
|
||||
# the graph. If the replacement graph had any arg_kwarg_vals,
|
||||
# or val/tensor_meta, we propagate those over.
|
||||
if "arg_kwarg_vals" in node.meta:
|
||||
result.meta["arg_kwarg_vals"] = node.meta["arg_kwarg_vals"]
|
||||
if "val" in node.meta and "val" not in result.meta:
|
||||
result.meta["val"] = node.meta["val"]
|
||||
if isinstance(node.meta["val"], torch.Tensor):
|
||||
|
@ -116,7 +116,6 @@ _COPY_META_FIELDS = [
|
||||
"_numeric_debug_handle", # TODO deprecated
|
||||
"custom",
|
||||
"partitioner_tag",
|
||||
"arg_kwarg_vals",
|
||||
]
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user