[inductor] Runtime estimations: use nccl estimator; mm only benchmark mode (#161405)

During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.

Adding optional usage of:
- c10d.time_estimator for collectives, which is based on NCCL estimator

Benchmark mode only for matmuls, as they are highly dependent on mm backend

- The logic mostly copied from Ruisi's PRs for inductor simple_fsdp https://github.com/pytorch/pytorch/pull/157572

This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()`

Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161405
Approved by: https://github.com/eellison
This commit is contained in:
IvanKobzarev
2025-09-08 04:35:52 -07:00
committed by PyTorch MergeBot
parent 3f5993316e
commit 25c170b72e
8 changed files with 324 additions and 32 deletions

View File

@ -259,6 +259,11 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
"reorder_compute_for_overlap",
],
)
@patch.object(
torch._inductor.config,
"runtime_estimations_mms_benchmark",
False,
)
def test_reorder_compute_for_overlap(self):
def func(a, *, tag, ranks, group_size):
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)

View File

@ -22,8 +22,13 @@ from torch._inductor.comms import (
sink_waits_iterative,
)
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
from torch._inductor.scheduler import BaseSchedulerNode
from torch._inductor.utils import run_and_get_triton_code
from torch._inductor.scheduler import (
_get_mm_like_fn,
BaseSchedulerNode,
get_estimate_runtime_cache,
get_estimate_runtime_cache_key_from_snode,
)
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
from torch.distributed.distributed_c10d import GroupMember
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_cuda import SM80OrLater
@ -1568,11 +1573,21 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
inputs = [x, w, ag_0, ag_1, ag_2, ag_3]
correct = func(*inputs, **self.get_world_trs())
with torch._inductor.config.patch(
{
"bucket_all_gathers_fx": "all",
"reorder_for_compute_comm_overlap": False,
}
with (
torch._inductor.config.patch(
{
"bucket_all_gathers_fx": "all",
"reorder_for_compute_comm_overlap": False,
"runtime_estimations_mms_benchmark": True,
}
),
torch._inductor.config_comms.patch(
{
"runtime_estimations_align_across_all_distributed_ranks": True,
}
),
# Clearing cache to cover runtime_estimations_mms_benchmark that use LocalCache
fresh_inductor_cache(),
):
compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
@ -1801,6 +1816,17 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
def _reorder_communication_preserving_peak_memory(
snodes: list[BaseSchedulerNode],
) -> list[BaseSchedulerNode]:
if torch._inductor.config.runtime_estimations_mms_benchmark:
cache = get_estimate_runtime_cache()
for snode in snodes:
if _get_mm_like_fn(snode) is None:
continue
cache_key = get_estimate_runtime_cache_key_from_snode(snode)
assert cache.lookup(cache_key) is not None
if torch._inductor.config_comms.runtime_estimations_align_across_all_distributed_ranks:
for snode in snodes:
assert snode.override_estimated_runtime is not None
nonlocal node_stats
(
reordered_snodes,
@ -1808,20 +1834,30 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
) = _reorder_communication_preserving_peak_memory_internal(snodes)
return reordered_snodes
with torch._inductor.config.patch(
{
"bucket_all_gathers_fx": "all",
"bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2,
"bucket_reduce_scatters_fx": "all",
"bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2,
"reorder_for_compute_comm_overlap": True,
"reorder_for_compute_comm_overlap_passes": [
sink_waits_iterative,
_reorder_communication_preserving_peak_memory,
],
"allow_buffer_reuse": False,
"test_configs.track_memory_lifecycle": "error",
}
with (
torch._inductor.config.patch(
{
"bucket_all_gathers_fx": "all",
"bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2,
"bucket_reduce_scatters_fx": "all",
"bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2,
"reorder_for_compute_comm_overlap": True,
"reorder_for_compute_comm_overlap_passes": [
sink_waits_iterative,
_reorder_communication_preserving_peak_memory,
],
"allow_buffer_reuse": False,
"test_configs.track_memory_lifecycle": "error",
"runtime_estimations_mms_benchmark": True,
}
),
torch._inductor.config_comms.patch(
{
"runtime_estimations_align_across_all_distributed_ranks": True,
}
),
# Clearing cache to cover runtime_estimations_mms_benchmark that use LocalCache
fresh_inductor_cache(),
):
compiled = torch.compile(func, fullgraph=True)
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())

View File

@ -1,20 +1,26 @@
import functools
import logging
import math
from enum import IntEnum
from typing import Optional
import sympy
import torch
from . import ir
from .utils import get_dtype_size, sympy_product
from .utils import get_dtype_size, snode_args_kwargs, sympy_product
from .virtualized import V
log = logging.getLogger(__name__)
class NCCL_COLL(IntEnum):
ALL_REDUCE = 0
ALL_GATHER = 1
REDUCE_SCATTER = 2
ALL_TO_ALL = 3
class NVIDIA_GPU_TYPE(IntEnum):
@ -49,6 +55,8 @@ def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
return NCCL_COLL.ALL_GATHER
elif "reduce_scatter" in kernel_name:
return NCCL_COLL.REDUCE_SCATTER
elif "torch.ops._dtensor.shard_dim_alltoall.default" in kernel_name:
return NCCL_COLL.ALL_TO_ALL
else:
raise ValueError(f"Unsupported collective kernel: {kernel_name}")
@ -158,9 +166,53 @@ llMaxBws = [
]
def estimate_nccl_collective_runtime_nccl_estimator(snode) -> Optional[float]: # type: ignore[no-untyped-def]
kernel = snode.node
assert kernel is not None
py_kernel_name = getattr(kernel, "python_kernel_name", "")
if not ("all_gather" in py_kernel_name or "reduce_scatter" in py_kernel_name):
# NCCL of version 2.27 sometimes unrecoverably fail for all_to_all, all_reduce
return None
from torch.distributed.distributed_c10d import _resolve_process_group
pg_name = kernel.constant_args[-1] # type: ignore[attr-defined]
pg = _resolve_process_group(pg_name)
rank: int = torch.distributed.get_rank(pg)
# TODO(ivankobzarev): Figure out how we can use time estimations,
# without cuda allocations.
device = torch.device(f"cuda:{rank}")
fn = eval(py_kernel_name)
args, kwargs = snode_args_kwargs(snode)
# TODO(ivankobzarev): fix out variants snode_args_kwargs
if "all_gather_into_tensor_out" in py_kernel_name:
args = args[1:] + args[0]
try:
with torch.distributed._time_estimator(
group=pg, device=device
) as time_estimator:
w = fn(*args, **kwargs)
torch.ops._c10d_functional.wait_tensor.default(w)
except Exception as e:
# NCCL estimator can fail
log.info(e)
return None
est_time_us = time_estimator.estimated_time
# -1000 constant is NCCL return in case of error during estimations.
# Observed it for all_to_all estimations.
if est_time_us < 0:
return None
est_time_ms = est_time_us / 1e3
return est_time_ms
def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
"""
Returns estimated NCCL collective runtime in nanoseconds (ns).
Returns estimated NCCL collective runtime in milliseconds (ms).
The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc.
We aim to estimate the runtime as accurately as possible.
@ -220,6 +272,8 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
if coll == NCCL_COLL.ALL_REDUCE:
nsteps = 2 * (nRanks - 1)
elif coll == NCCL_COLL.ALL_TO_ALL:
nsteps = 2 * (nRanks - 1)
elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
nsteps = nRanks - 1
@ -237,7 +291,7 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
nInterSteps = 2 * nNodes
else:
nInterSteps = 0
elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER, NCCL_COLL.ALL_TO_ALL):
nInterSteps = nNodes - 1
# First compute latency in us; then at the end, convert it to ns
@ -256,7 +310,9 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
# =============== final result ===============
transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns
return transport_ns + latency_ns
ns = transport_ns + latency_ns
ms = ns / 1e6
return ms
################################################################################################################

View File

@ -52,6 +52,28 @@ if TYPE_CHECKING:
from torch._inductor.scheduler import BaseSchedulerNode
def align_runtime_estimations_across_all_distributed_ranks(
snodes: list[BaseSchedulerNode],
):
runtime_estimations = {}
for snode in snodes:
runtime_estimations[snode] = snode.get_estimated_runtime()
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group
world_size = dist.get_world_size()
pg = _get_default_group()
gathered_runtime_estimations: list[list[float]] = [[] for _ in range(world_size)]
dist.all_gather_object(
gathered_runtime_estimations, list(runtime_estimations.values()), pg
)
median_runtime_estimations = torch.median(
torch.tensor(gathered_runtime_estimations), dim=0
).values.tolist()
for i in range(len(snodes)):
snodes[i].override_estimated_runtime = median_runtime_estimations[i]
def sink_waits(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
"""
Greedily schedules waits as late as possible.

View File

@ -416,6 +416,8 @@ bucket_reduce_scatters_fx_bucket_size_determinator: Optional[Callable[[int], int
# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle
estimate_op_runtime = "default"
runtime_estimations_mms_benchmark: bool = False
# unit: GB/s, uni-directional P2P bandwidth per card
# default value is NVLink
intra_node_bw = 300

View File

@ -0,0 +1,15 @@
import sys
from torch.utils._config_module import install_config_module
# Whether to use c10d._time_estimator for collectives runtime estimations.
runtime_estimations_use_nccl_lib_estimations: bool = False
# Config to enable sync of runtime estimations across distributed ranks,
# To prevent passes using this runtime estimations to make different
# decisions on different distributed ranks.
runtime_estimations_align_across_all_distributed_ranks: bool = False
# adds patch, save_config, etc
install_config_module(sys.modules[__name__])

View File

@ -26,6 +26,7 @@ import sympy
import torch
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
import torch.utils._pytree as pytree
from torch._dynamo.utils import counters, dynamo_timed
from torch._inductor.codecache import LambdaFuture, PyCodeCache
from torch._inductor.ir import TritonTemplateCallerBase
@ -35,10 +36,13 @@ from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._triton import has_triton
from . import comms, config, dependencies, ir, metrics
from . import comms, config, config_comms, dependencies, ir, metrics
from .analyze_preserves_zero_mask import can_codegen_without_upcasts
from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel
from .comm_analysis import estimate_nccl_collective_runtime
from .comm_analysis import (
estimate_nccl_collective_runtime,
estimate_nccl_collective_runtime_nccl_estimator,
)
from .dependencies import Dep, MemoryDep, StarDep, WeakDep
from .exc import GPUTooOldForTriton, TritonMissing
from .fx_utils import count_flops_fx
@ -212,6 +216,7 @@ class BaseSchedulerNode:
min_order: int
max_order: int
mpi_node: MemoryPlanningInfoForNode
override_estimated_runtime: Optional[float] = None
def __init__(self, scheduler: Scheduler) -> None:
self.scheduler: Scheduler = scheduler
@ -823,10 +828,16 @@ class BaseSchedulerNode:
counters["inductor"]["flop_count"] += resolved_flops
return resolved_flops
@cache_on_self
def get_estimated_runtime(self) -> float:
if self.override_estimated_runtime is not None:
return self.override_estimated_runtime
return self._get_estimated_runtime()
@cache_on_self
def _get_estimated_runtime(self) -> float:
"""
Returns estimated op runtime in nanoseconds (ns)
Returns estimated op runtime in milliseconds (ms)
"""
buf = self.get_nodes()[0].get_outputs()[0]
layout = buf.node.get_output_spec()
@ -838,6 +849,21 @@ class BaseSchedulerNode:
if is_collective(self.node):
assert isinstance(self.node, ir.IRNode)
try:
if config_comms.runtime_estimations_use_nccl_lib_estimations:
cache_key = get_estimate_runtime_cache_key_from_snode(self)
cache = get_estimate_runtime_cache()
cache_val = cache.lookup(cache_key)
if cache_val is not None:
assert isinstance(cache_val, float)
return cache_val
ms = estimate_nccl_collective_runtime_nccl_estimator(self)
if ms is None:
# NCCL estimations fail: fallback to in-tree algorithmic estimation.
ms = estimate_nccl_collective_runtime(self.node)
cache.set_value(cache_key, value=ms)
return ms
return estimate_nccl_collective_runtime(self.node)
except ValueError as e:
# We don't know how to estimate runtime for this collective,
@ -856,6 +882,10 @@ class BaseSchedulerNode:
# since it doesn't take extra time to get the result after the collective is completed.
return 0
ret = maybe_estimate_runtime_benchmark(self)
if ret is not None:
return ret
dtype = buf.node.maybe_get_dtype()
try:
gpu_memory_bandwidth = get_gpu_dram_gbps()
@ -876,7 +906,9 @@ class BaseSchedulerNode:
if flops_est == 0 or flops_est is None:
# no flops estimate, so fall back to memory estimate
return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth
ns = self.get_read_write_buffers_sizes() / gpu_memory_bandwidth
ms = ns / 1e6
return ms
# TODO(xmfan): find a better heuristic to model FLOPS/latency relationship
factor = 1.0
@ -885,8 +917,10 @@ class BaseSchedulerNode:
compute_time = (factor * flops_est / gpu_flops) * 1e9
transfer_time = counted_bytes / gpu_memory_bandwidth
# Return estimated runtime in nanoseconds
return max(compute_time, transfer_time)
# Return estimated runtime in milliseconds
ns = max(compute_time, transfer_time)
ms = ns / 1e6
return ms
def get_template_node(self) -> Optional[ir.TemplateBuffer]:
return None
@ -911,6 +945,77 @@ class BaseSchedulerNode:
return prologue, template_node, epilogue
@functools.cache
def get_estimate_runtime_cache() -> torch._inductor.codecache.LocalCache:
return torch._inductor.codecache.LocalCache()
def get_estimate_runtime_cache_key_from_snode(snode: BaseSchedulerNode) -> str:
python_kernel_name = getattr(snode.node, "python_kernel_name", "")
args = snode.node.inputs # type: ignore[union-attr]
args = snode.node.fill_non_provided_args( # type: ignore[union-attr]
[*args, *snode.node.constant_args], # type: ignore[union-attr]
snode.node.kwargs, # type: ignore[union-attr]
)
kwargs = snode.node.kwargs # type: ignore[union-attr]
flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs))
def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def]
return isinstance(x, ir.IRNode) and not isinstance(x, ir.GeneratorState)
cache_key = str(
(python_kernel_name,)
+ tuple(tuple(a.get_size()) if _is_tensor_ir(a) else None for a in flat_args)
)
return cache_key
def _get_mm_like_fn(snode: BaseSchedulerNode) -> Optional[Callable[[Any], Any]]:
if not isinstance(snode, ExternKernelSchedulerNode):
return None
mms_fns = {
"extern_kernels.mm": torch.ops.aten.mm,
"extern_kernels.bmm": torch.ops.aten.bmm,
"extern_kernels.addmm": torch.ops.aten.addmm,
}
python_kernel_name = getattr(snode.node, "python_kernel_name", "")
if python_kernel_name not in mms_fns:
return None
if not isinstance(snode.node, ir.ExternKernel):
return None
return mms_fns[python_kernel_name]
def maybe_estimate_runtime_benchmark(snode: BaseSchedulerNode) -> Optional[float]:
bench_fn = None
args_kwargs_fn = None
if config.runtime_estimations_mms_benchmark:
mm_fn = _get_mm_like_fn(snode)
if mm_fn is None:
return None
bench_fn = mm_fn
args_kwargs_fn = lambda: snode_args_kwargs(snode) # noqa: E731
else:
return None
cache_key = get_estimate_runtime_cache_key_from_snode(snode)
cache = get_estimate_runtime_cache()
cache_val = cache.lookup(cache_key)
if cache_val is not None:
assert isinstance(cache_val, float)
return cache_val
from .utils import snode_args_kwargs
args, kwargs = args_kwargs_fn()
from triton.testing import do_bench
ms = do_bench(lambda: bench_fn(*args, **kwargs))
cache.set_value(cache_key, value=ms)
return ms
class WhyNoFuse:
# TODO when we drop support for Python < 3.10, we can use
# @dataclass(slots=True) instead of manually specifying __slots__.
@ -2094,6 +2199,10 @@ class NodeUser:
_post_grad_graph_counter = itertools.count()
def used_non_deterministic_runtime_estimations() -> bool:
return config.runtime_estimations_mms_benchmark
class Scheduler:
"""
A Scheduler is a graph of BaseSchedulerNodes. It is responsible for
@ -2214,6 +2323,17 @@ class Scheduler:
assign_memory_planning_info_for_scheduler_buffers(
self.nodes, self.name_to_buf
)
if (
used_non_deterministic_runtime_estimations()
and config_comms.runtime_estimations_align_across_all_distributed_ranks
):
from .comms import (
align_runtime_estimations_across_all_distributed_ranks,
)
align_runtime_estimations_across_all_distributed_ranks(self.nodes)
from torch._logging import trace_structured
trace_structured(

View File

@ -59,6 +59,7 @@ from unittest import mock
import sympy
import torch
import torch.utils._pytree as pytree
from torch._inductor.analysis.device_info import datasheet_tops
from torch._inductor.runtime.hints import DeviceProperties
from torch.utils._dtype_abbrs import dtype_abbrs
@ -3666,3 +3667,38 @@ _unstable_customized_partition_wrapper = CUDAGraphWrapper()
def set_customized_partition_wrappers(wrapper: CUDAGraphWrapperType) -> None:
_unstable_customized_partition_wrapper.wrapper = wrapper
def snode_args_kwargs(snode: BaseSchedulerNode) -> tuple[list[Any], dict[str, Any]]:
args = snode.node.inputs # type: ignore[union-attr]
args = snode.node.fill_non_provided_args( # type: ignore[union-attr]
[*args, *snode.node.constant_args], # type: ignore[union-attr]
snode.node.kwargs, # type: ignore[union-attr]
)
kwargs = snode.node.kwargs # type: ignore[union-attr]
flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs))
def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def]
return isinstance(x, torch._inductor.ir.IRNode) and not isinstance(
x, torch._inductor.ir.GeneratorState
)
flat_args = [
torch._inductor.ir.ir_node_to_tensor(a, guard_shape=False)
if _is_tensor_ir(a)
else a
for a in flat_args
]
def _tensor(size, dtype, device) -> torch.Tensor: # type: ignore[no-untyped-def]
return torch.empty(size, dtype=dtype, device=device)
def to_real_tensor(e: Any) -> Any:
if not isinstance(e, torch.Tensor):
return e
out = _tensor(e.size(), e.dtype, e.device)
return out
flat_args = [to_real_tensor(a) for a in flat_args]
args, kwargs = pytree.tree_unflatten(flat_args, flat_args_pytree_spec)
return args, kwargs