mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add hop for additional control dependencies (#164568)
Adds [control_deps](https://en.wikipedia.org/wiki/Control_dependency) higher-order operator to enforce explicit scheduling dependencies in FX graphs. This prevents unwanted operation reordering/fusion by giving nodes additional dependencies, which we also respect in inductor by adding weakdeps on the additional dependencies.
This can be generally useful (such as for ordering collectives) but in this case I am using it so that fusions do not interfere with aten planned comm-compute overlap.
There's definitely some similarity with the `with_effects` hop. Talked with @angelayi - when @zou3519 is back we will figure out how we want to consolidate.
The implementation needs to be a subgraph (as opposed to `with_effects`) because inductor relies on `V.graph.current_node`. Changing the signature of the node with `with_effects` breaks this, and additionally, also breaks striding constraints on the wrapped node - see this [TODO](aed66248a0/torch/fx/experimental/proxy_tensor.py (L1246-L1249)
). By maintaining the node with its original calling structure in subgraph this all works.
Example transformation:
Before:
```
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, 1), kwargs = {})
%mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %arg1_1), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, 2), kwargs = {})
```
After:
```
add: "f32[256, 256]" = torch.ops.aten.add.Tensor(arg0_1, 1)
mm: "f32[256, 256]" = torch.ops.higher_order.control_deps((add,), subgraph_mm, arg1_1, arg1_1)
mul: "f32[256, 256]" = torch.ops.higher_order.control_deps((mm,), subgraph_mul, add)
```
The mm operation now explicitly depends on add completing first, and mul depends on mm, with original operations preserved in subgraphs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164568
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
This commit is contained in:
committed by
PyTorch MergeBot
parent
600267ea56
commit
4a39820e5e
78
test/inductor/test_control_deps.py
Normal file
78
test/inductor/test_control_deps.py
Normal file
@ -0,0 +1,78 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import IS_LINUX
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
GPU_TYPE,
|
||||
HAS_CUDA_AND_TRITON,
|
||||
requires_gpu,
|
||||
)
|
||||
|
||||
|
||||
class TestControlDeps(InductorTestCase):
|
||||
@config.patch(reorder_for_locality=False)
|
||||
@requires_gpu()
|
||||
def test_control_deps_prevents_fusion(self):
|
||||
def fn(a, b):
|
||||
c = a + 1
|
||||
d = b @ b
|
||||
e = c * 2
|
||||
return d, e
|
||||
|
||||
# Custom pass to add control dependency from d -> c
|
||||
def add_control_deps(graph):
|
||||
nodes = list(graph.nodes)
|
||||
|
||||
nodes = [n for n in graph.nodes if n.op == "call_function"]
|
||||
assert len(nodes) == 3
|
||||
c_node = nodes[0]
|
||||
d_node = nodes[1]
|
||||
e_node = nodes[2]
|
||||
|
||||
assert d_node.target == torch.ops.aten.mm.default
|
||||
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
deps_map = {d_node: OrderedSet([c_node]), e_node: OrderedSet([d_node])}
|
||||
torch._inductor.fx_passes.control_dependencies.preserve_node_ordering(
|
||||
graph, deps_map
|
||||
)
|
||||
sub_g = graph.find_nodes(
|
||||
op="call_function", target=torch.ops.higher_order.control_deps
|
||||
)
|
||||
assert len(sub_g) == 2
|
||||
|
||||
assert list(sub_g[0].meta["val"].shape) == [256, 256]
|
||||
assert list(sub_g[1].meta["val"].shape) == [256, 256]
|
||||
|
||||
for attr in graph.find_nodes(op="get_attr"):
|
||||
for n in getattr(graph.owning_module, attr.target).graph.nodes:
|
||||
assert list(n.meta["val"].shape) == [256, 256]
|
||||
|
||||
return graph
|
||||
|
||||
with torch._inductor.config.patch(
|
||||
post_grad_custom_post_pass=add_control_deps,
|
||||
):
|
||||
compiled_fn = torch.compile(fn)
|
||||
a = torch.rand([256, 256], device=GPU_TYPE)
|
||||
b = torch.rand([256, 256], device=GPU_TYPE)
|
||||
|
||||
_, code = run_and_get_code(torch.compile(fn), a, b)
|
||||
result = compiled_fn(a, b)
|
||||
|
||||
FileCheck().check(".run(").check("extern_kernels.mm(").check(".run(").run(
|
||||
code[0]
|
||||
)
|
||||
|
||||
expected = fn(a, b)
|
||||
torch.testing.assert_close(result, expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if IS_LINUX and HAS_CUDA_AND_TRITON:
|
||||
run_tests(needs="filelock")
|
226
torch/_inductor/fx_passes/control_dependencies.py
Normal file
226
torch/_inductor/fx_passes/control_dependencies.py
Normal file
@ -0,0 +1,226 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""
|
||||
Effect ordering pass for inductor.
|
||||
|
||||
This pass adds ordering dependencies to FX graphs using the control_deps HOP
|
||||
for precise control over scheduling constraints. When you need exact ordering between
|
||||
operations (e.g., collective_start -> mm -> wait), this pass wraps operations
|
||||
with control_deps to make dependencies explicit.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch.fx as fx
|
||||
from torch._higher_order_ops.utils import register_fake
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
class ControlDeps(HigherOrderOperator):
|
||||
"""
|
||||
Higher-order operator that enforces ordering by making dependencies explicit.
|
||||
|
||||
Schema: control_deps(additional_deps, target, *args, **kwargs) -> result
|
||||
where:
|
||||
- additional_deps: tuple of tensors that must be computed before this op
|
||||
- subgraph: GraphModule containing the exact operation to execute
|
||||
- args/kwargs: arguments for the target function
|
||||
|
||||
This ensures all tensors in additional_deps are computed before the target
|
||||
executes, creating explicit scheduling dependencies.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__("control_deps")
|
||||
|
||||
def __call__(self, additional_deps, subgraph, *args, **kwargs):
|
||||
"""Call the operator with dependencies and subgraph.
|
||||
|
||||
Args:
|
||||
additional_deps: Tuple of tensors that must be computed first
|
||||
subgraph: GraphModule containing the exact operation to execute
|
||||
*args: Arguments to pass to the subgraph
|
||||
"""
|
||||
if not isinstance(additional_deps, (tuple, list)):
|
||||
raise TypeError(
|
||||
f"additional_deps must be tuple/list, got {type(additional_deps).__name__}"
|
||||
)
|
||||
if not (isinstance(subgraph, fx.GraphModule) or callable(subgraph)):
|
||||
raise TypeError(
|
||||
f"subgraph must be GraphModule or callable, got {type(subgraph).__name__}"
|
||||
)
|
||||
return super().__call__(additional_deps, subgraph, *args, **kwargs)
|
||||
|
||||
|
||||
control_deps = ControlDeps()
|
||||
|
||||
|
||||
# Register fake implementation for tracing
|
||||
@register_fake(control_deps)
|
||||
def _(additional_deps, subgraph, *args, **kwargs):
|
||||
"""Fake tensor implementation - execute the subgraph."""
|
||||
return subgraph(*args, **kwargs)
|
||||
|
||||
|
||||
def get_subgraph_name(gm: fx.GraphModule, name):
|
||||
name = f"subgraph_{name}"
|
||||
|
||||
if not hasattr(gm, name):
|
||||
return name
|
||||
|
||||
i = 0
|
||||
while hasattr(gm, f"{name}_{i}"):
|
||||
i += 1
|
||||
|
||||
return f"{name}_{i}"
|
||||
|
||||
|
||||
def preserve_node_ordering(
|
||||
graph: fx.Graph,
|
||||
additional_deps_map: dict[fx.Node, OrderedSet[fx.Node]],
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Preserve node ordering using control_deps HOP with subgraph.
|
||||
|
||||
This function wraps operations with control_deps that:
|
||||
1. Makes additional dependencies explicit (first argument)
|
||||
2. Creates a subgraph internally to preserve the exact original operation
|
||||
3. Preserves the original node names
|
||||
|
||||
Args:
|
||||
graph: The FX graph to modify
|
||||
additional_deps_map: Mapping from dependent nodes to their dependencies
|
||||
verbose: If True, print debug information
|
||||
"""
|
||||
if not additional_deps_map:
|
||||
return
|
||||
|
||||
# Track replacements so we can update dependencies
|
||||
replacements: dict[fx.Node, fx.Node] = {}
|
||||
|
||||
# Process each node that needs additional dependencies
|
||||
for dependent_node, dep_nodes in additional_deps_map.items():
|
||||
assert dependent_node.op == "call_function", dependent_node.op
|
||||
|
||||
original_name = dependent_node.name
|
||||
original_args = dependent_node.args
|
||||
original_kwargs = dependent_node.kwargs
|
||||
original_meta = dependent_node.meta.copy()
|
||||
|
||||
updated_dep_nodes = [replacements.get(dep, dep) for dep in dep_nodes]
|
||||
|
||||
# Create a subgraph that preserves the exact original operation
|
||||
subgraph_module = _create_subgraph_for_node(graph, dependent_node)
|
||||
|
||||
owning_mod = graph.owning_module
|
||||
assert owning_mod is not None
|
||||
subgraph_attr_name = get_subgraph_name(owning_mod, original_name)
|
||||
setattr(graph.owning_module, subgraph_attr_name, subgraph_module)
|
||||
|
||||
# Create control_deps call with:
|
||||
# 1. Additional dependencies as first arg (explicit)
|
||||
# 2. Subgraph via get_attr (like b2b gemm pass)
|
||||
# 3. Original arguments (only fx.Node args and kwargs are passed)
|
||||
with graph.inserting_before(dependent_node):
|
||||
# Create get_attr node for the subgraph
|
||||
get_subgraph = graph.get_attr(subgraph_attr_name)
|
||||
|
||||
# add additional args
|
||||
node_args = [a for a in original_args if isinstance(a, fx.Node)]
|
||||
for value in original_kwargs.values():
|
||||
if isinstance(value, fx.Node):
|
||||
node_args.append(value)
|
||||
|
||||
# Create with temporary name first
|
||||
ordered_node = graph.call_function(
|
||||
control_deps,
|
||||
args=(
|
||||
tuple(updated_dep_nodes), # additional_deps
|
||||
get_subgraph, # subgraph via get_attr (like b2b gemm)
|
||||
*node_args, # original node arguments (from both args and kwargs)
|
||||
),
|
||||
kwargs={},
|
||||
name=f"__temp_{original_name}", # Temporary name to avoid conflict
|
||||
)
|
||||
|
||||
# Copy metadata from original node
|
||||
ordered_node.meta = original_meta
|
||||
# this will be constrained on the target node in subgraph if it exists
|
||||
ordered_node.meta.pop("eager_input_vals", None)
|
||||
|
||||
# Replace all uses of the original node with the ordered version
|
||||
dependent_node.replace_all_uses_with(ordered_node)
|
||||
|
||||
# Remove the original node from the graph
|
||||
graph.erase_node(dependent_node)
|
||||
|
||||
# Now rename the ordered node to the original name
|
||||
ordered_node.name = original_name # PRESERVE ORIGINAL NAME
|
||||
|
||||
# Track the replacement for future dependencies
|
||||
replacements[dependent_node] = ordered_node
|
||||
|
||||
|
||||
def _create_subgraph_for_node(graph: fx.Graph, node: fx.Node) -> fx.GraphModule:
|
||||
"""
|
||||
Create a subgraph that exactly recreates a node's operation.
|
||||
|
||||
The subgraph takes only the fx.Node arguments and recreates the operation
|
||||
with the exact target, args structure, and kwargs.
|
||||
|
||||
Args:
|
||||
graph: The parent graph
|
||||
node: The node to wrap in a subgraph
|
||||
|
||||
Returns:
|
||||
A GraphModule containing the subgraph
|
||||
"""
|
||||
# Get the owning module
|
||||
# torch.distributed.breakpoint(0)
|
||||
owning_module = graph.owning_module
|
||||
|
||||
# Create a new graph for the subgraph
|
||||
subgraph = fx.Graph(owning_module)
|
||||
|
||||
new_args: list[Any] = []
|
||||
placeholder_idx = 0
|
||||
for _, arg in enumerate(node.args):
|
||||
if not isinstance(arg, fx.Node):
|
||||
new_args.append(arg)
|
||||
continue
|
||||
|
||||
placeholder = subgraph.placeholder(f"arg_{placeholder_idx}")
|
||||
placeholder_idx += 1
|
||||
if "val" in arg.meta:
|
||||
placeholder.meta.update(arg.meta)
|
||||
new_args.append(placeholder) # type: ignore[arg-type]
|
||||
|
||||
new_kwargs: dict[str, Any] = {}
|
||||
for key, value in node.kwargs.items():
|
||||
if not isinstance(value, fx.Node):
|
||||
new_kwargs[key] = value
|
||||
continue
|
||||
|
||||
placeholder = subgraph.placeholder(f"kwarg_{key}")
|
||||
if "val" in value.meta:
|
||||
placeholder.meta.update(value.meta)
|
||||
|
||||
new_kwargs[key] = placeholder # type: ignore[assignment]
|
||||
|
||||
# Recreate the exact original operation in the subgraph
|
||||
assert callable(node.target)
|
||||
result = subgraph.call_function(
|
||||
node.target,
|
||||
tuple(new_args),
|
||||
new_kwargs, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Copy metadata from the original node
|
||||
result.meta.update(node.meta)
|
||||
|
||||
out = subgraph.output(result)
|
||||
if "val" in result.meta:
|
||||
out.meta["val"] = result.meta["val"]
|
||||
|
||||
return fx.GraphModule(owning_module, subgraph)
|
@ -385,6 +385,9 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
const_module.device_idxs if const_module else OrderedSet()
|
||||
)
|
||||
self.device_type = "cpu"
|
||||
self.additional_buffer_deps: dict[str, OrderedSet[str]] = defaultdict(
|
||||
OrderedSet
|
||||
)
|
||||
|
||||
# Inplace padding may require Inductor to allocate slightly larger
|
||||
# tensor for padding.
|
||||
|
@ -7226,6 +7226,62 @@ def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, *operands):
|
||||
return list(map(TensorBox.create, result)) # type: ignore[call-overload]
|
||||
|
||||
|
||||
# Import the control_deps_op HOP for lowering
|
||||
from torch._inductor.fx_passes.control_dependencies import control_deps
|
||||
|
||||
|
||||
@register_lowering(control_deps, type_promotion_kind=None)
|
||||
def control_deps_op_lowering(additional_deps, subgraph_fn, *args):
|
||||
"""
|
||||
Lower control_deps_op by ensuring dependencies are realized and tracking them.
|
||||
|
||||
The control_deps_op HOP makes dependencies explicit in the graph. During lowering:
|
||||
1. Realize all additional dependencies to ensure they're computed
|
||||
2. Execute the target operation normally
|
||||
3. Track the dependencies for the scheduler
|
||||
"""
|
||||
# Realize all additional dependencies
|
||||
dep_names = []
|
||||
for dep in additional_deps:
|
||||
if not isinstance(dep, IRNode):
|
||||
continue
|
||||
|
||||
dep.realize()
|
||||
dep_names.append(dep.get_name())
|
||||
|
||||
original_args = V.graph.current_node.args
|
||||
arg_offset = 2 # first two args (additional_deps, subgraph)
|
||||
assert len(args) + arg_offset == len(original_args)
|
||||
|
||||
output = None
|
||||
|
||||
assert len(subgraph_fn.graph_module.graph.find_nodes(op="placeholder")) == len(args)
|
||||
for i, node in enumerate(subgraph_fn.graph_module.graph.nodes):
|
||||
if node.op == "placeholder":
|
||||
V.graph.env[node] = args[i]
|
||||
continue
|
||||
elif node.op == "output":
|
||||
args, kwargs = V.graph.fetch_args_kwargs_from_env(node)
|
||||
output = torch.fx.Interpreter.output(V.graph, node, args, kwargs)
|
||||
else:
|
||||
V.graph.env[node] = V.graph.run_node(node)
|
||||
|
||||
assert output is not None and additional_deps
|
||||
output_list = output if isinstance(output, (list, tuple)) else [output]
|
||||
|
||||
for out in output_list:
|
||||
if not isinstance(out, IRNode):
|
||||
continue
|
||||
|
||||
# need to realize in order to add the dep
|
||||
out.realize()
|
||||
out_name = out.get_name()
|
||||
for dep_name in dep_names:
|
||||
V.graph.additional_buffer_deps[out_name].add(dep_name)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@register_lowering(torch._higher_order_ops.invoke_quant, type_promotion_kind=None)
|
||||
def invoke_quant_tracer(subgraph_fn: ir.Subgraph, *operands, scheme=None):
|
||||
output = None
|
||||
|
@ -2684,6 +2684,10 @@ class Scheduler:
|
||||
)
|
||||
add_user(other_name, node, is_weak=True)
|
||||
|
||||
for add_dep in V.graph.additional_buffer_deps[buf.get_name()]:
|
||||
add_user(add_dep, node, is_weak=True)
|
||||
node.add_fake_dep(WeakDep(add_dep, node.get_name()))
|
||||
|
||||
# add normal non-mutation dependencies
|
||||
for read in node.read_writes.reads:
|
||||
if not isinstance(read, WeakDep):
|
||||
|
Reference in New Issue
Block a user