Files
pytorch/torch/_inductor/comm_analysis.py
Will Feng e9804aaacc Fix unit tests and add logging for Inductor intra-graph reordering (#111981)
1. Fix code to make unit tests pass (incl. collect_env issue called out by @int3  in https://github.com/pytorch/pytorch/pull/108091#discussion_r1362901686).
2. Add logging for Inductor intra-graph reordering passes (`TORCH_LOGS="overlap"`), for easier debugging. Example log:
```
[rank0]:[2023-10-24 16:28:26,446] [0/0] torch._inductor.comms.__overlap: [DEBUG] ==== Visualize overlap before reordering pass <function reorder_compute_for_overlap at 0x7fa68c5568e0> ====
[rank0]:[2023-10-24 16:28:26,446] [0/0] torch._inductor.comms.__overlap: [DEBUG] ComputedBuffer (size=[4, 4], stride=[4, 1]) (buf0)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] ExternKernelOut (extern_kernels.mm) (size=[4, 4], stride=[4, 1]) (buf1)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] InPlaceHint (size=[4, 4], stride=[4, 1]) (buf2)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] AllReduce (size=[4, 4], stride=[4, 1]) (buf3)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] Wait (size=[4, 4], stride=[4, 1]) (buf4)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] ComputedBuffer (size=[4, 4], stride=[4, 1]) (buf5)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] InPlaceHint (size=[4, 4], stride=[4, 1]) (buf6)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] AllReduce (size=[4, 4], stride=[4, 1]) (buf7)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] Wait (size=[4, 4], stride=[4, 1]) (buf8)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] ExternKernelOut (extern_kernels.mm) (size=[4, 4], stride=[4, 1]) (buf9)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] ComputedBuffer (size=[4, 4], stride=[4, 1]) (buf10)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] ExternKernelOut (extern_kernels.mm) (size=[4, 4], stride=[4, 1]) (buf11)
[rank0]:[2023-10-24 16:28:26,447] [0/0] torch._inductor.comms.__overlap: [DEBUG] Est. runtime (ms): 0.000228

[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] ==== Visualize overlap after reordering pass <function reorder_compute_for_overlap at 0x7fa68c5568e0> ====
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] InPlaceHint (size=[4, 4], stride=[4, 1]) (buf2)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] AllReduce (size=[4, 4], stride=[4, 1]) (buf3)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] | ComputedBuffer (size=[4, 4], stride=[4, 1]) (buf0)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] | ExternKernelOut (extern_kernels.mm) (size=[4, 4], stride=[4, 1]) (buf1)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] | ExternKernelOut (extern_kernels.mm) (size=[4, 4], stride=[4, 1]) (buf9)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] Wait (size=[4, 4], stride=[4, 1]) (buf4)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] ComputedBuffer (size=[4, 4], stride=[4, 1]) (buf5)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] InPlaceHint (size=[4, 4], stride=[4, 1]) (buf6)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] AllReduce (size=[4, 4], stride=[4, 1]) (buf7)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] Wait (size=[4, 4], stride=[4, 1]) (buf8)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] ComputedBuffer (size=[4, 4], stride=[4, 1]) (buf10)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] ExternKernelOut (extern_kernels.mm) (size=[4, 4], stride=[4, 1]) (buf11)
[rank0]:[2023-10-24 16:28:26,448] [0/0] torch._inductor.comms.__overlap: [DEBUG] Est. runtime (ms): 0.000217
```
The `| SomeComputeOp` means the compute op is overlapped with the comm op above.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111981
Approved by: https://github.com/wanchaol
2023-10-25 18:19:43 +00:00

277 lines
8.8 KiB
Python

"""
Copyright (c) 2015-2020, NVIDIA CORPORATION. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of NVIDIA CORPORATION, Lawrence Berkeley National
Laboratory, the U.S. Department of Energy, nor the names of their
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
The U.S. Department of Energy funded the development of this software
under subcontract 7078610 with Lawrence Berkeley National Laboratory.
This code also includes files from the NVIDIA Tools Extension SDK project.
See:
https://github.com/NVIDIA/NVTX
for more information and license details.
"""
import math
from enum import IntEnum
import torch
from . import ir
from .utils import get_dtype_size, sympy_product
from .virtualized import V
class NCCL_COLL(IntEnum):
ALL_REDUCE = 0
ALL_GATHER = 1
REDUCE_SCATTER = 2
class NCCL_HW(IntEnum):
NVLINK = 0
PCI = 1
NET = 2
class NCCL_ALGO(IntEnum):
TREE = 0
RING = 1
class NCCL_PROTO(IntEnum):
# The ordering and enum values here matches original in
# https://github.com/NVIDIA/nccl/blob/0b083e52096c387bad7a5c5c65b26a9dca54de8c/src/include/devcomm.h#L28
# For difference between these protocols, see https://github.com/NVIDIA/nccl/issues/281#issuecomment-571816990
LL = 0 # Low-latency
# LL128 = 1 # Low-latency 128-byte
# SIMPLE = 2
class NVIDIA_GPU_TYPE(IntEnum):
VOLTA = 0
AMPERE = 1
HOPPER = 2
# Latencies in us
# len(NCCL_ALGO) x len(NCCL_PROTO)
baseLat = torch.tensor(
[
# Tree
[
6.8, # LL
],
# Ring
[
6.6, # LL
],
]
)
# Latencies in us
# len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO)
hwLat = torch.tensor(
[
# NVLINK
[
[0.6], # Tree (LL)
[0.6], # Ring (LL)
],
# PCI
[
[1.0], # Tree (LL)
[1.0], # Ring (LL)
],
# NET
[
[5.0], # Tree (LL)
[2.7], # Ring (LL)
],
]
)
# LL128 max BW per channel
llMaxBws = torch.tensor(
[
# Volta-N1/Intel-N2/Intel-N4
[
39.0,
39.0,
20.4,
],
# Ampere-N1/AMD-N2/AMD-N4
[
87.7,
22.5, # avg of ring & tree
19.0,
],
# Hopper-N1/AMD-N2/AMD-N4
[
87.7,
22.5, # avg of ring & tree
19.0,
],
]
)
def get_gpu_type() -> NVIDIA_GPU_TYPE:
gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run)
if "V100" in gpu_info:
return NVIDIA_GPU_TYPE.VOLTA
elif "A100" in gpu_info:
return NVIDIA_GPU_TYPE.AMPERE
elif "H100" in gpu_info:
return NVIDIA_GPU_TYPE.HOPPER
else:
# for other gpu types, assume Ampere
return NVIDIA_GPU_TYPE.AMPERE
def get_collective_type(snode: "BaseSchedulerNode") -> NCCL_COLL: # type: ignore[name-defined]
if isinstance(snode.node, (ir.AllReduce, ir.AllReduceCoalesced)):
return NCCL_COLL.ALL_REDUCE
elif isinstance(
snode.node, (ir.AllGatherIntoTensor, ir.AllGatherIntoTensorCoalesced)
):
return NCCL_COLL.ALL_GATHER
elif isinstance(
snode.node, (ir.ReduceScatterTensor, ir.ReduceScatterTensorCoalesced)
):
return NCCL_COLL.REDUCE_SCATTER
else:
raise Exception(f"Unsupported collective type: {snode.node}")
def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float: # type: ignore[name-defined]
"""
Returns estimated NCCL collective runtime in nanoseconds (ns).
The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc.
We aim to estimate the runtime as accurately as possible.
Assumptions:
- only ring algorithm (NCCL_ALGO_RING) is used
- only Low-Latency protocol (NCCL_PROTO_LL) is used, i.e. Simple or LL128 is not used
- 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
- collective is one of: allreduce, reducescatter, allgather
"""
tensor_numel = V.graph.sizevars.size_hint(sympy_product(snode.node.layout.size))
tensor_dtype = snode.node.layout.dtype
tensor_storage_size_bytes = tensor_numel * get_dtype_size(tensor_dtype)
# Convert bytes to GB
tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024
# Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus.
# TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
num_gpus_per_node = 8
_, _, group_size = snode.node.constant_args
nNodes = math.ceil(group_size / num_gpus_per_node)
nRanks = group_size # this is total # of gpus globally that participate in this collective op
if nRanks <= 1:
return 0
# Assumes ring algorithm
nccl_algo = NCCL_ALGO.RING
nccl_proto = NCCL_PROTO.LL
coll = get_collective_type(snode)
# =============== bandwidth computation ===============
# First compute bandwidth in GB/s; then at the end, convert it to GB/ns
bwIntra = torch._inductor.config.intra_node_bw
bwInter = torch._inductor.config.inter_node_bw
compCapIndex = get_gpu_type()
index2 = nNodes - 1 if nNodes <= 2 else 2
# LL: for single node, we look at GPU type; for multi-node, we look at CPU type
index1 = compCapIndex if nNodes == 1 else 0
llMaxBw = llMaxBws[index1][index2].item()
# NOTE: each step of ring algorithm is synchronized,
# and is bottlenecked by the slowest link which is the inter-node interconnect.
# hence when nNodes >= 2, bw is inter-node bandwidth.
# NOTE: the original code in https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc
# have this as `if nNodes <= 2` which seems wrong. Corrected it here.
bw = bwIntra if nNodes == 1 else bwInter
nChannels = 2 # Assume # channels is 2
busBw = nChannels * bw
# Various model refinements
busBw = min(
llMaxBw,
busBw
* (1.0 / 4.0 if (nNodes > 1 or coll == NCCL_COLL.ALL_REDUCE) else 1.0 / 3.0),
)
if coll == NCCL_COLL.ALL_REDUCE:
nsteps = 2 * (nRanks - 1)
elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
nsteps = nRanks - 1
# Convert bus BW to algorithm BW (tensor bytes / algoBW = actual execution time)
ratio = (1.0 * nRanks) / nsteps
bandwidth = busBw * ratio
# Convert GB/s to GB/ns
bandwidth_GB_per_ns = bandwidth / 1e9
# =============== latency computation ===============
intraHw = NCCL_HW.NVLINK
hw = intraHw if nNodes == 1 else NCCL_HW.NET
if coll == NCCL_COLL.ALL_REDUCE:
if nNodes > 1:
nInterSteps = 2 * nNodes
else:
nInterSteps = 0
elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
nInterSteps = nNodes - 1
# First compute latency in us; then at the end, convert it to ns
latency = baseLat[nccl_algo][nccl_proto].item()
intraLat = hwLat[intraHw][nccl_algo][nccl_proto].item()
interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto].item()
# Inter-node rings still have to launch nsteps * net overhead.
netOverhead = 0.0
if nNodes > 1:
netOverhead = 1.0 # getNetOverhead(comm);
intraLat = max(intraLat, netOverhead)
latency += (nsteps - nInterSteps) * intraLat + nInterSteps * interLat
# Convert us to ns
latency_ns = latency * 1e3
# =============== final result ===============
transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns
return transport_ns + latency_ns