Compare commits

...

1 Commits

Author SHA1 Message Date
57a76b44de fix all-to-all estimation 2025-11-14 17:05:47 -08:00
6 changed files with 234 additions and 30 deletions

View File

@ -10,6 +10,7 @@ import torch._dynamo.test_case
# for some reason importing functional collectives after dynamo breaks collectives handling!
import torch.distributed._functional_collectives as _functional_collectives
import torch.fx as fx
from torch._C import FileCheck
from torch._dynamo.utils import counters, same
from torch._inductor.utils import run_and_get_code, run_and_get_triton_code
@ -238,6 +239,49 @@ graph():
self.assertTrue(same(out, correct))
self.assertEqual(counters["inductor"]["overlap_scheduling_exposed"], 0)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_patches())
def test_schedulable_wait(self):
"""Test that if a wait node is scheduable or not."""
from torch._inductor.fx_passes.bucketing import _schedulable_wait_node
def test_graph():
graph = fx.Graph()
inp = graph.placeholder("inp")
group_size = graph.placeholder("group_size")
group_name = graph.placeholder("group_name")
ag_0_out = graph.call_function(
torch.ops._c10d_functional.all_gather_into_tensor.default,
args=(inp, group_size, group_name),
)
ag_0_wait = graph.call_function(
torch.ops._c10d_functional.wait_tensor.default,
args=(ag_0_out,),
)
ag_1_out = graph.call_function(
torch.ops._c10d_functional.all_gather_into_tensor.default,
args=(ag_0_wait, group_size, group_name),
)
ag_1_wait = graph.call_function(
torch.ops._c10d_functional.wait_tensor.default,
args=(ag_1_out,),
)
ag_2_wait = graph.call_function(
torch.ops._c10d_functional.wait_tensor.default,
args=(ag_1_wait,),
)
graph.output(ag_2_wait)
return graph
graph = test_graph()
schedulable = {"wait_tensor_default", "wait_tensor_default_1"}
for node in list(graph.nodes):
expected = node.name in schedulable
assert _schedulable_wait_node(node) is expected
@torch._inductor.config.patch(get_patches())
def test_reorder_compute_for_overlap_mul(self):
def func(a, *, tag, ranks, group_size):

View File

