[inductor] Runtime estimations: use nccl estimator; mm only benchmark mode (#161405)

During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms.

Adding optional usage of:
- c10d.time_estimator for collectives, which is based on NCCL estimator

Benchmark mode only for matmuls, as they are highly dependent on mm backend

- The logic mostly copied from Ruisi's PRs for inductor simple_fsdp https://github.com/pytorch/pytorch/pull/157572

This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()`

Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161405
Approved by: https://github.com/eellison
This commit is contained in:
IvanKobzarev
2025-09-08 04:35:52 -07:00
committed by PyTorch MergeBot
parent 3f5993316e
commit 25c170b72e
8 changed files with 324 additions and 32 deletions

View File

@ -59,6 +59,7 @@ from unittest import mock
import sympy
import torch
import torch.utils._pytree as pytree
from torch._inductor.analysis.device_info import datasheet_tops
from torch._inductor.runtime.hints import DeviceProperties
from torch.utils._dtype_abbrs import dtype_abbrs
@ -3666,3 +3667,38 @@ _unstable_customized_partition_wrapper = CUDAGraphWrapper()
def set_customized_partition_wrappers(wrapper: CUDAGraphWrapperType) -> None:
_unstable_customized_partition_wrapper.wrapper = wrapper
def snode_args_kwargs(snode: BaseSchedulerNode) -> tuple[list[Any], dict[str, Any]]:
args = snode.node.inputs # type: ignore[union-attr]
args = snode.node.fill_non_provided_args( # type: ignore[union-attr]
[*args, *snode.node.constant_args], # type: ignore[union-attr]
snode.node.kwargs, # type: ignore[union-attr]
)
kwargs = snode.node.kwargs # type: ignore[union-attr]
flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs))
def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def]
return isinstance(x, torch._inductor.ir.IRNode) and not isinstance(
x, torch._inductor.ir.GeneratorState
)
flat_args = [
torch._inductor.ir.ir_node_to_tensor(a, guard_shape=False)
if _is_tensor_ir(a)
else a
for a in flat_args
]
def _tensor(size, dtype, device) -> torch.Tensor: # type: ignore[no-untyped-def]
return torch.empty(size, dtype=dtype, device=device)
def to_real_tensor(e: Any) -> Any:
if not isinstance(e, torch.Tensor):
return e
out = _tensor(e.size(), e.dtype, e.device)
return out
flat_args = [to_real_tensor(a) for a in flat_args]
args, kwargs = pytree.tree_unflatten(flat_args, flat_args_pytree_spec)
return args, kwargs