Compare commits

...

3 Commits

Author SHA1 Message Date
95173b282e WIP 2024-08-07 16:12:10 -07:00
30438a640b [inductor] Add some more reinplacing tests
Also add some tests around the counters we added in a previous PR.

Test Plan:
- new tests

ghstack-source-id: 33862d21fcaf9c9c4de34d512518be118dda827a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132839
2024-08-07 08:02:00 -07:00
1e8f48a914 Add logging + counter for missed reinplacing opportunities
Summary:
- We add Inductor logs for what tensors we tried to reinplace, what
  tensors we were unable to reinplace, and of those tensors, which of
  those might be bugs (the "missed reinplacing opportunities"). You can
  tell this by reading the Inductor output graph but the logs make it
  easier to figure out.
- Add a dynamo_compile counter for missed reinplacing opportunities. The
  goal is to see how widespread existing problems (if any) are. We've had
  trouble getting all of the edge cases for the reinplacing pass; the
  counter will help us hunt down issues.

Test Plan:
- tested locally

ghstack-source-id: e371a6455e735ff7a8e3f31c7cf622eb55d5f513
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132758
2024-08-06 11:27:15 -07:00
9 changed files with 322 additions and 65 deletions

View File

@ -1,9 +1,16 @@
# Owner(s): ["module: inductor"]
import torch
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA
from functorch import make_fx
from torch._dynamo.utils import counters
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.fx_passes.reinplace import reinplace_inplaceable_ops_core
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCUDA,
)
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_GPU
aten = torch.ops.aten
@ -13,7 +20,22 @@ const = torch.tensor(0.0)
device = "cuda"
class TestReinplacingPassCorrectness(TestCase):
def num_reinplacing_failures():
return counters["inductor"]["possibly_missed_reinplacing_opportunities"]
@torch.library.custom_op("_reinplacing::sin", mutates_args={"out"})
def sin(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x.sin())
@torch.library.custom_op("_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"})
def sin_cos(x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor) -> None:
out_sin.copy_(x.sin())
out_cos.copy_(x.cos())
class TestReinplacingPassCorrectness(InductorTestCase):
def _test(self, f):
nf = torch.compile(f)
inp = (
@ -22,10 +44,10 @@ class TestReinplacingPassCorrectness(TestCase):
)
inp2 = (inp[0].clone(), inp[1].clone())
self.assertEqual(f(*inp), nf(*inp2))
# breakpoint()
self.assertEqual(inp, inp2)
def test_dont_modify_live(self):
@onlyCUDA
def test_dont_modify_live(self, device):
def f(x, y):
x = x.cos()
x2 = x.index_put((y,), const)
@ -33,7 +55,8 @@ class TestReinplacingPassCorrectness(TestCase):
self._test(f)
def test_dont_modify_view_of_live(self):
@onlyCUDA
def test_dont_modify_view_of_live(self, device):
def f(x, y):
x = x.cos()
x2 = aten.alias(x)
@ -43,13 +66,15 @@ class TestReinplacingPassCorrectness(TestCase):
self._test(f)
def test_dont_modify_input(self):
@onlyCUDA
def test_dont_modify_input(self, device):
def f(x, y):
return x.index_put((y,), const)
self._test(f)
def test_should_modify_inner(self):
@onlyCUDA
def test_should_modify_inner(self, device):
def f(x, y):
x = x.cos()
x = x.index_put((y,), const)
@ -57,14 +82,86 @@ class TestReinplacingPassCorrectness(TestCase):
self._test(f)
def test_should_modify_input(self):
@onlyCUDA
def test_should_modify_input(self, device):
def f(x, y):
x = x.index_put_((y,), const)
return x
self._test(f)
def test_counters(self, device):
counters.clear()
def f(x):
out = torch.empty_like(x)
_, new_out = auto_functionalized(sin._opoverload, x=x, out=out)
y = out * new_out
return new_out, y
x = torch.randn(3, device=device)
gm = make_fx(f, tracing_mode="fake")(x)
reinplace_inplaceable_ops_core(gm.graph)
# We shouldn't have been able to reinplace `out` because it was used after
# auto_functionalized. Note that this usually doesn't happen in practice;
# we're artificially creating this example to test the counter.
# IF THIS NUMBER GOES TO ZERO, PLEASE FIND ANOTHER EXAMPLE
self.assertEqual(num_reinplacing_failures(), 1)
def test_multi_output_intermediate(self, device):
for requires_grad in [False, True]:
counters.clear()
def f(x):
out1 = torch.empty_like(x)
out2 = torch.empty_like(x)
sin_cos(x, out1, out2)
return out1, out2, x**2
x = torch.randn(3, device=device, requires_grad=requires_grad)
res1, res2, _ = torch.compile(f)(x)
self.assertEqual(res1, x.sin())
self.assertEqual(res2, x.cos())
self.assertEqual(num_reinplacing_failures(), 0)
def test_multiple_mutations(self, device):
counters.clear()
def f(x, out):
sin(x, out)
sin(out, out)
sin(out, out)
return out
x = torch.randn(3, device=device)
out = torch.randn(3, device=device)
result = torch.compile(f)(x, out)
self.assertEqual(result, x.sin().sin().sin())
self.assertEqual(result, out)
self.assertEqual(num_reinplacing_failures(), 0)
def test_multiple_intermediate(self, device):
counters.clear()
def f(x):
out = torch.empty_like(x)
sin(x, out)
sin(out, out)
sin(out, out)
return out
x = torch.randn(3, device=device)
result = torch.compile(f)(x)
self.assertEqual(result, x.sin().sin().sin())
self.assertEqual(num_reinplacing_failures(), 0)
only_for = ("cpu", "cuda")
instantiate_device_type_tests(
TestReinplacingPassCorrectness, globals(), only_for=only_for
)
if __name__ == "__main__":
if IS_LINUX and HAS_CUDA:
run_tests()
if HAS_CPU or HAS_GPU:
run_tests(needs="filelock")

