mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[reland] Fix estimate_nccl_collective_runtime (#118986)
`estimate_nccl_collective_runtime` has been broken and the errors have been silently swallowed by inductor. This PR: - Fixes the issues described in https://github.com/pytorch/pytorch/issues/118497. - Adds white-box testing so future issues can be surfaced in tests. - Add support for native funcol IRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118986 Approved by: https://github.com/yf225 ghstack dependencies: #119102
This commit is contained in:
committed by
PyTorch MergeBot
parent
b2043c0543
commit
27ffede878
@ -1,7 +1,7 @@
|
||||
import math
|
||||
from enum import IntEnum
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from . import ir
|
||||
@ -9,9 +9,6 @@ from . import ir
|
||||
from .utils import get_dtype_size, sympy_product
|
||||
from .virtualized import V
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.scheduler import BaseSchedulerNode
|
||||
|
||||
|
||||
class NCCL_COLL(IntEnum):
|
||||
ALL_REDUCE = 0
|
||||
@ -26,7 +23,7 @@ class NVIDIA_GPU_TYPE(IntEnum):
|
||||
|
||||
|
||||
def get_gpu_type() -> NVIDIA_GPU_TYPE:
|
||||
gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run)
|
||||
gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or ""
|
||||
if "V100" in gpu_info:
|
||||
return NVIDIA_GPU_TYPE.VOLTA
|
||||
elif "A100" in gpu_info:
|
||||
@ -38,19 +35,52 @@ def get_gpu_type() -> NVIDIA_GPU_TYPE:
|
||||
return NVIDIA_GPU_TYPE.AMPERE
|
||||
|
||||
|
||||
def get_collective_type(snode: "BaseSchedulerNode") -> NCCL_COLL:
|
||||
if isinstance(snode.node, (ir.AllReduce, ir.AllReduceCoalesced)):
|
||||
def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
|
||||
if isinstance(node, ir._CollectiveKernel):
|
||||
kernel_name = node.python_kernel_name
|
||||
assert kernel_name is not None
|
||||
if "all_reduce" in kernel_name:
|
||||
return NCCL_COLL.ALL_REDUCE
|
||||
elif "all_gather" in kernel_name:
|
||||
return NCCL_COLL.ALL_GATHER
|
||||
elif "reduce_scatter" in kernel_name:
|
||||
return NCCL_COLL.REDUCE_SCATTER
|
||||
else:
|
||||
raise Exception(f"Unsupported collective kernel: {kernel_name}")
|
||||
|
||||
if isinstance(node, (ir.AllReduce, ir.AllReduceCoalesced)):
|
||||
return NCCL_COLL.ALL_REDUCE
|
||||
elif isinstance(
|
||||
snode.node, (ir.AllGatherIntoTensor, ir.AllGatherIntoTensorCoalesced)
|
||||
):
|
||||
elif isinstance(node, (ir.AllGatherIntoTensor, ir.AllGatherIntoTensorCoalesced)):
|
||||
return NCCL_COLL.ALL_GATHER
|
||||
elif isinstance(
|
||||
snode.node, (ir.ReduceScatterTensor, ir.ReduceScatterTensorCoalesced)
|
||||
):
|
||||
elif isinstance(node, (ir.ReduceScatterTensor, ir.ReduceScatterTensorCoalesced)):
|
||||
return NCCL_COLL.REDUCE_SCATTER
|
||||
else:
|
||||
raise Exception(f"Unsupported collective type: {snode.node}")
|
||||
raise Exception(f"Unsupported collective type: {node}")
|
||||
|
||||
|
||||
def get_collective_input_size_bytes(node: ir.IRNode) -> int:
|
||||
sz_bytes = 0
|
||||
for inp in node.inputs: # type: ignore[attr-defined]
|
||||
shape = inp.layout.size
|
||||
numel = sympy_product(inp.layout.size)
|
||||
if isinstance(numel, sympy.Integer):
|
||||
# For ease of testing
|
||||
numel = int(numel)
|
||||
else:
|
||||
numel = V.graph.sizevars.size_hint(numel)
|
||||
sz_bytes += numel * get_dtype_size(inp.layout.dtype)
|
||||
return sz_bytes
|
||||
|
||||
|
||||
def get_collective_group_size(node: ir.IRNode) -> int:
|
||||
if type(node) == ir._CollectiveKernel:
|
||||
from torch.distributed.distributed_c10d import _get_group_size_by_name
|
||||
|
||||
return _get_group_size_by_name(node.constant_args[-1])
|
||||
elif isinstance(node, ir.CollectiveKernel):
|
||||
return node.constant_args[2] # type: ignore[attr-defined]
|
||||
else:
|
||||
raise TypeError(f"Unsupported collective type: {node}")
|
||||
|
||||
|
||||
####################################################################################################################
|
||||
@ -80,68 +110,63 @@ class NCCL_PROTO(IntEnum):
|
||||
|
||||
# Latencies in us
|
||||
# len(NCCL_ALGO) x len(NCCL_PROTO)
|
||||
baseLat = torch.tensor(
|
||||
# NOTE: use array instead of tensor to prevent incompatibility with fake mode
|
||||
baseLat = [
|
||||
# Tree
|
||||
[
|
||||
# Tree
|
||||
[
|
||||
6.8, # LL
|
||||
],
|
||||
# Ring
|
||||
[
|
||||
6.6, # LL
|
||||
],
|
||||
]
|
||||
)
|
||||
6.8, # LL
|
||||
],
|
||||
# Ring
|
||||
[
|
||||
6.6, # LL
|
||||
],
|
||||
]
|
||||
|
||||
# Latencies in us
|
||||
# len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO)
|
||||
hwLat = torch.tensor(
|
||||
hwLat = [
|
||||
# NVLINK
|
||||
[
|
||||
# NVLINK
|
||||
[
|
||||
[0.6], # Tree (LL)
|
||||
[0.6], # Ring (LL)
|
||||
],
|
||||
# PCI
|
||||
[
|
||||
[1.0], # Tree (LL)
|
||||
[1.0], # Ring (LL)
|
||||
],
|
||||
# NET
|
||||
[
|
||||
[5.0], # Tree (LL)
|
||||
[2.7], # Ring (LL)
|
||||
],
|
||||
]
|
||||
)
|
||||
[0.6], # Tree (LL)
|
||||
[0.6], # Ring (LL)
|
||||
],
|
||||
# PCI
|
||||
[
|
||||
[1.0], # Tree (LL)
|
||||
[1.0], # Ring (LL)
|
||||
],
|
||||
# NET
|
||||
[
|
||||
[5.0], # Tree (LL)
|
||||
[2.7], # Ring (LL)
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
# LL128 max BW per channel
|
||||
llMaxBws = torch.tensor(
|
||||
llMaxBws = [
|
||||
# Volta-N1/Intel-N2/Intel-N4
|
||||
[
|
||||
# Volta-N1/Intel-N2/Intel-N4
|
||||
[
|
||||
39.0,
|
||||
39.0,
|
||||
20.4,
|
||||
],
|
||||
# Ampere-N1/AMD-N2/AMD-N4
|
||||
[
|
||||
87.7,
|
||||
22.5, # avg of ring & tree
|
||||
19.0,
|
||||
],
|
||||
# Hopper-N1/AMD-N2/AMD-N4
|
||||
[
|
||||
87.7,
|
||||
22.5, # avg of ring & tree
|
||||
19.0,
|
||||
],
|
||||
]
|
||||
)
|
||||
39.0,
|
||||
39.0,
|
||||
20.4,
|
||||
],
|
||||
# Ampere-N1/AMD-N2/AMD-N4
|
||||
[
|
||||
87.7,
|
||||
22.5, # avg of ring & tree
|
||||
19.0,
|
||||
],
|
||||
# Hopper-N1/AMD-N2/AMD-N4
|
||||
[
|
||||
87.7,
|
||||
22.5, # avg of ring & tree
|
||||
19.0,
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
|
||||
def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
|
||||
"""
|
||||
Returns estimated NCCL collective runtime in nanoseconds (ns).
|
||||
|
||||
@ -154,16 +179,14 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
|
||||
- 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
|
||||
- collective is one of: allreduce, reducescatter, allgather
|
||||
"""
|
||||
tensor_numel = V.graph.sizevars.size_hint(sympy_product(snode.node.layout.size))
|
||||
tensor_dtype = snode.node.layout.dtype
|
||||
tensor_storage_size_bytes = tensor_numel * get_dtype_size(tensor_dtype)
|
||||
tensor_storage_size_bytes = get_collective_input_size_bytes(node)
|
||||
# Convert bytes to GB
|
||||
tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024
|
||||
|
||||
# Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus.
|
||||
# TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
|
||||
num_gpus_per_node = 8
|
||||
_, _, group_size = snode.node.constant_args # type: ignore[attr-defined]
|
||||
group_size = get_collective_group_size(node)
|
||||
nNodes = math.ceil(group_size / num_gpus_per_node)
|
||||
nRanks = group_size # this is total # of gpus globally that participate in this collective op
|
||||
|
||||
@ -173,7 +196,7 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
|
||||
# Assumes ring algorithm
|
||||
nccl_algo = NCCL_ALGO.RING
|
||||
nccl_proto = NCCL_PROTO.LL
|
||||
coll = get_collective_type(snode)
|
||||
coll = get_collective_type(node)
|
||||
|
||||
# =============== bandwidth computation ===============
|
||||
# First compute bandwidth in GB/s; then at the end, convert it to GB/ns
|
||||
@ -185,7 +208,7 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
|
||||
index2 = nNodes - 1 if nNodes <= 2 else 2
|
||||
# LL: for single node, we look at GPU type; for multi-node, we look at CPU type
|
||||
index1 = compCapIndex if nNodes == 1 else 0
|
||||
llMaxBw = llMaxBws[index1][index2].item()
|
||||
llMaxBw = llMaxBws[index1][index2]
|
||||
|
||||
# NOTE: each step of ring algorithm is synchronized,
|
||||
# and is bottlenecked by the slowest link which is the inter-node interconnect.
|
||||
@ -227,9 +250,9 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
|
||||
nInterSteps = nNodes - 1
|
||||
|
||||
# First compute latency in us; then at the end, convert it to ns
|
||||
latency = baseLat[nccl_algo][nccl_proto].item()
|
||||
intraLat = hwLat[intraHw][nccl_algo][nccl_proto].item()
|
||||
interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto].item()
|
||||
latency = baseLat[nccl_algo][nccl_proto]
|
||||
intraLat = hwLat[intraHw][nccl_algo][nccl_proto]
|
||||
interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto]
|
||||
|
||||
# Inter-node rings still have to launch nsteps * net overhead.
|
||||
netOverhead = 0.0
|
||||
|
Reference in New Issue
Block a user