[inductor] do comm compute overlap at aten fx level (#163215)

This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing.

Subsequent prs will handle:
- use of exposure analysis to do bucketing
- make sure inductor respects comm/compute overlapping done at fx level
- non-profiling mm estimation/rank broadcasting of profile results

Other mis:
- Validate accuracy of nccl estimations  ( use ruisi's profiling instead ?)

For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives.

fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3

bwd example: https://gist.github.com/eellison/6cfc2285df53a94cfa4012f5fdae5c51

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163215
Approved by: https://github.com/IvanKobzarev
This commit is contained in:
eellison
2025-09-29 15:36:43 -07:00
committed by PyTorch MergeBot
parent c39357bab6
commit 0d7994ca97
8 changed files with 1104 additions and 10 deletions

View File

@ -26,6 +26,7 @@ if [[ "${SHARD_NUMBER:-2}" == "2" ]]; then
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_gloo time python test/run_test.py --verbose -i distributed/test_c10d_spawn_gloo
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_nccl time python test/run_test.py --verbose -i distributed/test_c10d_spawn_nccl
time python test/run_test.py --verbose -i distributed/test_compute_comm_reordering time python test/run_test.py --verbose -i distributed/test_compute_comm_reordering
time python test/run_test.py --verbose -i distributed/test_aten_comm_compute_reordering
time python test/run_test.py --verbose -i distributed/test_store time python test/run_test.py --verbose -i distributed/test_store
time python test/run_test.py --verbose -i distributed/test_symmetric_memory time python test/run_test.py --verbose -i distributed/test_symmetric_memory
time python test/run_test.py --verbose -i distributed/test_pg_wrapper time python test/run_test.py --verbose -i distributed/test_pg_wrapper

View File

@ -435,7 +435,7 @@ test_inductor_distributed() {
# this runs on both single-gpu and multi-gpu instance. It should be smart about skipping tests that aren't supported # this runs on both single-gpu and multi-gpu instance. It should be smart about skipping tests that aren't supported
# with if required # gpus aren't available # with if required # gpus aren't available
python test/run_test.py --include distributed/test_dynamo_distributed distributed/test_inductor_collectives distributed/test_compute_comm_reordering --verbose python test/run_test.py --include distributed/test_dynamo_distributed distributed/test_inductor_collectives distributed/test_aten_comm_compute_reordering distributed/test_compute_comm_reordering --verbose
assert_git_not_dirty assert_git_not_dirty
} }

View File

@ -0,0 +1,353 @@
# flake8: noqa: B950
# Owner(s): ["module: inductor"]
import unittest
from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.logging
import torch._dynamo.test_case
# for some reason importing functional collectives after dynamo breaks collectives handling!
import torch.distributed._functional_collectives as _functional_collectives
from torch._C import FileCheck
from torch._dynamo.utils import counters, same
from torch._inductor.utils import run_and_get_triton_code
from torch.testing._internal.common_distributed import (
_dynamo_dist_per_rank_init,
at_least_x_gpu,
DynamoDistributedMultiProcTestCase,
requires_accelerator_dist_backend,
)
aten = torch.ops.aten
import functools
from torch.testing._internal.common_fsdp import get_devtype
from torch.testing._internal.common_utils import skipIfRocm
from torch.testing._internal.inductor_utils import HAS_GPU
def estimate_aten_runtime(fx_node):
# for tests, assume a matmul can hide a single collective
if "c10" in str(fx_node.target):
return 1.0
elif fx_node.target == aten.mm.default:
return 1.0
else:
return None
device_type = str(get_devtype())
def apply_reordering_and_get_graph(graph, out_li) -> None:
gm = graph.owning_module
from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing
schedule_overlap_bucketing(gm)
gm.graph.lint()
out_li.append(str(gm.graph))
def run_and_get_aten_graph(fn, *inputs):
li = []
apply = functools.partial(apply_reordering_and_get_graph, out_li=li)
with torch._inductor.config.patch(post_grad_custom_post_pass=apply):
out = fn(*inputs)
return out, li[0]
def get_patches():
return {
"test_configs.estimate_aten_runtime": estimate_aten_runtime,
"reorder_for_locality": False,
"reorder_for_compute_comm_overlap_passes": [],
"compile_threads": 1,
"force_disable_caches": True,
}
@requires_accelerator_dist_backend()
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
"""
Run correctness checks in multi-proc runner, mark with minimum # GPUs to run under
Note: these tests are a fork of test/distributed/test_compute_comm_reordering.py
"""
def setUp(self):
super().setUp()
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
def get_world_trs(self):
return {
"tag": "",
"ranks": list(range(self.world_size)),
"group_size": self.world_size,
}
@property
def world_size(self) -> int:
# hack: no matter whether we have 2 or 3 or 4 gpus, just run on 2
# works around issue with skipif<2 and workers with unpredictable #s gpu
return 2
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_patches())
def test_sink_waits(self):
def func(a):
ar = _functional_collectives.all_reduce(a, "sum", "0")
b = torch.matmul(a, a)
return torch.matmul(ar, b)
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
out, aten_graph_str = run_and_get_aten_graph(torch.compile(func), inputs)
# Verify that the wait_tensor is sinked below the 1st matmul but
# above the 2nd matmul.
(
FileCheck()
.check("all_reduce.default")
.check("aten.mm.default")
.check("wait_tensor.default")
.check("aten.mm.default")
.run(aten_graph_str)
)
correct = func(inputs)
self.assertTrue(same(out, correct))
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0)
@torch._inductor.config.patch(get_patches())
def test_raise_comms(self):
def func(a):
b = torch.matmul(a, a)
c = torch.relu(b)
d = torch.matmul(c, c)
e = _functional_collectives.all_reduce((b + 1), "sum", "0")
return torch.matmul(d, e)
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
compiled = torch.compile(func)
out, aten_graph_str = run_and_get_aten_graph(torch.compile(func), inputs)
# Verify that the all_reduce_ has been raised above the 2nd matmul
# but below the 1st matmul. Note that the all_reduce_ directly
# writes to the output buffer of the 1st matmul, which is an input
# to the first relu. Therefore, the all_reduce_ should be scheduled
# after the first relu.
(
FileCheck()
.check("aten.mm")
.check("all_reduce.default")
.check("aten.mm")
.check("wait_tensor.default")
.check("aten.mm")
.run(aten_graph_str)
)
out = compiled(inputs)
correct = func(inputs)
self.assertTrue(same(out, correct))
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0)
@torch._inductor.config.patch(get_patches())
def test_sink_waits_raise_comms(self):
def func(a, *, tag, ranks, group_size):
b = torch.matmul(a, a)
c = torch.relu(b)
d = torch.matmul(c, c)
e = _functional_collectives.all_reduce(b, "sum", "0")
f = torch.relu(d)
g = torch.matmul(f, f)
return torch.mm(e, g)
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs = torch.ones(
4, 4, dtype=torch.float, device=device_type
) # + self.rank
kwargs = self.get_world_trs()
func = functools.partial(func, **kwargs)
compiled = torch.compile(func)
out, aten_graph_str = run_and_get_aten_graph(compiled, inputs)
# Things to verify:
# - The all_reduce_ and its prologue should be raised above the 2nd
# matmul but below the 1st matmul.
# - The wait_tensor should be sinked below the 3rd matmul but above
# the 4th matmul.
self.assertExpectedInline(
aten_graph_str,
"""\
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%mm : [num_users=2] = call_function[target=torch.ops.aten.mm.default](args = (%arg0_1, %arg0_1), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%mm,), kwargs = {})
%all_reduce : [num_users=1] = call_function[target=torch.ops._c10d_functional.all_reduce.default](args = (%mm, sum, 0), kwargs = {})
%mm_1 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%relu, %relu), kwargs = {})
%relu_1 : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%mm_1,), kwargs = {})
%mm_2 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%relu_1, %relu_1), kwargs = {})
%wait_tensor : [num_users=1] = call_function[target=torch.ops._c10d_functional.wait_tensor.default](args = (%all_reduce,), kwargs = {})
%mm_3 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%wait_tensor, %mm_2), kwargs = {})
return (mm_3,)""",
)
# Note: this triggered an all_reduce_ bug
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0)
@torch._inductor.config.patch(get_patches())
def test_reorder_compute_for_overlap_mul(self):
def func(a, *, tag, ranks, group_size):
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
g = torch.matmul(a, a)
c = torch.relu(a)
d = torch.matmul(c, c)
f = d * c * ar
fr = _functional_collectives.all_reduce(f, "sum", ranks, tag)
e = torch.matmul(d + ar + fr, g)
return (e,)
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
func_c = functools.partial(func, **self.get_world_trs())
compiled = torch.compile(func_c)
out_c, aten_graph_str = run_and_get_aten_graph(compiled, inputs)
# Note: because we have given collectives and mms equal estimation,
# we overlap each collective with a single mm.
# Same schedule as in test_reorder_compute_for_overlap_custom_runtime_estimation
# although there is an exposed collective
(
FileCheck()
.check("all_reduce.default")
.check("aten.mm")
.check("aten.mm")
.check("wait_tensor.default")
.check("aten.mul")
.check("all_reduce.default")
.check("wait_tensor.default")
.check("aten.mm")
.run(aten_graph_str)
)
correct = func(inputs, **self.get_world_trs())
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 1)
self.assertTrue(same(out_c, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@skipIfRocm
# TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor
@patch.object(torch._inductor.config, "compile_threads", 1)
@unittest.skipIf(True, "Logic not yet implemented")
@torch._inductor.config.patch(get_patches())
def test_grouped_scheduler_node(self):
def func(a, *, tag, ranks, group_size):
add = a + a
div = add / a
ar = _functional_collectives.all_reduce(div, "sum", ranks, tag)
# Normally, we would fuse `add = a + a`, `div = add / a` and `mul = a * a` together into a single fused op,
# but here in this unit test, we intentionally put `add`, `div` and `ar` computation
# into a GroupedSchedulerNode, which prevents them from being fused with any other ops.
mul = a * a
mm = torch.matmul(mul, ar)
return (mm,)
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, inputs, **self.get_world_trs())
# Expectations:
# 1. `add = a + a` and `div = add / a` are still fused, which means fusion
# still happens among nodes within a GroupedSchedulerNode.
# 2. `mul = a * a` is not fused with `add` or `div`, because the latter two are within
# GroupedSchedulerNode and thus are prevented from being fused with any outside ops.
FileCheck().check("triton_poi_fused_add_all_reduce_div_0.").check(
"_c10d_functional.all_reduce_."
).check("triton_poi_fused_mul_1.").run(code)
out = compiled(inputs, **self.get_world_trs())
correct = func(inputs, **self.get_world_trs())
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_patches())
def test_inductor_default_comms_ordering(self):
pg_info = self.get_world_trs()
tag = pg_info["tag"]
ranks = pg_info["ranks"]
group_size = pg_info["group_size"]
g1 = torch.ones(10, 10, device=device_type)
g2 = torch.ones(11, 11, device=device_type)
g3 = torch.ones(12, 12, device=device_type)
@torch.compile
def fn(g1, g2, g3):
handle1 = torch.ops.c10d_functional.all_reduce(
g1, "avg", tag, ranks, group_size
)
handle2 = torch.ops.c10d_functional.all_reduce(
g2, "avg", tag, ranks, group_size
)
handle3 = torch.ops.c10d_functional.all_reduce(
g3, "avg", tag, ranks, group_size
)
# wait on them in a different order
grad3 = torch.ops._c10d_functional.wait_tensor.default(handle3)
grad2 = torch.ops._c10d_functional.wait_tensor.default(handle2)
grad1 = torch.ops._c10d_functional.wait_tensor.default(handle1)
return grad3, grad2, grad1
with _dynamo_dist_per_rank_init(
self.rank, self.world_size, self.backend(device_type), fake_pg=True
):
# all_reduces remain in order!
# note: this isnt actually invariant of pass currently..
# but we should keep collectives stable without reordering opportunities
_, code = run_and_get_aten_graph(fn, g1, g2, g3)
FileCheck().check("all_reduce").check_same("arg0_1").check(
"all_reduce"
).check_same("arg1_1").check("all_reduce").check_same("arg2_1").run(code)
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 3)
# these have no overlap opportunities
self.assertEqual(counters["inductor"]["overlap_scheduling_bad_exposed"], 0)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -7,6 +7,7 @@ from typing import Optional
import sympy import sympy
import torch import torch
from torch.fx.operator_schemas import normalize_function
from . import ir from . import ir
from .utils import get_dtype_size, snode_args_kwargs, sympy_product from .utils import get_dtype_size, snode_args_kwargs, sympy_product
@ -43,11 +44,7 @@ def get_gpu_type() -> NVIDIA_GPU_TYPE:
return NVIDIA_GPU_TYPE.AMPERE return NVIDIA_GPU_TYPE.AMPERE
def get_collective_type(node: ir.IRNode) -> NCCL_COLL: def get_collective_type_from_kernel_name(kernel_name: str) -> NCCL_COLL:
if not isinstance(node, ir._CollectiveKernel):
raise ValueError(f"node is not a collective kernel: {node}")
kernel_name = node.python_kernel_name
assert kernel_name is not None assert kernel_name is not None
if "all_reduce" in kernel_name: if "all_reduce" in kernel_name:
return NCCL_COLL.ALL_REDUCE return NCCL_COLL.ALL_REDUCE
@ -61,6 +58,15 @@ def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
raise ValueError(f"Unsupported collective kernel: {kernel_name}") raise ValueError(f"Unsupported collective kernel: {kernel_name}")
def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
if not isinstance(node, ir._CollectiveKernel):
raise ValueError(f"node is not a collective kernel: {node}")
name = node.python_kernel_name
assert name is not None
return get_collective_type_from_kernel_name(name)
def get_collective_input_size_bytes(node: ir.IRNode) -> int: def get_collective_input_size_bytes(node: ir.IRNode) -> int:
sz_bytes = 0 sz_bytes = 0
for inp in node.inputs: # type: ignore[attr-defined] for inp in node.inputs: # type: ignore[attr-defined]
@ -210,7 +216,9 @@ def estimate_nccl_collective_runtime_nccl_estimator(snode) -> Optional[float]:
return est_time_ms return est_time_ms
def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: def estimate_nccl_collective_runtime_impl(
tensor_storage_size_bytes: int, group_size: int, coll: NCCL_COLL
) -> float:
""" """
Returns estimated NCCL collective runtime in milliseconds (ms). Returns estimated NCCL collective runtime in milliseconds (ms).
@ -223,14 +231,12 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
- 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. - 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
- collective is one of: allreduce, reducescatter, allgather - collective is one of: allreduce, reducescatter, allgather
""" """
tensor_storage_size_bytes = get_collective_input_size_bytes(node)
# Convert bytes to GB # Convert bytes to GB
tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024 tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024
# Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus. # Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus.
# TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
num_gpus_per_node = 8 num_gpus_per_node = 8
group_size = get_collective_group_size(node)
nNodes = math.ceil(group_size / num_gpus_per_node) nNodes = math.ceil(group_size / num_gpus_per_node)
nRanks = group_size # this is total # of gpus globally that participate in this collective op nRanks = group_size # this is total # of gpus globally that participate in this collective op
@ -240,7 +246,6 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
# Assumes ring algorithm # Assumes ring algorithm
nccl_algo = NCCL_ALGO.RING nccl_algo = NCCL_ALGO.RING
nccl_proto = NCCL_PROTO.LL nccl_proto = NCCL_PROTO.LL
coll = get_collective_type(node)
# =============== bandwidth computation =============== # =============== bandwidth computation ===============
# First compute bandwidth in GB/s; then at the end, convert it to GB/ns # First compute bandwidth in GB/s; then at the end, convert it to GB/ns
@ -318,3 +323,70 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
################################################################################################################ ################################################################################################################
# The above code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc # # The above code and constants are adapted from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc #
################################################################################################################ ################################################################################################################
def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
"""
Returns estimated NCCL collective runtime in nanoseconds (ns).
The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc.
We aim to estimate the runtime as accurately as possible.
Assumptions:
- only ring algorithm (NCCL_ALGO_RING) is used
- only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used
- 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
- collective is one of: allreduce, reducescatter, allgather
"""
tensor_storage_size_bytes = get_collective_input_size_bytes(node)
group_size = get_collective_group_size(node)
coll = get_collective_type(node)
return estimate_nccl_collective_runtime_impl(
tensor_storage_size_bytes, group_size, coll
)
def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
size = 0
for node in fx_node.all_input_nodes:
if (t := node.meta.get("val")) is not None:
size += t.numel() * t.element_size()
# TODO - symbolic
return size
def estimate_nccl_collective_runtime_from_fx_node(fx_node: torch.fx.Node) -> float:
"""
Returns estimated NCCL collective runtime in nanoseconds (ns).
The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc.
We aim to estimate the runtime as accurately as possible.
Assumptions:
- only ring algorithm (NCCL_ALGO_RING) is used
- only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used
- 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
- collective is one of: allreduce, reducescatter, allgather
"""
from torch.distributed.distributed_c10d import _get_group_size_by_name
tensor_storage_size_bytes = estimate_fx_collective_size(fx_node)
assert not isinstance(fx_node.target, str)
opt_args_kwargs = normalize_function(
fx_node.target,
args=fx_node.args,
kwargs=fx_node.kwargs,
normalize_to_only_use_kwargs=True,
)
assert opt_args_kwargs is not None
_, kwargs = opt_args_kwargs
group_size = _get_group_size_by_name(kwargs["group_name"])
assert isinstance(fx_node.target, torch._ops.OpOverload)
coll = get_collective_type_from_kernel_name(fx_node.target.name())
return estimate_nccl_collective_runtime_impl(
tensor_storage_size_bytes, group_size, coll
)

