mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155821 Approved by: https://github.com/ezyang, https://github.com/zou3519
759 lines
29 KiB
Python
759 lines
29 KiB
Python
# mypy: allow-untyped-defs
|
|
import itertools
|
|
import logging
|
|
import operator
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Union
|
|
|
|
import torch
|
|
from torch._C._dynamo.guards import compute_overlapping_tensors
|
|
from torch._dispatch.python import enable_python_dispatcher
|
|
from torch._dynamo.utils import ReinplaceCounters, ReInplaceTrigger
|
|
from torch._higher_order_ops.triton_kernel_wrap import (
|
|
kernel_side_table,
|
|
triton_kernel_wrapper_functional,
|
|
)
|
|
from torch._inductor import config, inductor_prims
|
|
from torch._inductor.fx_utils import get_node_storage, is_node_realized
|
|
from torch._inductor.lowering import (
|
|
inplaceable_foreach_ops as inplaceable_foreach_ops_lowerings,
|
|
)
|
|
from torch._inductor.virtualized import V
|
|
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
|
|
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
|
from torch.fx.passes.reinplace import _is_view_op
|
|
from torch.utils import _pytree as pytree
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
aten = torch.ops.aten
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class InplaceableOp:
|
|
inplace_op: Callable[..., Any]
|
|
mutated_arg: int
|
|
extra_check: Callable[[torch.fx.Node], bool] = lambda node: True
|
|
|
|
|
|
_SCATTER_OP_TO_VIEW = {
|
|
torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default,
|
|
torch.ops.aten.select_scatter.default: torch.ops.aten.select.int,
|
|
torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor,
|
|
torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default,
|
|
}
|
|
_VIEW_OP_TO_SCATTER = {v: k for k, v in _SCATTER_OP_TO_VIEW.items()}
|
|
|
|
|
|
def graph_call_function(graph: torch.fx.Graph, fn, *args, **kwargs):
|
|
fake_args, fake_kwargs = pytree.tree_map(
|
|
lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node,
|
|
(args, kwargs),
|
|
)
|
|
with V.fake_mode:
|
|
fake_result = fn(*fake_args, **fake_kwargs)
|
|
|
|
node = graph.call_function(fn, args, kwargs)
|
|
node.meta["val"] = fake_result
|
|
return node
|
|
|
|
|
|
@dataclass
|
|
class ViewOp:
|
|
target: torch._ops.OpOverload
|
|
args: tuple[Any, ...]
|
|
kwargs: dict[str, Any]
|
|
|
|
|
|
def _inplace_generalized_scatter(
|
|
inp: torch.Tensor, src: torch.Tensor, view_ops: list[ViewOp]
|
|
) -> torch.Tensor:
|
|
tmp = inp
|
|
for view in view_ops:
|
|
fake_args, fake_kwargs = pytree.tree_map(
|
|
lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node,
|
|
(view.args, view.kwargs),
|
|
)
|
|
tmp = view.target(tmp, *fake_args, **fake_kwargs)
|
|
try:
|
|
tmp.copy_(src)
|
|
except RuntimeError as e:
|
|
raise RuntimeError(
|
|
f"shape error in scatter op, can not broadcast {src.shape} to {tmp.shape}"
|
|
) from e
|
|
return inp
|
|
|
|
|
|
def _generalized_scatter(
|
|
inp: torch.Tensor, src: torch.Tensor, view_ops: list[ViewOp]
|
|
) -> torch.Tensor:
|
|
out = inp.clone()
|
|
return _inplace_generalized_scatter(out, src, view_ops)
|
|
|
|
|
|
def _decompose_scatter_functional_helper(
|
|
graph: torch.fx.Graph,
|
|
inp: torch.Tensor,
|
|
src: torch.Tensor,
|
|
view_ops: list[ViewOp],
|
|
) -> torch.fx.Node:
|
|
view_op, view_ops_tail = view_ops[0], view_ops[1:]
|
|
|
|
if view_ops_tail:
|
|
view = graph_call_function(
|
|
graph, view_op.target, inp, *view_op.args, **view_op.kwargs
|
|
)
|
|
src = _decompose_scatter_functional_helper(graph, view, src, view_ops[1:]) # type: ignore[assignment]
|
|
|
|
return graph_call_function(
|
|
graph,
|
|
_VIEW_OP_TO_SCATTER[view_op.target],
|
|
inp,
|
|
src,
|
|
*view_op.args,
|
|
**view_op.kwargs,
|
|
)
|
|
|
|
|
|
def _decompose_scatter_functional(
|
|
graph: torch.fx.Graph, node: torch.fx.Node
|
|
) -> torch.fx.Node:
|
|
"""Decompose _generalized_scatter to a sequence of view_scatter operations
|
|
|
|
e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)])
|
|
|
|
will become
|
|
|
|
view = aten.slice(inp, 0, 0, 10)
|
|
view_updated = aten.slice_scatter(view, src, 1, 10, -10)
|
|
inp_updated = aten.slice_scatter(inp, view_updated, 0, 0, 10)
|
|
"""
|
|
assert node.target is _generalized_scatter
|
|
return _decompose_scatter_functional_helper(graph, *node.args) # type: ignore[arg-type]
|
|
|
|
|
|
def _decompose_scatter_mutating(
|
|
graph: torch.fx.Graph, node: torch.fx.Node
|
|
) -> torch.fx.Node:
|
|
"""Decompose _generalized_scatter using mutations
|
|
|
|
e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)])
|
|
|
|
will become
|
|
|
|
inp_updated = aten.clone(inp)
|
|
slice1 = aten.slice(inp_updated, 0, 0, 10)
|
|
slice2 = aten.slice(slice1, 1, 10, -10)
|
|
slice2.copy_(src)
|
|
|
|
"""
|
|
assert node.target in (_generalized_scatter, _inplace_generalized_scatter)
|
|
inp, src, view_ops = node.args
|
|
assert not node.kwargs
|
|
|
|
if node.target is _generalized_scatter:
|
|
inp = graph_call_function(graph, aten.clone, inp)
|
|
|
|
tmp = inp
|
|
for view in view_ops: # type: ignore[union-attr]
|
|
tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs) # type: ignore[union-attr]
|
|
|
|
graph_call_function(graph, aten.copy_.default, tmp, src)
|
|
return inp # type: ignore[return-value]
|
|
|
|
|
|
# View ops whose view_scatter op is lowered into mutations anyway,
|
|
# so is never a pessimisation to decompose.
|
|
_ALWAYS_MUTATING_SCATTER_OPS = OrderedSet(
|
|
[
|
|
aten.as_strided.default,
|
|
aten.diagonal.default,
|
|
]
|
|
)
|
|
|
|
|
|
def scatter_always_uses_mutation(node: torch.fx.Node) -> bool:
|
|
_, _, view_ops = node.args
|
|
return any(view.target in _ALWAYS_MUTATING_SCATTER_OPS for view in view_ops) # type: ignore[union-attr]
|
|
|
|
|
|
def should_reinplace_scatter(node: torch.fx.Node) -> bool:
|
|
"""Choose between mutating and functional scatter decompositions
|
|
|
|
Reinplacing view scatter ops can be pessimising as it blocks fusion with the
|
|
input or output tensor computations. However, it is still profitable if the
|
|
input and output would have been realized anyway.
|
|
|
|
"""
|
|
inp, _src, _view_ops = node.args
|
|
|
|
# Mutating scatter ops unconditionally realize input and output
|
|
if scatter_always_uses_mutation(node):
|
|
return True
|
|
|
|
if is_node_realized(inp) and is_node_realized(node): # type: ignore[arg-type]
|
|
return True
|
|
|
|
# If the output is copied back into the input, this forces both to be
|
|
# realized as the output is a user of the input
|
|
if inp.op in ("placeholder", "get_attr") and any( # type: ignore[union-attr]
|
|
user.target is aten.copy_.default and user.args[0] is inp for user in node.users
|
|
):
|
|
return True
|
|
|
|
# Otherwise, assume fusions will make functional variants profitable
|
|
return False
|
|
|
|
|
|
def decompose_generalized_scatter(graph: torch.fx.Graph) -> None:
|
|
"""Replace _generalized_scatter with normal aten ops"""
|
|
for node in itertools.chain(
|
|
graph.find_nodes(op="call_function", target=_generalized_scatter),
|
|
graph.find_nodes(op="call_function", target=_inplace_generalized_scatter),
|
|
):
|
|
use_mutation = (
|
|
node.target is _inplace_generalized_scatter
|
|
or scatter_always_uses_mutation(node)
|
|
)
|
|
|
|
with graph.inserting_before(node):
|
|
if use_mutation:
|
|
new_node = _decompose_scatter_mutating(graph, node)
|
|
else:
|
|
new_node = _decompose_scatter_functional(graph, node)
|
|
|
|
node.replace_all_uses_with(new_node)
|
|
graph.erase_node(node)
|
|
|
|
|
|
def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None:
|
|
"""
|
|
This canonicalizes view scatter ops into a generalized form, defined as:
|
|
def scatter(inp, src, views):
|
|
tmp = inp.clone()
|
|
for view in views:
|
|
tmp = view(tmp)
|
|
tmp.copy_(src)
|
|
|
|
We also fuse consecutive view scatter ops of the form
|
|
a = scatter(view2(self), src, [view1])
|
|
b = scatter(self, a, [view2])
|
|
which can be rewritten as
|
|
b = scatter(self, src, [view2, view1])
|
|
a = view2(b)
|
|
|
|
This is both more efficient as we only do a single scatter, and also
|
|
easier to reinplace since there is only one use of `self`
|
|
"""
|
|
|
|
node_to_view_base: dict[torch.fx.Node, torch.fx.Node] = {}
|
|
node_to_view_op: dict[torch.fx.Node, list[ViewOp]] = defaultdict(list)
|
|
|
|
def handle_views(node: torch.fx.Node):
|
|
inp = node.args[0]
|
|
node_to_view_base[node] = node_to_view_base.get(inp, inp) # type: ignore[arg-type, assignment]
|
|
node_to_view_op[node] = [
|
|
*node_to_view_op[inp], # type: ignore[index]
|
|
ViewOp(
|
|
node.target, # type: ignore[arg-type]
|
|
args=node.args[1:],
|
|
kwargs=node.kwargs,
|
|
),
|
|
]
|
|
|
|
def handle_view_scatter(node: torch.fx.Node):
|
|
assert len(node.args) >= 2
|
|
inp, src = node.args[:2]
|
|
|
|
scatter_view_op = ViewOp(
|
|
_SCATTER_OP_TO_VIEW[node.target],
|
|
args=node.args[2:],
|
|
kwargs=node.kwargs,
|
|
)
|
|
|
|
def can_fuse():
|
|
if src.target is not _generalized_scatter: # type: ignore[union-attr]
|
|
return False
|
|
src_inp, _src_src, _src_scatter_view_op = src.args # type: ignore[union-attr]
|
|
|
|
inp_base = node_to_view_base.get(inp, inp) # type: ignore[arg-type]
|
|
src_base = node_to_view_base.get(src_inp, src_inp) # type: ignore[arg-type]
|
|
return inp_base is src_base and node_to_view_op[src_inp] == [ # type: ignore[index]
|
|
*node_to_view_op[inp], # type: ignore[index]
|
|
scatter_view_op,
|
|
]
|
|
|
|
if not can_fuse():
|
|
with graph.inserting_before(node):
|
|
new_node = graph_call_function(
|
|
graph,
|
|
_generalized_scatter,
|
|
inp,
|
|
src,
|
|
[scatter_view_op],
|
|
)
|
|
node.replace_all_uses_with(new_node)
|
|
graph.erase_node(node)
|
|
return
|
|
|
|
_src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr]
|
|
with graph.inserting_before(src): # type: ignore[arg-type]
|
|
new_node = graph_call_function(
|
|
graph,
|
|
_generalized_scatter,
|
|
inp,
|
|
src_src,
|
|
[scatter_view_op, *src_scatter_view_op], # type: ignore[misc]
|
|
)
|
|
node.replace_all_uses_with(new_node)
|
|
graph.erase_node(node)
|
|
|
|
if src.users: # type: ignore[union-attr]
|
|
new_src = graph_call_function(
|
|
graph,
|
|
_SCATTER_OP_TO_VIEW[node.target],
|
|
new_node,
|
|
*node.args[2:],
|
|
**node.kwargs,
|
|
)
|
|
|
|
handle_views(new_src)
|
|
src.replace_all_uses_with(new_src) # type: ignore[union-attr]
|
|
|
|
graph.erase_node(src) # type: ignore[arg-type]
|
|
|
|
for node in graph.nodes:
|
|
if _is_view_op(node.target):
|
|
handle_views(node)
|
|
elif node.target in _SCATTER_OP_TO_VIEW:
|
|
handle_view_scatter(node)
|
|
|
|
|
|
inplaceable_ops = {
|
|
aten.index_put.default: InplaceableOp(aten.index_put_.default, 0),
|
|
aten._unsafe_index_put.default: InplaceableOp(inductor_prims._unsafe_index_put_, 0),
|
|
_generalized_scatter: InplaceableOp(
|
|
_inplace_generalized_scatter,
|
|
0,
|
|
extra_check=should_reinplace_scatter,
|
|
),
|
|
}
|
|
|
|
try:
|
|
c10d_functional = torch.ops._c10d_functional
|
|
inplaceable_collective_ops = {
|
|
c10d_functional.all_reduce.default: InplaceableOp(
|
|
c10d_functional.all_reduce_.default, 0
|
|
),
|
|
c10d_functional.all_reduce_coalesced.default: InplaceableOp(
|
|
c10d_functional.all_reduce_coalesced_.default, 0
|
|
),
|
|
}
|
|
inplaceable_ops.update(inplaceable_collective_ops)
|
|
except AttributeError:
|
|
# _c10d_functional ops are only available when torch
|
|
# is built with USE_DISTRIBUTED=1.
|
|
pass
|
|
|
|
inplaceable_foreach_ops: dict[torch._ops.OpOverload, InplaceableOp] = {}
|
|
for outplace_op, inplace_op in inplaceable_foreach_ops_lowerings.items():
|
|
inplaceable_foreach_ops[outplace_op] = InplaceableOp(inplace_op, 0)
|
|
|
|
|
|
inplaceable_triton_ops = OrderedSet([triton_kernel_wrapper_functional])
|
|
|
|
|
|
# Operators that don't depend on the tensor data
|
|
META_ONLY_OPS = OrderedSet(
|
|
[
|
|
aten.sym_size.int,
|
|
aten.sym_stride.int,
|
|
aten.sym_numel.default,
|
|
aten.sym_storage_offset.default,
|
|
]
|
|
)
|
|
|
|
|
|
def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
|
|
"""
|
|
Reinplaces in-placeable operations.
|
|
If there are no uses of a view of the mutated arg after the current node,
|
|
it is possible to inplace the op.
|
|
This above algorithm could be justified by observing side effects. While
|
|
we traverse the graph in forwards direction, only latter nodes could view
|
|
side effects of the current node. If the current node is not used later as
|
|
well as no view of this node is used later in the graph, then it is safe to
|
|
inplace as there would be no way to observe the side effects.
|
|
This condition is slightly different for graph inputs where they can only
|
|
be inplaced if the above condition is true and there's a copy_ in the
|
|
epilogue that signals that the caller wants to observe the mutation.
|
|
|
|
Unlike JIT Inductor, AOTInductor currently unlifts weights and buffers from
|
|
input args, so instead of checking mutation on placeholder, AOTInductor
|
|
checks mutation on get_attr. This is subject to change in future.
|
|
"""
|
|
|
|
copy_args_to_copy_nodes = {}
|
|
# maps argument to the first copy_ node that mutates it.
|
|
copy_nodes = {}
|
|
mutated_inputs = OrderedSet[Any]()
|
|
storage_to_nodes = defaultdict(list)
|
|
node_order: dict[Any, int] = {}
|
|
for i, node in enumerate(reversed(graph.nodes)):
|
|
node_order[node] = len(graph.nodes) - i - 1
|
|
storage_to_nodes[get_node_storage(node)].append(node)
|
|
if node.target == aten.copy_.default and node.args[0].op in (
|
|
"placeholder",
|
|
"get_attr",
|
|
):
|
|
dst = node.args[0]
|
|
src = node.args[1]
|
|
# If the target is a getitem and it indexes a possible clone,
|
|
# then skip over it
|
|
if src.target == operator.getitem and (
|
|
(
|
|
src.args[0].target == triton_kernel_wrapper_functional
|
|
and src.args[0].kwargs["kwargs"][src.args[1]] == node.args[0]
|
|
)
|
|
or (src.args[0].target in inplaceable_foreach_ops)
|
|
or (src.args[0].target == torch.ops.higher_order.auto_functionalized)
|
|
):
|
|
src = src.args[0]
|
|
|
|
copy_args_to_copy_nodes[(dst, src)] = node
|
|
copy_nodes[dst] = node
|
|
|
|
mutated_inputs.add(node.args[0])
|
|
|
|
def any_use_of_views_after_node(node, shared_view_nodes, *, copy_node, mutated_arg):
|
|
node_loc = node_order[node]
|
|
copy_node_loc = node_order[copy_node] if copy_node is not None else None
|
|
|
|
def is_meta_only_user(node):
|
|
if _is_view_op(node.target):
|
|
return all(is_meta_only_user(u) for u in node.users)
|
|
return node.target in META_ONLY_OPS
|
|
|
|
for view in shared_view_nodes:
|
|
for user in view.users:
|
|
user_loc = node_order[user]
|
|
# Skip all users before node
|
|
if user_loc <= node_loc:
|
|
continue
|
|
# Ignore uses after the copy_ epilogue node, where the input
|
|
# has already been mutated anyway
|
|
if copy_node_loc is not None and copy_node_loc <= user_loc:
|
|
continue
|
|
# Reinplacing does not change shape metadata
|
|
if is_meta_only_user(user):
|
|
continue
|
|
# If our graph looks like:
|
|
# foo(mutated_arg)
|
|
# mutated_arg.copy_(other)
|
|
# then it's safe for us to reinplace foo because mutated_arg
|
|
# will get overwritten anyways.
|
|
if (
|
|
user.target is torch.ops.aten.copy_.default
|
|
and mutated_arg is user.args[0]
|
|
):
|
|
continue
|
|
return True
|
|
return False
|
|
|
|
def can_inplace(node, mutated_arg):
|
|
# ls should be a list of tensors that all shares the same storage.
|
|
def _overlap(ls) -> bool:
|
|
try:
|
|
return len(compute_overlapping_tensors(ls)) != 0
|
|
except GuardOnDataDependentSymNode:
|
|
# If we fail with data dependent error we assume they all overlap.
|
|
return True
|
|
|
|
if isinstance(mutated_arg, (list, tuple)):
|
|
# TODO Using _overlap here causes a several issues.
|
|
unique_storages = OrderedSet(get_node_storage(arg) for arg in mutated_arg)
|
|
if len(unique_storages) != len(mutated_arg):
|
|
# At least two Tensors in mutated_arg alias each other, so we can't reinplace it.
|
|
# We can probably do better (that is, reinplace one of them and clone the other)
|
|
# but that requires more work and mutable List[Tensor] are not that common.
|
|
return False
|
|
return all(can_inplace(node, arg) for arg in mutated_arg)
|
|
|
|
if get_node_storage(mutated_arg) is None:
|
|
return False
|
|
|
|
shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)]
|
|
|
|
# Only keep tensor that might overlap with mutated_arg.
|
|
shared_view_nodes = [
|
|
v
|
|
for v in shared_view_nodes
|
|
if _overlap([mutated_arg.meta["val"], v.meta["val"]])
|
|
]
|
|
|
|
if mutated_arg.op in ("placeholder", "get_attr"):
|
|
# Get the first copy_ node that mutates the mutated_arg.
|
|
copy_node = copy_nodes.get(mutated_arg, None)
|
|
if copy_node is None:
|
|
# There is no copy_ back to the candidate mutated_arg (which is a graph input).
|
|
# Therefore the semantics of the program are that it does not mutate
|
|
# mutated_arg, so we cannot re-inplace it.
|
|
return False
|
|
if any_use_of_views_after_node(
|
|
node, shared_view_nodes, copy_node=copy_node, mutated_arg=mutated_arg
|
|
):
|
|
return False
|
|
|
|
return True
|
|
elif any(view.op in ("placeholder", "get_attr") for view in shared_view_nodes):
|
|
# This should never happen in auto_functionalize_v2 non-inference mode,
|
|
# since all mutated_arg are bases.
|
|
|
|
# If mutated arg is view of any of the inputs of the graph,
|
|
# do not allow for inplacing.
|
|
# This would require more sophisticated algorithm to handle
|
|
return False
|
|
else:
|
|
return not any_use_of_views_after_node(
|
|
node, shared_view_nodes, copy_node=None, mutated_arg=mutated_arg
|
|
)
|
|
|
|
def log_inplace_results(
|
|
node_name,
|
|
old_tensors_to_clone,
|
|
tensors_to_clone,
|
|
missed_args,
|
|
missed_nodes,
|
|
trigger,
|
|
):
|
|
# Total size of possibly_missed_reinplacing_opportunities for tensors with static shapes.
|
|
missed_bytes = 0
|
|
|
|
def bytes(node):
|
|
t = node.meta.get("val", None)
|
|
if (
|
|
t is not None
|
|
and isinstance(t.element_size(), int)
|
|
and isinstance(t.numel(), int)
|
|
):
|
|
return t.element_size() * t.numel()
|
|
else:
|
|
return 0
|
|
|
|
for node in missed_nodes:
|
|
if isinstance(node, (list, tuple)):
|
|
for n in node:
|
|
missed_bytes += bytes(n)
|
|
else:
|
|
missed_bytes += bytes(node)
|
|
|
|
log.info(
|
|
"For node %s, attempted to reinplace %s. We were unable to reinplace %s; "
|
|
"%s (if non-empty) are possible missed reinplacing opportunities that may be bad for "
|
|
"memory usage and performance. Total size of missed opportunities with static shapes is"
|
|
" : %s bytes.",
|
|
node_name,
|
|
old_tensors_to_clone,
|
|
tensors_to_clone,
|
|
missed_args,
|
|
missed_bytes,
|
|
)
|
|
|
|
ReinplaceCounters.add_missed_opportunities(trigger, len(missed_args))
|
|
ReinplaceCounters.add_missed_bytes(trigger, missed_bytes)
|
|
|
|
replace_dict: dict[torch.fx.Node, torch.fx.Node] = {}
|
|
|
|
def reinplace_and_refine_tensors_to_clone(
|
|
old_tensors_to_clone, kwargs, node_name, trigger
|
|
):
|
|
tensors_to_clone: list[str] = []
|
|
storage_of_reinplaced_args = OrderedSet[Union[int, None]]()
|
|
|
|
# Those used to count possibly_missed_reinplacing_opportunities
|
|
missed_nodes = []
|
|
missed_args = []
|
|
|
|
# TODO this logic can be made more precise using _overlap
|
|
def tensor_with_same_storage_already_reinplaced(arg):
|
|
if isinstance(arg, (list, tuple)):
|
|
return any(
|
|
get_node_storage(a) in storage_of_reinplaced_args for a in arg
|
|
)
|
|
return get_node_storage(mutated_arg) in storage_of_reinplaced_args
|
|
|
|
for arg in old_tensors_to_clone:
|
|
assert arg in kwargs
|
|
|
|
mutated_arg = kwargs[arg]
|
|
|
|
# Let's say we have:
|
|
# - op(x, y) that mutates both x and y
|
|
# - new_x, new_y = functional_op(x, y) is the functional variant
|
|
# If we are presented with functional_op(x, x), we must not reinplace
|
|
# this into op(x, x), because then it would be writing to the same Tensor.
|
|
# Instead, it's OK to reinplace one of them and to clone the other:
|
|
# >>> y = x.clone()
|
|
# >>> op(x, y)
|
|
# This also applies if we have views: functional_op(x, x[0])
|
|
# should not reinplace into op(x, x[0]).
|
|
should_attempt_reinplace = not tensor_with_same_storage_already_reinplaced(
|
|
mutated_arg
|
|
)
|
|
if should_attempt_reinplace and can_inplace(node, mutated_arg):
|
|
# In general, we probably do not need those optimizations.
|
|
copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
|
|
if copy_node is not None:
|
|
replace_dict[copy_node] = copy_node.args[0]
|
|
if not trigger == ReInplaceTrigger.AUTO_FUNC_V2:
|
|
for user in node.users:
|
|
# For auto_functionalize_v2, arg is the index of the base, where base at index i corresponds to
|
|
# output atindex size(out)+i.
|
|
# This used to compare string with integers before for auto_functionalize_v2. Not sure
|
|
# if it was needed for inplaceable_triton_ops?
|
|
if user.target == operator.getitem and user.args[1] == arg:
|
|
replace_dict[user] = mutated_arg
|
|
|
|
if isinstance(mutated_arg, (list, tuple)):
|
|
for a in mutated_arg:
|
|
storage_of_reinplaced_args.add(get_node_storage(a))
|
|
else:
|
|
storage_of_reinplaced_args.add(get_node_storage(mutated_arg))
|
|
else:
|
|
if should_attempt_reinplace:
|
|
missed_args.append(arg)
|
|
missed_nodes.append(mutated_arg)
|
|
|
|
tensors_to_clone.append(arg)
|
|
|
|
log_inplace_results(
|
|
node_name,
|
|
old_tensors_to_clone,
|
|
tensors_to_clone,
|
|
missed_args,
|
|
missed_nodes,
|
|
trigger,
|
|
)
|
|
return tensors_to_clone
|
|
|
|
for node in graph.nodes:
|
|
if (inplaceable_op := inplaceable_ops.get(node.target, None)) is not None:
|
|
mutated_arg = node.args[inplaceable_op.mutated_arg]
|
|
if can_inplace(node, mutated_arg) and inplaceable_op.extra_check(node):
|
|
# TODO(yifu): this doesn't properly remove copy epilogues for
|
|
# ops that mutate multiple inputs. Need to revise the copy
|
|
# node tracking logic to support the case.
|
|
copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
|
|
if copy_node is not None:
|
|
replace_dict[copy_node] = copy_node.args[0]
|
|
node.target = inplaceable_op.inplace_op
|
|
elif node.target == torch.ops.higher_order.auto_functionalized_v2:
|
|
_mutable_op = node.args[0]
|
|
kwargs = node.kwargs
|
|
|
|
all_bases = kwargs["_all_bases"]
|
|
bases_to_clone = range(len(all_bases))
|
|
base_tensors_dct = dict(enumerate(all_bases))
|
|
new_bases_to_clone: list[int] = reinplace_and_refine_tensors_to_clone(
|
|
bases_to_clone,
|
|
base_tensors_dct,
|
|
node.target,
|
|
ReInplaceTrigger.AUTO_FUNC_V2,
|
|
)
|
|
# Stash the metadata. There is a pass later on where we decompose
|
|
# auto_functionalized into clones + a mutable op; this metadata
|
|
# tells the decomp to only clone the following inputs
|
|
node.meta["only_clone_these_tensors"] = new_bases_to_clone
|
|
elif node.target == torch.ops.higher_order.auto_functionalized:
|
|
_mutable_op = node.args[0]
|
|
from torch._higher_order_ops.auto_functionalize import get_mutable_args
|
|
|
|
tensors_to_clone, _ = get_mutable_args(_mutable_op)
|
|
# Don't try to reinplace Optional[Tensor] args that are None.
|
|
tensors_to_clone = [
|
|
t for t in tensors_to_clone if node.kwargs[t] is not None
|
|
]
|
|
tensors_to_clone = reinplace_and_refine_tensors_to_clone(
|
|
tensors_to_clone,
|
|
node.kwargs,
|
|
_mutable_op._name,
|
|
ReInplaceTrigger.AUTO_FUNC_V1,
|
|
)
|
|
|
|
# Stash the metadata. There is a pass later on where we decompose
|
|
# auto_functionalized into clones + a mutable op; this metadata
|
|
# tells the decomp to only clone the following inputs
|
|
node.meta["only_clone_these_tensors"] = tensors_to_clone
|
|
elif node.target in inplaceable_triton_ops:
|
|
kernel_idx = node.kwargs["kernel_idx"]
|
|
kernel = kernel_side_table.get_kernel(kernel_idx)
|
|
from triton.runtime.autotuner import Autotuner
|
|
from triton.runtime.jit import JITFunction
|
|
|
|
if isinstance(kernel, JITFunction):
|
|
kernel_name = kernel.fn.__name__
|
|
elif isinstance(kernel, Autotuner):
|
|
if config.is_fbcode():
|
|
# Autotuner has different implementations for AMD and NV
|
|
if torch.version.hip is None:
|
|
kernel_name = kernel.base_fn.__name__
|
|
else:
|
|
kernel_name = kernel.fn.__name__
|
|
else:
|
|
kernel_name = kernel.base_fn.__name__
|
|
else:
|
|
raise AssertionError("Unknown triton kernel type")
|
|
|
|
# inplaceable_triton_ops take an additional argument called
|
|
# tensors_to_clone which contain a list of tensors to clone
|
|
# This pass iterates over them and sees which ones are safe
|
|
# to eliminate (i.e. no longer need the clones)
|
|
tensors_to_clone = reinplace_and_refine_tensors_to_clone(
|
|
node.kwargs["tensors_to_clone"],
|
|
node.kwargs["kwargs"],
|
|
kernel_name,
|
|
ReInplaceTrigger.TRITON_OPS,
|
|
)
|
|
|
|
kwargs = dict(node.kwargs)
|
|
kwargs["tensors_to_clone"] = tensors_to_clone
|
|
node.kwargs = immutable_dict(kwargs)
|
|
if "eager_input_vals" in node.meta:
|
|
# We changed the kwargs, so we need to update eager_input_vals
|
|
# to something sane.
|
|
args, kwargs = node.meta["eager_input_vals"]
|
|
new_kwargs = {**kwargs}
|
|
new_kwargs["tensors_to_clone"] = immutable_list(tensors_to_clone)
|
|
new_kwargs = immutable_dict(new_kwargs)
|
|
node.meta["eager_input_vals"] = (args, new_kwargs)
|
|
elif (
|
|
inplaceable_op := inplaceable_foreach_ops.get(node.target, None)
|
|
) is not None:
|
|
mutated_args = node.args[inplaceable_op.mutated_arg]
|
|
|
|
if not all((arg, node) in copy_args_to_copy_nodes for arg in mutated_args):
|
|
continue
|
|
|
|
if can_inplace(node, mutated_args):
|
|
for arg in mutated_args:
|
|
copy_node = copy_args_to_copy_nodes[(arg, node)]
|
|
replace_dict[copy_node] = copy_node.args[0]
|
|
|
|
node.target = inplaceable_op.inplace_op
|
|
for node, replacement in replace_dict.items():
|
|
while replacement in replace_dict:
|
|
replacement = replace_dict[replacement]
|
|
replace_dict[node] = replacement
|
|
|
|
node.replace_all_uses_with(replacement)
|
|
graph.erase_node(node)
|
|
|
|
|
|
def reinplace_inplaceable_ops(graph: torch.fx.Graph) -> None:
|
|
with enable_python_dispatcher():
|
|
canonicalize_view_scatter_ops(graph)
|
|
reinplace_inplaceable_ops_core(graph)
|
|
decompose_generalized_scatter(graph)
|