mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[inductor] Runtime estimations: use nccl estimator; mm only benchmark mode (#161405)
During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms. Adding optional usage of: - c10d.time_estimator for collectives, which is based on NCCL estimator Benchmark mode only for matmuls, as they are highly dependent on mm backend - The logic mostly copied from Ruisi's PRs for inductor simple_fsdp https://github.com/pytorch/pytorch/pull/157572 This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()` Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161405 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
3f5993316e
commit
25c170b72e
@ -259,6 +259,11 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
"reorder_compute_for_overlap",
|
||||
],
|
||||
)
|
||||
@patch.object(
|
||||
torch._inductor.config,
|
||||
"runtime_estimations_mms_benchmark",
|
||||
False,
|
||||
)
|
||||
def test_reorder_compute_for_overlap(self):
|
||||
def func(a, *, tag, ranks, group_size):
|
||||
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)
|
||||
|
@ -22,8 +22,13 @@ from torch._inductor.comms import (
|
||||
sink_waits_iterative,
|
||||
)
|
||||
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
|
||||
from torch._inductor.scheduler import BaseSchedulerNode
|
||||
from torch._inductor.utils import run_and_get_triton_code
|
||||
from torch._inductor.scheduler import (
|
||||
_get_mm_like_fn,
|
||||
BaseSchedulerNode,
|
||||
get_estimate_runtime_cache,
|
||||
get_estimate_runtime_cache_key_from_snode,
|
||||
)
|
||||
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
|
||||
from torch.distributed.distributed_c10d import GroupMember
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_cuda import SM80OrLater
|
||||
@ -1568,11 +1573,21 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
inputs = [x, w, ag_0, ag_1, ag_2, ag_3]
|
||||
correct = func(*inputs, **self.get_world_trs())
|
||||
|
||||
with torch._inductor.config.patch(
|
||||
{
|
||||
"bucket_all_gathers_fx": "all",
|
||||
"reorder_for_compute_comm_overlap": False,
|
||||
}
|
||||
with (
|
||||
torch._inductor.config.patch(
|
||||
{
|
||||
"bucket_all_gathers_fx": "all",
|
||||
"reorder_for_compute_comm_overlap": False,
|
||||
"runtime_estimations_mms_benchmark": True,
|
||||
}
|
||||
),
|
||||
torch._inductor.config_comms.patch(
|
||||
{
|
||||
"runtime_estimations_align_across_all_distributed_ranks": True,
|
||||
}
|
||||
),
|
||||
# Clearing cache to cover runtime_estimations_mms_benchmark that use LocalCache
|
||||
fresh_inductor_cache(),
|
||||
):
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
|
||||
@ -1801,6 +1816,17 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
def _reorder_communication_preserving_peak_memory(
|
||||
snodes: list[BaseSchedulerNode],
|
||||
) -> list[BaseSchedulerNode]:
|
||||
if torch._inductor.config.runtime_estimations_mms_benchmark:
|
||||
cache = get_estimate_runtime_cache()
|
||||
for snode in snodes:
|
||||
if _get_mm_like_fn(snode) is None:
|
||||
continue
|
||||
cache_key = get_estimate_runtime_cache_key_from_snode(snode)
|
||||
assert cache.lookup(cache_key) is not None
|
||||
|
||||
if torch._inductor.config_comms.runtime_estimations_align_across_all_distributed_ranks:
|
||||
for snode in snodes:
|
||||
assert snode.override_estimated_runtime is not None
|
||||
nonlocal node_stats
|
||||
(
|
||||
reordered_snodes,
|
||||
@ -1808,20 +1834,30 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
) = _reorder_communication_preserving_peak_memory_internal(snodes)
|
||||
return reordered_snodes
|
||||
|
||||
with torch._inductor.config.patch(
|
||||
{
|
||||
"bucket_all_gathers_fx": "all",
|
||||
"bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2,
|
||||
"bucket_reduce_scatters_fx": "all",
|
||||
"bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2,
|
||||
"reorder_for_compute_comm_overlap": True,
|
||||
"reorder_for_compute_comm_overlap_passes": [
|
||||
sink_waits_iterative,
|
||||
_reorder_communication_preserving_peak_memory,
|
||||
],
|
||||
"allow_buffer_reuse": False,
|
||||
"test_configs.track_memory_lifecycle": "error",
|
||||
}
|
||||
with (
|
||||
torch._inductor.config.patch(
|
||||
{
|
||||
"bucket_all_gathers_fx": "all",
|
||||
"bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2,
|
||||
"bucket_reduce_scatters_fx": "all",
|
||||
"bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2,
|
||||
"reorder_for_compute_comm_overlap": True,
|
||||
"reorder_for_compute_comm_overlap_passes": [
|
||||
sink_waits_iterative,
|
||||
_reorder_communication_preserving_peak_memory,
|
||||
],
|
||||
"allow_buffer_reuse": False,
|
||||
"test_configs.track_memory_lifecycle": "error",
|
||||
"runtime_estimations_mms_benchmark": True,
|
||||
}
|
||||
),
|
||||
torch._inductor.config_comms.patch(
|
||||
{
|
||||
"runtime_estimations_align_across_all_distributed_ranks": True,
|
||||
}
|
||||
),
|
||||
# Clearing cache to cover runtime_estimations_mms_benchmark that use LocalCache
|
||||
fresh_inductor_cache(),
|
||||
):
|
||||
compiled = torch.compile(func, fullgraph=True)
|
||||
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
|
||||
|
@ -1,20 +1,26 @@
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
from enum import IntEnum
|
||||
from typing import Optional
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
|
||||
from . import ir
|
||||
from .utils import get_dtype_size, sympy_product
|
||||
from .utils import get_dtype_size, snode_args_kwargs, sympy_product
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NCCL_COLL(IntEnum):
|
||||
ALL_REDUCE = 0
|
||||
ALL_GATHER = 1
|
||||
REDUCE_SCATTER = 2
|
||||
ALL_TO_ALL = 3
|
||||
|
||||
|
||||
class NVIDIA_GPU_TYPE(IntEnum):
|
||||
@ -49,6 +55,8 @@ def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
|
||||
return NCCL_COLL.ALL_GATHER
|
||||
elif "reduce_scatter" in kernel_name:
|
||||
return NCCL_COLL.REDUCE_SCATTER
|
||||
elif "torch.ops._dtensor.shard_dim_alltoall.default" in kernel_name:
|
||||
return NCCL_COLL.ALL_TO_ALL
|
||||
else:
|
||||
raise ValueError(f"Unsupported collective kernel: {kernel_name}")
|
||||
|
||||
@ -158,9 +166,53 @@ llMaxBws = [
|
||||
]
|
||||
|
||||
|
||||
def estimate_nccl_collective_runtime_nccl_estimator(snode) -> Optional[float]: # type: ignore[no-untyped-def]
|
||||
kernel = snode.node
|
||||
assert kernel is not None
|
||||
py_kernel_name = getattr(kernel, "python_kernel_name", "")
|
||||
if not ("all_gather" in py_kernel_name or "reduce_scatter" in py_kernel_name):
|
||||
# NCCL of version 2.27 sometimes unrecoverably fail for all_to_all, all_reduce
|
||||
return None
|
||||
|
||||
from torch.distributed.distributed_c10d import _resolve_process_group
|
||||
|
||||
pg_name = kernel.constant_args[-1] # type: ignore[attr-defined]
|
||||
pg = _resolve_process_group(pg_name)
|
||||
rank: int = torch.distributed.get_rank(pg)
|
||||
# TODO(ivankobzarev): Figure out how we can use time estimations,
|
||||
# without cuda allocations.
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
||||
fn = eval(py_kernel_name)
|
||||
args, kwargs = snode_args_kwargs(snode)
|
||||
|
||||
# TODO(ivankobzarev): fix out variants snode_args_kwargs
|
||||
if "all_gather_into_tensor_out" in py_kernel_name:
|
||||
args = args[1:] + args[0]
|
||||
|
||||
try:
|
||||
with torch.distributed._time_estimator(
|
||||
group=pg, device=device
|
||||
) as time_estimator:
|
||||
w = fn(*args, **kwargs)
|
||||
torch.ops._c10d_functional.wait_tensor.default(w)
|
||||
except Exception as e:
|
||||
# NCCL estimator can fail
|
||||
log.info(e)
|
||||
return None
|
||||
|
||||
est_time_us = time_estimator.estimated_time
|
||||
# -1000 constant is NCCL return in case of error during estimations.
|
||||
# Observed it for all_to_all estimations.
|
||||
if est_time_us < 0:
|
||||
return None
|
||||
est_time_ms = est_time_us / 1e3
|
||||
return est_time_ms
|
||||
|
||||
|
||||
def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
|
||||
"""
|
||||
Returns estimated NCCL collective runtime in nanoseconds (ns).
|
||||
Returns estimated NCCL collective runtime in milliseconds (ms).
|
||||
|
||||
The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc.
|
||||
We aim to estimate the runtime as accurately as possible.
|
||||
@ -220,6 +272,8 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
|
||||
|
||||
if coll == NCCL_COLL.ALL_REDUCE:
|
||||
nsteps = 2 * (nRanks - 1)
|
||||
elif coll == NCCL_COLL.ALL_TO_ALL:
|
||||
nsteps = 2 * (nRanks - 1)
|
||||
elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
|
||||
nsteps = nRanks - 1
|
||||
|
||||
@ -237,7 +291,7 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
|
||||
nInterSteps = 2 * nNodes
|
||||
else:
|
||||
nInterSteps = 0
|
||||
elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
|
||||
elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER, NCCL_COLL.ALL_TO_ALL):
|
||||
nInterSteps = nNodes - 1
|
||||
|
||||
# First compute latency in us; then at the end, convert it to ns
|
||||
@ -256,7 +310,9 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
|
||||
|
||||
# =============== final result ===============
|
||||
transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns
|
||||
return transport_ns + latency_ns
|
||||
ns = transport_ns + latency_ns
|
||||
ms = ns / 1e6
|
||||
return ms
|
||||
|
||||
|
||||
################################################################################################################
|
||||
|
@ -52,6 +52,28 @@ if TYPE_CHECKING:
|
||||
from torch._inductor.scheduler import BaseSchedulerNode
|
||||
|
||||
|
||||
def align_runtime_estimations_across_all_distributed_ranks(
|
||||
snodes: list[BaseSchedulerNode],
|
||||
):
|
||||
runtime_estimations = {}
|
||||
for snode in snodes:
|
||||
runtime_estimations[snode] = snode.get_estimated_runtime()
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
pg = _get_default_group()
|
||||
gathered_runtime_estimations: list[list[float]] = [[] for _ in range(world_size)]
|
||||
dist.all_gather_object(
|
||||
gathered_runtime_estimations, list(runtime_estimations.values()), pg
|
||||
)
|
||||
median_runtime_estimations = torch.median(
|
||||
torch.tensor(gathered_runtime_estimations), dim=0
|
||||
).values.tolist()
|
||||
for i in range(len(snodes)):
|
||||
snodes[i].override_estimated_runtime = median_runtime_estimations[i]
|
||||
|
||||
|
||||
def sink_waits(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
|
||||
"""
|
||||
Greedily schedules waits as late as possible.
|
||||
|
@ -416,6 +416,8 @@ bucket_reduce_scatters_fx_bucket_size_determinator: Optional[Callable[[int], int
|
||||
# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle
|
||||
estimate_op_runtime = "default"
|
||||
|
||||
runtime_estimations_mms_benchmark: bool = False
|
||||
|
||||
# unit: GB/s, uni-directional P2P bandwidth per card
|
||||
# default value is NVLink
|
||||
intra_node_bw = 300
|
||||
|
15
torch/_inductor/config_comms.py
Normal file
15
torch/_inductor/config_comms.py
Normal file
@ -0,0 +1,15 @@
|
||||
import sys
|
||||
|
||||
from torch.utils._config_module import install_config_module
|
||||
|
||||
|
||||
# Whether to use c10d._time_estimator for collectives runtime estimations.
|
||||
runtime_estimations_use_nccl_lib_estimations: bool = False
|
||||
|
||||
# Config to enable sync of runtime estimations across distributed ranks,
|
||||
# To prevent passes using this runtime estimations to make different
|
||||
# decisions on different distributed ranks.
|
||||
runtime_estimations_align_across_all_distributed_ranks: bool = False
|
||||
|
||||
# adds patch, save_config, etc
|
||||
install_config_module(sys.modules[__name__])
|
@ -26,6 +26,7 @@ import sympy
|
||||
|
||||
import torch
|
||||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.utils import counters, dynamo_timed
|
||||
from torch._inductor.codecache import LambdaFuture, PyCodeCache
|
||||
from torch._inductor.ir import TritonTemplateCallerBase
|
||||
@ -35,10 +36,13 @@ from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
from . import comms, config, dependencies, ir, metrics
|
||||
from . import comms, config, config_comms, dependencies, ir, metrics
|
||||
from .analyze_preserves_zero_mask import can_codegen_without_upcasts
|
||||
from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel
|
||||
from .comm_analysis import estimate_nccl_collective_runtime
|
||||
from .comm_analysis import (
|
||||
estimate_nccl_collective_runtime,
|
||||
estimate_nccl_collective_runtime_nccl_estimator,
|
||||
)
|
||||
from .dependencies import Dep, MemoryDep, StarDep, WeakDep
|
||||
from .exc import GPUTooOldForTriton, TritonMissing
|
||||
from .fx_utils import count_flops_fx
|
||||
@ -212,6 +216,7 @@ class BaseSchedulerNode:
|
||||
min_order: int
|
||||
max_order: int
|
||||
mpi_node: MemoryPlanningInfoForNode
|
||||
override_estimated_runtime: Optional[float] = None
|
||||
|
||||
def __init__(self, scheduler: Scheduler) -> None:
|
||||
self.scheduler: Scheduler = scheduler
|
||||
@ -823,10 +828,16 @@ class BaseSchedulerNode:
|
||||
counters["inductor"]["flop_count"] += resolved_flops
|
||||
return resolved_flops
|
||||
|
||||
@cache_on_self
|
||||
def get_estimated_runtime(self) -> float:
|
||||
if self.override_estimated_runtime is not None:
|
||||
return self.override_estimated_runtime
|
||||
|
||||
return self._get_estimated_runtime()
|
||||
|
||||
@cache_on_self
|
||||
def _get_estimated_runtime(self) -> float:
|
||||
"""
|
||||
Returns estimated op runtime in nanoseconds (ns)
|
||||
Returns estimated op runtime in milliseconds (ms)
|
||||
"""
|
||||
buf = self.get_nodes()[0].get_outputs()[0]
|
||||
layout = buf.node.get_output_spec()
|
||||
@ -838,6 +849,21 @@ class BaseSchedulerNode:
|
||||
if is_collective(self.node):
|
||||
assert isinstance(self.node, ir.IRNode)
|
||||
try:
|
||||
if config_comms.runtime_estimations_use_nccl_lib_estimations:
|
||||
cache_key = get_estimate_runtime_cache_key_from_snode(self)
|
||||
cache = get_estimate_runtime_cache()
|
||||
cache_val = cache.lookup(cache_key)
|
||||
if cache_val is not None:
|
||||
assert isinstance(cache_val, float)
|
||||
return cache_val
|
||||
|
||||
ms = estimate_nccl_collective_runtime_nccl_estimator(self)
|
||||
if ms is None:
|
||||
# NCCL estimations fail: fallback to in-tree algorithmic estimation.
|
||||
ms = estimate_nccl_collective_runtime(self.node)
|
||||
|
||||
cache.set_value(cache_key, value=ms)
|
||||
return ms
|
||||
return estimate_nccl_collective_runtime(self.node)
|
||||
except ValueError as e:
|
||||
# We don't know how to estimate runtime for this collective,
|
||||
@ -856,6 +882,10 @@ class BaseSchedulerNode:
|
||||
# since it doesn't take extra time to get the result after the collective is completed.
|
||||
return 0
|
||||
|
||||
ret = maybe_estimate_runtime_benchmark(self)
|
||||
if ret is not None:
|
||||
return ret
|
||||
|
||||
dtype = buf.node.maybe_get_dtype()
|
||||
try:
|
||||
gpu_memory_bandwidth = get_gpu_dram_gbps()
|
||||
@ -876,7 +906,9 @@ class BaseSchedulerNode:
|
||||
|
||||
if flops_est == 0 or flops_est is None:
|
||||
# no flops estimate, so fall back to memory estimate
|
||||
return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth
|
||||
ns = self.get_read_write_buffers_sizes() / gpu_memory_bandwidth
|
||||
ms = ns / 1e6
|
||||
return ms
|
||||
|
||||
# TODO(xmfan): find a better heuristic to model FLOPS/latency relationship
|
||||
factor = 1.0
|
||||
@ -885,8 +917,10 @@ class BaseSchedulerNode:
|
||||
compute_time = (factor * flops_est / gpu_flops) * 1e9
|
||||
transfer_time = counted_bytes / gpu_memory_bandwidth
|
||||
|
||||
# Return estimated runtime in nanoseconds
|
||||
return max(compute_time, transfer_time)
|
||||
# Return estimated runtime in milliseconds
|
||||
ns = max(compute_time, transfer_time)
|
||||
ms = ns / 1e6
|
||||
return ms
|
||||
|
||||
def get_template_node(self) -> Optional[ir.TemplateBuffer]:
|
||||
return None
|
||||
@ -911,6 +945,77 @@ class BaseSchedulerNode:
|
||||
return prologue, template_node, epilogue
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_estimate_runtime_cache() -> torch._inductor.codecache.LocalCache:
|
||||
return torch._inductor.codecache.LocalCache()
|
||||
|
||||
|
||||
def get_estimate_runtime_cache_key_from_snode(snode: BaseSchedulerNode) -> str:
|
||||
python_kernel_name = getattr(snode.node, "python_kernel_name", "")
|
||||
args = snode.node.inputs # type: ignore[union-attr]
|
||||
args = snode.node.fill_non_provided_args( # type: ignore[union-attr]
|
||||
[*args, *snode.node.constant_args], # type: ignore[union-attr]
|
||||
snode.node.kwargs, # type: ignore[union-attr]
|
||||
)
|
||||
kwargs = snode.node.kwargs # type: ignore[union-attr]
|
||||
flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs))
|
||||
|
||||
def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def]
|
||||
return isinstance(x, ir.IRNode) and not isinstance(x, ir.GeneratorState)
|
||||
|
||||
cache_key = str(
|
||||
(python_kernel_name,)
|
||||
+ tuple(tuple(a.get_size()) if _is_tensor_ir(a) else None for a in flat_args)
|
||||
)
|
||||
return cache_key
|
||||
|
||||
|
||||
def _get_mm_like_fn(snode: BaseSchedulerNode) -> Optional[Callable[[Any], Any]]:
|
||||
if not isinstance(snode, ExternKernelSchedulerNode):
|
||||
return None
|
||||
mms_fns = {
|
||||
"extern_kernels.mm": torch.ops.aten.mm,
|
||||
"extern_kernels.bmm": torch.ops.aten.bmm,
|
||||
"extern_kernels.addmm": torch.ops.aten.addmm,
|
||||
}
|
||||
python_kernel_name = getattr(snode.node, "python_kernel_name", "")
|
||||
if python_kernel_name not in mms_fns:
|
||||
return None
|
||||
if not isinstance(snode.node, ir.ExternKernel):
|
||||
return None
|
||||
return mms_fns[python_kernel_name]
|
||||
|
||||
|
||||
def maybe_estimate_runtime_benchmark(snode: BaseSchedulerNode) -> Optional[float]:
|
||||
bench_fn = None
|
||||
args_kwargs_fn = None
|
||||
if config.runtime_estimations_mms_benchmark:
|
||||
mm_fn = _get_mm_like_fn(snode)
|
||||
if mm_fn is None:
|
||||
return None
|
||||
bench_fn = mm_fn
|
||||
args_kwargs_fn = lambda: snode_args_kwargs(snode) # noqa: E731
|
||||
else:
|
||||
return None
|
||||
|
||||
cache_key = get_estimate_runtime_cache_key_from_snode(snode)
|
||||
cache = get_estimate_runtime_cache()
|
||||
cache_val = cache.lookup(cache_key)
|
||||
if cache_val is not None:
|
||||
assert isinstance(cache_val, float)
|
||||
return cache_val
|
||||
|
||||
from .utils import snode_args_kwargs
|
||||
|
||||
args, kwargs = args_kwargs_fn()
|
||||
from triton.testing import do_bench
|
||||
|
||||
ms = do_bench(lambda: bench_fn(*args, **kwargs))
|
||||
|
||||
cache.set_value(cache_key, value=ms)
|
||||
return ms
|
||||
|
||||
|
||||
class WhyNoFuse:
|
||||
# TODO when we drop support for Python < 3.10, we can use
|
||||
# @dataclass(slots=True) instead of manually specifying __slots__.
|
||||
@ -2094,6 +2199,10 @@ class NodeUser:
|
||||
_post_grad_graph_counter = itertools.count()
|
||||
|
||||
|
||||
def used_non_deterministic_runtime_estimations() -> bool:
|
||||
return config.runtime_estimations_mms_benchmark
|
||||
|
||||
|
||||
class Scheduler:
|
||||
"""
|
||||
A Scheduler is a graph of BaseSchedulerNodes. It is responsible for
|
||||
@ -2214,6 +2323,17 @@ class Scheduler:
|
||||
assign_memory_planning_info_for_scheduler_buffers(
|
||||
self.nodes, self.name_to_buf
|
||||
)
|
||||
|
||||
if (
|
||||
used_non_deterministic_runtime_estimations()
|
||||
and config_comms.runtime_estimations_align_across_all_distributed_ranks
|
||||
):
|
||||
from .comms import (
|
||||
align_runtime_estimations_across_all_distributed_ranks,
|
||||
)
|
||||
|
||||
align_runtime_estimations_across_all_distributed_ranks(self.nodes)
|
||||
|
||||
from torch._logging import trace_structured
|
||||
|
||||
trace_structured(
|
||||
|
@ -59,6 +59,7 @@ from unittest import mock
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._inductor.analysis.device_info import datasheet_tops
|
||||
from torch._inductor.runtime.hints import DeviceProperties
|
||||
from torch.utils._dtype_abbrs import dtype_abbrs
|
||||
@ -3666,3 +3667,38 @@ _unstable_customized_partition_wrapper = CUDAGraphWrapper()
|
||||
|
||||
def set_customized_partition_wrappers(wrapper: CUDAGraphWrapperType) -> None:
|
||||
_unstable_customized_partition_wrapper.wrapper = wrapper
|
||||
|
||||
|
||||
def snode_args_kwargs(snode: BaseSchedulerNode) -> tuple[list[Any], dict[str, Any]]:
|
||||
args = snode.node.inputs # type: ignore[union-attr]
|
||||
args = snode.node.fill_non_provided_args( # type: ignore[union-attr]
|
||||
[*args, *snode.node.constant_args], # type: ignore[union-attr]
|
||||
snode.node.kwargs, # type: ignore[union-attr]
|
||||
)
|
||||
kwargs = snode.node.kwargs # type: ignore[union-attr]
|
||||
flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs))
|
||||
|
||||
def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def]
|
||||
return isinstance(x, torch._inductor.ir.IRNode) and not isinstance(
|
||||
x, torch._inductor.ir.GeneratorState
|
||||
)
|
||||
|
||||
flat_args = [
|
||||
torch._inductor.ir.ir_node_to_tensor(a, guard_shape=False)
|
||||
if _is_tensor_ir(a)
|
||||
else a
|
||||
for a in flat_args
|
||||
]
|
||||
|
||||
def _tensor(size, dtype, device) -> torch.Tensor: # type: ignore[no-untyped-def]
|
||||
return torch.empty(size, dtype=dtype, device=device)
|
||||
|
||||
def to_real_tensor(e: Any) -> Any:
|
||||
if not isinstance(e, torch.Tensor):
|
||||
return e
|
||||
out = _tensor(e.size(), e.dtype, e.device)
|
||||
return out
|
||||
|
||||
flat_args = [to_real_tensor(a) for a in flat_args]
|
||||
args, kwargs = pytree.tree_unflatten(flat_args, flat_args_pytree_spec)
|
||||
return args, kwargs
|
||||
|
Reference in New Issue
Block a user