mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Runtime estimations: use nccl estimator; mm only benchmark mode (#161405)
During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms. Adding optional usage of: - c10d.time_estimator for collectives, which is based on NCCL estimator Benchmark mode only for matmuls, as they are highly dependent on mm backend - The logic mostly copied from Ruisi's PRs for inductor simple_fsdp https://github.com/pytorch/pytorch/pull/157572 This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()` Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161405 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
3f5993316e
commit
25c170b72e
@ -259,6 +259,11 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
|||||||
"reorder_compute_for_overlap",
|
"reorder_compute_for_overlap",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@patch.object(
|
||||||
|
torch._inductor.config,
|
||||||
|
"runtime_estimations_mms_benchmark",
|
||||||
|
False,
|
||||||
|
)
|
||||||
def test_reorder_compute_for_overlap(self):
|
def test_reorder_compute_for_overlap(self):
|
||||||
def func(a, *, tag, ranks, group_size):
|
def func(a, *, tag, ranks, group_size):
|
||||||
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
|
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
|
||||||
|
@ -22,8 +22,13 @@ from torch._inductor.comms import (
|
|||||||
sink_waits_iterative,
|
sink_waits_iterative,
|
||||||
)
|
)
|
||||||
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
|
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
|
||||||
from torch._inductor.scheduler import BaseSchedulerNode
|
from torch._inductor.scheduler import (
|
||||||
from torch._inductor.utils import run_and_get_triton_code
|
_get_mm_like_fn,
|
||||||
|
BaseSchedulerNode,
|
||||||
|
get_estimate_runtime_cache,
|
||||||
|
get_estimate_runtime_cache_key_from_snode,
|
||||||
|
)
|
||||||
|
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
|
||||||
from torch.distributed.distributed_c10d import GroupMember
|
from torch.distributed.distributed_c10d import GroupMember
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
from torch.testing._internal.common_cuda import SM80OrLater
|
from torch.testing._internal.common_cuda import SM80OrLater
|
||||||
@ -1568,11 +1573,21 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||||||
inputs = [x, w, ag_0, ag_1, ag_2, ag_3]
|
inputs = [x, w, ag_0, ag_1, ag_2, ag_3]
|
||||||
correct = func(*inputs, **self.get_world_trs())
|
correct = func(*inputs, **self.get_world_trs())
|
||||||
|
|
||||||
with torch._inductor.config.patch(
|
with (
|
||||||
{
|
torch._inductor.config.patch(
|
||||||
"bucket_all_gathers_fx": "all",
|
{
|
||||||
"reorder_for_compute_comm_overlap": False,
|
"bucket_all_gathers_fx": "all",
|
||||||
}
|
"reorder_for_compute_comm_overlap": False,
|
||||||
|
"runtime_estimations_mms_benchmark": True,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
torch._inductor.config_comms.patch(
|
||||||
|
{
|
||||||
|
"runtime_estimations_align_across_all_distributed_ranks": True,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
# Clearing cache to cover runtime_estimations_mms_benchmark that use LocalCache
|
||||||
|
fresh_inductor_cache(),
|
||||||
):
|
):
|
||||||
compiled = torch.compile(func)
|
compiled = torch.compile(func)
|
||||||
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
|
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
|
||||||
@ -1801,6 +1816,17 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||||||
def _reorder_communication_preserving_peak_memory(
|
def _reorder_communication_preserving_peak_memory(
|
||||||
snodes: list[BaseSchedulerNode],
|
snodes: list[BaseSchedulerNode],
|
||||||
) -> list[BaseSchedulerNode]:
|
) -> list[BaseSchedulerNode]:
|
||||||
|
if torch._inductor.config.runtime_estimations_mms_benchmark:
|
||||||
|
cache = get_estimate_runtime_cache()
|
||||||
|
for snode in snodes:
|
||||||
|
if _get_mm_like_fn(snode) is None:
|
||||||
|
continue
|
||||||
|
cache_key = get_estimate_runtime_cache_key_from_snode(snode)
|
||||||
|
assert cache.lookup(cache_key) is not None
|
||||||
|
|
||||||
|
if torch._inductor.config_comms.runtime_estimations_align_across_all_distributed_ranks:
|
||||||
|
for snode in snodes:
|
||||||
|
assert snode.override_estimated_runtime is not None
|
||||||
nonlocal node_stats
|
nonlocal node_stats
|
||||||
(
|
(
|
||||||
reordered_snodes,
|
reordered_snodes,
|
||||||
@ -1808,20 +1834,30 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
|||||||
) = _reorder_communication_preserving_peak_memory_internal(snodes)
|
) = _reorder_communication_preserving_peak_memory_internal(snodes)
|
||||||
return reordered_snodes
|
return reordered_snodes
|
||||||
|
|
||||||
with torch._inductor.config.patch(
|
with (
|
||||||
{
|
torch._inductor.config.patch(
|
||||||
"bucket_all_gathers_fx": "all",
|
{
|
||||||
"bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2,
|
"bucket_all_gathers_fx": "all",
|
||||||
"bucket_reduce_scatters_fx": "all",
|
"bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2,
|
||||||
"bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2,
|
"bucket_reduce_scatters_fx": "all",
|
||||||
"reorder_for_compute_comm_overlap": True,
|
"bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2,
|
||||||
"reorder_for_compute_comm_overlap_passes": [
|
"reorder_for_compute_comm_overlap": True,
|
||||||
sink_waits_iterative,
|
"reorder_for_compute_comm_overlap_passes": [
|
||||||
_reorder_communication_preserving_peak_memory,
|
sink_waits_iterative,
|
||||||
],
|
_reorder_communication_preserving_peak_memory,
|
||||||
"allow_buffer_reuse": False,
|
],
|
||||||
"test_configs.track_memory_lifecycle": "error",
|
"allow_buffer_reuse": False,
|
||||||
}
|
"test_configs.track_memory_lifecycle": "error",
|
||||||
|
"runtime_estimations_mms_benchmark": True,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
torch._inductor.config_comms.patch(
|
||||||
|
{
|
||||||
|
"runtime_estimations_align_across_all_distributed_ranks": True,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
# Clearing cache to cover runtime_estimations_mms_benchmark that use LocalCache
|
||||||
|
fresh_inductor_cache(),
|
||||||
):
|
):
|
||||||
compiled = torch.compile(func, fullgraph=True)
|
compiled = torch.compile(func, fullgraph=True)
|
||||||
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
|
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
|
||||||
|
@ -1,20 +1,26 @@
|
|||||||
import functools
|
import functools
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from . import ir
|
from . import ir
|
||||||
from .utils import get_dtype_size, sympy_product
|
from .utils import get_dtype_size, snode_args_kwargs, sympy_product
|
||||||
from .virtualized import V
|
from .virtualized import V
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class NCCL_COLL(IntEnum):
|
class NCCL_COLL(IntEnum):
|
||||||
ALL_REDUCE = 0
|
ALL_REDUCE = 0
|
||||||
ALL_GATHER = 1
|
ALL_GATHER = 1
|
||||||
REDUCE_SCATTER = 2
|
REDUCE_SCATTER = 2
|
||||||
|
ALL_TO_ALL = 3
|
||||||
|
|
||||||
|
|
||||||
class NVIDIA_GPU_TYPE(IntEnum):
|
class NVIDIA_GPU_TYPE(IntEnum):
|
||||||
@ -49,6 +55,8 @@ def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
|
|||||||
return NCCL_COLL.ALL_GATHER
|
return NCCL_COLL.ALL_GATHER
|
||||||
elif "reduce_scatter" in kernel_name:
|
elif "reduce_scatter" in kernel_name:
|
||||||
return NCCL_COLL.REDUCE_SCATTER
|
return NCCL_COLL.REDUCE_SCATTER
|
||||||
|
elif "torch.ops._dtensor.shard_dim_alltoall.default" in kernel_name:
|
||||||
|
return NCCL_COLL.ALL_TO_ALL
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported collective kernel: {kernel_name}")
|
raise ValueError(f"Unsupported collective kernel: {kernel_name}")
|
||||||
|
|
||||||
@ -158,9 +166,53 @@ llMaxBws = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_nccl_collective_runtime_nccl_estimator(snode) -> Optional[float]: # type: ignore[no-untyped-def]
|
||||||
|
kernel = snode.node
|
||||||
|
assert kernel is not None
|
||||||
|
py_kernel_name = getattr(kernel, "python_kernel_name", "")
|
||||||
|
if not ("all_gather" in py_kernel_name or "reduce_scatter" in py_kernel_name):
|
||||||
|
# NCCL of version 2.27 sometimes unrecoverably fail for all_to_all, all_reduce
|
||||||
|
return None
|
||||||
|
|
||||||
|
from torch.distributed.distributed_c10d import _resolve_process_group
|
||||||
|
|
||||||
|
pg_name = kernel.constant_args[-1] # type: ignore[attr-defined]
|
||||||
|
pg = _resolve_process_group(pg_name)
|
||||||
|
rank: int = torch.distributed.get_rank(pg)
|
||||||
|
# TODO(ivankobzarev): Figure out how we can use time estimations,
|
||||||
|
# without cuda allocations.
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
|
||||||
|
fn = eval(py_kernel_name)
|
||||||
|
args, kwargs = snode_args_kwargs(snode)
|
||||||
|
|
||||||
|
# TODO(ivankobzarev): fix out variants snode_args_kwargs
|
||||||
|
if "all_gather_into_tensor_out" in py_kernel_name:
|
||||||
|
args = args[1:] + args[0]
|
||||||
|
|
||||||
|
try:
|
||||||
|
with torch.distributed._time_estimator(
|
||||||
|
group=pg, device=device
|
||||||
|
) as time_estimator:
|
||||||
|
w = fn(*args, **kwargs)
|
||||||
|
torch.ops._c10d_functional.wait_tensor.default(w)
|
||||||
|
except Exception as e:
|
||||||
|
# NCCL estimator can fail
|
||||||
|
log.info(e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
est_time_us = time_estimator.estimated_time
|
||||||
|
# -1000 constant is NCCL return in case of error during estimations.
|
||||||
|
# Observed it for all_to_all estimations.
|
||||||
|
if est_time_us < 0:
|
||||||
|
return None
|
||||||
|
est_time_ms = est_time_us / 1e3
|
||||||
|
return est_time_ms
|
||||||
|
|
||||||
|
|
||||||
def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
|
def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
|
||||||
"""
|
"""
|
||||||
Returns estimated NCCL collective runtime in nanoseconds (ns).
|
Returns estimated NCCL collective runtime in milliseconds (ms).
|
||||||
|
|
||||||
The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc.
|
The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc.
|
||||||
We aim to estimate the runtime as accurately as possible.
|
We aim to estimate the runtime as accurately as possible.
|
||||||
@ -220,6 +272,8 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
|
|||||||
|
|
||||||
if coll == NCCL_COLL.ALL_REDUCE:
|
if coll == NCCL_COLL.ALL_REDUCE:
|
||||||
nsteps = 2 * (nRanks - 1)
|
nsteps = 2 * (nRanks - 1)
|
||||||
|
elif coll == NCCL_COLL.ALL_TO_ALL:
|
||||||
|
nsteps = 2 * (nRanks - 1)
|
||||||
elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
|
elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
|
||||||
nsteps = nRanks - 1
|
nsteps = nRanks - 1
|
||||||
|
|
||||||
@ -237,7 +291,7 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
|
|||||||
nInterSteps = 2 * nNodes
|
nInterSteps = 2 * nNodes
|
||||||
else:
|
else:
|
||||||
nInterSteps = 0
|
nInterSteps = 0
|
||||||
elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
|
elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER, NCCL_COLL.ALL_TO_ALL):
|
||||||
nInterSteps = nNodes - 1
|
nInterSteps = nNodes - 1
|
||||||
|
|
||||||
# First compute latency in us; then at the end, convert it to ns
|
# First compute latency in us; then at the end, convert it to ns
|
||||||
@ -256,7 +310,9 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
|
|||||||
|
|
||||||
# =============== final result ===============
|
# =============== final result ===============
|
||||||
transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns
|
transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns
|
||||||
return transport_ns + latency_ns
|
ns = transport_ns + latency_ns
|
||||||
|
ms = ns / 1e6
|
||||||
|
return ms
|
||||||
|
|
||||||
|
|
||||||
################################################################################################################
|
################################################################################################################
|
||||||
|
@ -52,6 +52,28 @@ if TYPE_CHECKING:
|
|||||||
from torch._inductor.scheduler import BaseSchedulerNode
|
from torch._inductor.scheduler import BaseSchedulerNode
|
||||||
|
|
||||||
|
|
||||||
|
def align_runtime_estimations_across_all_distributed_ranks(
|
||||||
|
snodes: list[BaseSchedulerNode],
|
||||||
|
):
|
||||||
|
runtime_estimations = {}
|
||||||
|
for snode in snodes:
|
||||||
|
runtime_estimations[snode] = snode.get_estimated_runtime()
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed.distributed_c10d import _get_default_group
|
||||||
|
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
pg = _get_default_group()
|
||||||
|
gathered_runtime_estimations: list[list[float]] = [[] for _ in range(world_size)]
|
||||||
|
dist.all_gather_object(
|
||||||
|
gathered_runtime_estimations, list(runtime_estimations.values()), pg
|
||||||
|
)
|
||||||
|
median_runtime_estimations = torch.median(
|
||||||
|
torch.tensor(gathered_runtime_estimations), dim=0
|
||||||
|
).values.tolist()
|
||||||
|
for i in range(len(snodes)):
|
||||||
|
snodes[i].override_estimated_runtime = median_runtime_estimations[i]
|
||||||
|
|
||||||
|
|
||||||
def sink_waits(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
|
def sink_waits(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
|
||||||
"""
|
"""
|
||||||
Greedily schedules waits as late as possible.
|
Greedily schedules waits as late as possible.
|
||||||
|
@ -416,6 +416,8 @@ bucket_reduce_scatters_fx_bucket_size_determinator: Optional[Callable[[int], int
|
|||||||
# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle
|
# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle
|
||||||
estimate_op_runtime = "default"
|
estimate_op_runtime = "default"
|
||||||
|
|
||||||
|
runtime_estimations_mms_benchmark: bool = False
|
||||||
|
|
||||||
# unit: GB/s, uni-directional P2P bandwidth per card
|
# unit: GB/s, uni-directional P2P bandwidth per card
|
||||||
# default value is NVLink
|
# default value is NVLink
|
||||||
intra_node_bw = 300
|
intra_node_bw = 300
|
||||||
|
15
torch/_inductor/config_comms.py
Normal file
15
torch/_inductor/config_comms.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
from torch.utils._config_module import install_config_module
|
||||||
|
|
||||||
|
|
||||||
|
# Whether to use c10d._time_estimator for collectives runtime estimations.
|
||||||
|
runtime_estimations_use_nccl_lib_estimations: bool = False
|
||||||
|
|
||||||
|
# Config to enable sync of runtime estimations across distributed ranks,
|
||||||
|
# To prevent passes using this runtime estimations to make different
|
||||||
|
# decisions on different distributed ranks.
|
||||||
|
runtime_estimations_align_across_all_distributed_ranks: bool = False
|
||||||
|
|
||||||
|
# adds patch, save_config, etc
|
||||||
|
install_config_module(sys.modules[__name__])
|
@ -26,6 +26,7 @@ import sympy
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||||
|
import torch.utils._pytree as pytree
|
||||||
from torch._dynamo.utils import counters, dynamo_timed
|
from torch._dynamo.utils import counters, dynamo_timed
|
||||||
from torch._inductor.codecache import LambdaFuture, PyCodeCache
|
from torch._inductor.codecache import LambdaFuture, PyCodeCache
|
||||||
from torch._inductor.ir import TritonTemplateCallerBase
|
from torch._inductor.ir import TritonTemplateCallerBase
|
||||||
@ -35,10 +36,13 @@ from torch.utils._ordered_set import OrderedSet
|
|||||||
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
||||||
from torch.utils._triton import has_triton
|
from torch.utils._triton import has_triton
|
||||||
|
|
||||||
from . import comms, config, dependencies, ir, metrics
|
from . import comms, config, config_comms, dependencies, ir, metrics
|
||||||
from .analyze_preserves_zero_mask import can_codegen_without_upcasts
|
from .analyze_preserves_zero_mask import can_codegen_without_upcasts
|
||||||
from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel
|
from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel
|
||||||
from .comm_analysis import estimate_nccl_collective_runtime
|
from .comm_analysis import (
|
||||||
|
estimate_nccl_collective_runtime,
|
||||||
|
estimate_nccl_collective_runtime_nccl_estimator,
|
||||||
|
)
|
||||||
from .dependencies import Dep, MemoryDep, StarDep, WeakDep
|
from .dependencies import Dep, MemoryDep, StarDep, WeakDep
|
||||||
from .exc import GPUTooOldForTriton, TritonMissing
|
from .exc import GPUTooOldForTriton, TritonMissing
|
||||||
from .fx_utils import count_flops_fx
|
from .fx_utils import count_flops_fx
|
||||||
@ -212,6 +216,7 @@ class BaseSchedulerNode:
|
|||||||
min_order: int
|
min_order: int
|
||||||
max_order: int
|
max_order: int
|
||||||
mpi_node: MemoryPlanningInfoForNode
|
mpi_node: MemoryPlanningInfoForNode
|
||||||
|
override_estimated_runtime: Optional[float] = None
|
||||||
|
|
||||||
def __init__(self, scheduler: Scheduler) -> None:
|
def __init__(self, scheduler: Scheduler) -> None:
|
||||||
self.scheduler: Scheduler = scheduler
|
self.scheduler: Scheduler = scheduler
|
||||||
@ -823,10 +828,16 @@ class BaseSchedulerNode:
|
|||||||
counters["inductor"]["flop_count"] += resolved_flops
|
counters["inductor"]["flop_count"] += resolved_flops
|
||||||
return resolved_flops
|
return resolved_flops
|
||||||
|
|
||||||
@cache_on_self
|
|
||||||
def get_estimated_runtime(self) -> float:
|
def get_estimated_runtime(self) -> float:
|
||||||
|
if self.override_estimated_runtime is not None:
|
||||||
|
return self.override_estimated_runtime
|
||||||
|
|
||||||
|
return self._get_estimated_runtime()
|
||||||
|
|
||||||
|
@cache_on_self
|
||||||
|
def _get_estimated_runtime(self) -> float:
|
||||||
"""
|
"""
|
||||||
Returns estimated op runtime in nanoseconds (ns)
|
Returns estimated op runtime in milliseconds (ms)
|
||||||
"""
|
"""
|
||||||
buf = self.get_nodes()[0].get_outputs()[0]
|
buf = self.get_nodes()[0].get_outputs()[0]
|
||||||
layout = buf.node.get_output_spec()
|
layout = buf.node.get_output_spec()
|
||||||
@ -838,6 +849,21 @@ class BaseSchedulerNode:
|
|||||||
if is_collective(self.node):
|
if is_collective(self.node):
|
||||||
assert isinstance(self.node, ir.IRNode)
|
assert isinstance(self.node, ir.IRNode)
|
||||||
try:
|
try:
|
||||||
|
if config_comms.runtime_estimations_use_nccl_lib_estimations:
|
||||||
|
cache_key = get_estimate_runtime_cache_key_from_snode(self)
|
||||||
|
cache = get_estimate_runtime_cache()
|
||||||
|
cache_val = cache.lookup(cache_key)
|
||||||
|
if cache_val is not None:
|
||||||
|
assert isinstance(cache_val, float)
|
||||||
|
return cache_val
|
||||||
|
|
||||||
|
ms = estimate_nccl_collective_runtime_nccl_estimator(self)
|
||||||
|
if ms is None:
|
||||||
|
# NCCL estimations fail: fallback to in-tree algorithmic estimation.
|
||||||
|
ms = estimate_nccl_collective_runtime(self.node)
|
||||||
|
|
||||||
|
cache.set_value(cache_key, value=ms)
|
||||||
|
return ms
|
||||||
return estimate_nccl_collective_runtime(self.node)
|
return estimate_nccl_collective_runtime(self.node)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# We don't know how to estimate runtime for this collective,
|
# We don't know how to estimate runtime for this collective,
|
||||||
@ -856,6 +882,10 @@ class BaseSchedulerNode:
|
|||||||
# since it doesn't take extra time to get the result after the collective is completed.
|
# since it doesn't take extra time to get the result after the collective is completed.
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
ret = maybe_estimate_runtime_benchmark(self)
|
||||||
|
if ret is not None:
|
||||||
|
return ret
|
||||||
|
|
||||||
dtype = buf.node.maybe_get_dtype()
|
dtype = buf.node.maybe_get_dtype()
|
||||||
try:
|
try:
|
||||||
gpu_memory_bandwidth = get_gpu_dram_gbps()
|
gpu_memory_bandwidth = get_gpu_dram_gbps()
|
||||||
@ -876,7 +906,9 @@ class BaseSchedulerNode:
|
|||||||
|
|
||||||
if flops_est == 0 or flops_est is None:
|
if flops_est == 0 or flops_est is None:
|
||||||
# no flops estimate, so fall back to memory estimate
|
# no flops estimate, so fall back to memory estimate
|
||||||
return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth
|
ns = self.get_read_write_buffers_sizes() / gpu_memory_bandwidth
|
||||||
|
ms = ns / 1e6
|
||||||
|
return ms
|
||||||
|
|
||||||
# TODO(xmfan): find a better heuristic to model FLOPS/latency relationship
|
# TODO(xmfan): find a better heuristic to model FLOPS/latency relationship
|
||||||
factor = 1.0
|
factor = 1.0
|
||||||
@ -885,8 +917,10 @@ class BaseSchedulerNode:
|
|||||||
compute_time = (factor * flops_est / gpu_flops) * 1e9
|
compute_time = (factor * flops_est / gpu_flops) * 1e9
|
||||||
transfer_time = counted_bytes / gpu_memory_bandwidth
|
transfer_time = counted_bytes / gpu_memory_bandwidth
|
||||||
|
|
||||||
# Return estimated runtime in nanoseconds
|
# Return estimated runtime in milliseconds
|
||||||
return max(compute_time, transfer_time)
|
ns = max(compute_time, transfer_time)
|
||||||
|
ms = ns / 1e6
|
||||||
|
return ms
|
||||||
|
|
||||||
def get_template_node(self) -> Optional[ir.TemplateBuffer]:
|
def get_template_node(self) -> Optional[ir.TemplateBuffer]:
|
||||||
return None
|
return None
|
||||||
@ -911,6 +945,77 @@ class BaseSchedulerNode:
|
|||||||
return prologue, template_node, epilogue
|
return prologue, template_node, epilogue
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def get_estimate_runtime_cache() -> torch._inductor.codecache.LocalCache:
|
||||||
|
return torch._inductor.codecache.LocalCache()
|
||||||
|
|
||||||
|
|
||||||
|
def get_estimate_runtime_cache_key_from_snode(snode: BaseSchedulerNode) -> str:
|
||||||
|
python_kernel_name = getattr(snode.node, "python_kernel_name", "")
|
||||||
|
args = snode.node.inputs # type: ignore[union-attr]
|
||||||
|
args = snode.node.fill_non_provided_args( # type: ignore[union-attr]
|
||||||
|
[*args, *snode.node.constant_args], # type: ignore[union-attr]
|
||||||
|
snode.node.kwargs, # type: ignore[union-attr]
|
||||||
|
)
|
||||||
|
kwargs = snode.node.kwargs # type: ignore[union-attr]
|
||||||
|
flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs))
|
||||||
|
|
||||||
|
def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def]
|
||||||
|
return isinstance(x, ir.IRNode) and not isinstance(x, ir.GeneratorState)
|
||||||
|
|
||||||
|
cache_key = str(
|
||||||
|
(python_kernel_name,)
|
||||||
|
+ tuple(tuple(a.get_size()) if _is_tensor_ir(a) else None for a in flat_args)
|
||||||
|
)
|
||||||
|
return cache_key
|
||||||
|
|
||||||
|
|
||||||
|
def _get_mm_like_fn(snode: BaseSchedulerNode) -> Optional[Callable[[Any], Any]]:
|
||||||
|
if not isinstance(snode, ExternKernelSchedulerNode):
|
||||||
|
return None
|
||||||
|
mms_fns = {
|
||||||
|
"extern_kernels.mm": torch.ops.aten.mm,
|
||||||
|
"extern_kernels.bmm": torch.ops.aten.bmm,
|
||||||
|
"extern_kernels.addmm": torch.ops.aten.addmm,
|
||||||
|
}
|
||||||
|
python_kernel_name = getattr(snode.node, "python_kernel_name", "")
|
||||||
|
if python_kernel_name not in mms_fns:
|
||||||
|
return None
|
||||||
|
if not isinstance(snode.node, ir.ExternKernel):
|
||||||
|
return None
|
||||||
|
return mms_fns[python_kernel_name]
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_estimate_runtime_benchmark(snode: BaseSchedulerNode) -> Optional[float]:
|
||||||
|
bench_fn = None
|
||||||
|
args_kwargs_fn = None
|
||||||
|
if config.runtime_estimations_mms_benchmark:
|
||||||
|
mm_fn = _get_mm_like_fn(snode)
|
||||||
|
if mm_fn is None:
|
||||||
|
return None
|
||||||
|
bench_fn = mm_fn
|
||||||
|
args_kwargs_fn = lambda: snode_args_kwargs(snode) # noqa: E731
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cache_key = get_estimate_runtime_cache_key_from_snode(snode)
|
||||||
|
cache = get_estimate_runtime_cache()
|
||||||
|
cache_val = cache.lookup(cache_key)
|
||||||
|
if cache_val is not None:
|
||||||
|
assert isinstance(cache_val, float)
|
||||||
|
return cache_val
|
||||||
|
|
||||||
|
from .utils import snode_args_kwargs
|
||||||
|
|
||||||
|
args, kwargs = args_kwargs_fn()
|
||||||
|
from triton.testing import do_bench
|
||||||
|
|
||||||
|
ms = do_bench(lambda: bench_fn(*args, **kwargs))
|
||||||
|
|
||||||
|
cache.set_value(cache_key, value=ms)
|
||||||
|
return ms
|
||||||
|
|
||||||
|
|
||||||
class WhyNoFuse:
|
class WhyNoFuse:
|
||||||
# TODO when we drop support for Python < 3.10, we can use
|
# TODO when we drop support for Python < 3.10, we can use
|
||||||
# @dataclass(slots=True) instead of manually specifying __slots__.
|
# @dataclass(slots=True) instead of manually specifying __slots__.
|
||||||
@ -2094,6 +2199,10 @@ class NodeUser:
|
|||||||
_post_grad_graph_counter = itertools.count()
|
_post_grad_graph_counter = itertools.count()
|
||||||
|
|
||||||
|
|
||||||
|
def used_non_deterministic_runtime_estimations() -> bool:
|
||||||
|
return config.runtime_estimations_mms_benchmark
|
||||||
|
|
||||||
|
|
||||||
class Scheduler:
|
class Scheduler:
|
||||||
"""
|
"""
|
||||||
A Scheduler is a graph of BaseSchedulerNodes. It is responsible for
|
A Scheduler is a graph of BaseSchedulerNodes. It is responsible for
|
||||||
@ -2214,6 +2323,17 @@ class Scheduler:
|
|||||||
assign_memory_planning_info_for_scheduler_buffers(
|
assign_memory_planning_info_for_scheduler_buffers(
|
||||||
self.nodes, self.name_to_buf
|
self.nodes, self.name_to_buf
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
used_non_deterministic_runtime_estimations()
|
||||||
|
and config_comms.runtime_estimations_align_across_all_distributed_ranks
|
||||||
|
):
|
||||||
|
from .comms import (
|
||||||
|
align_runtime_estimations_across_all_distributed_ranks,
|
||||||
|
)
|
||||||
|
|
||||||
|
align_runtime_estimations_across_all_distributed_ranks(self.nodes)
|
||||||
|
|
||||||
from torch._logging import trace_structured
|
from torch._logging import trace_structured
|
||||||
|
|
||||||
trace_structured(
|
trace_structured(
|
||||||
|
@ -59,6 +59,7 @@ from unittest import mock
|
|||||||
import sympy
|
import sympy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.utils._pytree as pytree
|
||||||
from torch._inductor.analysis.device_info import datasheet_tops
|
from torch._inductor.analysis.device_info import datasheet_tops
|
||||||
from torch._inductor.runtime.hints import DeviceProperties
|
from torch._inductor.runtime.hints import DeviceProperties
|
||||||
from torch.utils._dtype_abbrs import dtype_abbrs
|
from torch.utils._dtype_abbrs import dtype_abbrs
|
||||||
@ -3666,3 +3667,38 @@ _unstable_customized_partition_wrapper = CUDAGraphWrapper()
|
|||||||
|
|
||||||
def set_customized_partition_wrappers(wrapper: CUDAGraphWrapperType) -> None:
|
def set_customized_partition_wrappers(wrapper: CUDAGraphWrapperType) -> None:
|
||||||
_unstable_customized_partition_wrapper.wrapper = wrapper
|
_unstable_customized_partition_wrapper.wrapper = wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def snode_args_kwargs(snode: BaseSchedulerNode) -> tuple[list[Any], dict[str, Any]]:
|
||||||
|
args = snode.node.inputs # type: ignore[union-attr]
|
||||||
|
args = snode.node.fill_non_provided_args( # type: ignore[union-attr]
|
||||||
|
[*args, *snode.node.constant_args], # type: ignore[union-attr]
|
||||||
|
snode.node.kwargs, # type: ignore[union-attr]
|
||||||
|
)
|
||||||
|
kwargs = snode.node.kwargs # type: ignore[union-attr]
|
||||||
|
flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs))
|
||||||
|
|
||||||
|
def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def]
|
||||||
|
return isinstance(x, torch._inductor.ir.IRNode) and not isinstance(
|
||||||
|
x, torch._inductor.ir.GeneratorState
|
||||||
|
)
|
||||||
|
|
||||||
|
flat_args = [
|
||||||
|
torch._inductor.ir.ir_node_to_tensor(a, guard_shape=False)
|
||||||
|
if _is_tensor_ir(a)
|
||||||
|
else a
|
||||||
|
for a in flat_args
|
||||||
|
]
|
||||||
|
|
||||||
|
def _tensor(size, dtype, device) -> torch.Tensor: # type: ignore[no-untyped-def]
|
||||||
|
return torch.empty(size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def to_real_tensor(e: Any) -> Any:
|
||||||
|
if not isinstance(e, torch.Tensor):
|
||||||
|
return e
|
||||||
|
out = _tensor(e.size(), e.dtype, e.device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
flat_args = [to_real_tensor(a) for a in flat_args]
|
||||||
|
args, kwargs = pytree.tree_unflatten(flat_args, flat_args_pytree_spec)
|
||||||
|
return args, kwargs
|
||||||
|
Reference in New Issue
Block a user