View File

@ -891,6 +891,9 @@ def _compile(
fail_reason: Optional[str] = None
fail_user_frame_filename: Optional[str] = None
fail_user_frame_lineno: Optional[int] = None
start_possibly_missed_reinplacing_opportunities = torch._dynamo.utils.counters[
"inductor"
]["possibly_missed_reinplacing_opportunities"]
guarded_code = None
try:
guarded_code = compile_inner(code, one_graph, hooks, transform)
@ -954,6 +957,12 @@ def _compile(
compliant_custom_ops = {
op.__qualname__ for op in output.compliant_custom_ops
}
possibly_missed_reinplacing_opportunities = (
torch._dynamo.utils.counters["inductor"][
"possibly_missed_reinplacing_opportunities"
]
- start_possibly_missed_reinplacing_opportunities
)
else:
guard_count = None
shape_env_guard_count = None
@ -969,6 +978,7 @@ def _compile(
restart_reasons = set()
# If compilation failed, the entire time is wasted
dynamo_time_before_restart = time.time() - start_time
possibly_missed_reinplacing_opportunities = None
metrics = CompilationMetrics(
str(compile_id),
@ -997,6 +1007,7 @@ def _compile(
restart_reasons,
dynamo_time_before_restart,
guarded_code is not None,
possibly_missed_reinplacing_opportunities,
)
record_compilation_metrics(metrics)
torch._dynamo.callback_handler.run_end_callbacks()

View File

@ -742,6 +742,7 @@ class CompilationMetrics:
# to install any guarded code. True means we actually decided to install
# a compiled frame
has_guarded_code: bool
possibly_missed_reinplacing_opportunities: Optional[int]
@dataclasses.dataclass

View File

@ -46,7 +46,7 @@ def fx_graph_cse(fx_g: torch.fx.graph.Graph):
from torch._inductor.pattern_matcher import (
compute_mutation_region_ids,
same_mutation_regions,
safe_to_operate,
)
compute_mutation_region_ids(fx_g) # type: ignore[arg-type]
@ -100,7 +100,7 @@ def fx_graph_cse(fx_g: torch.fx.graph.Graph):
overwrite_due_to_mutation = False
if hash_val_in_hash_env and token_map[hash_val] == token:
duplicate_n_prev = hash_env[hash_val]
if same_mutation_regions(n, duplicate_n_prev):
if safe_to_operate(fx_g, [n, duplicate_n_prev]):
env[n] = duplicate_n_prev
continue
else:

View File

@ -31,7 +31,7 @@ from ..pattern_matcher import (
CallFunctionVarArgs,
filter_nodes,
get_arg_value,
get_mutation_region_id,
safe_to_operate,
Ignored,
init_once_fakemode,
KeywordArg,
@ -163,8 +163,7 @@ def reorder_for_locality(graph: torch.fx.Graph):
other_node.op == "call_function"
and other_node.target != operator.getitem
and all((n in seen_nodes) for n in other_node.users)
and get_mutation_region_id(graph, node)
== get_mutation_region_id(graph, other_node)
and safe_to_operate(graph, [node, other_node])
):
# move node's producers right before it
node.prepend(other_node)

View File

@ -1,12 +1,16 @@
# mypy: allow-untyped-defs
import itertools
import logging
import operator
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple
import torch
from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_functional
from torch._higher_order_ops.triton_kernel_wrap import (
kernel_side_table,
triton_kernel_wrapper_functional,
)
from torch._inductor import inductor_prims
from torch._inductor.fx_utils import get_node_storage, is_node_realized
from torch._inductor.lowering import (
@ -18,6 +22,7 @@ from torch.fx.passes.reinplace import _is_view_op
from torch.utils import _pytree as pytree
log = logging.getLogger(__name__)
aten = torch.ops.aten
@ -488,9 +493,10 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
replace_dict: Dict[torch.fx.Node, torch.fx.Node] = {}
def reinplace_and_refine_tensors_to_clone(old_tensors_to_clone, kwargs):
def reinplace_and_refine_tensors_to_clone(old_tensors_to_clone, kwargs, node_name):
tensors_to_clone: List[str] = []
storage_of_reinplaced_args = set()
possibly_missed_reinplacing_opportunities = []
def tensor_with_same_storage_already_reinplaced(arg):
if isinstance(arg, (list, tuple)):
@ -502,20 +508,21 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
for arg in old_tensors_to_clone:
assert arg in kwargs
mutated_arg = kwargs[arg]
if (
# Let's say we have:
# - op(x, y) that mutates both x and y
# - new_x, new_y = functional_op(x, y) is the functional variant
# If we are presented with functional_op(x, x), we must not reinplace
# this into op(x, x), because then it would be writing to the same Tensor.
# Instead, it's OK to reinplace one of them and to clone the other:
# >>> y = x.clone()
# >>> op(x, y)
# This also applies if we have views: functional_op(x, x[0])
# should not reinplace into op(x, x[0]).
not tensor_with_same_storage_already_reinplaced(mutated_arg)
and can_inplace(node, mutated_arg)
):
# Let's say we have:
# - op(x, y) that mutates both x and y
# - new_x, new_y = functional_op(x, y) is the functional variant
# If we are presented with functional_op(x, x), we must not reinplace
# this into op(x, x), because then it would be writing to the same Tensor.
# Instead, it's OK to reinplace one of them and to clone the other:
# >>> y = x.clone()
# >>> op(x, y)
# This also applies if we have views: functional_op(x, x[0])
# should not reinplace into op(x, x[0]).
should_attempt_reinplace = not tensor_with_same_storage_already_reinplaced(
mutated_arg
)
if should_attempt_reinplace and can_inplace(node, mutated_arg):
copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
if copy_node is not None:
replace_dict[copy_node] = copy_node.args[0]
@ -529,7 +536,22 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
else:
storage_of_reinplaced_args.add(get_node_storage(mutated_arg))
else:
if should_attempt_reinplace:
possibly_missed_reinplacing_opportunities.append(arg)
tensors_to_clone.append(arg)
log.info(
"For node %s, attempted to reinplace %s. We were unable to reinplace %s; "
"%s (if non-empty) are possible missed reinplacing opportunities that may be bad for "
"memory usage and performance.",
node_name,
old_tensors_to_clone,
tensors_to_clone,
possibly_missed_reinplacing_opportunities,
)
torch._dynamo.utils.counters["inductor"][
"possibly_missed_reinplacing_opportunities"
] += len(possibly_missed_reinplacing_opportunities)
return tensors_to_clone
for node in graph.nodes:
@ -553,7 +575,7 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
t for t in tensors_to_clone if node.kwargs[t] is not None
]
tensors_to_clone = reinplace_and_refine_tensors_to_clone(
tensors_to_clone, node.kwargs
tensors_to_clone, node.kwargs, _mutable_op._name
)
# Stash the metadata. There is a pass later on where we decompose
@ -561,12 +583,24 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
# tells the decomp to only clone the following inputs
node.meta["only_clone_these_tensors"] = tensors_to_clone
elif node.target in inplaceable_triton_ops:
kernel_idx = node.kwargs["kernel_idx"]
kernel = kernel_side_table.get_kernel(kernel_idx)
from triton.runtime.autotuner import Autotuner
from triton.runtime.jit import JITFunction
if isinstance(kernel, JITFunction):
kernel_name = kernel.fn.__name__
elif isinstance(kernel, Autotuner):
kernel_name = kernel.base_fn.__name__
else:
raise AssertionError("Unknown triton kernel type")
# inplaceable_triton_ops take an additional argument called
# tensors_to_clone which contain a list of tensors to clone
# This pass iterates over them and sees which ones are safe
# to eliminate (i.e. no longer need the clones)
tensors_to_clone = reinplace_and_refine_tensors_to_clone(
node.kwargs["tensors_to_clone"], node.kwargs["kwargs"]
node.kwargs["tensors_to_clone"], node.kwargs["kwargs"], kernel_name
)
kwargs = dict(node.kwargs)

View File

@ -189,6 +189,16 @@ def get_storage(t: torch.Tensor) -> int:
return t.untyped_storage()._cdata
def get_node_storage_obj(node: torch.fx.Node) -> Optional[int]:
if "val" not in node.meta:
return None
if not isinstance(node.meta["val"], torch.Tensor):
return None
if not torch._C._has_storage(node.meta["val"]):
return None
return node.meta["val"].untyped_storage()
def get_node_storage(node: torch.fx.Node) -> Optional[int]:
if "val" not in node.meta:
return None

View File

@ -48,6 +48,7 @@ import os
import re
import textwrap
import typing
import weakref
from abc import ABC, abstractmethod
from collections import defaultdict
from pathlib import Path
@ -79,6 +80,7 @@ import torch.utils._pytree as pytree
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import counters
from torch._inductor.config import trace as trace_config
from torch._inductor.fx_utils import get_node_storage, get_node_storage_obj
from torch._prims_common import is_integer_dtype
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
@ -1604,7 +1606,71 @@ def is_start_of_fx_graph(graph: torch.fx.Graph, node: torch.fx.Node) -> bool:
_mutation_op_re = re.compile(r"(?<!_)(_$|_[.]|(\b|_)(set|enter|exit|seed)(\b|_))(?!_)")
def is_mutation_op(node: torch.fx.Node) -> bool:
def safe_to_operate(graph, nodes):
"""Is it safe to do graph transforms on nodes?
The provided graph may not be completely functional. This helper function
determines e.g. if we are allowed to reorder nodes or pattern match on them
in the presence of mutable operations.
There are two main types of mutable operations that can introduce
constraints on the nodes we can operate on:
- mutable operations
- auto_functionalized mutable operations
All of the metadata for this function gets computed in
`compute_node_constraints`
"""
# New nodes may have been added to the graph so we update all of them.
# NB: we assume the changes do not include the addition of
# any new barrier mutation or auto_functionalized nodes
# https://github.com/pytorch/pytorch/issues/132932
for node in nodes:
graph.meta["node_constraints"].update(node)
# [mutable operations]
# We treat all mutable operations in the graph as "barriers".
# that is, a node before a mutable operation (like set_) is not allowed to
# interact with after the mutable operation.
#
# Each node in the graph gets tagged with a "mutation_region_id"
# that is equal to the number of mutable ops before the node.
metas = [node.meta for node in nodes]
if len(set(meta['mutation_region_id'] for meta in metas)) > 1:
return False
# [auto_functionalized mutable operations]
# These are technically functional operations, but we still impose
# constraints to make it so that the reinplacing pass is able
# to reinplace them back to their mutable variants.
#
# For auto_functionalized nodes: all inputs that need reinplacing
# (and their views) are not allowed to interact with nodes that
# are after the auto_functionalized node.
#
# Each node in the graph is tagged with a "auto_functionalized_region_id"
# that is equal to the number of auto_functionalized nodes before the node.
#
# nodes that need reinplacing (and their views) are tagged with
# the auto_functionalized_region_id that they must be in or before.
key = 'must_be_in_or_before_auto_functionalized_region_id'
min_region_id_constraints = [meta[key] for meta in metas if key in meta]
if len(min_region_id_constraints) == 0:
return True
min_region_id_constraint = min(min_region_id_constraints)
region_ids = [meta['auto_functionalized_region_id'] for meta in metas]
for region_id in region_ids:
if region_id > min_region_id_constraint:
return False
return True
def compute_node_constraints(graph):
graph._meta["node_constraints"] = NodeConstraints(graph)
graph._meta["node_constraints"].compute_node_constraints()
def is_barrier_mutation_op(node: torch.fx.Node) -> bool:
if node.op == "call_function":
if _mutation_op_re.search(node.target.__name__): # type: ignore[union-attr]
return True
@ -1614,35 +1680,76 @@ def is_mutation_op(node: torch.fx.Node) -> bool:
return node.kwargs.get("out") is not None
def same_mutation_regions(a: torch.fx.Node, b: torch.fx.Node) -> bool:
assert "mutation_region_id" in a.meta
assert "mutation_region_id" in b.meta
return a.meta["mutation_region_id"] == b.meta["mutation_region_id"]
def should_compute_node_constraints(graph: torch.fx.GraphModule) -> bool:
return "node_constraints" not in graph.meta
def get_mutation_region_id(graph: torch.fx.Graph, node: torch.fx.Node) -> int:
n = node
while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n):
n = n.prev
mutation_region_id = n.meta.get("mutation_region_id", 0)
while n is not node:
n = n.next
if is_mutation_op(n):
mutation_region_id += 1
n.meta["mutation_region_id"] = mutation_region_id
return mutation_region_id
class NodeConstraints:
def __init__(self, graph):
self.graph = weakref.proxy(graph)
# Maps storage to a list of nodes that have the same storage.
self.storage_to_nodes = weakref.WeakKeyDictionary()
def compute_node_constraints(self):
current_barrier_mutation_region_id = 0
current_auto_functionalized_region_id = 0
for node in self.graph.nodes:
storage = get_node_storage_obj(node)
if storage is not None:
if storage not in self.storage_to_nodes:
self.storage_to_nodes[storage] = []
self.storage_to_nodes[storage].append(weakref.ref(node))
def should_compute_mutation_region_ids(graph: torch.fx.GraphModule) -> bool:
return "mutation_region_id" not in next(iter(graph.nodes)).meta
for node in self.graph.nodes:
current_barrier_mutation_region_id, current_auto_functionalized_region_id = self.process(node, current_barrier_mutation_region_id, current_auto_functionalized_region_id)
def process(self, node, current_barrier_mutation_region_id, current_auto_functionalized_region_id):
if is_barrier_mutation_op(node):
current_barrier_mutation_region_id += 1
def compute_mutation_region_ids(graph: torch.fx.GraphModule) -> None:
mutation_region_id = 0
for nd in graph.nodes:
if is_mutation_op(nd):
mutation_region_id += 1
nd.meta["mutation_region_id"] = mutation_region_id
node.meta["auto_functionalized_region_id"] = current_auto_functionalized_region_id
if node.target is torch.ops.higher_order.auto_functionalized:
# Get all inputs to the auto_functionalized that are reinplacing candidates.
mutable_op = node.args[0]
mutable_args_names = torch._higher_order_ops.auto_functionalize.get_mutable_arg_names(mutable_op)
mutable_args = [node.kwargs[name] for name in mutable_args_names]
mutable_storages = set(get_node_storage_obj(arg) for arg in mutable_args)
# Get all nodes that are aliases of all reinplacing candidates
aliased_nodes = []
for storage in mutable_storages:
aliased_nodes.extend(self.storage_to_nodes[storage])
assert len(aliased_nodes) >= 1, "must at least alias self"
for ref in aliased_nodes:
actual_node = ref()
key = "must_be_in_or_before_auto_functionalized_region_id"
if actual_node is not None and key not in actual_node.meta:
actual_node.meta[key] = current_auto_functionalized_region_id
current_auto_functionalized_region_id += 1
node.meta["mutation_region_id"] = current_barrier_mutation_region_id
return current_barrier_mutation_region_id, current_auto_functionalized_region_id
def update(self, node):
"""If a node has no constraints metadata, compute some"""
n = node
while "mutation_region_id" not in n.meta and not is_start_of_fx_graph(graph, n):
storage = get_node_storage_obj(n)
if storage is not None:
if storage not in self.storage_to_nodes:
self.storage_to_nodes[storage] = []
self.storage_to_nodes[storage].append(weakref.ref(n))
n = n.prev
current_barrier_mutation_region_id = n.meta['mutation_region_id']
current_auto_functionalized_region_id = n.meta['auto_functionalized_region_id']
while n is not node:
n = n.next
current_barrier_mutation_region_id, current_auto_functionalized_region_id = self.process(node, current_barrier_mutation_region_id, current_auto_functionalized_region_id)
class PatternMatcherPass:
@ -1671,11 +1778,8 @@ class PatternMatcherPass:
raise RuntimeError(
f"The input to PatternMatcherPass must be a GraphModule or a Graph, but got {type(gm)}"
)
if should_compute_mutation_region_ids(graph):
compute_mutation_region_ids(graph)
get_mutation_region_id_partial = functools.partial(
get_mutation_region_id, graph
)
if should_compute_node_constraints(graph):
compute_node_constraints(graph)
count = 0
nodes = []
has_call_module = False
@ -1709,7 +1813,7 @@ class PatternMatcherPass:
# pattern match crosses mutation barrier - discard
if (
is_match(m)
and len(set(map(get_mutation_region_id_partial, m.nodes))) != 1 # type: ignore[possibly-undefined]
and not safe_to_operate(graph, m.nodes)
):
continue
if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name:

View File

@ -883,6 +883,7 @@ class Graph:
self._codegen = CodeGen()
self._co_fields : Dict[str, Any] = {}
self._find_nodes_lookup_table = _FindNodesLookupTable()
self._meta = {}
@property
def owning_module(self):