mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Runtime estimations: use nccl estimator; mm only benchmark mode (#161405)
During comms reordering , sink wait iterative observed previous runtime estimations pretty off for collectives and mms. Adding optional usage of: - c10d.time_estimator for collectives, which is based on NCCL estimator Benchmark mode only for matmuls, as they are highly dependent on mm backend - The logic mostly copied from Ruisi's PRs for inductor simple_fsdp https://github.com/pytorch/pytorch/pull/157572 This estimations corrections are in default `BaseSchedulerNode.estimate_runtime()` Differential Revision: [D81152294](https://our.internmc.facebook.com/intern/diff/D81152294) Pull Request resolved: https://github.com/pytorch/pytorch/pull/161405 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
3f5993316e
commit
25c170b72e
@ -59,6 +59,7 @@ from unittest import mock
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._inductor.analysis.device_info import datasheet_tops
|
||||
from torch._inductor.runtime.hints import DeviceProperties
|
||||
from torch.utils._dtype_abbrs import dtype_abbrs
|
||||
@ -3666,3 +3667,38 @@ _unstable_customized_partition_wrapper = CUDAGraphWrapper()
|
||||
|
||||
def set_customized_partition_wrappers(wrapper: CUDAGraphWrapperType) -> None:
|
||||
_unstable_customized_partition_wrapper.wrapper = wrapper
|
||||
|
||||
|
||||
def snode_args_kwargs(snode: BaseSchedulerNode) -> tuple[list[Any], dict[str, Any]]:
|
||||
args = snode.node.inputs # type: ignore[union-attr]
|
||||
args = snode.node.fill_non_provided_args( # type: ignore[union-attr]
|
||||
[*args, *snode.node.constant_args], # type: ignore[union-attr]
|
||||
snode.node.kwargs, # type: ignore[union-attr]
|
||||
)
|
||||
kwargs = snode.node.kwargs # type: ignore[union-attr]
|
||||
flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs))
|
||||
|
||||
def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def]
|
||||
return isinstance(x, torch._inductor.ir.IRNode) and not isinstance(
|
||||
x, torch._inductor.ir.GeneratorState
|
||||
)
|
||||
|
||||
flat_args = [
|
||||
torch._inductor.ir.ir_node_to_tensor(a, guard_shape=False)
|
||||
if _is_tensor_ir(a)
|
||||
else a
|
||||
for a in flat_args
|
||||
]
|
||||
|
||||
def _tensor(size, dtype, device) -> torch.Tensor: # type: ignore[no-untyped-def]
|
||||
return torch.empty(size, dtype=dtype, device=device)
|
||||
|
||||
def to_real_tensor(e: Any) -> Any:
|
||||
if not isinstance(e, torch.Tensor):
|
||||
return e
|
||||
out = _tensor(e.size(), e.dtype, e.device)
|
||||
return out
|
||||
|
||||
flat_args = [to_real_tensor(a) for a in flat_args]
|
||||
args, kwargs = pytree.tree_unflatten(flat_args, flat_args_pytree_spec)
|
||||
return args, kwargs
|
||||
|
Reference in New Issue
Block a user