Files
pytorch/test/inductor/test_control_deps.py
eellison 4a39820e5e Add hop for additional control dependencies (#164568)
Adds [control_deps](https://en.wikipedia.org/wiki/Control_dependency) higher-order operator to enforce explicit scheduling dependencies in FX graphs. This prevents unwanted operation reordering/fusion by giving nodes additional dependencies, which we also respect in inductor by adding weakdeps on the additional dependencies.

This can be generally useful (such as for ordering collectives) but in this case I am using it so that fusions do not interfere with aten planned comm-compute overlap.

There's definitely some similarity with the `with_effects` hop. Talked with @angelayi  - when @zou3519  is back we will figure out how we want to consolidate.

The implementation needs to be a subgraph (as opposed to `with_effects`) because inductor relies on `V.graph.current_node`. Changing the signature of the node with `with_effects`  breaks this, and additionally, also breaks striding constraints on the wrapped node - see this [TODO](aed66248a0/torch/fx/experimental/proxy_tensor.py (L1246-L1249)). By maintaining the node with its original calling structure in subgraph this all works.

Example transformation:

Before:
```
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 1), kwargs = {})
%mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %arg1_1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {})
```
After:
```
add: "f32[256, 256]" = torch.ops.aten.add.Tensor(arg0_1, 1)
mm: "f32[256, 256]" = torch.ops.higher_order.control_deps((add,), subgraph_mm, arg1_1, arg1_1)
mul: "f32[256, 256]" = torch.ops.higher_order.control_deps((mm,), subgraph_mul, add)
```

The mm operation now explicitly depends on add completing first, and mul depends on mm, with original operations preserved in subgraphs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164568
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
2025-10-06 15:47:55 +00:00

79 lines
2.5 KiB
Python

# Owner(s): ["module: inductor"]
import torch
from torch._inductor import config
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
HAS_CUDA_AND_TRITON,
requires_gpu,
)
class TestControlDeps(InductorTestCase):
@config.patch(reorder_for_locality=False)
@requires_gpu()
def test_control_deps_prevents_fusion(self):
def fn(a, b):
c = a + 1
d = b @ b
e = c * 2
return d, e
# Custom pass to add control dependency from d -> c
def add_control_deps(graph):
nodes = list(graph.nodes)
nodes = [n for n in graph.nodes if n.op == "call_function"]
assert len(nodes) == 3
c_node = nodes[0]
d_node = nodes[1]
e_node = nodes[2]
assert d_node.target == torch.ops.aten.mm.default
from torch.utils._ordered_set import OrderedSet
deps_map = {d_node: OrderedSet([c_node]), e_node: OrderedSet([d_node])}
torch._inductor.fx_passes.control_dependencies.preserve_node_ordering(
graph, deps_map
)
sub_g = graph.find_nodes(
op="call_function", target=torch.ops.higher_order.control_deps
)
assert len(sub_g) == 2
assert list(sub_g[0].meta["val"].shape) == [256, 256]
assert list(sub_g[1].meta["val"].shape) == [256, 256]
for attr in graph.find_nodes(op="get_attr"):
for n in getattr(graph.owning_module, attr.target).graph.nodes:
assert list(n.meta["val"].shape) == [256, 256]
return graph
with torch._inductor.config.patch(
post_grad_custom_post_pass=add_control_deps,
):
compiled_fn = torch.compile(fn)
a = torch.rand([256, 256], device=GPU_TYPE)
b = torch.rand([256, 256], device=GPU_TYPE)
_, code = run_and_get_code(torch.compile(fn), a, b)
result = compiled_fn(a, b)
FileCheck().check(".run(").check("extern_kernels.mm(").check(".run(").run(
code[0]
)
expected = fn(a, b)
torch.testing.assert_close(result, expected)
if __name__ == "__main__":
if IS_LINUX and HAS_CUDA_AND_TRITON:
run_tests(needs="filelock")