@ -23,7 +23,12 @@ from torch._inductor.comms import (
sink_waits_iterative,
)
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
from torch._inductor.fx_passes.bucketing import is_all_gather_into_tensor
from torch._inductor.fx_passes.bucketing import (
is_all_gather_into_tensor,
is_all_reduce_tensor,
is_all_to_all_tensor,
is_reduce_scatter_tensor,
)
from torch._inductor.scheduler import (
_get_mm_like_fn,
BaseSchedulerNode,
@ -2193,7 +2198,7 @@ class TestSyncDecisionCrossRanks(MultiProcessTestCase):
self.assertEqual(saved_values, [wt1])
@skip_if_lt_x_gpu(2)
def test_comm_analysis(self):
def test_all_gather_comm_analysis(self):
store = c10d.FileStore(self.file_name, self.world_size)
torch.cuda.set_device(self.rank)
c10d.init_process_group(
@ -2234,6 +2239,140 @@ class TestSyncDecisionCrossRanks(MultiProcessTestCase):
)
assert est_ms_nccl > 0
@skip_if_lt_x_gpu(2)
def test_reduce_scatter_comm_analysis(self):
store = c10d.FileStore(self.file_name, self.world_size)
torch.cuda.set_device(self.rank)
c10d.init_process_group(
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
)
group = c10d.distributed_c10d._get_default_group()
group_name = "default"
torch._C._distributed_c10d._register_process_group(
group_name, torch.distributed.group.WORLD
)
group_size = group.size()
def func(inp, group_size, group_name):
rs_0_out = torch.ops._c10d_functional.reduce_scatter_tensor(
inp, "sum", group_size, group_name
)
rs_0_wait = torch.ops.c10d_functional.wait_tensor(rs_0_out)
rs_1_out = torch.ops._c10d_functional.reduce_scatter_tensor(
rs_0_wait, "sum", group_size, group_name
)
rs_1_wait = torch.ops.c10d_functional.wait_tensor(rs_1_out)
return rs_1_wait
gm = make_fx(func)(torch.ones(4, 4, device=self.device), group_size, group_name)
g = gm.graph
for n in g.nodes:
if is_reduce_scatter_tensor(n):
from torch._inductor.comm_analysis import (
estimate_nccl_collective_runtime_from_fx_node,
)
est_ms = estimate_nccl_collective_runtime_from_fx_node(
n, use_nccl_estimator=False
)
assert est_ms > 0
est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node(
n, use_nccl_estimator=True
)
assert est_ms_nccl > 0
@skip_if_lt_x_gpu(2)
def test_all_reduce_comm_analysis(self):
store = c10d.FileStore(self.file_name, self.world_size)
torch.cuda.set_device(self.rank)
c10d.init_process_group(
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
)
group = c10d.distributed_c10d._get_default_group()
group_name = "default"
torch._C._distributed_c10d._register_process_group(
group_name, torch.distributed.group.WORLD
)
group_size = group.size()
def func(inp, group_size, group_name):
ar_0_out = torch.ops._c10d_functional.all_reduce(inp, "sum", group_name)
ar_0_wait = torch.ops.c10d_functional.wait_tensor(ar_0_out)
ar_1_out = torch.ops._c10d_functional.all_reduce(
ar_0_wait, "sum", group_name
)
ar_1_wait = torch.ops.c10d_functional.wait_tensor(ar_1_out)
return ar_1_wait
gm = make_fx(func)(torch.ones(4, 4, device=self.device), group_size, group_name)
g = gm.graph
for n in g.nodes:
if is_all_reduce_tensor(n):
from torch._inductor.comm_analysis import (
estimate_nccl_collective_runtime_from_fx_node,
)
est_ms = estimate_nccl_collective_runtime_from_fx_node(
n, use_nccl_estimator=False
)
assert est_ms > 0
est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node(
n, use_nccl_estimator=True
)
assert est_ms_nccl > 0
@skip_if_lt_x_gpu(2)
def test_all_to_all_comm_analysis(self):
store = c10d.FileStore(self.file_name, self.world_size)
torch.cuda.set_device(self.rank)
c10d.init_process_group(
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
)
group = c10d.distributed_c10d._get_default_group()
group_name = "default"
torch._C._distributed_c10d._register_process_group(
group_name, torch.distributed.group.WORLD
)
group_size = group.size()
def func(inp, group_size, group_name):
chunk = inp.numel() // self.world_size
split_sizes = [chunk] * self.world_size
a2a_0_out = torch.ops._c10d_functional.all_to_all_single(
inp,
split_sizes,
split_sizes,
group_name,
)
a2a_0_wait = torch.ops.c10d_functional.wait_tensor(a2a_0_out)
a2a_1_out = torch.ops._c10d_functional.all_to_all_single(
a2a_0_wait,
split_sizes,
split_sizes,
group_name,
)
a2a_1_wait = torch.ops.c10d_functional.wait_tensor(a2a_1_out)
return a2a_1_wait
gm = make_fx(func)(
torch.ones(group_size * 4, 1, device=self.device), group_size, group_name
)
g = gm.graph
for n in g.nodes:
if is_all_to_all_tensor(n):
from torch._inductor.comm_analysis import (
estimate_nccl_collective_runtime_from_fx_node,
)
est_ms = estimate_nccl_collective_runtime_from_fx_node(
n, use_nccl_estimator=False
)
assert est_ms > 0
est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node(
n, use_nccl_estimator=True
)
assert est_ms_nccl > 0
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -23,6 +23,7 @@ class NCCL_COLL(IntEnum):
ALL_GATHER = 1
REDUCE_SCATTER = 2
ALL_TO_ALL = 3
UNSUPPORTED = 4
class NVIDIA_GPU_TYPE(IntEnum):
@ -53,10 +54,10 @@ def get_collective_type_from_kernel_name(kernel_name: str) -> NCCL_COLL:
return NCCL_COLL.ALL_GATHER
elif "reduce_scatter" in kernel_name:
return NCCL_COLL.REDUCE_SCATTER
elif "torch.ops._dtensor.shard_dim_alltoall.default" in kernel_name:
elif any(comm in kernel_name for comm in ("all_to_all", "alltoall")):
return NCCL_COLL.ALL_TO_ALL
else:
raise ValueError(f"Unsupported collective kernel: {kernel_name}")
return NCCL_COLL.UNSUPPORTED
def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
@ -347,13 +348,12 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
size = 0
sz_bytes = 0
for node in fx_node.all_input_nodes:
if (t := node.meta.get("val")) is not None:
size += t.numel() * t.element_size()
# TODO - symbolic
return size
numel = get_size_numel(t.size())
sz_bytes += numel * get_dtype_size(t.dtype)
return sz_bytes
def estimate_nccl_collective_runtime_from_fx_node(

View File

@ -10,6 +10,10 @@ import torch.distributed as dist
import torch.utils._pytree as pytree
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import detect_fake_mode
from torch._inductor.comm_analysis import (
get_collective_type_from_kernel_name,
NCCL_COLL,
)
from torch._inductor.runtime.runtime_utils import dynamo_timed
from torch._logging import trace_structured
from torch.fx.experimental.proxy_tensor import make_fx
@ -52,6 +56,23 @@ def _ar_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
return (group_name, reduce_op, dtype)
def _schedulable_wait_node(node: torch.fx.Node) -> bool:
"""
Add additional check on if the wait node is schedulable
We should not schedule a fx node that is:
1. wait on a collective that is not callable
2. wait on a non-NCCL communication node
"""
if not is_wait_tensor(node):
return False
assert isinstance(node.args[0], torch.fx.Node)
assert isinstance(node.args[0].target.name(), str)
is_callable: bool = node.args[0].op == "call_function"
coll: NCCL_COLL = get_collective_type_from_kernel_name(node.args[0].target.name())
is_collective: bool = coll != NCCL_COLL.UNSUPPORTED
return is_callable and is_collective
def bucket_key(node: torch.fx.Node, mode: BucketMode | None = None) -> object | None:
if is_all_gather_into_tensor(node):
group_key_fn = (
@ -138,7 +159,6 @@ def is_wait_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target is torch.ops._c10d_functional.wait_tensor.default
and node.args[0].op == "call_function"
)
@ -149,6 +169,13 @@ def is_all_reduce_tensor(node: torch.fx.Node) -> bool:
)
def is_all_to_all_tensor(node: torch.fx.Node) -> bool:
return (
node.op == "call_function"
and node.target is torch.ops._c10d_functional.all_to_all_single.default
)
def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool:
return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]) # type: ignore[arg-type]

