From 27ffede8784e43df4ed43ce37f86dbffb5b05f49 Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Sun, 11 Feb 2024 21:11:33 -0800 Subject: [PATCH] [reland] Fix estimate_nccl_collective_runtime (#118986) `estimate_nccl_collective_runtime` has been broken and the errors have been silently swallowed by inductor. This PR: - Fixes the issues described in https://github.com/pytorch/pytorch/issues/118497. - Adds white-box testing so future issues can be surfaced in tests. - Add support for native funcol IRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118986 Approved by: https://github.com/yf225 ghstack dependencies: #119102 --- .../test_compute_comm_reordering.py | 11 +- test/inductor/test_snode_runtime.py | 137 ++++++++++++++ torch/_inductor/comm_analysis.py | 171 ++++++++++-------- torch/_inductor/ir.py | 10 + torch/_inductor/scheduler.py | 22 ++- torch/_inductor/utils.py | 12 ++ .../distributed/c10d/FakeProcessGroup.hpp | 6 + 7 files changed, 282 insertions(+), 87 deletions(-) diff --git a/test/distributed/test_compute_comm_reordering.py b/test/distributed/test_compute_comm_reordering.py index a3563e59d7dc..198a5a912f0f 100644 --- a/test/distributed/test_compute_comm_reordering.py +++ b/test/distributed/test_compute_comm_reordering.py @@ -278,9 +278,14 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): self.assertTrue(same(out, correct)) def test_nccl_heuristics(self): - assert list(baseLat.shape) == [len(NCCL_ALGO), len(NCCL_PROTO)] - assert list(hwLat.shape) == [len(NCCL_HW), len(NCCL_ALGO), len(NCCL_PROTO)] - assert llMaxBws.shape[0] == len(NVIDIA_GPU_TYPE) + assert len(baseLat) == len(NCCL_ALGO) + assert all(len(x) == len(NCCL_PROTO) for x in baseLat) + + assert len(hwLat) == len(NCCL_HW) + assert all(len(x) == len(NCCL_ALGO) for x in hwLat) + assert all(len(y) == len(NCCL_PROTO) for x in hwLat for y in x) + + assert len(llMaxBws) == len(NVIDIA_GPU_TYPE) if __name__ == "__main__": diff --git a/test/inductor/test_snode_runtime.py b/test/inductor/test_snode_runtime.py index b18a3e122d71..06059a6c0287 100644 --- a/test/inductor/test_snode_runtime.py +++ b/test/inductor/test_snode_runtime.py @@ -1,13 +1,21 @@ # Owner(s): ["module: inductor"] +from unittest import skipIf + import torch +import torch.distributed as dist + from torch._inductor import metrics +from torch._inductor.comm_analysis import estimate_nccl_collective_runtime from torch._inductor.compile_fx import compile_fx, count_bytes_inner +from torch._inductor.utils import is_collective from torch.testing._internal.common_utils import TestCase as TorchTestCase from torch.testing._internal.inductor_utils import HAS_CUDA aten = torch.ops.aten +c10d = torch.ops.c10d_functional +_c10d = torch.ops._c10d_functional def count_bytes_inductor(gm, example_inputs): @@ -165,6 +173,135 @@ class MemoryBoundedTests(TestCase): self.assertNotZero(calculate_runtime(f, *inp)) +@skipIf(not dist.is_available(), "requires distributed") +class TestCommAnalysis(TestCase): + WORLD_SIZE: int = 8 + RANKS = list(range(8)) + + def _verify_runtime_estimation(self, fn, inps): + from torch.testing._internal.distributed.fake_pg import FakeStore + + store = FakeStore() + dist.init_process_group( + backend="fake", rank=0, world_size=self.WORLD_SIZE, store=store + ) + try: + metrics.reset() + torch._dynamo.optimize(count_bytes_inductor)(fn)(*inps) + found_collective = False + for snode, runtime in metrics.node_runtimes: + if not is_collective(snode.node): + continue + found_collective = True + # Inductor swallows errors from snode runtime estimations. + # We call estimate_nccl_collective_runtime in a white-box + # fashion here so potential issues can be surfaced in tests. + est = estimate_nccl_collective_runtime(snode.node) + self.assertNotZero(est) + # Also make sure estimate_nccl_collective_runtime works + # correctly in inductor. + self.assertNotZero(runtime) + # Make sure a collective kernel is found in graph + self.assertTrue(found_collective) + finally: + dist.destroy_process_group() + + def test_legacy_all_reduce(self): + def fn(x): + r = c10d.all_reduce(x, "sum", "", self.RANKS, self.WORLD_SIZE) + return c10d.wait_tensor(r) + + inp = T(10, 10) + self._verify_runtime_estimation(fn, (inp,)) + + def test_legacy_all_reduce_coalesced(self): + def fn(x): + rs = c10d.all_reduce_coalesced(x, "sum", "", self.RANKS, self.WORLD_SIZE) + return [c10d.wait_tensor(r) for r in rs] + + inp = [T(10, 10), T(15, 15)] + self._verify_runtime_estimation(fn, (inp,)) + + def test_legacy_all_gather_into_tensor_coalesced(self): + def fn(x): + rs = c10d.all_gather_into_tensor_coalesced( + x, + "", + self.RANKS, + self.WORLD_SIZE, + ) + return [c10d.wait_tensor(r) for r in rs] + + inp = [T(10, 10), T(15, 15)] + self._verify_runtime_estimation(fn, (inp,)) + + def test_all_reduce(self): + def fn(x): + r = _c10d.all_reduce(x, "sum", "0") + return _c10d.wait_tensor(r) + + inp = T(10, 10) + self._verify_runtime_estimation(fn, (inp,)) + + def test_all_reduce_coalesced(self): + def fn(x): + rs = _c10d.all_reduce_coalesced(x, "sum", "0") + return [_c10d.wait_tensor(r) for r in rs] + + inp = [T(10, 10), T(15, 15)] + self._verify_runtime_estimation(fn, (inp,)) + + def test_all_gather_into_tensor(self): + def fn(x): + rs = _c10d.all_gather_into_tensor( + x, + self.WORLD_SIZE, + "0", + ) + return [_c10d.wait_tensor(r) for r in rs] + + inp = T(10, 10) + self._verify_runtime_estimation(fn, (inp,)) + + def test_all_gather_into_tensor_coalesced(self): + def fn(x): + rs = _c10d.all_gather_into_tensor_coalesced( + x, + self.WORLD_SIZE, + "0", + ) + return [_c10d.wait_tensor(r) for r in rs] + + inp = [T(10, 10), T(15, 15)] + self._verify_runtime_estimation(fn, (inp,)) + + def test_reduce_scatter_tensor(self): + def fn(x): + rs = _c10d.reduce_scatter_tensor( + x, + "sum", + self.WORLD_SIZE, + "0", + ) + return [_c10d.wait_tensor(r) for r in rs] + + inp = T(self.WORLD_SIZE, 10) + self._verify_runtime_estimation(fn, (inp,)) + + def test_reduce_scatter_tensor_coalesced(self): + def fn(x): + rs = _c10d.reduce_scatter_tensor_coalesced( + x, + "sum", + self.WORLD_SIZE, + "0", + ) + return [_c10d.wait_tensor(r) for r in rs] + + inp = [T(self.WORLD_SIZE, 10), T(self.WORLD_SIZE, 15)] + self._verify_runtime_estimation(fn, (inp,)) + + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_inductor/comm_analysis.py b/torch/_inductor/comm_analysis.py index f1f555ff45ac..6ff48e5dc6d9 100644 --- a/torch/_inductor/comm_analysis.py +++ b/torch/_inductor/comm_analysis.py @@ -1,7 +1,7 @@ import math from enum import IntEnum -from typing import TYPE_CHECKING +import sympy import torch from . import ir @@ -9,9 +9,6 @@ from . import ir from .utils import get_dtype_size, sympy_product from .virtualized import V -if TYPE_CHECKING: - from torch._inductor.scheduler import BaseSchedulerNode - class NCCL_COLL(IntEnum): ALL_REDUCE = 0 @@ -26,7 +23,7 @@ class NVIDIA_GPU_TYPE(IntEnum): def get_gpu_type() -> NVIDIA_GPU_TYPE: - gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) + gpu_info = torch.utils.collect_env.get_gpu_info(torch.utils.collect_env.run) or "" if "V100" in gpu_info: return NVIDIA_GPU_TYPE.VOLTA elif "A100" in gpu_info: @@ -38,19 +35,52 @@ def get_gpu_type() -> NVIDIA_GPU_TYPE: return NVIDIA_GPU_TYPE.AMPERE -def get_collective_type(snode: "BaseSchedulerNode") -> NCCL_COLL: - if isinstance(snode.node, (ir.AllReduce, ir.AllReduceCoalesced)): +def get_collective_type(node: ir.IRNode) -> NCCL_COLL: + if isinstance(node, ir._CollectiveKernel): + kernel_name = node.python_kernel_name + assert kernel_name is not None + if "all_reduce" in kernel_name: + return NCCL_COLL.ALL_REDUCE + elif "all_gather" in kernel_name: + return NCCL_COLL.ALL_GATHER + elif "reduce_scatter" in kernel_name: + return NCCL_COLL.REDUCE_SCATTER + else: + raise Exception(f"Unsupported collective kernel: {kernel_name}") + + if isinstance(node, (ir.AllReduce, ir.AllReduceCoalesced)): return NCCL_COLL.ALL_REDUCE - elif isinstance( - snode.node, (ir.AllGatherIntoTensor, ir.AllGatherIntoTensorCoalesced) - ): + elif isinstance(node, (ir.AllGatherIntoTensor, ir.AllGatherIntoTensorCoalesced)): return NCCL_COLL.ALL_GATHER - elif isinstance( - snode.node, (ir.ReduceScatterTensor, ir.ReduceScatterTensorCoalesced) - ): + elif isinstance(node, (ir.ReduceScatterTensor, ir.ReduceScatterTensorCoalesced)): return NCCL_COLL.REDUCE_SCATTER else: - raise Exception(f"Unsupported collective type: {snode.node}") + raise Exception(f"Unsupported collective type: {node}") + + +def get_collective_input_size_bytes(node: ir.IRNode) -> int: + sz_bytes = 0 + for inp in node.inputs: # type: ignore[attr-defined] + shape = inp.layout.size + numel = sympy_product(inp.layout.size) + if isinstance(numel, sympy.Integer): + # For ease of testing + numel = int(numel) + else: + numel = V.graph.sizevars.size_hint(numel) + sz_bytes += numel * get_dtype_size(inp.layout.dtype) + return sz_bytes + + +def get_collective_group_size(node: ir.IRNode) -> int: + if type(node) == ir._CollectiveKernel: + from torch.distributed.distributed_c10d import _get_group_size_by_name + + return _get_group_size_by_name(node.constant_args[-1]) + elif isinstance(node, ir.CollectiveKernel): + return node.constant_args[2] # type: ignore[attr-defined] + else: + raise TypeError(f"Unsupported collective type: {node}") #################################################################################################################### @@ -80,68 +110,63 @@ class NCCL_PROTO(IntEnum): # Latencies in us # len(NCCL_ALGO) x len(NCCL_PROTO) -baseLat = torch.tensor( +# NOTE: use array instead of tensor to prevent incompatibility with fake mode +baseLat = [ + # Tree [ - # Tree - [ - 6.8, # LL - ], - # Ring - [ - 6.6, # LL - ], - ] -) + 6.8, # LL + ], + # Ring + [ + 6.6, # LL + ], +] # Latencies in us # len(NCCL_HW) x len(NCCL_ALGO) x len(NCCL_PROTO) -hwLat = torch.tensor( +hwLat = [ + # NVLINK [ - # NVLINK - [ - [0.6], # Tree (LL) - [0.6], # Ring (LL) - ], - # PCI - [ - [1.0], # Tree (LL) - [1.0], # Ring (LL) - ], - # NET - [ - [5.0], # Tree (LL) - [2.7], # Ring (LL) - ], - ] -) + [0.6], # Tree (LL) + [0.6], # Ring (LL) + ], + # PCI + [ + [1.0], # Tree (LL) + [1.0], # Ring (LL) + ], + # NET + [ + [5.0], # Tree (LL) + [2.7], # Ring (LL) + ], +] # LL128 max BW per channel -llMaxBws = torch.tensor( +llMaxBws = [ + # Volta-N1/Intel-N2/Intel-N4 [ - # Volta-N1/Intel-N2/Intel-N4 - [ - 39.0, - 39.0, - 20.4, - ], - # Ampere-N1/AMD-N2/AMD-N4 - [ - 87.7, - 22.5, # avg of ring & tree - 19.0, - ], - # Hopper-N1/AMD-N2/AMD-N4 - [ - 87.7, - 22.5, # avg of ring & tree - 19.0, - ], - ] -) + 39.0, + 39.0, + 20.4, + ], + # Ampere-N1/AMD-N2/AMD-N4 + [ + 87.7, + 22.5, # avg of ring & tree + 19.0, + ], + # Hopper-N1/AMD-N2/AMD-N4 + [ + 87.7, + 22.5, # avg of ring & tree + 19.0, + ], +] -def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float: +def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: """ Returns estimated NCCL collective runtime in nanoseconds (ns). @@ -154,16 +179,14 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float: - 8 gpus per node # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. - collective is one of: allreduce, reducescatter, allgather """ - tensor_numel = V.graph.sizevars.size_hint(sympy_product(snode.node.layout.size)) - tensor_dtype = snode.node.layout.dtype - tensor_storage_size_bytes = tensor_numel * get_dtype_size(tensor_dtype) + tensor_storage_size_bytes = get_collective_input_size_bytes(node) # Convert bytes to GB tensor_storage_size_GB = tensor_storage_size_bytes / 1024 / 1024 / 1024 # Currently assumes each node has 8 gpus. And when >1 node is used, assumes each node uses all 8 gpus. # TODO: Need to find a way to get accurate "gpus per node" and "# nodes" info. num_gpus_per_node = 8 - _, _, group_size = snode.node.constant_args # type: ignore[attr-defined] + group_size = get_collective_group_size(node) nNodes = math.ceil(group_size / num_gpus_per_node) nRanks = group_size # this is total # of gpus globally that participate in this collective op @@ -173,7 +196,7 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float: # Assumes ring algorithm nccl_algo = NCCL_ALGO.RING nccl_proto = NCCL_PROTO.LL - coll = get_collective_type(snode) + coll = get_collective_type(node) # =============== bandwidth computation =============== # First compute bandwidth in GB/s; then at the end, convert it to GB/ns @@ -185,7 +208,7 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float: index2 = nNodes - 1 if nNodes <= 2 else 2 # LL: for single node, we look at GPU type; for multi-node, we look at CPU type index1 = compCapIndex if nNodes == 1 else 0 - llMaxBw = llMaxBws[index1][index2].item() + llMaxBw = llMaxBws[index1][index2] # NOTE: each step of ring algorithm is synchronized, # and is bottlenecked by the slowest link which is the inter-node interconnect. @@ -227,9 +250,9 @@ def estimate_nccl_collective_runtime(snode: "BaseSchedulerNode") -> float: nInterSteps = nNodes - 1 # First compute latency in us; then at the end, convert it to ns - latency = baseLat[nccl_algo][nccl_proto].item() - intraLat = hwLat[intraHw][nccl_algo][nccl_proto].item() - interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto].item() + latency = baseLat[nccl_algo][nccl_proto] + intraLat = hwLat[intraHw][nccl_algo][nccl_proto] + interLat = hwLat[NCCL_HW.NET][nccl_algo][nccl_proto] # Inter-node rings still have to launch nsteps * net overhead. netOverhead = 0.0 diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 4c5294e1ec03..dd0e0b63baad 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -7494,6 +7494,8 @@ class _CollectiveKernel(FallbackKernel): def create_inplace( cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs ) -> None: + cpp_kernel_name = kernel._name + python_kernel_name = cpp_kernel_name.replace("::", ".") with V.graph.fake_mode: ( example_output, @@ -7511,6 +7513,8 @@ class _CollectiveKernel(FallbackKernel): non_tensor_args, unflatten_args, ) + packed.cpp_kernel_name = cpp_kernel_name + packed.python_kernel_name = python_kernel_name def mark_mutation(x): if isinstance(x.data, BaseView): @@ -7545,6 +7549,8 @@ class _CollectiveKernel(FallbackKernel): def create_out_of_place( cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs ): + cpp_kernel_name = kernel._name + python_kernel_name = cpp_kernel_name.replace("::", ".") with V.graph.fake_mode: ( example_output, @@ -7564,6 +7570,8 @@ class _CollectiveKernel(FallbackKernel): non_tensor_args, unflatten_args, ) + packed.cpp_kernel_name = cpp_kernel_name + packed.python_kernel_name = python_kernel_name packed.outputs = [ MultiOutput( cls.tensor_to_layout(tensor), @@ -7581,6 +7589,8 @@ class _CollectiveKernel(FallbackKernel): non_tensor_args, unflatten_args, ) + packed.cpp_kernel_name = cpp_kernel_name + packed.python_kernel_name = python_kernel_name packed.outputs = [packed] return packed diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 493491f429f3..11839aad9576 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -44,6 +44,8 @@ from .utils import ( get_dtype_size, get_gpu_dram_gbps, green_text, + is_collective, + is_wait, red_text, sympy_product, ) @@ -584,6 +586,16 @@ class BaseSchedulerNode: # default to no reordering based on runtime return 0 + # Collective kernels + if is_collective(self.node): + return estimate_nccl_collective_runtime(self.node) + elif is_wait(self.node): + # ir.Wait is only used for collective ops. + # The time needed for the collective op is already estimated and considered + # when we are processing the collective op IR node, so ir.Wait takes 0 time + # since it doesn't take extra time to get the result after the collective is completed. + return 0 + try: gpu_memory_bandwidth = get_gpu_dram_gbps() gpu_flops = get_device_tflops(dtype) * 10**12 @@ -629,16 +641,6 @@ class BaseSchedulerNode: # Return estimated runtime in nanoseconds (bytes / gbps) return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth - # Collective kernels - if isinstance(self.node, ir.CollectiveKernel): - return estimate_nccl_collective_runtime(self) - elif isinstance(self.node, ir.Wait): - # ir.Wait is only used for collective ops. - # The time needed for the collective op is already estimated and considered - # when we are processing the collective op IR node, so ir.Wait takes 0 time - # since it doesn't take extra time to get the result after the collective is completed. - return 0 - return 0 diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 254f6338e369..61d9a7907f56 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1261,3 +1261,15 @@ def pass_execution_and_save(func, gm, msg): t, time_elapsed, ) + + +def is_collective(node): + from . import ir + + return isinstance(node, ir.CollectiveKernel) or type(node) == ir._CollectiveKernel + + +def is_wait(node): + from . import ir + + return isinstance(node, ir.Wait) or type(node) == ir._WaitKernel diff --git a/torch/csrc/distributed/c10d/FakeProcessGroup.hpp b/torch/csrc/distributed/c10d/FakeProcessGroup.hpp index c03e65c0ee63..2736e0e3538d 100644 --- a/torch/csrc/distributed/c10d/FakeProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/FakeProcessGroup.hpp @@ -9,6 +9,12 @@ class FakeWork : public Work { bool wait(std::chrono::milliseconds timeout) override { return true; } + + c10::intrusive_ptr getFuture() override { + auto fut = c10::make_intrusive(c10::NoneType::get()); + fut->markCompleted(); + return fut; + } }; class FakeProcessGroup : public Backend {