""" Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of NVIDIA CORPORATION, Lawrence Berkeley National Laboratory, the U.S. Department of Energy, nor the names of their contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. The U.S. Department of Energy funded the development of this software under subcontract 7078610 with Lawrence Berkeley National Laboratory. This code also includes files from the NVIDIA Tools Extension SDK project. See: https://github.com/NVIDIA/NVTX for more information and license details. """ import math from enum import Enum import torch from . import ir from .utils import get_dtype_size, sympy_product from .virtualized import V class NCCL_COLL(Enum): ALL_REDUCE = 0 ALL_GATHER = 1 REDUCE_SCATTER = 2 class NCCL_HW(Enum): NVLINK = 0 PCI = 1 NET = 2 class NCCL_ALGO(Enum): TREE = 0 RING = 1 class NCCL_PROTO(Enum): SIMPLE = 0 LL = 1 LL128 = 2 class NVIDIA_GPU_TYPE(Enum): VOLTA = 0 AMPERE = 1 HOPPER = 2 # Latencies in us # len(NCCL_ALGO) x len(NCCL_PROTO) baseLat = torch.tensor( [ # Tree [ 6.8, # LL ], # Ring [ 6.6, # LL ], ] ) # Latencies in us # len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO) hwLat = torch.tensor( [ # NVLINK [ [0.6], # Tree (LL) [0.6], # Ring (LL) ], # PCI [ [1.0], # Tree (LL) [1.0], # Ring (LL) ], # NET [ [5.0], # Tree (LL) [2.7], # Ring (LL) ], ] ) # LL128 max BW per channel llMaxBws = torch.tensor( [ # Volta-N1/Intel-N2/Intel-N4 [ 39.0, 39.0, 20.4, ], # Ampere-N1/AMD-N2/AMD-N4 [ 87.7, 22.5, # avg of ring & tree 19.0, ], # Hopper-N1/AMD-N2/AMD-N4 [ 87.7, 22.5, # avg of ring & tree 19.0, ], ] ) def get_gpu_type() -> NVIDIA_GPU_TYPE: gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) if "V100" in gpu_info: return NVIDIA_GPU_TYPE.VOLTA elif "A100" in gpu_info: return NVIDIA_GPU_TYPE.AMPERE elif "H100" in gpu_info: return NVIDIA_GPU_TYPE.HOPPER else: # for other gpu types, assume Ampere return NVIDIA_GPU_TYPE.AMPERE def get_collective_type(snode: "BaseSchedulerNode") -> NCCL_COLL: # type: ignore[name-defined] if isinstance(snode.node, (ir.AllReduce, ir.AllReduceCoalesced)): return NCCL_COLL.ALL_REDUCE elif isinstance( snode.node, (ir.AllGatherIntoTensor, ir.AllGatherIntoTensorCoalesced) ): return NCCL_COLL.ALL_GATHER elif isinstance( snode.node, (ir.ReduceScatterTensor, ir.ReduceScatterTensorCoalesced) ): return NCCL_COLL.REDUCE_SCATTER else: raise Exception(f"Unsupported collective type: {snode.node}") def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float: # type: ignore[name-defined] """ Returns estimated NCCL collective runtime in nanoseconds (ns). The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc. We aim to estimate the runtime as accurately as possible. Assumptions: - only ring algorithm (NCCL_ALGO_RING) is used - only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used - 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. - collective is one of: allreduce, reducescatter, allgather """ tensor_numel = V.graph.sizevars.size_hint(sympy_product(snode.node.layout.size)) tensor_dtype = snode.node.layout.dtype tensor_storage_size_bytes = tensor_numel * get_dtype_size(tensor_dtype) # Convert bytes to GB tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024 # Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus. # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. num_gpus_per_node = 8 _, _, group_size = snode.node.constant_args nNodes = math.ceil(group_size / num_gpus_per_node) nRanks = group_size # this is total # of gpus globally that participate in this collective op if nRanks <= 1: return 0 # Assumes ring algorithm nccl_algo = NCCL_ALGO.RING coll = get_collective_type(snode) # =============== bandwidth computation =============== # First compute bandwidth in GB/s; then at the end, convert it to GB/ns bwIntra = torch._inductor.config.intra_node_bw bwInter = torch._inductor.config.inter_node_bw compCapIndex = get_gpu_type() index2 = nNodes - 1 if nNodes <= 2 else 2 # LL: for single node, we look at GPU type; for multi-node, we look at CPU type index1 = compCapIndex if nNodes == 1 else 0 llMaxBw = llMaxBws[index1][index2] # NOTE: each step of ring algorithm is synchronized, # and is bottlenecked by the slowest link which is the inter-node interconnect. # hence when nNodes >= 2, bw is inter-node bandwidth. # NOTE: the original code in https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc # have this as `if nNodes <= 2` which seems wrong. Corrected it here. bw = bwIntra if nNodes == 1 else bwInter nChannels = 2 # Assume # channels is 2 busBw = nChannels * bw # Various model refinements busBw = min( llMaxBw, busBw * (1.0 / 4.0 if (nNodes > 1 or coll == NCCL_COLL.ALL_REDUCE) else 1.0 / 3.0), ) if coll == NCCL_COLL.ALL_REDUCE: nsteps = 2 * (nRanks - 1) elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): nsteps = nRanks - 1 # Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time) ratio = (1.0 * nRanks) / nsteps bandwidth = busBw * ratio # Convert GB/s to GB/ns bandwidth_GB_per_ns = bandwidth / 1e9 # =============== latency computation =============== intraHw = NCCL_HW.NVLINK hw = intraHw if nNodes == 1 else NCCL_HW.NET if coll == NCCL_COLL.ALL_REDUCE: if nNodes > 1: nInterSteps = 2 * nNodes else: nInterSteps = 0 elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): nInterSteps = nNodes - 1 # First compute latency in us; then at the end, convert it to ns latency = baseLat[nccl_algo] intraLat = hwLat[intraHw][nccl_algo] interLat = hwLat[NCCL_HW.NET][nccl_algo] lat = hwLat[hw][nccl_algo] # Inter-node rings still have to launch nsteps * net overhead. netOverhead = 0.0 if nNodes > 1: netOverhead = 1.0 # getNetOverhead(comm); intraLat = max(intraLat, netOverhead) latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat # Convert us to ns latency_ns = latency * 1e3 # =============== final result =============== transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns return transport_ns + latency_ns