mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[reland] Fix estimate_nccl_collective_runtime (#118986)
`estimate_nccl_collective_runtime` has been broken and the errors have been silently swallowed by inductor. This PR: - Fixes the issues described in https://github.com/pytorch/pytorch/issues/118497. - Adds white-box testing so future issues can be surfaced in tests. - Add support for native funcol IRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118986 Approved by: https://github.com/yf225 ghstack dependencies: #119102
This commit is contained in:
committed by
PyTorch MergeBot
parent
b2043c0543
commit
27ffede878
@ -278,9 +278,14 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
def test_nccl_heuristics(self):
|
||||
assert list(baseLat.shape) == [len(NCCL_ALGO), len(NCCL_PROTO)]
|
||||
assert list(hwLat.shape) == [len(NCCL_HW), len(NCCL_ALGO), len(NCCL_PROTO)]
|
||||
assert llMaxBws.shape[0] == len(NVIDIA_GPU_TYPE)
|
||||
assert len(baseLat) == len(NCCL_ALGO)
|
||||
assert all(len(x) == len(NCCL_PROTO) for x in baseLat)
|
||||
|
||||
assert len(hwLat) == len(NCCL_HW)
|
||||
assert all(len(x) == len(NCCL_ALGO) for x in hwLat)
|
||||
assert all(len(y) == len(NCCL_PROTO) for x in hwLat for y in x)
|
||||
|
||||
assert len(llMaxBws) == len(NVIDIA_GPU_TYPE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,13 +1,21 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
|
||||
from unittest import skipIf
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from torch._inductor import metrics
|
||||
from torch._inductor.comm_analysis import estimate_nccl_collective_runtime
|
||||
from torch._inductor.compile_fx import compile_fx, count_bytes_inner
|
||||
from torch._inductor.utils import is_collective
|
||||
from torch.testing._internal.common_utils import TestCase as TorchTestCase
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
|
||||
aten = torch.ops.aten
|
||||
c10d = torch.ops.c10d_functional
|
||||
_c10d = torch.ops._c10d_functional
|
||||
|
||||
|
||||
def count_bytes_inductor(gm, example_inputs):
|
||||
@ -165,6 +173,135 @@ class MemoryBoundedTests(TestCase):
|
||||
self.assertNotZero(calculate_runtime(f, *inp))
|
||||
|
||||
|
||||
@skipIf(not dist.is_available(), "requires distributed")
|
||||
class TestCommAnalysis(TestCase):
|
||||
WORLD_SIZE: int = 8
|
||||
RANKS = list(range(8))
|
||||
|
||||
def _verify_runtime_estimation(self, fn, inps):
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
|
||||
store = FakeStore()
|
||||
dist.init_process_group(
|
||||
backend="fake", rank=0, world_size=self.WORLD_SIZE, store=store
|
||||
)
|
||||
try:
|
||||
metrics.reset()
|
||||
torch._dynamo.optimize(count_bytes_inductor)(fn)(*inps)
|
||||
found_collective = False
|
||||
for snode, runtime in metrics.node_runtimes:
|
||||
if not is_collective(snode.node):
|
||||
continue
|
||||
found_collective = True
|
||||
# Inductor swallows errors from snode runtime estimations.
|
||||
# We call estimate_nccl_collective_runtime in a white-box
|
||||
# fashion here so potential issues can be surfaced in tests.
|
||||
est = estimate_nccl_collective_runtime(snode.node)
|
||||
self.assertNotZero(est)
|
||||
# Also make sure estimate_nccl_collective_runtime works
|
||||
# correctly in inductor.
|
||||
self.assertNotZero(runtime)
|
||||
# Make sure a collective kernel is found in graph
|
||||
self.assertTrue(found_collective)
|
||||
finally:
|
||||
dist.destroy_process_group()
|
||||
|
||||
def test_legacy_all_reduce(self):
|
||||
def fn(x):
|
||||
r = c10d.all_reduce(x, "sum", "", self.RANKS, self.WORLD_SIZE)
|
||||
return c10d.wait_tensor(r)
|
||||
|
||||
inp = T(10, 10)
|
||||
self._verify_runtime_estimation(fn, (inp,))
|
||||
|
||||
def test_legacy_all_reduce_coalesced(self):
|
||||
def fn(x):
|
||||
rs = c10d.all_reduce_coalesced(x, "sum", "", self.RANKS, self.WORLD_SIZE)
|
||||
return [c10d.wait_tensor(r) for r in rs]
|
||||
|
||||
inp = [T(10, 10), T(15, 15)]
|
||||
self._verify_runtime_estimation(fn, (inp,))
|
||||
|
||||
def test_legacy_all_gather_into_tensor_coalesced(self):
|
||||
def fn(x):
|
||||
rs = c10d.all_gather_into_tensor_coalesced(
|
||||
x,
|
||||
"",
|
||||
self.RANKS,
|
||||
self.WORLD_SIZE,
|
||||
)
|
||||
return [c10d.wait_tensor(r) for r in rs]
|
||||
|
||||
inp = [T(10, 10), T(15, 15)]
|
||||
self._verify_runtime_estimation(fn, (inp,))
|
||||
|
||||
def test_all_reduce(self):
|
||||
def fn(x):
|
||||
r = _c10d.all_reduce(x, "sum", "0")
|
||||
return _c10d.wait_tensor(r)
|
||||
|
||||
inp = T(10, 10)
|
||||
self._verify_runtime_estimation(fn, (inp,))
|
||||
|
||||
def test_all_reduce_coalesced(self):
|
||||
def fn(x):
|
||||
rs = _c10d.all_reduce_coalesced(x, "sum", "0")
|
||||
return [_c10d.wait_tensor(r) for r in rs]
|
||||
|
||||
inp = [T(10, 10), T(15, 15)]
|
||||
self._verify_runtime_estimation(fn, (inp,))
|
||||
|
||||
def test_all_gather_into_tensor(self):
|
||||
def fn(x):
|
||||
rs = _c10d.all_gather_into_tensor(
|
||||
x,
|
||||
self.WORLD_SIZE,
|
||||
"0",
|
||||
)
|
||||
return [_c10d.wait_tensor(r) for r in rs]
|
||||
|
||||
inp = T(10, 10)
|
||||
self._verify_runtime_estimation(fn, (inp,))
|
||||
|
||||
def test_all_gather_into_tensor_coalesced(self):
|
||||
def fn(x):
|
||||
rs = _c10d.all_gather_into_tensor_coalesced(
|
||||
x,
|
||||
self.WORLD_SIZE,
|
||||
"0",
|
||||
)
|
||||
return [_c10d.wait_tensor(r) for r in rs]
|
||||
|
||||
inp = [T(10, 10), T(15, 15)]
|
||||
self._verify_runtime_estimation(fn, (inp,))
|
||||
|
||||
def test_reduce_scatter_tensor(self):
|
||||
def fn(x):
|
||||
rs = _c10d.reduce_scatter_tensor(
|
||||
x,
|
||||
"sum",
|
||||
self.WORLD_SIZE,
|
||||
"0",
|
||||
)
|
||||
return [_c10d.wait_tensor(r) for r in rs]
|
||||
|
||||
inp = T(self.WORLD_SIZE, 10)
|
||||
self._verify_runtime_estimation(fn, (inp,))
|
||||
|
||||
def test_reduce_scatter_tensor_coalesced(self):
|
||||
def fn(x):
|
||||
rs = _c10d.reduce_scatter_tensor_coalesced(
|
||||
x,
|
||||
"sum",
|
||||
self.WORLD_SIZE,
|
||||
"0",
|
||||
)
|
||||
return [_c10d.wait_tensor(r) for r in rs]
|
||||
|
||||
inp = [T(self.WORLD_SIZE, 10), T(self.WORLD_SIZE, 15)]
|
||||
self._verify_runtime_estimation(fn, (inp,))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import math
|
||||
from enum import IntEnum
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from . import ir
|
||||
@ -9,9 +9,6 @@ from . import ir
|
||||
from .utils import get_dtype_size, sympy_product
|
||||
from .virtualized import V
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.scheduler import BaseSchedulerNode
|
||||
|
||||
|
||||
class NCCL_COLL(IntEnum):
|
||||
ALL_REDUCE = 0
|
||||
@ -26,7 +23,7 @@ class NVIDIA_GPU_TYPE(IntEnum):
|
||||
|
||||
|
||||
def get_gpu_type() -> NVIDIA_GPU_TYPE:
|
||||
gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run)
|
||||
gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or ""
|
||||
if "V100" in gpu_info:
|
||||
return NVIDIA_GPU_TYPE.VOLTA
|
||||
elif "A100" in gpu_info:
|
||||
@ -38,19 +35,52 @@ def get_gpu_type() -> NVIDIA_GPU_TYPE:
|
||||
return NVIDIA_GPU_TYPE.AMPERE
|
||||
|
||||
|
||||
def get_collective_type(snode: "BaseSchedulerNode") -> NCCL_COLL:
|
||||
if isinstance(snode.node, (ir.AllReduce, ir.AllReduceCoalesced)):
|
||||
def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
|
||||
if isinstance(node, ir._CollectiveKernel):
|
||||
kernel_name = node.python_kernel_name
|
||||
assert kernel_name is not None
|
||||
if "all_reduce" in kernel_name:
|
||||
return NCCL_COLL.ALL_REDUCE
|
||||
elif isinstance(
|
||||
snode.node, (ir.AllGatherIntoTensor, ir.AllGatherIntoTensorCoalesced)
|
||||
):
|
||||
elif "all_gather" in kernel_name:
|
||||
return NCCL_COLL.ALL_GATHER
|
||||
elif isinstance(
|
||||
snode.node, (ir.ReduceScatterTensor, ir.ReduceScatterTensorCoalesced)
|
||||
):
|
||||
elif "reduce_scatter" in kernel_name:
|
||||
return NCCL_COLL.REDUCE_SCATTER
|
||||
else:
|
||||
raise Exception(f"Unsupported collective type: {snode.node}")
|
||||
raise Exception(f"Unsupported collective kernel: {kernel_name}")
|
||||
|
||||
if isinstance(node, (ir.AllReduce, ir.AllReduceCoalesced)):
|
||||
return NCCL_COLL.ALL_REDUCE
|
||||
elif isinstance(node, (ir.AllGatherIntoTensor, ir.AllGatherIntoTensorCoalesced)):
|
||||
return NCCL_COLL.ALL_GATHER
|
||||
elif isinstance(node, (ir.ReduceScatterTensor, ir.ReduceScatterTensorCoalesced)):
|
||||
return NCCL_COLL.REDUCE_SCATTER
|
||||
else:
|
||||
raise Exception(f"Unsupported collective type: {node}")
|
||||
|
||||
|
||||
def get_collective_input_size_bytes(node: ir.IRNode) -> int:
|
||||
sz_bytes = 0
|
||||
for inp in node.inputs: # type: ignore[attr-defined]
|
||||
shape = inp.layout.size
|
||||
numel = sympy_product(inp.layout.size)
|
||||
if isinstance(numel, sympy.Integer):
|
||||
# For ease of testing
|
||||
numel = int(numel)
|
||||
else:
|
||||
numel = V.graph.sizevars.size_hint(numel)
|
||||
sz_bytes += numel * get_dtype_size(inp.layout.dtype)
|
||||
return sz_bytes
|
||||
|
||||
|
||||
def get_collective_group_size(node: ir.IRNode) -> int:
|
||||
if type(node) == ir._CollectiveKernel:
|
||||
from torch.distributed.distributed_c10d import _get_group_size_by_name
|
||||
|
||||
return _get_group_size_by_name(node.constant_args[-1])
|
||||
elif isinstance(node, ir.CollectiveKernel):
|
||||
return node.constant_args[2] # type: ignore[attr-defined]
|
||||
else:
|
||||
raise TypeError(f"Unsupported collective type: {node}")
|
||||
|
||||
|
||||
####################################################################################################################
|
||||
@ -80,8 +110,8 @@ class NCCL_PROTO(IntEnum):
|
||||
|
||||
# Latencies in us
|
||||
# len(NCCL_ALGO) x len(NCCL_PROTO)
|
||||
baseLat = torch.tensor(
|
||||
[
|
||||
# NOTE: use array instead of tensor to prevent incompatibility with fake mode
|
||||
baseLat = [
|
||||
# Tree
|
||||
[
|
||||
6.8, # LL
|
||||
@ -90,13 +120,11 @@ baseLat = torch.tensor(
|
||||
[
|
||||
6.6, # LL
|
||||
],
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
# Latencies in us
|
||||
# len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO)
|
||||
hwLat = torch.tensor(
|
||||
[
|
||||
hwLat = [
|
||||
# NVLINK
|
||||
[
|
||||
[0.6], # Tree (LL)
|
||||
@ -112,13 +140,11 @@ hwLat = torch.tensor(
|
||||
[5.0], # Tree (LL)
|
||||
[2.7], # Ring (LL)
|
||||
],
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
# LL128 max BW per channel
|
||||
llMaxBws = torch.tensor(
|
||||
[
|
||||
llMaxBws = [
|
||||
# Volta-N1/Intel-N2/Intel-N4
|
||||
[
|
||||
39.0,
|
||||
@ -137,11 +163,10 @@ llMaxBws = torch.tensor(
|
||||
22.5, # avg of ring & tree
|
||||
19.0,
|
||||
],
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
|
||||
def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
|
||||
"""
|
||||
Returns estimated NCCL collective runtime in nanoseconds (ns).
|
||||
|
||||
@ -154,16 +179,14 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
|
||||
- 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
|
||||
- collective is one of: allreduce, reducescatter, allgather
|
||||
"""
|
||||
tensor_numel = V.graph.sizevars.size_hint(sympy_product(snode.node.layout.size))
|
||||
tensor_dtype = snode.node.layout.dtype
|
||||
tensor_storage_size_bytes = tensor_numel * get_dtype_size(tensor_dtype)
|
||||
tensor_storage_size_bytes = get_collective_input_size_bytes(node)
|
||||
# Convert bytes to GB
|
||||
tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024
|
||||
|
||||
# Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus.
|
||||
# TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info.
|
||||
num_gpus_per_node = 8
|
||||
_, _, group_size = snode.node.constant_args # type: ignore[attr-defined]
|
||||
group_size = get_collective_group_size(node)
|
||||
nNodes = math.ceil(group_size / num_gpus_per_node)
|
||||
nRanks = group_size # this is total # of gpus globally that participate in this collective op
|
||||
|
||||
@ -173,7 +196,7 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
|
||||
# Assumes ring algorithm
|
||||
nccl_algo = NCCL_ALGO.RING
|
||||
nccl_proto = NCCL_PROTO.LL
|
||||
coll = get_collective_type(snode)
|
||||
coll = get_collective_type(node)
|
||||
|
||||
# =============== bandwidth computation ===============
|
||||
# First compute bandwidth in GB/s; then at the end, convert it to GB/ns
|
||||
@ -185,7 +208,7 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
|
||||
index2 = nNodes - 1 if nNodes <= 2 else 2
|
||||
# LL: for single node, we look at GPU type; for multi-node, we look at CPU type
|
||||
index1 = compCapIndex if nNodes == 1 else 0
|
||||
llMaxBw = llMaxBws[index1][index2].item()
|
||||
llMaxBw = llMaxBws[index1][index2]
|
||||
|
||||
# NOTE: each step of ring algorithm is synchronized,
|
||||
# and is bottlenecked by the slowest link which is the inter-node interconnect.
|
||||
@ -227,9 +250,9 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float:
|
||||
nInterSteps = nNodes - 1
|
||||
|
||||
# First compute latency in us; then at the end, convert it to ns
|
||||
latency = baseLat[nccl_algo][nccl_proto].item()
|
||||
intraLat = hwLat[intraHw][nccl_algo][nccl_proto].item()
|
||||
interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto].item()
|
||||
latency = baseLat[nccl_algo][nccl_proto]
|
||||
intraLat = hwLat[intraHw][nccl_algo][nccl_proto]
|
||||
interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto]
|
||||
|
||||
# Inter-node rings still have to launch nsteps * net overhead.
|
||||
netOverhead = 0.0
|
||||
|
@ -7494,6 +7494,8 @@ class _CollectiveKernel(FallbackKernel):
|
||||
def create_inplace(
|
||||
cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs
|
||||
) -> None:
|
||||
cpp_kernel_name = kernel._name
|
||||
python_kernel_name = cpp_kernel_name.replace("::", ".")
|
||||
with V.graph.fake_mode:
|
||||
(
|
||||
example_output,
|
||||
@ -7511,6 +7513,8 @@ class _CollectiveKernel(FallbackKernel):
|
||||
non_tensor_args,
|
||||
unflatten_args,
|
||||
)
|
||||
packed.cpp_kernel_name = cpp_kernel_name
|
||||
packed.python_kernel_name = python_kernel_name
|
||||
|
||||
def mark_mutation(x):
|
||||
if isinstance(x.data, BaseView):
|
||||
@ -7545,6 +7549,8 @@ class _CollectiveKernel(FallbackKernel):
|
||||
def create_out_of_place(
|
||||
cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs
|
||||
):
|
||||
cpp_kernel_name = kernel._name
|
||||
python_kernel_name = cpp_kernel_name.replace("::", ".")
|
||||
with V.graph.fake_mode:
|
||||
(
|
||||
example_output,
|
||||
@ -7564,6 +7570,8 @@ class _CollectiveKernel(FallbackKernel):
|
||||
non_tensor_args,
|
||||
unflatten_args,
|
||||
)
|
||||
packed.cpp_kernel_name = cpp_kernel_name
|
||||
packed.python_kernel_name = python_kernel_name
|
||||
packed.outputs = [
|
||||
MultiOutput(
|
||||
cls.tensor_to_layout(tensor),
|
||||
@ -7581,6 +7589,8 @@ class _CollectiveKernel(FallbackKernel):
|
||||
non_tensor_args,
|
||||
unflatten_args,
|
||||
)
|
||||
packed.cpp_kernel_name = cpp_kernel_name
|
||||
packed.python_kernel_name = python_kernel_name
|
||||
packed.outputs = [packed]
|
||||
return packed
|
||||
|
||||
|
@ -44,6 +44,8 @@ from .utils import (
|
||||
get_dtype_size,
|
||||
get_gpu_dram_gbps,
|
||||
green_text,
|
||||
is_collective,
|
||||
is_wait,
|
||||
red_text,
|
||||
sympy_product,
|
||||
)
|
||||
@ -584,6 +586,16 @@ class BaseSchedulerNode:
|
||||
# default to no reordering based on runtime
|
||||
return 0
|
||||
|
||||
# Collective kernels
|
||||
if is_collective(self.node):
|
||||
return estimate_nccl_collective_runtime(self.node)
|
||||
elif is_wait(self.node):
|
||||
# ir.Wait is only used for collective ops.
|
||||
# The time needed for the collective op is already estimated and considered
|
||||
# when we are processing the collective op IR node, so ir.Wait takes 0 time
|
||||
# since it doesn't take extra time to get the result after the collective is completed.
|
||||
return 0
|
||||
|
||||
try:
|
||||
gpu_memory_bandwidth = get_gpu_dram_gbps()
|
||||
gpu_flops = get_device_tflops(dtype) * 10**12
|
||||
@ -629,16 +641,6 @@ class BaseSchedulerNode:
|
||||
# Return estimated runtime in nanoseconds (bytes / gbps)
|
||||
return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth
|
||||
|
||||
# Collective kernels
|
||||
if isinstance(self.node, ir.CollectiveKernel):
|
||||
return estimate_nccl_collective_runtime(self)
|
||||
elif isinstance(self.node, ir.Wait):
|
||||
# ir.Wait is only used for collective ops.
|
||||
# The time needed for the collective op is already estimated and considered
|
||||
# when we are processing the collective op IR node, so ir.Wait takes 0 time
|
||||
# since it doesn't take extra time to get the result after the collective is completed.
|
||||
return 0
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
|
@ -1261,3 +1261,15 @@ def pass_execution_and_save(func, gm, msg):
|
||||
t,
|
||||
time_elapsed,
|
||||
)
|
||||
|
||||
|
||||
def is_collective(node):
|
||||
from . import ir
|
||||
|
||||
return isinstance(node, ir.CollectiveKernel) or type(node) == ir._CollectiveKernel
|
||||
|
||||
|
||||
def is_wait(node):
|
||||
from . import ir
|
||||
|
||||
return isinstance(node, ir.Wait) or type(node) == ir._WaitKernel
|
||||
|
@ -9,6 +9,12 @@ class FakeWork : public Work {
|
||||
bool wait(std::chrono::milliseconds timeout) override {
|
||||
return true;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
|
||||
auto fut = c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
|
||||
fut->markCompleted();
|
||||
return fut;
|
||||
}
|
||||
};
|
||||
|
||||
class FakeProcessGroup : public Backend {
|
||||
|
Reference in New Issue
Block a user