mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
We adapted the cost model from NCCL code, we should apply their license here as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/111670 Approved by: https://github.com/Chillee, https://github.com/wanchaol
274 lines
8.4 KiB
Python
274 lines
8.4 KiB
Python
"""
|
|
Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
|
|
|
|
Redistribution and use in source and binary forms, with or without
|
|
modification, are permitted provided that the following conditions
|
|
are met:
|
|
* Redistributions of source code must retain the above copyright
|
|
notice, this list of conditions and the following disclaimer.
|
|
* Redistributions in binary form must reproduce the above copyright
|
|
notice, this list of conditions and the following disclaimer in the
|
|
documentation and/or other materials provided with the distribution.
|
|
* Neither the name of NVIDIA CORPORATION, Lawrence Berkeley National
|
|
Laboratory, the U.S. Department of Energy, nor the names of their
|
|
contributors may be used to endorse or promote products derived
|
|
from this software without specific prior written permission.
|
|
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
|
|
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
|
PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
|
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
|
OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
The U.S. Department of Energy funded the development of this software
|
|
under subcontract 7078610 with Lawrence Berkeley National Laboratory.
|
|
|
|
|
|
This code also includes files from the NVIDIA Tools Extension SDK project.
|
|
|
|
See:
|
|
|
|
https://github.com/NVIDIA/NVTX
|
|
|
|
for more information and license details.
|
|
"""
|
|
|
|
import math
|
|
from enum import Enum
|
|
|
|
import torch
|
|
from . import ir
|
|
|
|
from .utils import get_dtype_size, sympy_product
|
|
from .virtualized import V
|
|
|
|
|
|
class NCCL_COLL(Enum):
|
|
ALL_REDUCE = 0
|
|
ALL_GATHER = 1
|
|
REDUCE_SCATTER = 2
|
|
|
|
|
|
class NCCL_HW(Enum):
|
|
NVLINK = 0
|
|
PCI = 1
|
|
NET = 2
|
|
|
|
|
|
class NCCL_ALGO(Enum):
|
|
TREE = 0
|
|
RING = 1
|
|
|
|
|
|
class NCCL_PROTO(Enum):
|
|
SIMPLE = 0
|
|
LL = 1
|
|
LL128 = 2
|
|
|
|
|
|
class NVIDIA_GPU_TYPE(Enum):
|
|
VOLTA = 0
|
|
AMPERE = 1
|
|
HOPPER = 2
|
|
|
|
|
|
# Latencies in us
|
|
# len(NCCL_ALGO) x len(NCCL_PROTO)
|
|
baseLat = torch.tensor(
|
|
[
|
|
# Tree
|
|
[
|
|
6.8, # LL
|
|
],
|
|
# Ring
|
|
[
|
|
6.6, # LL
|
|
],
|
|
]
|
|
)
|
|
|
|
# Latencies in us
|
|
# len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO)
|
|
hwLat = torch.tensor(
|
|
[
|
|
# NVLINK
|
|
[
|
|
[0.6], # Tree (LL)
|
|
[0.6], # Ring (LL)
|
|
],
|
|
# PCI
|
|
[
|
|
[1.0], # Tree (LL)
|
|
[1.0], # Ring (LL)
|
|
],
|
|
# NET
|
|
[
|
|
[5.0], # Tree (LL)
|
|
[2.7], # Ring (LL)
|
|
],
|
|
]
|
|
)
|
|
|
|
|
|
# LL128 max BW per channel
|
|
llMaxBws = torch.tensor(
|
|
[
|
|
# Volta-N1/Intel-N2/Intel-N4
|
|
[
|
|
39.0,
|
|
39.0,
|
|
20.4,
|
|
],
|
|
# Ampere-N1/AMD-N2/AMD-N4
|
|
[
|
|
87.7,
|
|
22.5, # avg of ring & tree
|
|
19.0,
|
|
],
|
|
# Hopper-N1/AMD-N2/AMD-N4
|
|
[
|
|
87.7,
|
|
22.5, # avg of ring & tree
|
|
19.0,
|
|
],
|
|
]
|
|
)
|
|
|
|
|
|
def get_gpu_type() -> NVIDIA_GPU_TYPE:
|
|
gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run)
|
|
if "V100" in gpu_info:
|
|
return NVIDIA_GPU_TYPE.VOLTA
|
|
elif "A100" in gpu_info:
|
|
return NVIDIA_GPU_TYPE.AMPERE
|
|
elif "H100" in gpu_info:
|
|
return NVIDIA_GPU_TYPE.HOPPER
|
|
else:
|
|
# for other gpu types, assume Ampere
|
|
return NVIDIA_GPU_TYPE.AMPERE
|
|
|
|
|
|
def get_collective_type(snode: "BaseSchedulerNode") -> NCCL_COLL: # type: ignore[name-defined]
|
|
if isinstance(snode.node, (ir.AllReduce, ir.AllReduceCoalesced)):
|
|
return NCCL_COLL.ALL_REDUCE
|
|
elif isinstance(
|
|
snode.node, (ir.AllGatherIntoTensor, ir.AllGatherIntoTensorCoalesced)
|
|
):
|
|
return NCCL_COLL.ALL_GATHER
|
|
elif isinstance(
|
|
snode.node, (ir.ReduceScatterTensor, ir.ReduceScatterTensorCoalesced)
|
|
):
|
|
return NCCL_COLL.REDUCE_SCATTER
|
|
else:
|
|
raise Exception(f"Unsupported collective type: {snode.node}")
|
|
|
|
|
|
def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float: # type: ignore[name-defined]
|
|
"""
|
|
Returns estimated NCCL collective runtime in nanoseconds (ns).
|
|
|
|
The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc.
|
|
We aim to estimate the runtime as accurately as possible.
|
|
|
|
Assumptions:
|
|
- only ring algorithm (NCCL_ALGO_RING) is used
|
|
- only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used
|
|
- 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
|
|
- collective is one of: allreduce, reducescatter, allgather
|
|
"""
|
|
tensor_numel = V.graph.sizevars.size_hint(sympy_product(snode.node.layout.size))
|
|
tensor_dtype = snode.node.layout.dtype
|
|
tensor_storage_size_bytes = tensor_numel * get_dtype_size(tensor_dtype)
|
|
# Convert bytes to GB
|
|
tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024
|
|
|
|
# Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus.
|
|
# TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
|
|
num_gpus_per_node = 8
|
|
_, _, group_size = snode.node.constant_args
|
|
nNodes = math.ceil(group_size / num_gpus_per_node)
|
|
nRanks = group_size # this is total # of gpus globally that participate in this collective op
|
|
|
|
if nRanks <= 1:
|
|
return 0
|
|
|
|
# Assumes ring algorithm
|
|
nccl_algo = NCCL_ALGO.RING
|
|
coll = get_collective_type(snode)
|
|
|
|
# =============== bandwidth computation ===============
|
|
# First compute bandwidth in GB/s; then at the end, convert it to GB/ns
|
|
|
|
bwIntra = torch._inductor.config.intra_node_bw
|
|
bwInter = torch._inductor.config.inter_node_bw
|
|
|
|
compCapIndex = get_gpu_type()
|
|
index2 = nNodes - 1 if nNodes <= 2 else 2
|
|
# LL: for single node, we look at GPU type; for multi-node, we look at CPU type
|
|
index1 = compCapIndex if nNodes == 1 else 0
|
|
llMaxBw = llMaxBws[index1][index2]
|
|
|
|
# NOTE: each step of ring algorithm is synchronized,
|
|
# and is bottlenecked by the slowest link which is the inter-node interconnect.
|
|
# hence when nNodes >= 2, bw is inter-node bandwidth.
|
|
# NOTE: the original code in https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc
|
|
# have this as `if nNodes <= 2` which seems wrong. Corrected it here.
|
|
bw = bwIntra if nNodes == 1 else bwInter
|
|
nChannels = 2 # Assume # channels is 2
|
|
busBw = nChannels * bw
|
|
|
|
# Various model refinements
|
|
busBw = min(
|
|
llMaxBw,
|
|
busBw
|
|
* (1.0 / 4.0 if (nNodes > 1 or coll == NCCL_COLL.ALL_REDUCE) else 1.0 / 3.0),
|
|
)
|
|
|
|
if coll == NCCL_COLL.ALL_REDUCE:
|
|
nsteps = 2 * (nRanks - 1)
|
|
elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
|
|
nsteps = nRanks - 1
|
|
|
|
# Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time)
|
|
ratio = (1.0 * nRanks) / nsteps
|
|
bandwidth = busBw * ratio
|
|
# Convert GB/s to GB/ns
|
|
bandwidth_GB_per_ns = bandwidth / 1e9
|
|
|
|
# =============== latency computation ===============
|
|
intraHw = NCCL_HW.NVLINK
|
|
hw = intraHw if nNodes == 1 else NCCL_HW.NET
|
|
|
|
if coll == NCCL_COLL.ALL_REDUCE:
|
|
if nNodes > 1:
|
|
nInterSteps = 2 * nNodes
|
|
else:
|
|
nInterSteps = 0
|
|
elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
|
|
nInterSteps = nNodes - 1
|
|
|
|
# First compute latency in us; then at the end, convert it to ns
|
|
latency = baseLat[nccl_algo]
|
|
intraLat = hwLat[intraHw][nccl_algo]
|
|
interLat = hwLat[NCCL_HW.NET][nccl_algo]
|
|
|
|
lat = hwLat[hw][nccl_algo]
|
|
# Inter-node rings still have to launch nsteps * net overhead.
|
|
netOverhead = 0.0
|
|
if nNodes > 1:
|
|
netOverhead = 1.0 # getNetOverhead(comm);
|
|
intraLat = max(intraLat, netOverhead)
|
|
latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat
|
|
# Convert us to ns
|
|
latency_ns = latency * 1e3
|
|
|
|
# =============== final result ===============
|
|
transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns
|
|
return transport_ns + latency_ns
|