View File

@ -8,9 +8,9 @@ import torch
import torch.fx as fx
from torch._dynamo.graph_deduplication import _stable_topological_sort
from torch._inductor.fx_passes.bucketing import (
_schedulable_wait_node,
is_all_gather_into_tensor as is_all_gather,
is_reduce_scatter_tensor as is_reduce_scatter,
is_wait_tensor,
merge_all_gather_bucket,
merge_reduce_scatter_bucket,
)
@ -36,7 +36,10 @@ class ManualOverlapPreservingBucketer(OverlapPreservingBucketer):
"""
def __init__(
self, node_users: dict[fx.Node, OrderedSet[fx.Node]], *args: Any, **kwargs: Any
self,
node_users: dict[fx.Node, OrderedSet[fx.Node]],
*args: Any,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
self.node_users = node_users
@ -97,7 +100,7 @@ class ManualOverlapPreservingBucketer(OverlapPreservingBucketer):
)
# Identify the new wait and start
new_waits = [n for n in new_nodes if is_wait_tensor(n)]
new_waits = [n for n in new_nodes if _schedulable_wait_node(n)]
assert len(new_waits) == 1, f"Expected exactly one new wait, got {new_waits}"
new_wait = new_waits[0]
new_start = new_wait.args[0]
@ -186,7 +189,7 @@ class ManualOverlapScheduler(OverlapScheduler):
def _identify_collectives(self) -> None:
"""Identify all collective operations."""
for node in self.nodes:
if is_wait_tensor(node):
if _schedulable_wait_node(node):
start = node.args[0]
info = CollectiveInfo(
start_node=start,

View File

@ -11,7 +11,8 @@ from typing import Any, Literal
import torch
import torch.fx as fx
from torch._dynamo.utils import counters, dynamo_timed
from torch._inductor.fx_passes.bucketing import is_wait_tensor
from torch._inductor.comm_analysis import estimate_fx_collective_size
from torch._inductor.fx_passes.bucketing import _schedulable_wait_node, is_wait_tensor
from torch._inductor.fx_passes.memory_estimator import (
_is_releasable,
build_memory_profile,
@ -67,16 +68,6 @@ def estimate_collective_time(
)
def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
size = 0
for node in fx_node.all_input_nodes:
if (t := node.meta.get("val")) is not None:
# todo - symbolic
size += t.numel() * t.element_size()
return size
def is_compute_node(n: fx.Node) -> bool:
"""
Should we consider this node computationally expensive ?
@ -318,7 +309,7 @@ class OverlapScheduler:
def _identify_collectives(self) -> None:
"""Identify all collective operations."""
for node in self.nodes:
if is_wait_tensor(node):
if _schedulable_wait_node(node):
start = node.args[0]
coll_time_ms = estimate_collective_time(
start, custom_runtime_estimation=self.custom_runtime_estimation
@ -531,7 +522,7 @@ class OverlapScheduler:
self._handle_compute(node)
elif node in self.collective_info:
self._handle_collective_start(node)
elif is_wait_tensor(node):
elif _schedulable_wait_node(node):
self._handle_wait(node)
else:
self._handle_other(node)
@ -596,7 +587,7 @@ class OverlapScheduler:
def _compute_score(self, node: fx.Node) -> object:
"""Compute priority score for a node"""
if is_wait_tensor(node):
if _schedulable_wait_node(node):
info = self.collective_info[self.wait_to_start[node]]
# defer waits locally if they are exposed.
compute_local_priority = int(info.is_exposed)
@ -827,7 +818,7 @@ class OverlapScheduler:
# thus forcing it to be exposed.
# however, if it is already hidden or it cannot be possible hidden,
# it's fine to schedule it
if is_wait_tensor(node):
if _schedulable_wait_node(node):
info = self.collective_info[self.wait_to_start[node]]
if info.hiding_node and info.hiding_node != curr_compute_node:
continue
@ -875,7 +866,7 @@ class OverlapScheduler:
assert all(n not in self.scheduled for n in path)
for node in sorted(path, key=lambda n: self.node_idx[n]):
assert not (is_compute_node(node) or node in self.unscheduled_collectives)
if is_wait_tensor(node):
if _schedulable_wait_node(node):
# When we schedule wait tensors, we also force realization of all
# collectives enqueued prior to their corresponding collective.
# It's possible the scheduling of one wait tensor here has forced