[inductor] Runtime estimations: use nccl estimator; mm only benchmark mode (#161405)

During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.

Adding optional usage of:
- c10d.time_estimator for collectives, which is based on NCCL estimator

Benchmark mode only for matmuls, as they are highly dependent on mm backend

- The logic mostly copied from Ruisi's PRs for inductor simple_fsdp https://github.com/pytorch/pytorch/pull/157572

This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()`

Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161405
Approved by: https://github.com/eellison
This commit is contained in:
IvanKobzarev
2025-09-08 04:35:52 -07:00
committed by PyTorch MergeBot
parent 3f5993316e
commit 25c170b72e
8 changed files with 324 additions and 32 deletions

View File

@ -259,6 +259,11 @@ class TestComputeCommReorderingMultiProc(DynamoDistributedMultiProcTestCase):
"reorder_compute_for_overlap", "reorder_compute_for_overlap",
], ],
) )
@patch.object(
torch._inductor.config,
"runtime_estimations_mms_benchmark",
False,
)
def test_reorder_compute_for_overlap(self): def test_reorder_compute_for_overlap(self):
def func(a, *, tag, ranks, group_size): def func(a, *, tag, ranks, group_size):
ar = _functional_collectives.all_reduce(a, "sum", ranks, tag) ar = _functional_collectives.all_reduce(a, "sum", ranks, tag)

View File

@ -22,8 +22,13 @@ from torch._inductor.comms import (
sink_waits_iterative, sink_waits_iterative,
) )
from torch._inductor.compile_fx import compile_fx as inductor_compile_fx from torch._inductor.compile_fx import compile_fx as inductor_compile_fx
from torch._inductor.scheduler import BaseSchedulerNode from torch._inductor.scheduler import (
from torch._inductor.utils import run_and_get_triton_code _get_mm_like_fn,
BaseSchedulerNode,
get_estimate_runtime_cache,
get_estimate_runtime_cache_key_from_snode,
)
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
from torch.distributed.distributed_c10d import GroupMember from torch.distributed.distributed_c10d import GroupMember
from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_cuda import SM80OrLater from torch.testing._internal.common_cuda import SM80OrLater
@ -1568,11 +1573,21 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
inputs = [x, w, ag_0, ag_1, ag_2, ag_3] inputs = [x, w, ag_0, ag_1, ag_2, ag_3]
correct = func(*inputs, **self.get_world_trs()) correct = func(*inputs, **self.get_world_trs())
with torch._inductor.config.patch( with (
{ torch._inductor.config.patch(
"bucket_all_gathers_fx": "all", {
"reorder_for_compute_comm_overlap": False, "bucket_all_gathers_fx": "all",
} "reorder_for_compute_comm_overlap": False,
"runtime_estimations_mms_benchmark": True,
}
),
torch._inductor.config_comms.patch(
{
"runtime_estimations_align_across_all_distributed_ranks": True,
}
),
# Clearing cache to cover runtime_estimations_mms_benchmark that use LocalCache
fresh_inductor_cache(),
): ):
compiled = torch.compile(func) compiled = torch.compile(func)
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs()) code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
@ -1801,6 +1816,17 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
def _reorder_communication_preserving_peak_memory( def _reorder_communication_preserving_peak_memory(
snodes: list[BaseSchedulerNode], snodes: list[BaseSchedulerNode],
) -> list[BaseSchedulerNode]: ) -> list[BaseSchedulerNode]:
if torch._inductor.config.runtime_estimations_mms_benchmark:
cache = get_estimate_runtime_cache()
for snode in snodes:
if _get_mm_like_fn(snode) is None:
continue
cache_key = get_estimate_runtime_cache_key_from_snode(snode)
assert cache.lookup(cache_key) is not None
if torch._inductor.config_comms.runtime_estimations_align_across_all_distributed_ranks:
for snode in snodes:
assert snode.override_estimated_runtime is not None
nonlocal node_stats nonlocal node_stats
( (
reordered_snodes, reordered_snodes,
@ -1808,20 +1834,30 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
) = _reorder_communication_preserving_peak_memory_internal(snodes) ) = _reorder_communication_preserving_peak_memory_internal(snodes)
return reordered_snodes return reordered_snodes
with torch._inductor.config.patch( with (
{ torch._inductor.config.patch(
"bucket_all_gathers_fx": "all", {
"bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2, "bucket_all_gathers_fx": "all",
"bucket_reduce_scatters_fx": "all", "bucket_all_gathers_fx_bucket_size_determinator": lambda _: 2,
"bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2, "bucket_reduce_scatters_fx": "all",
"reorder_for_compute_comm_overlap": True, "bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2,
"reorder_for_compute_comm_overlap_passes": [ "reorder_for_compute_comm_overlap": True,
sink_waits_iterative, "reorder_for_compute_comm_overlap_passes": [
_reorder_communication_preserving_peak_memory, sink_waits_iterative,
], _reorder_communication_preserving_peak_memory,
"allow_buffer_reuse": False, ],
"test_configs.track_memory_lifecycle": "error", "allow_buffer_reuse": False,
} "test_configs.track_memory_lifecycle": "error",
"runtime_estimations_mms_benchmark": True,
}
),
torch._inductor.config_comms.patch(
{
"runtime_estimations_align_across_all_distributed_ranks": True,
}
),
# Clearing cache to cover runtime_estimations_mms_benchmark that use LocalCache
fresh_inductor_cache(),
): ):
compiled = torch.compile(func, fullgraph=True) compiled = torch.compile(func, fullgraph=True)
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs()) code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())

View File

@ -1,20 +1,26 @@
import functools import functools
import logging
import math import math
from enum import IntEnum from enum import IntEnum
from typing import Optional
import sympy import sympy
import torch import torch
from . import ir from . import ir
from .utils import get_dtype_size, sympy_product from .utils import get_dtype_size, snode_args_kwargs, sympy_product
from .virtualized import V from .virtualized import V
log = logging.getLogger(__name__)
class NCCL_COLL(IntEnum): class NCCL_COLL(IntEnum):
ALL_REDUCE = 0 ALL_REDUCE = 0
ALL_GATHER = 1 ALL_GATHER = 1
REDUCE_SCATTER = 2 REDUCE_SCATTER = 2
ALL_TO_ALL = 3
class NVIDIA_GPU_TYPE(IntEnum): class NVIDIA_GPU_TYPE(IntEnum):
@ -49,6 +55,8 @@ def get_collective_type(node: ir.IRNode) -> NCCL_COLL:
return NCCL_COLL.ALL_GATHER return NCCL_COLL.ALL_GATHER
elif "reduce_scatter" in kernel_name: elif "reduce_scatter" in kernel_name:
return NCCL_COLL.REDUCE_SCATTER return NCCL_COLL.REDUCE_SCATTER
elif "torch.ops._dtensor.shard_dim_alltoall.default" in kernel_name:
return NCCL_COLL.ALL_TO_ALL
else: else:
raise ValueError(f"Unsupported collective kernel: {kernel_name}") raise ValueError(f"Unsupported collective kernel: {kernel_name}")
@ -158,9 +166,53 @@ llMaxBws = [
] ]
def estimate_nccl_collective_runtime_nccl_estimator(snode) -> Optional[float]: # type: ignore[no-untyped-def]
kernel = snode.node
assert kernel is not None
py_kernel_name = getattr(kernel, "python_kernel_name", "")
if not ("all_gather" in py_kernel_name or "reduce_scatter" in py_kernel_name):
# NCCL of version 2.27 sometimes unrecoverably fail for all_to_all, all_reduce
return None
from torch.distributed.distributed_c10d import _resolve_process_group
pg_name = kernel.constant_args[-1] # type: ignore[attr-defined]
pg = _resolve_process_group(pg_name)
rank: int = torch.distributed.get_rank(pg)
# TODO(ivankobzarev): Figure out how we can use time estimations,
# without cuda allocations.
device = torch.device(f"cuda:{rank}")
fn = eval(py_kernel_name)
args, kwargs = snode_args_kwargs(snode)
# TODO(ivankobzarev): fix out variants snode_args_kwargs
if "all_gather_into_tensor_out" in py_kernel_name:
args = args[1:] + args[0]
try:
with torch.distributed._time_estimator(
group=pg, device=device
) as time_estimator:
w = fn(*args, **kwargs)
torch.ops._c10d_functional.wait_tensor.default(w)
except Exception as e:
# NCCL estimator can fail
log.info(e)
return None
est_time_us = time_estimator.estimated_time
# -1000 constant is NCCL return in case of error during estimations.
# Observed it for all_to_all estimations.
if est_time_us < 0:
return None
est_time_ms = est_time_us / 1e3
return est_time_ms
def estimate_nccl_collective_runtime(node: ir.IRNode) -> float: def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
""" """
Returns estimated NCCL collective runtime in nanoseconds (ns). Returns estimated NCCL collective runtime in milliseconds (ms).
The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc. The following heuristics are copied from https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc.
We aim to estimate the runtime as accurately as possible. We aim to estimate the runtime as accurately as possible.
@ -220,6 +272,8 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
if coll == NCCL_COLL.ALL_REDUCE: if coll == NCCL_COLL.ALL_REDUCE:
nsteps = 2 * (nRanks - 1) nsteps = 2 * (nRanks - 1)
elif coll == NCCL_COLL.ALL_TO_ALL:
nsteps = 2 * (nRanks - 1)
elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER):
nsteps = nRanks - 1 nsteps = nRanks - 1
@ -237,7 +291,7 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
nInterSteps = 2 * nNodes nInterSteps = 2 * nNodes
else: else:
nInterSteps = 0 nInterSteps = 0
elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER): elif coll in (NCCL_COLL.REDUCE_SCATTER, NCCL_COLL.ALL_GATHER, NCCL_COLL.ALL_TO_ALL):
nInterSteps = nNodes - 1 nInterSteps = nNodes - 1
# First compute latency in us; then at the end, convert it to ns # First compute latency in us; then at the end, convert it to ns
@ -256,7 +310,9 @@ def estimate_nccl_collective_runtime(node: ir.IRNode) -> float:
# =============== final result =============== # =============== final result ===============
transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns transport_ns = tensor_storage_size_GB / bandwidth_GB_per_ns
return transport_ns + latency_ns ns = transport_ns + latency_ns
ms = ns / 1e6
return ms
################################################################################################################ ################################################################################################################

View File

@ -52,6 +52,28 @@ if TYPE_CHECKING:
from torch._inductor.scheduler import BaseSchedulerNode from torch._inductor.scheduler import BaseSchedulerNode
def align_runtime_estimations_across_all_distributed_ranks(
snodes: list[BaseSchedulerNode],
):
runtime_estimations = {}
for snode in snodes:
runtime_estimations[snode] = snode.get_estimated_runtime()
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group
world_size = dist.get_world_size()
pg = _get_default_group()
gathered_runtime_estimations: list[list[float]] = [[] for _ in range(world_size)]
dist.all_gather_object(
gathered_runtime_estimations, list(runtime_estimations.values()), pg
)
median_runtime_estimations = torch.median(
torch.tensor(gathered_runtime_estimations), dim=0
).values.tolist()
for i in range(len(snodes)):
snodes[i].override_estimated_runtime = median_runtime_estimations[i]
def sink_waits(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: def sink_waits(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
""" """
Greedily schedules waits as late as possible. Greedily schedules waits as late as possible.

View File

@ -416,6 +416,8 @@ bucket_reduce_scatters_fx_bucket_size_determinator: Optional[Callable[[int], int
# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle # for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle
estimate_op_runtime = "default" estimate_op_runtime = "default"
runtime_estimations_mms_benchmark: bool = False
# unit: GB/s, uni-directional P2P bandwidth per card # unit: GB/s, uni-directional P2P bandwidth per card
# default value is NVLink # default value is NVLink
intra_node_bw = 300 intra_node_bw = 300

View File

@ -0,0 +1,15 @@
import sys
from torch.utils._config_module import install_config_module
# Whether to use c10d._time_estimator for collectives runtime estimations.
runtime_estimations_use_nccl_lib_estimations: bool = False
# Config to enable sync of runtime estimations across distributed ranks,
# To prevent passes using this runtime estimations to make different
# decisions on different distributed ranks.
runtime_estimations_align_across_all_distributed_ranks: bool = False
# adds patch, save_config, etc
install_config_module(sys.modules[__name__])

View File

@ -26,6 +26,7 @@ import sympy
import torch import torch
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
import torch.utils._pytree as pytree
from torch._dynamo.utils import counters, dynamo_timed from torch._dynamo.utils import counters, dynamo_timed
from torch._inductor.codecache import LambdaFuture, PyCodeCache from torch._inductor.codecache import LambdaFuture, PyCodeCache
from torch._inductor.ir import TritonTemplateCallerBase from torch._inductor.ir import TritonTemplateCallerBase
@ -35,10 +36,13 @@ from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._triton import has_triton from torch.utils._triton import has_triton
from . import comms, config, dependencies, ir, metrics from . import comms, config, config_comms, dependencies, ir, metrics
from .analyze_preserves_zero_mask import can_codegen_without_upcasts from .analyze_preserves_zero_mask import can_codegen_without_upcasts
from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel
from .comm_analysis import estimate_nccl_collective_runtime from .comm_analysis import (
estimate_nccl_collective_runtime,
estimate_nccl_collective_runtime_nccl_estimator,
)
from .dependencies import Dep, MemoryDep, StarDep, WeakDep from .dependencies import Dep, MemoryDep, StarDep, WeakDep
from .exc import GPUTooOldForTriton, TritonMissing from .exc import GPUTooOldForTriton, TritonMissing
from .fx_utils import count_flops_fx from .fx_utils import count_flops_fx
@ -212,6 +216,7 @@ class BaseSchedulerNode:
min_order: int min_order: int
max_order: int max_order: int
mpi_node: MemoryPlanningInfoForNode mpi_node: MemoryPlanningInfoForNode
override_estimated_runtime: Optional[float] = None
def __init__(self, scheduler: Scheduler) -> None: def __init__(self, scheduler: Scheduler) -> None:
self.scheduler: Scheduler = scheduler self.scheduler: Scheduler = scheduler
@ -823,10 +828,16 @@ class BaseSchedulerNode:
counters["inductor"]["flop_count"] += resolved_flops counters["inductor"]["flop_count"] += resolved_flops
return resolved_flops return resolved_flops
@cache_on_self
def get_estimated_runtime(self) -> float: def get_estimated_runtime(self) -> float:
if self.override_estimated_runtime is not None:
return self.override_estimated_runtime
return self._get_estimated_runtime()
@cache_on_self
def _get_estimated_runtime(self) -> float:
""" """
Returns estimated op runtime in nanoseconds (ns) Returns estimated op runtime in milliseconds (ms)
""" """
buf = self.get_nodes()[0].get_outputs()[0] buf = self.get_nodes()[0].get_outputs()[0]
layout = buf.node.get_output_spec() layout = buf.node.get_output_spec()
@ -838,6 +849,21 @@ class BaseSchedulerNode:
if is_collective(self.node): if is_collective(self.node):
assert isinstance(self.node, ir.IRNode) assert isinstance(self.node, ir.IRNode)
try: try:
if config_comms.runtime_estimations_use_nccl_lib_estimations:
cache_key = get_estimate_runtime_cache_key_from_snode(self)
cache = get_estimate_runtime_cache()
cache_val = cache.lookup(cache_key)
if cache_val is not None:
assert isinstance(cache_val, float)
return cache_val
ms = estimate_nccl_collective_runtime_nccl_estimator(self)
if ms is None:
# NCCL estimations fail: fallback to in-tree algorithmic estimation.
ms = estimate_nccl_collective_runtime(self.node)
cache.set_value(cache_key, value=ms)
return ms
return estimate_nccl_collective_runtime(self.node) return estimate_nccl_collective_runtime(self.node)
except ValueError as e: except ValueError as e:
# We don't know how to estimate runtime for this collective, # We don't know how to estimate runtime for this collective,
@ -856,6 +882,10 @@ class BaseSchedulerNode:
# since it doesn't take extra time to get the result after the collective is completed. # since it doesn't take extra time to get the result after the collective is completed.
return 0 return 0
ret = maybe_estimate_runtime_benchmark(self)
if ret is not None:
return ret
dtype = buf.node.maybe_get_dtype() dtype = buf.node.maybe_get_dtype()
try: try:
gpu_memory_bandwidth = get_gpu_dram_gbps() gpu_memory_bandwidth = get_gpu_dram_gbps()
@ -876,7 +906,9 @@ class BaseSchedulerNode:
if flops_est == 0 or flops_est is None: if flops_est == 0 or flops_est is None:
# no flops estimate, so fall back to memory estimate # no flops estimate, so fall back to memory estimate
return self.get_read_write_buffers_sizes() / gpu_memory_bandwidth ns = self.get_read_write_buffers_sizes() / gpu_memory_bandwidth
ms = ns / 1e6
return ms
# TODO(xmfan): find a better heuristic to model FLOPS/latency relationship # TODO(xmfan): find a better heuristic to model FLOPS/latency relationship
factor = 1.0 factor = 1.0
@ -885,8 +917,10 @@ class BaseSchedulerNode:
compute_time = (factor * flops_est / gpu_flops) * 1e9 compute_time = (factor * flops_est / gpu_flops) * 1e9
transfer_time = counted_bytes / gpu_memory_bandwidth transfer_time = counted_bytes / gpu_memory_bandwidth
# Return estimated runtime in nanoseconds # Return estimated runtime in milliseconds
return max(compute_time, transfer_time) ns = max(compute_time, transfer_time)
ms = ns / 1e6
return ms
def get_template_node(self) -> Optional[ir.TemplateBuffer]: def get_template_node(self) -> Optional[ir.TemplateBuffer]:
return None return None
@ -911,6 +945,77 @@ class BaseSchedulerNode:
return prologue, template_node, epilogue return prologue, template_node, epilogue
@functools.cache
def get_estimate_runtime_cache() -> torch._inductor.codecache.LocalCache:
return torch._inductor.codecache.LocalCache()
def get_estimate_runtime_cache_key_from_snode(snode: BaseSchedulerNode) -> str:
python_kernel_name = getattr(snode.node, "python_kernel_name", "")
args = snode.node.inputs # type: ignore[union-attr]
args = snode.node.fill_non_provided_args( # type: ignore[union-attr]
[*args, *snode.node.constant_args], # type: ignore[union-attr]
snode.node.kwargs, # type: ignore[union-attr]
)
kwargs = snode.node.kwargs # type: ignore[union-attr]
flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs))
def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def]
return isinstance(x, ir.IRNode) and not isinstance(x, ir.GeneratorState)
cache_key = str(
(python_kernel_name,)
+ tuple(tuple(a.get_size()) if _is_tensor_ir(a) else None for a in flat_args)
)
return cache_key
def _get_mm_like_fn(snode: BaseSchedulerNode) -> Optional[Callable[[Any], Any]]:
if not isinstance(snode, ExternKernelSchedulerNode):
return None
mms_fns = {
"extern_kernels.mm": torch.ops.aten.mm,
"extern_kernels.bmm": torch.ops.aten.bmm,
"extern_kernels.addmm": torch.ops.aten.addmm,
}
python_kernel_name = getattr(snode.node, "python_kernel_name", "")
if python_kernel_name not in mms_fns:
return None
if not isinstance(snode.node, ir.ExternKernel):
return None
return mms_fns[python_kernel_name]
def maybe_estimate_runtime_benchmark(snode: BaseSchedulerNode) -> Optional[float]:
bench_fn = None
args_kwargs_fn = None
if config.runtime_estimations_mms_benchmark:
mm_fn = _get_mm_like_fn(snode)
if mm_fn is None:
return None
bench_fn = mm_fn
args_kwargs_fn = lambda: snode_args_kwargs(snode) # noqa: E731
else:
return None
cache_key = get_estimate_runtime_cache_key_from_snode(snode)
cache = get_estimate_runtime_cache()
cache_val = cache.lookup(cache_key)
if cache_val is not None:
assert isinstance(cache_val, float)
return cache_val
from .utils import snode_args_kwargs
args, kwargs = args_kwargs_fn()
from triton.testing import do_bench
ms = do_bench(lambda: bench_fn(*args, **kwargs))
cache.set_value(cache_key, value=ms)
return ms
class WhyNoFuse: class WhyNoFuse:
# TODO when we drop support for Python < 3.10, we can use # TODO when we drop support for Python < 3.10, we can use
# @dataclass(slots=True) instead of manually specifying __slots__. # @dataclass(slots=True) instead of manually specifying __slots__.
@ -2094,6 +2199,10 @@ class NodeUser:
_post_grad_graph_counter = itertools.count() _post_grad_graph_counter = itertools.count()
def used_non_deterministic_runtime_estimations() -> bool:
return config.runtime_estimations_mms_benchmark
class Scheduler: class Scheduler:
""" """
A Scheduler is a graph of BaseSchedulerNodes. It is responsible for A Scheduler is a graph of BaseSchedulerNodes. It is responsible for
@ -2214,6 +2323,17 @@ class Scheduler:
assign_memory_planning_info_for_scheduler_buffers( assign_memory_planning_info_for_scheduler_buffers(
self.nodes, self.name_to_buf self.nodes, self.name_to_buf
) )
if (
used_non_deterministic_runtime_estimations()
and config_comms.runtime_estimations_align_across_all_distributed_ranks
):
from .comms import (
align_runtime_estimations_across_all_distributed_ranks,
)
align_runtime_estimations_across_all_distributed_ranks(self.nodes)
from torch._logging import trace_structured from torch._logging import trace_structured
trace_structured( trace_structured(

View File

@ -59,6 +59,7 @@ from unittest import mock
import sympy import sympy
import torch import torch
import torch.utils._pytree as pytree
from torch._inductor.analysis.device_info import datasheet_tops from torch._inductor.analysis.device_info import datasheet_tops
from torch._inductor.runtime.hints import DeviceProperties from torch._inductor.runtime.hints import DeviceProperties
from torch.utils._dtype_abbrs import dtype_abbrs from torch.utils._dtype_abbrs import dtype_abbrs
@ -3666,3 +3667,38 @@ _unstable_customized_partition_wrapper = CUDAGraphWrapper()
def set_customized_partition_wrappers(wrapper: CUDAGraphWrapperType) -> None: def set_customized_partition_wrappers(wrapper: CUDAGraphWrapperType) -> None:
_unstable_customized_partition_wrapper.wrapper = wrapper _unstable_customized_partition_wrapper.wrapper = wrapper
def snode_args_kwargs(snode: BaseSchedulerNode) -> tuple[list[Any], dict[str, Any]]:
args = snode.node.inputs # type: ignore[union-attr]
args = snode.node.fill_non_provided_args( # type: ignore[union-attr]
[*args, *snode.node.constant_args], # type: ignore[union-attr]
snode.node.kwargs, # type: ignore[union-attr]
)
kwargs = snode.node.kwargs # type: ignore[union-attr]
flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs))
def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def]
return isinstance(x, torch._inductor.ir.IRNode) and not isinstance(
x, torch._inductor.ir.GeneratorState
)
flat_args = [
torch._inductor.ir.ir_node_to_tensor(a, guard_shape=False)
if _is_tensor_ir(a)
else a
for a in flat_args
]
def _tensor(size, dtype, device) -> torch.Tensor: # type: ignore[no-untyped-def]
return torch.empty(size, dtype=dtype, device=device)
def to_real_tensor(e: Any) -> Any:
if not isinstance(e, torch.Tensor):
return e
out = _tensor(e.size(), e.dtype, e.device)
return out
flat_args = [to_real_tensor(a) for a in flat_args]
args, kwargs = pytree.tree_unflatten(flat_args, flat_args_pytree_spec)
return args, kwargs