Add hop for additional control dependencies (#164568)

Adds [control_deps](https://en.wikipedia.org/wiki/Control_dependency) higher-order operator to enforce explicit scheduling dependencies in FX graphs. This prevents unwanted operation reordering/fusion by giving nodes additional dependencies, which we also respect in inductor by adding weakdeps on the additional dependencies.

This can be generally useful (such as for ordering collectives) but in this case I am using it so that fusions do not interfere with aten planned comm-compute overlap.

There's definitely some similarity with the `with_effects` hop. Talked with @angelayi  - when @zou3519  is back we will figure out how we want to consolidate.

The implementation needs to be a subgraph (as opposed to `with_effects`) because inductor relies on `V.graph.current_node`. Changing the signature of the node with `with_effects`  breaks this, and additionally, also breaks striding constraints on the wrapped node - see this [TODO](aed66248a0/torch/fx/experimental/proxy_tensor.py (L1246-L1249)). By maintaining the node with its original calling structure in subgraph this all works.

Example transformation:

Before:
```
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 1), kwargs = {})
%mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %arg1_1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {})
```
After:
```
add: "f32[256, 256]" = torch.ops.aten.add.Tensor(arg0_1, 1)
mm: "f32[256, 256]" = torch.ops.higher_order.control_deps((add,), subgraph_mm, arg1_1, arg1_1)
mul: "f32[256, 256]" = torch.ops.higher_order.control_deps((mm,), subgraph_mul, add)
```

The mm operation now explicitly depends on add completing first, and mul depends on mm, with original operations preserved in subgraphs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164568
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
This commit is contained in:
eellison
2025-10-03 08:09:51 -07:00
committed by PyTorch MergeBot
parent 600267ea56
commit 4a39820e5e
5 changed files with 367 additions and 0 deletions

View File

@ -2684,6 +2684,10 @@ class Scheduler:
)
add_user(other_name, node, is_weak=True)
for add_dep in V.graph.additional_buffer_deps[buf.get_name()]:
add_user(add_dep, node, is_weak=True)
node.add_fake_dep(WeakDep(add_dep, node.get_name()))
# add normal non-mutation dependencies
for read in node.read_writes.reads:
if not isinstance(read, WeakDep):