[reland] Fix estimate_nccl_collective_runtime (#118986)

`estimate_nccl_collective_runtime` has been broken and the errors have been silently swallowed by inductor. This PR:
- Fixes the issues described in https://github.com/pytorch/pytorch/issues/118497.
- Adds white-box testing so future issues can be surfaced in tests.
- Add support for native funcol IRs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118986
Approved by: https://github.com/yf225
ghstack dependencies: #119102
This commit is contained in:
Yifu Wang
2024-02-11 21:11:33 -08:00
committed by PyTorch MergeBot
parent b2043c0543
commit 27ffede878
7 changed files with 282 additions and 87 deletions

View File

@ -278,9 +278,14 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
self.assertTrue(same(out, correct))
def test_nccl_heuristics(self):
assert list(baseLat.shape) == [len(NCCL_ALGO), len(NCCL_PROTO)]
assert list(hwLat.shape) == [len(NCCL_HW), len(NCCL_ALGO), len(NCCL_PROTO)]
assert llMaxBws.shape[0] == len(NVIDIA_GPU_TYPE)
assert len(baseLat) == len(NCCL_ALGO)
assert all(len(x) == len(NCCL_PROTO) for x in baseLat)
assert len(hwLat) == len(NCCL_HW)
assert all(len(x) == len(NCCL_ALGO) for x in hwLat)
assert all(len(y) == len(NCCL_PROTO) for x in hwLat for y in x)
assert len(llMaxBws) == len(NVIDIA_GPU_TYPE)
if __name__ == "__main__":

View File

@ -1,13 +1,21 @@
# Owner(s): ["module: inductor"]
from unittest import skipIf
import torch
import torch.distributed as dist
from torch._inductor import metrics
from torch._inductor.comm_analysis import estimate_nccl_collective_runtime
from torch._inductor.compile_fx import compile_fx, count_bytes_inner
from torch._inductor.utils import is_collective
from torch.testing._internal.common_utils import TestCase as TorchTestCase
from torch.testing._internal.inductor_utils import HAS_CUDA
aten = torch.ops.aten
c10d = torch.ops.c10d_functional
_c10d = torch.ops._c10d_functional
def count_bytes_inductor(gm, example_inputs):
@ -165,6 +173,135 @@ class MemoryBoundedTests(TestCase):
self.assertNotZero(calculate_runtime(f, *inp))
@skipIf(not dist.is_available(), "requires distributed")
class TestCommAnalysis(TestCase):
WORLD_SIZE: int = 8
RANKS = list(range(8))
def _verify_runtime_estimation(self, fn, inps):
from torch.testing._internal.distributed.fake_pg import FakeStore
store = FakeStore()
dist.init_process_group(
backend="fake", rank=0, world_size=self.WORLD_SIZE, store=store
)
try:
metrics.reset()
torch._dynamo.optimize(count_bytes_inductor)(fn)(*inps)
found_collective = False
for snode, runtime in metrics.node_runtimes:
if not is_collective(snode.node):
continue
found_collective = True
# Inductor swallows errors from snode runtime estimations.
# We call estimate_nccl_collective_runtime in a white-box
# fashion here so potential issues can be surfaced in tests.
est = estimate_nccl_collective_runtime(snode.node)
self.assertNotZero(est)
# Also make sure estimate_nccl_collective_runtime works
# correctly in inductor.
self.assertNotZero(runtime)
# Make sure a collective kernel is found in graph
self.assertTrue(found_collective)
finally:
dist.destroy_process_group()
def test_legacy_all_reduce(self):
def fn(x):
r = c10d.all_reduce(x, "sum", "", self.RANKS, self.WORLD_SIZE)
return c10d.wait_tensor(r)
inp = T(10, 10)
self._verify_runtime_estimation(fn, (inp,))
def test_legacy_all_reduce_coalesced(self):
def fn(x):
rs = c10d.all_reduce_coalesced(x, "sum", "", self.RANKS, self.WORLD_SIZE)
return [c10d.wait_tensor(r) for r in rs]
inp = [T(10, 10), T(15, 15)]
self._verify_runtime_estimation(fn, (inp,))
def test_legacy_all_gather_into_tensor_coalesced(self):
def fn(x):
rs = c10d.all_gather_into_tensor_coalesced(
x,
"",
self.RANKS,
self.WORLD_SIZE,
)
return [c10d.wait_tensor(r) for r in rs]
inp = [T(10, 10), T(15, 15)]
self._verify_runtime_estimation(fn, (inp,))
def test_all_reduce(self):
def fn(x):
r = _c10d.all_reduce(x, "sum", "0")
return _c10d.wait_tensor(r)
inp = T(10, 10)
self._verify_runtime_estimation(fn, (inp,))
def test_all_reduce_coalesced(self):
def fn(x):
rs = _c10d.all_reduce_coalesced(x, "sum", "0")
return [_c10d.wait_tensor(r) for r in rs]
inp = [T(10, 10), T(15, 15)]
self._verify_runtime_estimation(fn, (inp,))
def test_all_gather_into_tensor(self):
def fn(x):
rs = _c10d.all_gather_into_tensor(
x,
self.WORLD_SIZE,
"0",
)
return [_c10d.wait_tensor(r) for r in rs]
inp = T(10, 10)
self._verify_runtime_estimation(fn, (inp,))
def test_all_gather_into_tensor_coalesced(self):
def fn(x):
rs = _c10d.all_gather_into_tensor_coalesced(
x,
self.WORLD_SIZE,
"0",
)
return [_c10d.wait_tensor(r) for r in rs]
inp = [T(10, 10), T(15, 15)]
self._verify_runtime_estimation(fn, (inp,))
def test_reduce_scatter_tensor(self):
def fn(x):
rs = _c10d.reduce_scatter_tensor(
x,
"sum",
self.WORLD_SIZE,
"0",
)
return [_c10d.wait_tensor(r) for r in rs]
inp = T(self.WORLD_SIZE, 10)
self._verify_runtime_estimation(fn, (inp,))
def test_reduce_scatter_tensor_coalesced(self):
def fn(x):
rs = _c10d.reduce_scatter_tensor_coalesced(
x,
"sum",
self.WORLD_SIZE,
"0",
)
return [_c10d.wait_tensor(r) for r in rs]
inp = [T(self.WORLD_SIZE, 10), T(self.WORLD_SIZE, 15)]
self._verify_runtime_estimation(fn, (inp,))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1,7 +1,7 @@
import math
from enum import IntEnum
from typing import TYPE_CHECKING
import sympy
import torch
from . import ir
@ -9,9 +9,6 @@ from . import ir
from .utils import get_dtype_size, sympy_product
from .virtualized import V
if TYPE_CHECKING:
from torch._inductor.scheduler import BaseSchedulerNode
class NCCL_COLL(IntEnum):
ALL_REDUCE = 0
@ -26,7 +23,7 @@ class NVIDIA_GPU_TYPE(IntEnum):
def get_gpu_type() -> NVIDIA_GPU_TYPE:
gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run)
gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or ""
if "V100" in gpu_info:
return NVIDIA_GPU_TYPE.VOLTA
elif "A100" in gpu_info:
@ -38,19 +35,52 @@ def get_gpu_type() -> NVIDIA_GPU_TYPE:
return NVIDIA_GPU_TYPE.AMPERE
def get_collective_type(snode: "BaseSchedulerNode") -> NCCL_COLL:
if isinstance(snode.node, (ir.AllReduce, ir.AllReduceCoalesced)):
def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
if isinstance(node, ir._CollectiveKernel):
kernel_name = node.python_kernel_name
assert kernel_name is not None
if "all_reduce" in kernel_name:
return NCCL_COLL.ALL_REDUCE
elif isinstance(
snode.node, (ir.AllGatherIntoTensor, ir.AllGatherIntoTensorCoalesced)
):
elif "all_gather" in kernel_name:
return NCCL_COLL.ALL_GATHER
elif isinstance(
snode.node, (ir.ReduceScatterTensor, ir.ReduceScatterTensorCoalesced)
):
elif "reduce_scatter" in kernel_name:
return NCCL_COLL.REDUCE_SCATTER
else:
raise Exception(f"Unsupported collective type: {snode.node}")
raise Exception(f"Unsupported collective kernel: {kernel_name}")
if isinstance(node, (ir.AllReduce, ir.AllReduceCoalesced)):
return NCCL_COLL.ALL_REDUCE
elif isinstance(node, (ir.AllGatherIntoTensor, ir.AllGatherIntoTensorCoalesced)):
return NCCL_COLL.ALL_GATHER
elif isinstance(node, (ir.ReduceScatterTensor, ir.ReduceScatterTensorCoalesced)):
return NCCL_COLL.REDUCE_SCATTER
else:
raise Exception(f"Unsupported collective type: {node}")
def get_collective_input_size_bytes(node: ir.IRNode) -> int:
sz_bytes = 0
for inp in node.inputs: # type: ignore[attr-defined]
shape = inp.layout.size
numel = sympy_product(inp.layout.size)
if isinstance(numel, sympy.Integer):
# For ease of testing
numel = int(numel)
else:
numel = V.graph.sizevars.size_hint(numel)
sz_bytes += numel * get_dtype_size(inp.layout.dtype)
return sz_bytes
def get_collective_group_size(node: ir.IRNode) -> int:
if type(node) == ir._CollectiveKernel:
from torch.distributed.distributed_c10d import _get_group_size_by_name
return _get_group_size_by_name(node.constant_args[-1])
elif isinstance(node, ir.CollectiveKernel):
return node.constant_args[2] # type: ignore[attr-defined]
else:
raise TypeError(f"Unsupported collective type: {node}")
####################################################################################################################
@ -80,8 +110,8 @@ class NCCL_PROTO(IntEnum):
# Latencies in us
# len(NCCL_ALGO) x len(NCCL_PROTO)
baseLat = torch.tensor(
[
# NOTE: use array instead of tensor to prevent incompatibility with fake mode
baseLat = [
# Tree
[
6.8, # LL
@ -90,13 +120,11 @@ baseLat = torch.tensor(
[
6.6, # LL
],
]
)
]
# Latencies in us
# len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO)
hwLat = torch.tensor(
[
hwLat = [
# NVLINK
[
[0.6], # Tree (LL)
@ -112,13 +140,11 @@ hwLat = torch.tensor(
[5.0], # Tree (LL)
[2.7], # Ring (LL)
],
]
)
]
# LL128 max BW per channel
llMaxBws = torch.tensor(
[
llMaxBws = [
# Volta-N1/Intel-N2/Intel-N4
[
39.0,
@ -137,11 +163,10 @@ llMaxBws = torch.tensor(
22.5, # avg of ring & tree
19.0,
],
]
)
]
def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
"""
Returns estimated NCCL collective runtime in nanoseconds (ns).
@ -154,16 +179,14 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
- 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
- collective is one of: allreduce, reducescatter, allgather
"""
tensor_numel = V.graph.sizevars.size_hint(sympy_product(snode.node.layout.size))
tensor_dtype = snode.node.layout.dtype
tensor_storage_size_bytes = tensor_numel * get_dtype_size(tensor_dtype)
tensor_storage_size_bytes = get_collective_input_size_bytes(node)
# Convert bytes to GB
tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024
# Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus.
# TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
num_gpus_per_node = 8
_, _, group_size = snode.node.constant_args # type: ignore[attr-defined]
group_size = get_collective_group_size(node)
nNodes = math.ceil(group_size / num_gpus_per_node)
nRanks = group_size # this is total # of gpus globally that participate in this collective op
@ -173,7 +196,7 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
# Assumes ring algorithm
nccl_algo = NCCL_ALGO.RING
nccl_proto = NCCL_PROTO.LL
coll = get_collective_type(snode)
coll = get_collective_type(node)
# =============== bandwidth computation ===============
# First compute bandwidth in GB/s; then at the end, convert it to GB/ns
@ -185,7 +208,7 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
index2 = nNodes - 1 if nNodes <= 2 else 2
# LL: for single node, we look at GPU type; for multi-node, we look at CPU type
index1 = compCapIndex if nNodes == 1 else 0
llMaxBw = llMaxBws[index1][index2].item()
llMaxBw = llMaxBws[index1][index2]
# NOTE: each step of ring algorithm is synchronized,
# and is bottlenecked by the slowest link which is the inter-node interconnect.
@ -227,9 +250,9 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
nInterSteps = nNodes - 1
# First compute latency in us; then at the end, convert it to ns
latency = baseLat[nccl_algo][nccl_proto].item()
intraLat = hwLat[intraHw][nccl_algo][nccl_proto].item()
interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto].item()
latency = baseLat[nccl_algo][nccl_proto]
intraLat = hwLat[intraHw][nccl_algo][nccl_proto]
interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto]
# Inter-node rings still have to launch nsteps * net overhead.
netOverhead = 0.0

View File

@ -7494,6 +7494,8 @@ class _CollectiveKernel(FallbackKernel):
def create_inplace(
cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs
) -> None:
cpp_kernel_name = kernel._name
python_kernel_name = cpp_kernel_name.replace("::", ".")
with V.graph.fake_mode:
(
example_output,
@ -7511,6 +7513,8 @@ class _CollectiveKernel(FallbackKernel):
non_tensor_args,
unflatten_args,
)
packed.cpp_kernel_name = cpp_kernel_name
packed.python_kernel_name = python_kernel_name
def mark_mutation(x):
if isinstance(x.data, BaseView):
@ -7545,6 +7549,8 @@ class _CollectiveKernel(FallbackKernel):
def create_out_of_place(
cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs
):
cpp_kernel_name = kernel._name
python_kernel_name = cpp_kernel_name.replace("::", ".")
with V.graph.fake_mode:
(
example_output,
@ -7564,6 +7570,8 @@ class _CollectiveKernel(FallbackKernel):
non_tensor_args,
unflatten_args,
)
packed.cpp_kernel_name = cpp_kernel_name
packed.python_kernel_name = python_kernel_name
packed.outputs = [
MultiOutput(
cls.tensor_to_layout(tensor),
@ -7581,6 +7589,8 @@ class _CollectiveKernel(FallbackKernel):
non_tensor_args,
unflatten_args,
)
packed.cpp_kernel_name = cpp_kernel_name
packed.python_kernel_name = python_kernel_name
packed.outputs = [packed]
return packed

View File

@ -44,6 +44,8 @@ from .utils import (
get_dtype_size,
get_gpu_dram_gbps,
green_text,
is_collective,
is_wait,
red_text,
sympy_product,
)
@ -584,6 +586,16 @@ class BaseSchedulerNode:
# default to no reordering based on runtime
return 0
# Collective kernels
if is_collective(self.node):
return estimate_nccl_collective_runtime(self.node)
elif is_wait(self.node):
# ir.Wait is only used for collective ops.
# The time needed for the collective op is already estimated and considered
# when we are processing the collective op IR node, so ir.Wait takes 0 time
# since it doesn't take extra time to get the result after the collective is completed.
return 0
try:
gpu_memory_bandwidth = get_gpu_dram_gbps()
gpu_flops = get_device_tflops(dtype) * 10**12
@ -629,16 +641,6 @@ class BaseSchedulerNode:
# Return estimated runtime in nanoseconds (bytes / gbps)
return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth
# Collective kernels
if isinstance(self.node, ir.CollectiveKernel):
return estimate_nccl_collective_runtime(self)
elif isinstance(self.node, ir.Wait):
# ir.Wait is only used for collective ops.
# The time needed for the collective op is already estimated and considered
# when we are processing the collective op IR node, so ir.Wait takes 0 time
# since it doesn't take extra time to get the result after the collective is completed.
return 0
return 0

View File

@ -1261,3 +1261,15 @@ def pass_execution_and_save(func, gm, msg):
t,
time_elapsed,
)
def is_collective(node):
from . import ir
return isinstance(node, ir.CollectiveKernel) or type(node) == ir._CollectiveKernel
def is_wait(node):
from . import ir
return isinstance(node, ir.Wait) or type(node) == ir._WaitKernel

View File

@ -9,6 +9,12 @@ class FakeWork : public Work {
bool wait(std::chrono::milliseconds timeout) override {
return true;
}
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
auto fut = c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
fut->markCompleted();
return fut;
}
};
class FakeProcessGroup : public Backend {