Add hop for additional control dependencies (#164568)

Adds [control_deps](https://en.wikipedia.org/wiki/Control_dependency) higher-order operator to enforce explicit scheduling dependencies in FX graphs. This prevents unwanted operation reordering/fusion by giving nodes additional dependencies, which we also respect in inductor by adding weakdeps on the additional dependencies.

This can be generally useful (such as for ordering collectives) but in this case I am using it so that fusions do not interfere with aten planned comm-compute overlap.

There's definitely some similarity with the `with_effects` hop. Talked with @angelayi  - when @zou3519  is back we will figure out how we want to consolidate.

The implementation needs to be a subgraph (as opposed to `with_effects`) because inductor relies on `V.graph.current_node`. Changing the signature of the node with `with_effects`  breaks this, and additionally, also breaks striding constraints on the wrapped node - see this [TODO](aed66248a0/torch/fx/experimental/proxy_tensor.py (L1246-L1249)). By maintaining the node with its original calling structure in subgraph this all works.

Example transformation:

Before:
```
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 1), kwargs = {})
%mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %arg1_1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {})
```
After:
```
add: "f32[256, 256]" = torch.ops.aten.add.Tensor(arg0_1, 1)
mm: "f32[256, 256]" = torch.ops.higher_order.control_deps((add,), subgraph_mm, arg1_1, arg1_1)
mul: "f32[256, 256]" = torch.ops.higher_order.control_deps((mm,), subgraph_mul, add)
```

The mm operation now explicitly depends on add completing first, and mul depends on mm, with original operations preserved in subgraphs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164568
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
This commit is contained in:
eellison
2025-10-03 08:09:51 -07:00
committed by PyTorch MergeBot
parent 600267ea56
commit 4a39820e5e
5 changed files with 367 additions and 0 deletions

View File

@ -0,0 +1,78 @@
# Owner(s): ["module: inductor"]
import torch
from torch._inductor import config
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
HAS_CUDA_AND_TRITON,
requires_gpu,
)
class TestControlDeps(InductorTestCase):
@config.patch(reorder_for_locality=False)
@requires_gpu()
def test_control_deps_prevents_fusion(self):
def fn(a, b):
c = a + 1
d = b @ b
e = c * 2
return d, e
# Custom pass to add control dependency from d -> c
def add_control_deps(graph):
nodes = list(graph.nodes)
nodes = [n for n in graph.nodes if n.op == "call_function"]
assert len(nodes) == 3
c_node = nodes[0]
d_node = nodes[1]
e_node = nodes[2]
assert d_node.target == torch.ops.aten.mm.default
from torch.utils._ordered_set import OrderedSet
deps_map = {d_node: OrderedSet([c_node]), e_node: OrderedSet([d_node])}
torch._inductor.fx_passes.control_dependencies.preserve_node_ordering(
graph, deps_map
)
sub_g = graph.find_nodes(
op="call_function", target=torch.ops.higher_order.control_deps
)
assert len(sub_g) == 2
assert list(sub_g[0].meta["val"].shape) == [256, 256]
assert list(sub_g[1].meta["val"].shape) == [256, 256]
for attr in graph.find_nodes(op="get_attr"):
for n in getattr(graph.owning_module, attr.target).graph.nodes:
assert list(n.meta["val"].shape) == [256, 256]
return graph
with torch._inductor.config.patch(
post_grad_custom_post_pass=add_control_deps,
):
compiled_fn = torch.compile(fn)
a = torch.rand([256, 256], device=GPU_TYPE)
b = torch.rand([256, 256], device=GPU_TYPE)
_, code = run_and_get_code(torch.compile(fn), a, b)
result = compiled_fn(a, b)
FileCheck().check(".run(").check("extern_kernels.mm(").check(".run(").run(
code[0]
)
expected = fn(a, b)
torch.testing.assert_close(result, expected)
if __name__ == "__main__":
if IS_LINUX and HAS_CUDA_AND_TRITON:
run_tests(needs="filelock")

View File

@ -0,0 +1,226 @@
# mypy: allow-untyped-defs
"""
Effect ordering pass for inductor.
This pass adds ordering dependencies to FX graphs using the control_deps HOP
for precise control over scheduling constraints. When you need exact ordering between
operations (e.g., collective_start -> mm -> wait), this pass wraps operations
with control_deps to make dependencies explicit.
"""
from typing import Any
import torch.fx as fx
from torch._higher_order_ops.utils import register_fake
from torch._ops import HigherOrderOperator
from torch.utils._ordered_set import OrderedSet
class ControlDeps(HigherOrderOperator):
"""
Higher-order operator that enforces ordering by making dependencies explicit.
Schema: control_deps(additional_deps, target, *args, **kwargs) -> result
where:
- additional_deps: tuple of tensors that must be computed before this op
- subgraph: GraphModule containing the exact operation to execute
- args/kwargs: arguments for the target function
This ensures all tensors in additional_deps are computed before the target
executes, creating explicit scheduling dependencies.
"""
def __init__(self) -> None:
super().__init__("control_deps")
def __call__(self, additional_deps, subgraph, *args, **kwargs):
"""Call the operator with dependencies and subgraph.
Args:
additional_deps: Tuple of tensors that must be computed first
subgraph: GraphModule containing the exact operation to execute
*args: Arguments to pass to the subgraph
"""
if not isinstance(additional_deps, (tuple, list)):
raise TypeError(
f"additional_deps must be tuple/list, got {type(additional_deps).__name__}"
)
if not (isinstance(subgraph, fx.GraphModule) or callable(subgraph)):
raise TypeError(
f"subgraph must be GraphModule or callable, got {type(subgraph).__name__}"
)
return super().__call__(additional_deps, subgraph, *args, **kwargs)
control_deps = ControlDeps()
# Register fake implementation for tracing
@register_fake(control_deps)
def _(additional_deps, subgraph, *args, **kwargs):
"""Fake tensor implementation - execute the subgraph."""
return subgraph(*args, **kwargs)
def get_subgraph_name(gm: fx.GraphModule, name):
name = f"subgraph_{name}"
if not hasattr(gm, name):
return name
i = 0
while hasattr(gm, f"{name}_{i}"):
i += 1
return f"{name}_{i}"
def preserve_node_ordering(
graph: fx.Graph,
additional_deps_map: dict[fx.Node, OrderedSet[fx.Node]],
verbose: bool = False,
) -> None:
"""
Preserve node ordering using control_deps HOP with subgraph.
This function wraps operations with control_deps that:
1. Makes additional dependencies explicit (first argument)
2. Creates a subgraph internally to preserve the exact original operation
3. Preserves the original node names
Args:
graph: The FX graph to modify
additional_deps_map: Mapping from dependent nodes to their dependencies
verbose: If True, print debug information
"""
if not additional_deps_map:
return
# Track replacements so we can update dependencies
replacements: dict[fx.Node, fx.Node] = {}
# Process each node that needs additional dependencies
for dependent_node, dep_nodes in additional_deps_map.items():
assert dependent_node.op == "call_function", dependent_node.op
original_name = dependent_node.name
original_args = dependent_node.args
original_kwargs = dependent_node.kwargs
original_meta = dependent_node.meta.copy()
updated_dep_nodes = [replacements.get(dep, dep) for dep in dep_nodes]
# Create a subgraph that preserves the exact original operation
subgraph_module = _create_subgraph_for_node(graph, dependent_node)
owning_mod = graph.owning_module
assert owning_mod is not None
subgraph_attr_name = get_subgraph_name(owning_mod, original_name)
setattr(graph.owning_module, subgraph_attr_name, subgraph_module)
# Create control_deps call with:
# 1. Additional dependencies as first arg (explicit)
# 2. Subgraph via get_attr (like b2b gemm pass)
# 3. Original arguments (only fx.Node args and kwargs are passed)
with graph.inserting_before(dependent_node):
# Create get_attr node for the subgraph
get_subgraph = graph.get_attr(subgraph_attr_name)
# add additional args
node_args = [a for a in original_args if isinstance(a, fx.Node)]
for value in original_kwargs.values():
if isinstance(value, fx.Node):
node_args.append(value)
# Create with temporary name first
ordered_node = graph.call_function(
control_deps,
args=(
tuple(updated_dep_nodes), # additional_deps
get_subgraph, # subgraph via get_attr (like b2b gemm)
*node_args, # original node arguments (from both args and kwargs)
),
kwargs={},
name=f"__temp_{original_name}", # Temporary name to avoid conflict
)
# Copy metadata from original node
ordered_node.meta = original_meta
# this will be constrained on the target node in subgraph if it exists
ordered_node.meta.pop("eager_input_vals", None)
# Replace all uses of the original node with the ordered version
dependent_node.replace_all_uses_with(ordered_node)
# Remove the original node from the graph
graph.erase_node(dependent_node)
# Now rename the ordered node to the original name
ordered_node.name = original_name # PRESERVE ORIGINAL NAME
# Track the replacement for future dependencies
replacements[dependent_node] = ordered_node
def _create_subgraph_for_node(graph: fx.Graph, node: fx.Node) -> fx.GraphModule:
"""
Create a subgraph that exactly recreates a node's operation.
The subgraph takes only the fx.Node arguments and recreates the operation
with the exact target, args structure, and kwargs.
Args:
graph: The parent graph
node: The node to wrap in a subgraph
Returns:
A GraphModule containing the subgraph
"""
# Get the owning module
# torch.distributed.breakpoint(0)
owning_module = graph.owning_module
# Create a new graph for the subgraph
subgraph = fx.Graph(owning_module)
new_args: list[Any] = []
placeholder_idx = 0
for _, arg in enumerate(node.args):
if not isinstance(arg, fx.Node):
new_args.append(arg)
continue
placeholder = subgraph.placeholder(f"arg_{placeholder_idx}")
placeholder_idx += 1
if "val" in arg.meta:
placeholder.meta.update(arg.meta)
new_args.append(placeholder) # type: ignore[arg-type]
new_kwargs: dict[str, Any] = {}
for key, value in node.kwargs.items():
if not isinstance(value, fx.Node):
new_kwargs[key] = value
continue
placeholder = subgraph.placeholder(f"kwarg_{key}")
if "val" in value.meta:
placeholder.meta.update(value.meta)
new_kwargs[key] = placeholder # type: ignore[assignment]
# Recreate the exact original operation in the subgraph
assert callable(node.target)
result = subgraph.call_function(
node.target,
tuple(new_args),
new_kwargs, # type: ignore[arg-type]
)
# Copy metadata from the original node
result.meta.update(node.meta)
out = subgraph.output(result)
if "val" in result.meta:
out.meta["val"] = result.meta["val"]
return fx.GraphModule(owning_module, subgraph)

View File

@ -385,6 +385,9 @@ class GraphLowering(torch.fx.Interpreter):
const_module.device_idxs if const_module else OrderedSet()
)
self.device_type = "cpu"
self.additional_buffer_deps: dict[str, OrderedSet[str]] = defaultdict(
OrderedSet
)
# Inplace padding may require Inductor to allocate slightly larger
# tensor for padding.

