mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add hop for additional control dependencies (#164568)
Adds [control_deps](https://en.wikipedia.org/wiki/Control_dependency) higher-order operator to enforce explicit scheduling dependencies in FX graphs. This prevents unwanted operation reordering/fusion by giving nodes additional dependencies, which we also respect in inductor by adding weakdeps on the additional dependencies.
This can be generally useful (such as for ordering collectives) but in this case I am using it so that fusions do not interfere with aten planned comm-compute overlap.
There's definitely some similarity with the `with_effects` hop. Talked with @angelayi - when @zou3519 is back we will figure out how we want to consolidate.
The implementation needs to be a subgraph (as opposed to `with_effects`) because inductor relies on `V.graph.current_node`. Changing the signature of the node with `with_effects` breaks this, and additionally, also breaks striding constraints on the wrapped node - see this [TODO](aed66248a0/torch/fx/experimental/proxy_tensor.py (L1246-L1249)
). By maintaining the node with its original calling structure in subgraph this all works.
Example transformation:
Before:
```
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 1), kwargs = {})
%mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %arg1_1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {})
```
After:
```
add: "f32[256, 256]" = torch.ops.aten.add.Tensor(arg0_1, 1)
mm: "f32[256, 256]" = torch.ops.higher_order.control_deps((add,), subgraph_mm, arg1_1, arg1_1)
mul: "f32[256, 256]" = torch.ops.higher_order.control_deps((mm,), subgraph_mul, add)
```
The mm operation now explicitly depends on add completing first, and mul depends on mm, with original operations preserved in subgraphs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164568
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
This commit is contained in:
committed by
PyTorch MergeBot
parent
600267ea56
commit
4a39820e5e
@ -2684,6 +2684,10 @@ class Scheduler:
|
||||
)
|
||||
add_user(other_name, node, is_weak=True)
|
||||
|
||||
for add_dep in V.graph.additional_buffer_deps[buf.get_name()]:
|
||||
add_user(add_dep, node, is_weak=True)
|
||||
node.add_fake_dep(WeakDep(add_dep, node.get_name()))
|
||||
|
||||
# add normal non-mutation dependencies
|
||||
for read in node.read_writes.reads:
|
||||
if not isinstance(read, WeakDep):
|
||||
|
Reference in New Issue
Block a user