Compare commits

...

6 Commits

Author SHA1 Message Date
fe1145581d Update on "[inductor] Respect pg topology for nccl collectives estimations"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-29 05:35:08 -07:00
a608181919 Update on "[inductor] Respect pg topology for nccl collectives estimations"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-29 05:24:51 -07:00
b8cefde1cc Update on "[inductor] Respect pg topology for nccl collectives estimations"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-29 05:22:48 -07:00
87d8918af4 Update on "[inductor] Respect pg topology for nccl collectives estimations"
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-29 05:10:27 -07:00
7953a54ce7 Update on "[inductor] Respect pg topology for nccl collectives estimations"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
2025-10-29 05:06:08 -07:00
c730750df2 [inductor] Respect pg topology for nccl collectives estimations
[ghstack-poisoned]
2025-10-28 06:39:05 -07:00
2 changed files with 229 additions and 11 deletions

View File

@ -23,6 +23,7 @@ from torch._inductor.comms import (
sink_waits_iterative,
)
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
from torch._inductor.fx_passes.bucketing import is_all_gather_into_tensor
from torch._inductor.scheduler import (
_get_mm_like_fn,
BaseSchedulerNode,
@ -2197,6 +2198,47 @@ class TestSyncDecisionCrossRanks(MultiProcessTestCase):
saved_values = _sync_decision_cross_ranks(test_graph, saved_values)
self.assertEqual(saved_values, [wt1])
@skip_if_lt_x_gpu(2)
def test_comm_analysis(self):
store = c10d.FileStore(self.file_name, self.world_size)
torch.cuda.set_device(self.rank)
c10d.init_process_group(
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
)
group = c10d.distributed_c10d._get_default_group()
group_name = "default"
torch._C._distributed_c10d._register_process_group(
group_name, torch.distributed.group.WORLD
)
group_size = group.size()
def func(inp, group_size, group_name):
ag_0_out = torch.ops._c10d_functional.all_gather_into_tensor(
inp, group_size, group_name
)
ag_0_wait = torch.ops.c10d_functional.wait_tensor(ag_0_out)
ag_1_out = torch.ops._c10d_functional.all_gather_into_tensor(
ag_0_wait, group_size, group_name
)
ag_1_wait = torch.ops.c10d_functional.wait_tensor(ag_1_out)
return ag_1_wait
from torch._inductor.comm_analysis import nccl_pg_connectivity
conn = nccl_pg_connectivity(group)
assert len(conn) > 0
gm = make_fx(func)(torch.ones(4, 4, device=self.device), group_size, group_name)
g = gm.graph
for n in g.nodes:
if is_all_gather_into_tensor(n):
from torch._inductor.comm_analysis import (
estimate_nccl_collective_runtime_from_fx_node,
)
est_ms = estimate_nccl_collective_runtime_from_fx_node(n)
assert est_ms > 0
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1,12 +1,16 @@
import functools
# mypy: disable-error-code=attr-defined
import logging
import math
from enum import IntEnum
from typing import Optional
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum, IntEnum
from functools import lru_cache
from typing import ClassVar, Optional
import sympy
import torch
from torch.distributed import ProcessGroup
from torch.fx.operator_schemas import normalize_function
from . import ir
@ -30,7 +34,7 @@ class NVIDIA_GPU_TYPE(IntEnum):
HOPPER = 2
@functools.lru_cache
@lru_cache
def get_gpu_type() -> NVIDIA_GPU_TYPE:
gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or ""
if "V100" in gpu_info:
@ -216,8 +220,164 @@ def estimate_nccl_collective_runtime_nccl_estimator(snode) -> Optional[float]:
return est_time_ms
class ConnectionType(Enum):
CURRENT_DEVICE = "current_device"
NVLINK = "nvlink"
INFINIBAND = "infiniband"
class Connection(ABC):
"""Base class for GPU interconnect connections"""
@property
@abstractmethod
def type(self) -> ConnectionType:
"""Connection type"""
@property
@abstractmethod
def bandwidth(self) -> float:
"""Bidirectional bandwidth in GB/s"""
@dataclass
class CurrentDevice(Connection):
"""Dummy Implementation"""
type: ConnectionType = ConnectionType.CURRENT_DEVICE
bandwidth: float = 0.0
@dataclass
class NVLinkConnection(Connection):
"""NVLink connection for intra-node GPU communication"""
version: str
num_links: int
_bandwidth: float = field(init=False)
# Bidirectional bandwidth per link (GB/s)
BANDWIDTH_PER_LINK: ClassVar[dict[str, float]] = {
"1.0": 40.0, # 20 GB/s unidirectional × 2
"2.0": 50.0, # 25 GB/s unidirectional × 2
"3.0": 50.0, # 25 GB/s unidirectional × 2 (A100)
"4.0": 50.0, # 25 GB/s unidirectional × 2 (H100, H200)
"5.0": 400.0, # 200 GB/s unidirectional x 2 (GB200)
}
def __post_init__(self) -> None:
if self.version not in self.BANDWIDTH_PER_LINK:
raise ValueError(
f"Unknown NVLink version: {self.version}. "
f"Supported versions: {list(self.BANDWIDTH_PER_LINK.keys())}"
)
# Calculate total bidirectional bandwidth
self._bandwidth = self.BANDWIDTH_PER_LINK[self.version] * self.num_links
@property
def bandwidth(self) -> float:
"""Bidirectional bandwidth in GB/s"""
return self._bandwidth
@property
def type(self) -> ConnectionType:
return ConnectionType.NVLINK
def __str__(self) -> str:
return f"NVLink v{self.version} ({self.num_links} links, {self.bandwidth:.0f} GB/s bidirectional)"
def __repr__(self) -> str:
return f"NVLinkConnection(version='{self.version}', num_links={self.num_links}, bandwidth={self.bandwidth:.0f})"
@dataclass
class InfiniBandConnection(Connection):
"""InfiniBand connection for inter-node communication"""
rate: float = 200.0 # Default (4xHDR), ibstat Rate output, Gbps per direction
num_ports: int = 4
_bandwidth: float = field(init=False)
def __post_init__(self) -> None:
self._bandwidth = self.rate / 8 * self.num_ports * 2
@property
def type(self) -> ConnectionType:
return ConnectionType.INFINIBAND
@property
def bandwidth(self) -> float:
"""Bidirectional bandwidth in GB/s"""
return self._bandwidth
def __str__(self) -> str:
return f"InfiniBand Rate:{self.rate} ({self.bandwidth:.0f} GB/s bidirectional)"
def __repr__(self) -> str:
return (
f"InfiniBandConnection(rate='{self.rate}', bandwidth={self.bandwidth:.0f})"
)
@lru_cache(maxsize=128)
def nccl_pg_connectivity(pg: ProcessGroup) -> list[Connection]:
"""
Returns NCCL ProcessGroup Connectivity.
The list of Connection objects, corresponding to each rank in PG.
Connection has bandwidth property to return bidirectional bandwidth.
For own ranks returns CurrentDevice.
Attention:
Does collective operation to gather uuid of devices in PG.
The result should be cached.
"""
rank = pg.rank()
size = pg.size()
from torch._C._autograd import DeviceType
from torch._C._distributed_c10d import _detect_dma_connectivity
nvlink_conn = _detect_dma_connectivity(DeviceType.CUDA, "nvlink")
nvlink_matrix = nvlink_conn.matrix
devices_uuids = [
str(torch.cuda.get_device_properties(i).uuid)
for i in range(torch.cuda.device_count())
]
current_device_idx = torch.cuda.current_device()
props = torch.cuda.get_device_properties(current_device_idx)
uuid = str(props.uuid)
gathered: list[list[str]] = [[] for _ in range(size)]
torch.distributed.all_gather_object(gathered, [uuid], pg)
nvlink_n = len(nvlink_matrix)
uuid_to_nvlinkconn = {}
for dev_idx in range(nvlink_n):
# TODO: retrieve nvlink version
version = "4.0"
num_links = nvlink_matrix[current_device_idx][dev_idx]
uuid_to_nvlinkconn[devices_uuids[dev_idx]] = NVLinkConnection(
version, num_links
)
conn: list[Connection] = []
for r in range(size):
uuid = gathered[r][0]
if r == rank:
conn.append(CurrentDevice())
continue
if uuid in uuid_to_nvlinkconn:
conn.append(uuid_to_nvlinkconn[uuid])
continue
# TODO: get number of ports and ib rate, parsing ibstat?
conn.append(InfiniBandConnection())
return conn
def estimate_nccl_collective_runtime_impl(
tensor_storage_size_bytes: int, group_size: int, coll: NCCL_COLL
tensor_storage_size_bytes: int,
group_size: int,
coll: NCCL_COLL,
group_name: Optional[str] = None,
) -> float:
"""
Returns estimated NCCL collective runtime in milliseconds (ms).
@ -236,13 +396,25 @@ def estimate_nccl_collective_runtime_impl(
# Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus.
# TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
num_gpus_per_node = 8
nNodes = math.ceil(group_size / num_gpus_per_node)
nRanks = group_size # this is total # of gpus globally that participate in this collective op
if nRanks <= 1:
return 0
if group_name is not None:
from torch._C._distributed_c10d import _resolve_process_group
pg = _resolve_process_group(group_name)
group_conn = nccl_pg_connectivity(pg)
num_ib_conn = 1
for c in group_conn:
if isinstance(c, InfiniBandConnection):
num_ib_conn += 1
nNodes = num_ib_conn
else:
num_gpus_per_node = 8
nNodes = math.ceil(group_size / num_gpus_per_node)
# Assumes ring algorithm
nccl_algo = NCCL_ALGO.RING
nccl_proto = NCCL_PROTO.LL
@ -341,8 +513,10 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
tensor_storage_size_bytes = get_collective_input_size_bytes(node)
group_size = get_collective_group_size(node)
coll = get_collective_type(node)
group_name: Optional[str] = None
# TODO: retrieve group_name from node.constant_args
return estimate_nccl_collective_runtime_impl(
tensor_storage_size_bytes, group_size, coll
tensor_storage_size_bytes, group_size, coll, group_name
)
@ -357,7 +531,8 @@ def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
def estimate_nccl_collective_runtime_from_fx_node(
fx_node: torch.fx.Node, override_size: Optional[int] = None
fx_node: torch.fx.Node,
override_size: Optional[int] = None,
) -> float:
"""
Returns estimated NCCL collective runtime in nanoseconds (ns).
@ -388,10 +563,11 @@ def estimate_nccl_collective_runtime_from_fx_node(
assert opt_args_kwargs is not None
_, kwargs = opt_args_kwargs
group_size = _get_group_size_by_name(kwargs["group_name"])
group_name = kwargs["group_name"]
group_size = _get_group_size_by_name(group_name)
assert isinstance(fx_node.target, torch._ops.OpOverload)
coll = get_collective_type_from_kernel_name(fx_node.target.name())
return estimate_nccl_collective_runtime_impl(
tensor_storage_size_bytes, group_size, coll
tensor_storage_size_bytes, group_size, coll, group_name
)