View File

@ -7226,6 +7226,62 @@ def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, *operands):
return list(map(TensorBox.create, result)) # type: ignore[call-overload]
# Import the control_deps_op HOP for lowering
from torch._inductor.fx_passes.control_dependencies import control_deps
@register_lowering(control_deps, type_promotion_kind=None)
def control_deps_op_lowering(additional_deps, subgraph_fn, *args):
"""
Lower control_deps_op by ensuring dependencies are realized and tracking them.
The control_deps_op HOP makes dependencies explicit in the graph. During lowering:
1. Realize all additional dependencies to ensure they're computed
2. Execute the target operation normally
3. Track the dependencies for the scheduler
"""
# Realize all additional dependencies
dep_names = []
for dep in additional_deps:
if not isinstance(dep, IRNode):
continue
dep.realize()
dep_names.append(dep.get_name())
original_args = V.graph.current_node.args
arg_offset = 2 # first two args (additional_deps, subgraph)
assert len(args) + arg_offset == len(original_args)
output = None
assert len(subgraph_fn.graph_module.graph.find_nodes(op="placeholder")) == len(args)
for i, node in enumerate(subgraph_fn.graph_module.graph.nodes):
if node.op == "placeholder":
V.graph.env[node] = args[i]
continue
elif node.op == "output":
args, kwargs = V.graph.fetch_args_kwargs_from_env(node)
output = torch.fx.Interpreter.output(V.graph, node, args, kwargs)
else:
V.graph.env[node] = V.graph.run_node(node)
assert output is not None and additional_deps
output_list = output if isinstance(output, (list, tuple)) else [output]
for out in output_list:
if not isinstance(out, IRNode):
continue
# need to realize in order to add the dep
out.realize()
out_name = out.get_name()
for dep_name in dep_names:
V.graph.additional_buffer_deps[out_name].add(dep_name)
return output
@register_lowering(torch._higher_order_ops.invoke_quant, type_promotion_kind=None)
def invoke_quant_tracer(subgraph_fn: ir.Subgraph, *operands, scheme=None):
output = None

View File

@ -2684,6 +2684,10 @@ class Scheduler:
)
add_user(other_name, node, is_weak=True)
for add_dep in V.graph.additional_buffer_deps[buf.get_name()]:
add_user(add_dep, node, is_weak=True)
node.add_fake_dep(WeakDep(add_dep, node.get_name()))
# add normal non-mutation dependencies
for read in node.read_writes.reads:
if not isinstance(read, WeakDep):