View File

@ -2007,6 +2007,17 @@ class test_configs:
# for unit testing # for unit testing
use_libtorch = False use_libtorch = False
# to be migrated when ready for use
aten_fx_overlap_scheduling = False
# to be migrated when ready for use
# runtime estimation function for ops
# for user-defined estimation function, pass in the function handle
# TODO - need estimated and profile based version
estimate_aten_runtime: Union[
Literal["default"], Callable[[torch.fx.Node], Optional[float]]
] = "default"
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403 from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -0,0 +1,655 @@
import functools
import heapq
import itertools
import logging
import sys
from collections import Counter, defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
import torch
import torch.fx as fx
from torch._dynamo.utils import counters, dynamo_timed
from torch.utils._mode_utils import no_dispatch
from torch.utils._ordered_set import OrderedSet
log = logging.getLogger(__name__)
from ..pattern_matcher import stable_topological_sort
def is_wait_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.wait_tensor.default
)
def get_custom_estimation(n: fx.Node) -> Optional[float]:
runtime_estimation = torch._inductor.config.test_configs.estimate_aten_runtime
if runtime_estimation == "default":
return None
assert callable(runtime_estimation)
return runtime_estimation(n)
def estimate_collective_time(n: fx.Node) -> float:
if (est := get_custom_estimation(n)) is not None:
return est
return torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node(
n
)
def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
size = 0
for node in fx_node.all_input_nodes:
if (t := node.meta.get("val")) is not None:
# todo - symbolic
size += t.numel() * t.element_size()
return size
def is_compute_node(n: fx.Node) -> bool:
return (
getattr(n.target, "overloadpacket", None)
in torch.utils.flop_counter.flop_registry
)
def get_hint(x: Union[int, torch.SymInt]) -> Optional[int]:
if isinstance(x, int):
return x
assert isinstance(x, torch.SymInt)
if not x.node.has_hint():
return None
return x.node.hint
def get_collective_do_bench() -> Callable[[Callable[[], Any]], float]:
with dynamo_timed("collective_compute_do_bench"):
return functools.partial(
torch._inductor.runtime.benchmarking.benchmarker.benchmark_gpu,
warmup=5,
)
def benchmark_node(n: fx.Node) -> float:
if (est := get_custom_estimation(n)) is not None:
return est
from torch._dynamo.testing import rand_strided
# todo - skip unbacked, symbolic
success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs(n)
if not success:
return 0
unbacked_tensor = False
key = f"{str(n.target)}: "
def to_real(t: torch.Tensor) -> Optional[torch.Tensor]:
shape = [get_hint(dim) for dim in t.shape]
stride = [get_hint(s) for s in t.stride()]
if any(s is None for s in itertools.chain(shape, stride)):
nonlocal unbacked_tensor
unbacked_tensor = True
return None
nonlocal key
key += f"T: {shape, stride, t.dtype} "
return rand_strided(shape, stride, device=t.device, dtype=t.dtype) # type: ignore[arg-type]
with no_dispatch():
args, kwargs = torch.utils._pytree.tree_map_only(
torch.Tensor,
lambda t: to_real(t),
(args, kwargs),
)
if val := get_cached_node_time(key):
return val
if unbacked_tensor:
return 0
bench = get_collective_do_bench()
out = bench(lambda: n.target(*args, **kwargs)) # type: ignore[operator]
set_cached_node_time(key, out)
return out
@functools.cache
def get_benchmark_cache() -> torch._inductor.codecache.LocalCache:
return torch._inductor.codecache.LocalCache()
def get_cached_node_time(key: str) -> float:
return get_benchmark_cache().lookup(key) # type: ignore[return-value]
def set_cached_node_time(key: str, value: float) -> None:
return get_benchmark_cache().set_value(key, value=value)
@dataclass
class CollectiveInfo:
"""Track info about a collective operation"""
start_node: fx.Node
wait_node: fx.Node
size_bytes: int
estimated_time_ms: float
exposed_time_ms: float # How much of this collective is still exposed
hiding_node: Optional[fx.Node] = None # Node that hides this collective
@property
def is_exposed(self) -> bool:
return self.exposed_time_ms != 0
class OverlapScheduler:
"""
Scheduler that reorders operations to maximize compute-collective overlap.
The reordering is done as a scheduling pass. We maintain a priority queue of
schedulable nodes. The nodes are ranked by:
1) the compute node depth they dominate. this allows reordering locally, such as with
parallel mms, and also allows overlapping reduce scatter nodes outputs in the backward
with compute by deferring their waits.
2) whether the current node is a collective or wait that is currently exposed but has a compute
node which it could be overlapped with.
3) original order in the graph for stability.
When we schedule compute nodes, we first overlap exposed in-flight collectives, then look for unscheduled
collectives that can be scheduled concurrently.
TODO:
- experiment with other priority scores / allow other mechanisms of reorder / more strict adherence to original graph
- memory limit for deferred scheduling of reduce_scatter nodes.
"""
def __init__(
self,
gm: torch.fx.GraphModule,
max_in_flight_gb: float = 2.0,
compute_overlap_multipler: float = 1.0,
max_coll_distance: int = 1000,
):
self.gm = gm
self.graph = gm.graph
self.compute_overlap_multipler = compute_overlap_multipler
self.max_node_distance = max_coll_distance
self.max_in_flight_bytes: int = int(max_in_flight_gb * 1024 * 1024 * 1024)
# Build structures
stable_topological_sort(self.graph)
self.nodes = list(self.graph.nodes)
self.node_idx = {n: i for i, n in enumerate(self.nodes)}
self.node_ancestors: dict[fx.Node, OrderedSet[fx.Node]] = (
self._collect_node_ancestors()
)
# Identify collectives and compute nodes
self.collective_info: dict[fx.Node, CollectiveInfo] = {}
self.unscheduled_collectives: OrderedSet[fx.Node] = OrderedSet()
self.wait_to_start: dict[fx.Node, fx.Node] = {}
self._identify_collectives()
self.compute_depth = self._calculate_compute_node_depth()
self.compute_nodes = [n for n in self.nodes if is_compute_node(n)]
# Scheduling state
self.potentially_hidden_collectives = (
self.compute_potential_hidden_collectives()
)
self.potentially_hidden_waits = self.compute_potential_hidden_waits()
self.in_degree = Counter(user for node in self.nodes for user in node.users)
self.ready: list[tuple[object, fx.Node]] = []
for node in self.nodes:
if self.in_degree[node] == 0:
heapq.heappush(self.ready, (self._compute_score(node), node))
self.in_flight: dict[fx.Node, CollectiveInfo] = {} # start -> info
self.in_flight_bytes = 0
self.scheduled: OrderedSet[fx.Node] = OrderedSet()
def _collect_node_ancestors(self) -> dict[fx.Node, OrderedSet[fx.Node]]:
"""Collect all ancestors for each node."""
ancestors: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
for node in self.nodes:
for input_node in node.all_input_nodes:
ancestors[node].add(input_node)
ancestors[node] |= ancestors[input_node]
return ancestors
def _identify_collectives(self) -> None:
"""Identify all collective operations."""
for node in self.nodes:
if is_wait_tensor(node):
start = node.args[0]
coll_time_ms = estimate_collective_time(start)
info = CollectiveInfo(
start_node=start,
wait_node=node,
size_bytes=estimate_fx_collective_size(start),
estimated_time_ms=coll_time_ms,
exposed_time_ms=coll_time_ms, # Initially fully exposed
)
self.collective_info[start] = info
self.wait_to_start[node] = start
self.unscheduled_collectives.add(start)
def _calculate_compute_node_depth(self) -> dict[fx.Node, int]:
"""Compute forward depth and minimum dominance depth (infinity if blocks no compute)."""
# First pass: forward compute depth
in_degree: dict[fx.Node, int] = {}
compute_depth: dict[fx.Node, int] = {}
queue: list[fx.Node] = []
for node in self.graph.nodes:
num_inputs = len(node.all_input_nodes)
if num_inputs == 0:
queue.append(node)
else:
in_degree[node] = num_inputs
while queue:
node = queue.pop()
max_input_depth = max(
(compute_depth[inp] for inp in node.all_input_nodes), default=0
)
compute_depth[node] = max_input_depth + is_compute_node(node)
for use in node.users:
in_degree[use] -= 1
if in_degree[use] == 0:
queue.append(use)
# Second pass: minimum dominance (what's the earliest compute this blocks)
compute_depth_dominance: dict[fx.Node, int] = {}
for node in reversed(self.graph.nodes):
if is_compute_node(node):
# consider compute nodes to be at their own depth
dominance = compute_depth[node]
else:
# For non-compute nodes, find minimum compute they block
dominance = min(
(compute_depth_dominance[succ] for succ in node.users),
default=sys.maxsize,
)
compute_depth_dominance[node] = dominance
return compute_depth_dominance
def run(self) -> torch.fx.GraphModule:
"""Run the scheduling algorithm."""
while self.ready:
if self._should_force_wait_for_memory():
self._force_oldest_wait()
continue
_, node = heapq.heappop(self.ready)
# we don't always remove nodes from the heap when we schedule them
if node in self.scheduled:
continue
if is_compute_node(node):
self._handle_compute(node)
elif node in self.collective_info:
self._handle_collective_start(node)
elif is_wait_tensor(node):
self._handle_wait(node)
else:
self._handle_other(node)
self._reorder_graph()
return self.gm
def _handle_other(self, node: fx.Node) -> None:
self._schedule(node)
def _schedule(self, node: fx.Node) -> None:
"""Schedule a node."""
assert node not in self.scheduled
assert all(n in self.scheduled for n in node.all_input_nodes)
self.scheduled.add(node)
for user in node.users:
self.in_degree[user] -= 1
if self.in_degree[user] == 0:
heapq.heappush(self.ready, (self._compute_score(user), user))
def _compute_score(self, node: fx.Node) -> object:
"""Compute priority score for a node"""
if is_wait_tensor(node):
info = self.collective_info[self.wait_to_start[node]]
# TODO: we could consider even deferring waits that are not potentially hidden
# so as to overlap comm with itself. although exposed comms should bucketed with each other.
overlappable = info.is_exposed and node in self.potentially_hidden_waits
else:
overlappable = self.in_overlappable_collective_unary_chain(node)
return (
self.compute_depth[node], # what depth compute it blocks
overlappable, # Defer hideable collective ops
self.node_idx[node], # Original order for stability
)
@staticmethod
def is_cheap_fn(node: fx.Node) -> bool:
return getattr(node.target, "is_view", False) or torch.Tag.pointwise in getattr(
node.target, "tags", ()
)
def in_overlappable_collective_unary_chain(self, curr: fx.Node) -> bool:
while True:
if len(curr.users) != 1:
return False
user = next(iter(curr.users))
if len(user.all_input_nodes) != 1:
return False
if user in self.unscheduled_collectives:
return user in self.potentially_hidden_collectives
if not self.is_cheap_fn(user):
return False
curr = user
return False
def _should_force_wait_for_memory(self) -> bool:
"""Check if we need to force a wait due to memory pressure"""
return self.in_flight_bytes >= self.max_in_flight_bytes
def _force_oldest_wait(self) -> None:
"""Schedule the oldest in flight wait"""
self._handle_wait(self._get_oldest_wait())
def _handle_collective_start(self, node: fx.Node) -> None:
"""Handle scheduling a collective start."""
info = self.collective_info[node]
self.in_flight[node] = info
self.in_flight_bytes += info.size_bytes
self.unscheduled_collectives.discard(node)
self._schedule(node)
def _handle_wait(self, node: fx.Node) -> None:
"""Handle scheduling a wait."""
assert node in self.wait_to_start
coll_start = self.wait_to_start[node]
assert coll_start in self.in_flight
self.in_flight_bytes -= self.in_flight[coll_start].size_bytes
del self.in_flight[coll_start]
self._schedule(node)
def _handle_compute(self, node: fx.Node) -> None:
"""Handle scheduling compute and finding overlaps."""
compute_time = benchmark_node(node)
available_compute = compute_time * self.compute_overlap_multipler
# First reduce exposed time of in-flight collectives
for info in self.in_flight.values():
if info.exposed_time_ms == 0:
continue
overlap_amount = min(info.exposed_time_ms, available_compute)
info.exposed_time_ms -= overlap_amount
available_compute -= overlap_amount
if info.exposed_time_ms == 0:
info.hiding_node = node
elif available_compute == 0:
break
# Then, look for unscheduled collectives we can overlap
if available_compute:
self._schedule_collectives_for_overlap(node, available_compute)
self._schedule(node)
def _schedule_collectives_for_overlap(
self, compute_node: fx.Node, available_compute_time: float
) -> None:
"""Opportunistically schedule collectives that can be hidden by compute."""
compute_ancestors = self.node_ancestors[compute_node]
# copy unscheduled_collectives to local because we modify it during iteration
possible_collectives = []
for collective in self.unscheduled_collectives:
distance = abs(self.node_idx[compute_node] - self.node_idx[collective])
if distance > self.max_node_distance:
break
possible_collectives.append(collective)
for collective in possible_collectives:
if available_compute_time == 0:
break
info = self.collective_info[collective]
# Skip if compute depends on collective or vice versa
if (
collective in compute_ancestors
or compute_node in self.node_ancestors[collective]
):
continue
while (
self.in_flight
and (self.max_in_flight_bytes - self.in_flight_bytes) < info.size_bytes
and self._wait_is_hidden(self._get_oldest_wait(), compute_node)
):
self._force_oldest_wait()
if (self.max_in_flight_bytes - self.in_flight_bytes) < info.size_bytes:
continue
# Check if we can reach this collective without scheduling compute, other collectives, or waits
path = self._find_schedulable_path(collective, compute_node)
if path is None:
continue
# Schedule path to this collective
self._schedule_path_to_collective(path, compute_node)
# Update the exposed time for this newly scheduled collective
overlap_amount = min(info.estimated_time_ms, available_compute_time)
info.exposed_time_ms -= overlap_amount
if info.exposed_time_ms == 0:
info.hiding_node = compute_node
available_compute_time -= overlap_amount
self._handle_collective_start(collective)
def _find_schedulable_path(
self, target: fx.Node, curr_compute_node: Optional[fx.Node]
) -> Optional[OrderedSet[fx.Node]]:
"""Find path to target by collecting unscheduled dependencies."""
# TODO - following path faster than doing set difference here
unscheduled_ancestors = self.node_ancestors[target] - self.scheduled
# only schedule non distributed, non compute nodes
for node in unscheduled_ancestors:
if is_compute_node(node):
return None
if node in self.unscheduled_collectives:
return None
# if we schedule a wait tensor whose start collective is hidden by the
# current compute node we are scheduling, then we are effectively exposing it.
# similarly, dont schedule a wait of a collective that could be otherwise hidden,
# thus forcing it to be exposed.
# however, if it is already hidden or it cannot be possible hidden,
# it's fine to schedule it
if is_wait_tensor(node):
info = self.collective_info[self.wait_to_start[node]]
if info.hiding_node and info.hiding_node != curr_compute_node:
continue
elif node not in self.potentially_hidden_waits:
continue
return None
return unscheduled_ancestors
def _get_oldest_wait(self) -> fx.Node:
oldest_start = next(iter(self.in_flight))
return self.collective_info[oldest_start].wait_node
def _wait_is_hidden(
self, wait_node: fx.Node, compute_node: Optional[fx.Node] = None
) -> bool:
assert is_wait_tensor(wait_node)
info = self.collective_info[self.wait_to_start[wait_node]]
return not info.is_exposed and info.hiding_node != compute_node
def _schedule_path_to_collective(
self, path: OrderedSet[fx.Node], curr_compute_node: fx.Node
) -> None:
"""Schedule all nodes needed to reach a collective."""
for node in sorted(path, key=lambda n: self.node_idx[n]):
assert not (is_compute_node(node) or node in self.unscheduled_collectives)
if is_wait_tensor(node):
info = self.collective_info[self.wait_to_start[node]]
assert not info.hiding_node == curr_compute_node
self._handle_wait(node)
continue
self._schedule(node)
def reorder_graph(self) -> None:
output_node = self.graph.output_node()
for node in self.scheduled:
if node.op == "placeholder":
continue
output_node.prepend(node)
self.graph.lint()
def _reorder_graph(self) -> None:
"""Reorder graph based on schedule."""
exposed = [
c
for c in self.collective_info.values()
if c.exposed_time_ms == c.estimated_time_ms
]
potentially_hidden_collectives = self.compute_potential_hidden_collectives(
limit_coll_per_compute=True
)
bad_exposed = [
c for c in exposed if c.start_node in potentially_hidden_collectives
]
counters["inductor"]["overlap_scheduling_exposed"] += len(exposed)
counters["inductor"]["overlap_scheduling_bad_exposed"] += len(bad_exposed)
counters["inductor"]["overlap_scheduling_potentially_hidden"] += len(
potentially_hidden_collectives
)
log.info(
"Overlap scheduling: total exposed %s, total bad exposed %s, total potentially hidden %s",
len(exposed),
len(bad_exposed),
len(potentially_hidden_collectives),
)
self.reorder_graph()
def compute_potential_hidden_nodes(
self, nodes_to_check: Iterable[fx.Node], limit_coll_per_compute: bool = False
) -> dict[fx.Node, fx.Node]:
"""
Returns a dict containing a mapping of nodes which could potentially be hidden to their hiding node
"""
used_compute_nodes: OrderedSet[fx.Node] = OrderedSet()
def could_be_hidden(start: fx.Node) -> Optional[fx.Node]:
for compute_node in self.compute_nodes:
if limit_coll_per_compute and compute_node in used_compute_nodes:
continue
if (
start not in self.node_ancestors[compute_node]
and compute_node not in self.node_ancestors[start]
):
if limit_coll_per_compute:
used_compute_nodes.add(compute_node)
return compute_node
return None
# TODO: We could potentially limit compute nodes per overlap time,
# today, this is optimistic, and just serves to avoid deferring
# collectives/waits that have no possible overlap as well as for analysis of how
# successfully we hid compute
potentially_hidden = {}
for node in nodes_to_check:
if mm := could_be_hidden(node):
potentially_hidden[node] = mm
return potentially_hidden
def compute_potential_hidden_collectives(
self, limit_coll_per_compute: bool = False
) -> dict[fx.Node, fx.Node]:
"""Compute which collective operations could be hidden by compute."""
return self.compute_potential_hidden_nodes(
self.collective_info.keys(), limit_coll_per_compute
)
def compute_potential_hidden_waits(
self, limit_coll_per_compute: bool = False
) -> dict[fx.Node, fx.Node]:
"""Compute which wait operations could be hidden by compte."""
wait_nodes = [info.wait_node for info in self.collective_info.values()]
return self.compute_potential_hidden_nodes(wait_nodes, limit_coll_per_compute)
def schedule_overlap_bucketing(
gm: torch.fx.GraphModule,
max_in_flight_gb: float = 2.0,
compute_overlap_multipler: float = 1.0,
max_coll_distance: int = 1000,
) -> torch.fx.GraphModule:
"""Schedule nodes to maximize compute-collective overlap.
Args:
gm: Input graph module to optimize.
max_in_flight_gb: Maximum GB of concurrent collective data.
compute_overlap_multipler: Scale factor for compute time used to hide collectives.
max_coll_distance: Maximum node distance for overlap consideration.
"""
return OverlapScheduler(
gm,
compute_overlap_multipler=compute_overlap_multipler,
max_in_flight_gb=max_in_flight_gb,
max_coll_distance=max_coll_distance,
).run()

View File

@ -202,6 +202,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
GraphTransformObserver(gm, pass_name).apply_gm_pass(custom_backend_pass) GraphTransformObserver(gm, pass_name).apply_gm_pass(custom_backend_pass)
collectives_bucketing: bool = False collectives_bucketing: bool = False
if config.bucket_reduce_scatters_fx != "none": if config.bucket_reduce_scatters_fx != "none":
from torch._inductor.fx_passes.bucketing import bucket_reduce_scatter from torch._inductor.fx_passes.bucketing import bucket_reduce_scatter
from torch._inductor.fx_passes.fsdp import bucket_fsdp_reduce_scatter from torch._inductor.fx_passes.fsdp import bucket_fsdp_reduce_scatter

View File

@ -9112,6 +9112,7 @@ class _CollectiveKernel(FallbackKernel):
assert not unbacked_bindings, f"{kernel} {unbacked_bindings}" assert not unbacked_bindings, f"{kernel} {unbacked_bindings}"
for tensor_arg in tensor_args: for tensor_arg in tensor_args:
tensor_arg.realize() tensor_arg.realize()
V.graph.mark_buffer_mutated(tensor_arg.get_name())
device = tensor_args[0].get_device() device = tensor_args[0].get_device()
packed = cls( packed = cls(