[reland] Fix estimate_nccl_collective_runtime (#118986)

`estimate_nccl_collective_runtime` has been broken and the errors have been silently swallowed by inductor. This PR:
- Fixes the issues described in https://github.com/pytorch/pytorch/issues/118497.
- Adds white-box testing so future issues can be surfaced in tests.
- Add support for native funcol IRs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118986
Approved by: https://github.com/yf225
ghstack dependencies: #119102
This commit is contained in:
Yifu Wang
2024-02-11 21:11:33 -08:00
committed by PyTorch MergeBot
parent b2043c0543
commit 27ffede878
7 changed files with 282 additions and 87 deletions

View File

@ -1,7 +1,7 @@
import math
from enum import IntEnum
from typing import TYPE_CHECKING
import sympy
import torch
from . import ir
@ -9,9 +9,6 @@ from . import ir
from .utils import get_dtype_size, sympy_product
from .virtualized import V
if TYPE_CHECKING:
from torch._inductor.scheduler import BaseSchedulerNode
class NCCL_COLL(IntEnum):
ALL_REDUCE = 0
@ -26,7 +23,7 @@ class NVIDIA_GPU_TYPE(IntEnum):
def get_gpu_type() -> NVIDIA_GPU_TYPE:
gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run)
gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or ""
if "V100" in gpu_info:
return NVIDIA_GPU_TYPE.VOLTA
elif "A100" in gpu_info:
@ -38,19 +35,52 @@ def get_gpu_type() -> NVIDIA_GPU_TYPE:
return NVIDIA_GPU_TYPE.AMPERE
def get_collective_type(snode: "BaseSchedulerNode") -> NCCL_COLL:
if isinstance(snode.node, (ir.AllReduce, ir.AllReduceCoalesced)):
def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
if isinstance(node, ir._CollectiveKernel):
kernel_name = node.python_kernel_name
assert kernel_name is not None
if "all_reduce" in kernel_name:
return NCCL_COLL.ALL_REDUCE
elif "all_gather" in kernel_name:
return NCCL_COLL.ALL_GATHER
elif "reduce_scatter" in kernel_name:
return NCCL_COLL.REDUCE_SCATTER
else:
raise Exception(f"Unsupported collective kernel: {kernel_name}")
if isinstance(node, (ir.AllReduce, ir.AllReduceCoalesced)):
return NCCL_COLL.ALL_REDUCE
elif isinstance(
snode.node, (ir.AllGatherIntoTensor, ir.AllGatherIntoTensorCoalesced)
):
elif isinstance(node, (ir.AllGatherIntoTensor, ir.AllGatherIntoTensorCoalesced)):
return NCCL_COLL.ALL_GATHER
elif isinstance(
snode.node, (ir.ReduceScatterTensor, ir.ReduceScatterTensorCoalesced)
):
elif isinstance(node, (ir.ReduceScatterTensor, ir.ReduceScatterTensorCoalesced)):
return NCCL_COLL.REDUCE_SCATTER
else:
raise Exception(f"Unsupported collective type: {snode.node}")
raise Exception(f"Unsupported collective type: {node}")
def get_collective_input_size_bytes(node: ir.IRNode) -> int:
sz_bytes = 0
for inp in node.inputs: # type: ignore[attr-defined]
shape = inp.layout.size
numel = sympy_product(inp.layout.size)
if isinstance(numel, sympy.Integer):
# For ease of testing
numel = int(numel)
else:
numel = V.graph.sizevars.size_hint(numel)
sz_bytes += numel * get_dtype_size(inp.layout.dtype)
return sz_bytes
def get_collective_group_size(node: ir.IRNode) -> int:
if type(node) == ir._CollectiveKernel:
from torch.distributed.distributed_c10d import _get_group_size_by_name
return _get_group_size_by_name(node.constant_args[-1])
elif isinstance(node, ir.CollectiveKernel):
return node.constant_args[2] # type: ignore[attr-defined]
else:
raise TypeError(f"Unsupported collective type: {node}")
####################################################################################################################
@ -80,68 +110,63 @@ class NCCL_PROTO(IntEnum):
# Latencies in us
# len(NCCL_ALGO) x len(NCCL_PROTO)
baseLat = torch.tensor(
# NOTE: use array instead of tensor to prevent incompatibility with fake mode
baseLat = [
# Tree
[
# Tree
[
6.8, # LL
],
# Ring
[
6.6, # LL
],
]
)
6.8, # LL
],
# Ring
[
6.6, # LL
],
]
# Latencies in us
# len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO)
hwLat = torch.tensor(
hwLat = [
# NVLINK
[
# NVLINK
[
[0.6], # Tree (LL)
[0.6], # Ring (LL)
],
# PCI
[
[1.0], # Tree (LL)
[1.0], # Ring (LL)
],
# NET
[
[5.0], # Tree (LL)
[2.7], # Ring (LL)
],
]
)
[0.6], # Tree (LL)
[0.6], # Ring (LL)
],
# PCI
[
[1.0], # Tree (LL)
[1.0], # Ring (LL)
],
# NET
[
[5.0], # Tree (LL)
[2.7], # Ring (LL)
],
]
# LL128 max BW per channel
llMaxBws = torch.tensor(
llMaxBws = [
# Volta-N1/Intel-N2/Intel-N4
[
# Volta-N1/Intel-N2/Intel-N4
[
39.0,
39.0,
20.4,
],
# Ampere-N1/AMD-N2/AMD-N4
[
87.7,
22.5, # avg of ring & tree
19.0,
],
# Hopper-N1/AMD-N2/AMD-N4
[
87.7,
22.5, # avg of ring & tree
19.0,
],
]
)
39.0,
39.0,
20.4,
],
# Ampere-N1/AMD-N2/AMD-N4
[
87.7,
22.5, # avg of ring & tree
19.0,
],
# Hopper-N1/AMD-N2/AMD-N4
[
87.7,
22.5, # avg of ring & tree
19.0,
],
]
def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
"""
Returns estimated NCCL collective runtime in nanoseconds (ns).
@ -154,16 +179,14 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
- 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
- collective is one of: allreduce, reducescatter, allgather
"""
tensor_numel = V.graph.sizevars.size_hint(sympy_product(snode.node.layout.size))
tensor_dtype = snode.node.layout.dtype
tensor_storage_size_bytes = tensor_numel * get_dtype_size(tensor_dtype)
tensor_storage_size_bytes = get_collective_input_size_bytes(node)
# Convert bytes to GB
tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024
# Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus.
# TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
num_gpus_per_node = 8
_, _, group_size = snode.node.constant_args # type: ignore[attr-defined]
group_size = get_collective_group_size(node)
nNodes = math.ceil(group_size / num_gpus_per_node)
nRanks = group_size # this is total # of gpus globally that participate in this collective op
@ -173,7 +196,7 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
# Assumes ring algorithm
nccl_algo = NCCL_ALGO.RING
nccl_proto = NCCL_PROTO.LL
coll = get_collective_type(snode)
coll = get_collective_type(node)
# =============== bandwidth computation ===============
# First compute bandwidth in GB/s; then at the end, convert it to GB/ns
@ -185,7 +208,7 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
index2 = nNodes - 1 if nNodes <= 2 else 2
# LL: for single node, we look at GPU type; for multi-node, we look at CPU type
index1 = compCapIndex if nNodes == 1 else 0
llMaxBw = llMaxBws[index1][index2].item()
llMaxBw = llMaxBws[index1][index2]
# NOTE: each step of ring algorithm is synchronized,
# and is bottlenecked by the slowest link which is the inter-node interconnect.
@ -227,9 +250,9 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
nInterSteps = nNodes - 1
# First compute latency in us; then at the end, convert it to ns
latency = baseLat[nccl_algo][nccl_proto].item()
intraLat = hwLat[intraHw][nccl_algo][nccl_proto].item()
interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto].item()
latency = baseLat[nccl_algo][nccl_proto]
intraLat = hwLat[intraHw][nccl_algo][nccl_proto]
interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto]
# Inter-node rings still have to launch nsteps * net overhead.
netOverhead = 0.0