diff --git a/test/distributed/test_compute_comm_reordering.py b/test/distributed/test_compute_comm_reordering.py index c05d5edae233..986fc2a0247d 100644 --- a/test/distributed/test_compute_comm_reordering.py +++ b/test/distributed/test_compute_comm_reordering.py @@ -259,6 +259,11 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase): "reorder_compute_for_overlap", ], ) + @patch.object( + torch._inductor.config, + "runtime_estimations_mms_benchmark", + False, + ) def test_reorder_compute_for_overlap(self): def func(a, *, tag, ranks, group_size): ar = _functional_collectives.all_reduce(a, "sum", ranks, tag) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 656c03aa6cfd..a69628354e84 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -22,8 +22,13 @@ from torch._inductor.comms import ( sink_waits_iterative, ) from torch._inductor.compile_fx import compile_fx as inductor_compile_fx -from torch._inductor.scheduler import BaseSchedulerNode -from torch._inductor.utils import run_and_get_triton_code +from torch._inductor.scheduler import ( + _get_mm_like_fn, + BaseSchedulerNode, + get_estimate_runtime_cache, + get_estimate_runtime_cache_key_from_snode, +) +from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code from torch.distributed.distributed_c10d import GroupMember from torch.fx.experimental.proxy_tensor import make_fx from torch.testing._internal.common_cuda import SM80OrLater @@ -1568,11 +1573,21 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): inputs = [x, w, ag_0, ag_1, ag_2, ag_3] correct = func(*inputs, **self.get_world_trs()) - with torch._inductor.config.patch( - { - "bucket_all_gathers_fx": "all", - "reorder_for_compute_comm_overlap": False, - } + with ( + torch._inductor.config.patch( + { + "bucket_all_gathers_fx": "all", + "reorder_for_compute_comm_overlap": False, + "runtime_estimations_mms_benchmark": True, + } + ), + torch._inductor.config_comms.patch( + { + "runtime_estimations_align_across_all_distributed_ranks": True, + } + ), + # Clearing cache to cover runtime_estimations_mms_benchmark that use LocalCache + fresh_inductor_cache(), ): compiled = torch.compile(func) code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs()) @@ -1801,6 +1816,17 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): def _reorder_communication_preserving_peak_memory( snodes: list[BaseSchedulerNode], ) -> list[BaseSchedulerNode]: + if torch._inductor.config.runtime_estimations_mms_benchmark: + cache = get_estimate_runtime_cache() + for snode in snodes: + if _get_mm_like_fn(snode) is None: + continue + cache_key = get_estimate_runtime_cache_key_from_snode(snode) + assert cache.lookup(cache_key) is not None + + if torch._inductor.config_comms.runtime_estimations_align_across_all_distributed_ranks: + for snode in snodes: + assert snode.override_estimated_runtime is not None nonlocal node_stats ( reordered_snodes, @@ -1808,20 +1834,30 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase): ) = _reorder_communication_preserving_peak_memory_internal(snodes) return reordered_snodes - with torch._inductor.config.patch( - { - "bucket_all_gathers_fx": "all", - "bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2, - "bucket_reduce_scatters_fx": "all", - "bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2, - "reorder_for_compute_comm_overlap": True, - "reorder_for_compute_comm_overlap_passes": [ - sink_waits_iterative, - _reorder_communication_preserving_peak_memory, - ], - "allow_buffer_reuse": False, - "test_configs.track_memory_lifecycle": "error", - } + with ( + torch._inductor.config.patch( + { + "bucket_all_gathers_fx": "all", + "bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2, + "bucket_reduce_scatters_fx": "all", + "bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2, + "reorder_for_compute_comm_overlap": True, + "reorder_for_compute_comm_overlap_passes": [ + sink_waits_iterative, + _reorder_communication_preserving_peak_memory, + ], + "allow_buffer_reuse": False, + "test_configs.track_memory_lifecycle": "error", + "runtime_estimations_mms_benchmark": True, + } + ), + torch._inductor.config_comms.patch( + { + "runtime_estimations_align_across_all_distributed_ranks": True, + } + ), + # Clearing cache to cover runtime_estimations_mms_benchmark that use LocalCache + fresh_inductor_cache(), ): compiled = torch.compile(func, fullgraph=True) code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs()) diff --git a/torch/_inductor/comm_analysis.py b/torch/_inductor/comm_analysis.py index 2a69a0531347..c24cf336e66a 100644 --- a/torch/_inductor/comm_analysis.py +++ b/torch/_inductor/comm_analysis.py @@ -1,20 +1,26 @@ import functools +import logging import math from enum import IntEnum +from typing import Optional import sympy import torch from . import ir -from .utils import get_dtype_size, sympy_product +from .utils import get_dtype_size, snode_args_kwargs, sympy_product from .virtualized import V +log = logging.getLogger(__name__) + + class NCCL_COLL(IntEnum): ALL_REDUCE = 0 ALL_GATHER = 1 REDUCE_SCATTER = 2 + ALL_TO_ALL = 3 class NVIDIA_GPU_TYPE(IntEnum): @@ -49,6 +55,8 @@ def get_collective_type(node: ir.IRNode) -> NCCL_COLL: return NCCL_COLL.ALL_GATHER elif "reduce_scatter" in kernel_name: return NCCL_COLL.REDUCE_SCATTER + elif "torch.ops._dtensor.shard_dim_alltoall.default" in kernel_name: + return NCCL_COLL.ALL_TO_ALL else: raise ValueError(f"Unsupported collective kernel: {kernel_name}") @@ -158,9 +166,53 @@ llMaxBws = [ ] +def estimate_nccl_collective_runtime_nccl_estimator(snode) -> Optional[float]: # type: ignore[no-untyped-def] + kernel = snode.node + assert kernel is not None + py_kernel_name = getattr(kernel, "python_kernel_name", "") + if not ("all_gather" in py_kernel_name or "reduce_scatter" in py_kernel_name): + # NCCL of version 2.27 sometimes unrecoverably fail for all_to_all, all_reduce + return None + + from torch.distributed.distributed_c10d import _resolve_process_group + + pg_name = kernel.constant_args[-1] # type: ignore[attr-defined] + pg = _resolve_process_group(pg_name) + rank: int = torch.distributed.get_rank(pg) + # TODO(ivankobzarev): Figure out how we can use time estimations, + # without cuda allocations. + device = torch.device(f"cuda:{rank}") + + fn = eval(py_kernel_name) + args, kwargs = snode_args_kwargs(snode) + + # TODO(ivankobzarev): fix out variants snode_args_kwargs + if "all_gather_into_tensor_out" in py_kernel_name: + args = args[1:] + args[0] + + try: + with torch.distributed._time_estimator( + group=pg, device=device + ) as time_estimator: + w = fn(*args, **kwargs) + torch.ops._c10d_functional.wait_tensor.default(w) + except Exception as e: + # NCCL estimator can fail + log.info(e) + return None + + est_time_us = time_estimator.estimated_time + # -1000 constant is NCCL return in case of error during estimations. + # Observed it for all_to_all estimations. + if est_time_us < 0: + return None + est_time_ms = est_time_us / 1e3 + return est_time_ms + + def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: """ - Returns estimated NCCL collective runtime in nanoseconds (ns). + Returns estimated NCCL collective runtime in milliseconds (ms). The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc. We aim to estimate the runtime as accurately as possible. @@ -220,6 +272,8 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: if coll == NCCL_COLL.ALL_REDUCE: nsteps = 2 * (nRanks - 1) + elif coll == NCCL_COLL.ALL_TO_ALL: + nsteps = 2 * (nRanks - 1) elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): nsteps = nRanks - 1 @@ -237,7 +291,7 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: nInterSteps = 2 * nNodes else: nInterSteps = 0 - elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): + elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER, NCCL_COLL.ALL_TO_ALL): nInterSteps = nNodes - 1 # First compute latency in us; then at the end, convert it to ns @@ -256,7 +310,9 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: # =============== final result =============== transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns - return transport_ns + latency_ns + ns = transport_ns + latency_ns + ms = ns / 1e6 + return ms ################################################################################################################ diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index af4651a42a8e..fa8bb30f238c 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -52,6 +52,28 @@ if TYPE_CHECKING: from torch._inductor.scheduler import BaseSchedulerNode +def align_runtime_estimations_across_all_distributed_ranks( + snodes: list[BaseSchedulerNode], +): + runtime_estimations = {} + for snode in snodes: + runtime_estimations[snode] = snode.get_estimated_runtime() + import torch.distributed as dist + from torch.distributed.distributed_c10d import _get_default_group + + world_size = dist.get_world_size() + pg = _get_default_group() + gathered_runtime_estimations: list[list[float]] = [[] for _ in range(world_size)] + dist.all_gather_object( + gathered_runtime_estimations, list(runtime_estimations.values()), pg + ) + median_runtime_estimations = torch.median( + torch.tensor(gathered_runtime_estimations), dim=0 + ).values.tolist() + for i in range(len(snodes)): + snodes[i].override_estimated_runtime = median_runtime_estimations[i] + + def sink_waits(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: """ Greedily schedules waits as late as possible. diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index a3a4bb1db751..f6921a057ba0 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -416,6 +416,8 @@ bucket_reduce_scatters_fx_bucket_size_determinator: Optional[Callable[[int], int # for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle estimate_op_runtime = "default" +runtime_estimations_mms_benchmark: bool = False + # unit: GB/s, uni-directional P2P bandwidth per card # default value is NVLink intra_node_bw = 300 diff --git a/torch/_inductor/config_comms.py b/torch/_inductor/config_comms.py new file mode 100644 index 000000000000..b5dbf424f35b --- /dev/null +++ b/torch/_inductor/config_comms.py @@ -0,0 +1,15 @@ +import sys + +from torch.utils._config_module import install_config_module + + +# Whether to use c10d._time_estimator for collectives runtime estimations. +runtime_estimations_use_nccl_lib_estimations: bool = False + +# Config to enable sync of runtime estimations across distributed ranks, +# To prevent passes using this runtime estimations to make different +# decisions on different distributed ranks. +runtime_estimations_align_across_all_distributed_ranks: bool = False + +# adds patch, save_config, etc +install_config_module(sys.modules[__name__]) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 6afcbde3e2a9..783614dd5132 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -26,6 +26,7 @@ import sympy import torch import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch.utils._pytree as pytree from torch._dynamo.utils import counters, dynamo_timed from torch._inductor.codecache import LambdaFuture, PyCodeCache from torch._inductor.ir import TritonTemplateCallerBase @@ -35,10 +36,13 @@ from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT from torch.utils._triton import has_triton -from . import comms, config, dependencies, ir, metrics +from . import comms, config, config_comms, dependencies, ir, metrics from .analyze_preserves_zero_mask import can_codegen_without_upcasts from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel -from .comm_analysis import estimate_nccl_collective_runtime +from .comm_analysis import ( + estimate_nccl_collective_runtime, + estimate_nccl_collective_runtime_nccl_estimator, +) from .dependencies import Dep, MemoryDep, StarDep, WeakDep from .exc import GPUTooOldForTriton, TritonMissing from .fx_utils import count_flops_fx @@ -212,6 +216,7 @@ class BaseSchedulerNode: min_order: int max_order: int mpi_node: MemoryPlanningInfoForNode + override_estimated_runtime: Optional[float] = None def __init__(self, scheduler: Scheduler) -> None: self.scheduler: Scheduler = scheduler @@ -823,10 +828,16 @@ class BaseSchedulerNode: counters["inductor"]["flop_count"] += resolved_flops return resolved_flops - @cache_on_self def get_estimated_runtime(self) -> float: + if self.override_estimated_runtime is not None: + return self.override_estimated_runtime + + return self._get_estimated_runtime() + + @cache_on_self + def _get_estimated_runtime(self) -> float: """ - Returns estimated op runtime in nanoseconds (ns) + Returns estimated op runtime in milliseconds (ms) """ buf = self.get_nodes()[0].get_outputs()[0] layout = buf.node.get_output_spec() @@ -838,6 +849,21 @@ class BaseSchedulerNode: if is_collective(self.node): assert isinstance(self.node, ir.IRNode) try: + if config_comms.runtime_estimations_use_nccl_lib_estimations: + cache_key = get_estimate_runtime_cache_key_from_snode(self) + cache = get_estimate_runtime_cache() + cache_val = cache.lookup(cache_key) + if cache_val is not None: + assert isinstance(cache_val, float) + return cache_val + + ms = estimate_nccl_collective_runtime_nccl_estimator(self) + if ms is None: + # NCCL estimations fail: fallback to in-tree algorithmic estimation. + ms = estimate_nccl_collective_runtime(self.node) + + cache.set_value(cache_key, value=ms) + return ms return estimate_nccl_collective_runtime(self.node) except ValueError as e: # We don't know how to estimate runtime for this collective, @@ -856,6 +882,10 @@ class BaseSchedulerNode: # since it doesn't take extra time to get the result after the collective is completed. return 0 + ret = maybe_estimate_runtime_benchmark(self) + if ret is not None: + return ret + dtype = buf.node.maybe_get_dtype() try: gpu_memory_bandwidth = get_gpu_dram_gbps() @@ -876,7 +906,9 @@ class BaseSchedulerNode: if flops_est == 0 or flops_est is None: # no flops estimate, so fall back to memory estimate - return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth + ns = self.get_read_write_buffers_sizes() / gpu_memory_bandwidth + ms = ns / 1e6 + return ms # TODO(xmfan): find a better heuristic to model FLOPS/latency relationship factor = 1.0 @@ -885,8 +917,10 @@ class BaseSchedulerNode: compute_time = (factor * flops_est / gpu_flops) * 1e9 transfer_time = counted_bytes / gpu_memory_bandwidth - # Return estimated runtime in nanoseconds - return max(compute_time, transfer_time) + # Return estimated runtime in milliseconds + ns = max(compute_time, transfer_time) + ms = ns / 1e6 + return ms def get_template_node(self) -> Optional[ir.TemplateBuffer]: return None @@ -911,6 +945,77 @@ class BaseSchedulerNode: return prologue, template_node, epilogue +@functools.cache +def get_estimate_runtime_cache() -> torch._inductor.codecache.LocalCache: + return torch._inductor.codecache.LocalCache() + + +def get_estimate_runtime_cache_key_from_snode(snode: BaseSchedulerNode) -> str: + python_kernel_name = getattr(snode.node, "python_kernel_name", "") + args = snode.node.inputs # type: ignore[union-attr] + args = snode.node.fill_non_provided_args( # type: ignore[union-attr] + [*args, *snode.node.constant_args], # type: ignore[union-attr] + snode.node.kwargs, # type: ignore[union-attr] + ) + kwargs = snode.node.kwargs # type: ignore[union-attr] + flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs)) + + def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def] + return isinstance(x, ir.IRNode) and not isinstance(x, ir.GeneratorState) + + cache_key = str( + (python_kernel_name,) + + tuple(tuple(a.get_size()) if _is_tensor_ir(a) else None for a in flat_args) + ) + return cache_key + + +def _get_mm_like_fn(snode: BaseSchedulerNode) -> Optional[Callable[[Any], Any]]: + if not isinstance(snode, ExternKernelSchedulerNode): + return None + mms_fns = { + "extern_kernels.mm": torch.ops.aten.mm, + "extern_kernels.bmm": torch.ops.aten.bmm, + "extern_kernels.addmm": torch.ops.aten.addmm, + } + python_kernel_name = getattr(snode.node, "python_kernel_name", "") + if python_kernel_name not in mms_fns: + return None + if not isinstance(snode.node, ir.ExternKernel): + return None + return mms_fns[python_kernel_name] + + +def maybe_estimate_runtime_benchmark(snode: BaseSchedulerNode) -> Optional[float]: + bench_fn = None + args_kwargs_fn = None + if config.runtime_estimations_mms_benchmark: + mm_fn = _get_mm_like_fn(snode) + if mm_fn is None: + return None + bench_fn = mm_fn + args_kwargs_fn = lambda: snode_args_kwargs(snode) # noqa: E731 + else: + return None + + cache_key = get_estimate_runtime_cache_key_from_snode(snode) + cache = get_estimate_runtime_cache() + cache_val = cache.lookup(cache_key) + if cache_val is not None: + assert isinstance(cache_val, float) + return cache_val + + from .utils import snode_args_kwargs + + args, kwargs = args_kwargs_fn() + from triton.testing import do_bench + + ms = do_bench(lambda: bench_fn(*args, **kwargs)) + + cache.set_value(cache_key, value=ms) + return ms + + class WhyNoFuse: # TODO when we drop support for Python < 3.10, we can use # @dataclass(slots=True) instead of manually specifying __slots__. @@ -2094,6 +2199,10 @@ class NodeUser: _post_grad_graph_counter = itertools.count() +def used_non_deterministic_runtime_estimations() -> bool: + return config.runtime_estimations_mms_benchmark + + class Scheduler: """ A Scheduler is a graph of BaseSchedulerNodes. It is responsible for @@ -2214,6 +2323,17 @@ class Scheduler: assign_memory_planning_info_for_scheduler_buffers( self.nodes, self.name_to_buf ) + + if ( + used_non_deterministic_runtime_estimations() + and config_comms.runtime_estimations_align_across_all_distributed_ranks + ): + from .comms import ( + align_runtime_estimations_across_all_distributed_ranks, + ) + + align_runtime_estimations_across_all_distributed_ranks(self.nodes) + from torch._logging import trace_structured trace_structured( diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index a7302381f9d3..abb850ea4cce 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -59,6 +59,7 @@ from unittest import mock import sympy import torch +import torch.utils._pytree as pytree from torch._inductor.analysis.device_info import datasheet_tops from torch._inductor.runtime.hints import DeviceProperties from torch.utils._dtype_abbrs import dtype_abbrs @@ -3666,3 +3667,38 @@ _unstable_customized_partition_wrapper = CUDAGraphWrapper() def set_customized_partition_wrappers(wrapper: CUDAGraphWrapperType) -> None: _unstable_customized_partition_wrapper.wrapper = wrapper + + +def snode_args_kwargs(snode: BaseSchedulerNode) -> tuple[list[Any], dict[str, Any]]: + args = snode.node.inputs # type: ignore[union-attr] + args = snode.node.fill_non_provided_args( # type: ignore[union-attr] + [*args, *snode.node.constant_args], # type: ignore[union-attr] + snode.node.kwargs, # type: ignore[union-attr] + ) + kwargs = snode.node.kwargs # type: ignore[union-attr] + flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs)) + + def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def] + return isinstance(x, torch._inductor.ir.IRNode) and not isinstance( + x, torch._inductor.ir.GeneratorState + ) + + flat_args = [ + torch._inductor.ir.ir_node_to_tensor(a, guard_shape=False) + if _is_tensor_ir(a) + else a + for a in flat_args + ] + + def _tensor(size, dtype, device) -> torch.Tensor: # type: ignore[no-untyped-def] + return torch.empty(size, dtype=dtype, device=device) + + def to_real_tensor(e: Any) -> Any: + if not isinstance(e, torch.Tensor): + return e + out = _tensor(e.size(), e.dtype, e.device) + return out + + flat_args = [to_real_tensor(a) for a in flat_args] + args, kwargs = pytree.tree_unflatten(flat_args, flat_args_pytree_spec) + return args, kwargs