Compare commits

...

4 Commits

Author SHA1 Message Date
40c83dfe91 cherry pick PR 160967 Fix bucketing introducing cycles 2025-08-19 13:45:59 -07:00
359ade0e92 [NOT_FOR_LAND] Ruisis real benchmark estimations
ghstack-source-id: 10255cfb7ac1159d7d1294154a3570cb4ac22b4c
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160991
2025-08-19 13:45:59 -07:00
83505666a4 [inductor] Unsafe reordering collectives; Limit by runtime estimations
and configurated num of gemms.

ghstack-source-id: 3d6059df090f33061f6b69a4a168dffd4c9846e1
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160527
2025-08-19 13:45:59 -07:00
324ba7e254 [inductor] Estimate peak memory allocfree and applying to reordering
collectives

ghstack-source-id: d0eefb95966bb2eb2188b2c6e48dbcb8acea3247
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160113
2025-08-19 13:45:59 -07:00
7 changed files with 1509 additions and 262 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,112 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union
from torch._logging import trace_structured
from .memory import estimate_peak_memory_allocfree
if TYPE_CHECKING:
from torch.utils._ordered_set import OrderedSet
from .memory import FreeableInputBuffer, SNodeMemory
from .scheduler import BaseSchedulerNode, SchedulerBuffer
def _debug_iterative_memory_recompute(
candidate: BaseSchedulerNode,
gns: list[BaseSchedulerNode],
group_names: str,
snodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
peak_memory: int,
iter_curr_memory: dict[BaseSchedulerNode, tuple[int, int]],
snodes_allocfree: dict[BaseSchedulerNode, SNodeMemory],
tlparse_name: str,
gn_to_bufs_last_use: dict[
BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]]
],
) -> bool:
iterative_recompute_error = False
candidate_allocfree = snodes_allocfree[candidate]
est_peak_memory, snodes_curr_memory, snodes_allocfree, _ = (
estimate_peak_memory_allocfree(
snodes, name_to_freeable_input_buf, graph_outputs
)
)
est_curr_memory = dict(zip(snodes, snodes_curr_memory))
iter_cm = iter_curr_memory[candidate]
new_cm = est_curr_memory[candidate]
log = ""
if est_peak_memory > peak_memory:
log = "ITERATIVE PEAK DOES NOT MATCH"
iterative_recompute_error = True
if iter_cm != new_cm:
log = "ITERATIVE CURR MEMORY CANDIDATE DOES NOT MATCH"
iterative_recompute_error = True
for i, gn in enumerate(gns):
iter_gnm = iter_curr_memory[gn]
new_gnm = est_curr_memory[gn]
if iter_gnm != new_gnm:
log = f"ITERATIVE GN CURR MEMORY DOES NOT MATCH:{gn.get_name()}"
iterative_recompute_error = True
if iterative_recompute_error:
log += (
f"\nCANDIDATE:{candidate.get_name()}"
f"\nGROUP:{group_names}"
f"\nPEAK_MEMORY_BEFORE:{peak_memory}"
f"\nPEAK_MEMORY_AFTER_SWAP:{est_peak_memory}"
f"\nCANDIDATE:{candidate.debug_str()}"
f"\nCANDIDATE_ITER_CURR_MEMORY:{iter_cm}"
f"\nCANDIDATE_NEW__CURR_MEMORY:{new_cm}"
f"\nCANDIDATE_ITER_ALLOCFREE:{candidate_allocfree}"
f"\nCANDIDATE_NEW_ALLOCFREE:{snodes_allocfree[candidate]}"
)
peak_log = ""
for i, (pre, post) in enumerate(snodes_curr_memory):
if est_peak_memory == pre:
n = snodes[i]
peak_log = (
f"\nNEW_PEAK:{est_peak_memory}(BASE:{peak_memory})"
f" @ SNODE[{i}/{len(snodes)}]:{n.get_name()} {n.debug_str()}"
)
break
group_log = ""
for i, gn in enumerate(gns):
iter_gnm = iter_curr_memory[gn]
new_gnm = est_curr_memory[gn]
group_log += (
f"\nGROUP_NODE[{i}]:{gn.debug_str()}"
f"\nGROUP_NODE[{i}] ITER_GNM[{gn.get_name()}]:{iter_gnm}"
f"\nGROUP_NODE[{i}] ESTM_GNM[{gn.get_name()}]:{new_gnm}"
f"\nGROUP_NODE[{i}] ITER_allocfree:{snodes_allocfree[gn]}"
f"\nGROUP_NODE[{i}] ESTM_allocfree:{snodes_allocfree[gn]}"
)
log += peak_log
log += group_log
log += f"\nGN_TO_BUFS_LAST_USE:{gn_to_bufs_last_use}"
log += "\n\n".join(
[
(
f"\nSNODE[{i}]\n{n.debug_str()}"
f"\nITER_cur_mem:{iter_curr_memory[n]}"
f"\nESTM_cur_mem:{est_curr_memory[n]}"
f"\nITER_allocfree:{snodes_allocfree[n]}"
f"\nESTM_allocfree:{snodes_allocfree[n]}"
)
for i, n in enumerate(snodes)
]
)
tname = f"{tlparse_name}_ITERATIVE_RECOMPUTE_ERROR"
print(f"{tname}:\n{log}")
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": tname,
"encoding": "string",
},
payload_fn=lambda: log,
)
return iterative_recompute_error

View File

@ -389,6 +389,22 @@ reorder_prefetch_limit: Optional[int] = None
# enable operator reordering for peak memory optimization
reorder_for_peak_memory = True
reorder_iterative_debug_memory_recompute: bool = False
reorder_iterative_debug_limit_to_reorder: Optional[int] = (
None
if (env_str := os.getenv("PYTORCH_REORDER_COLLECTIVES_LIMIT")) is None
else int(env_str)
)
sink_waits_iterative_debug_limit_to_sink: Optional[int] = (
None if (env_str := os.getenv("PYTORCH_SINK_WAITS_LIMIT")) is None else int(env_str)
)
reorder_iterative_swapped_gemm_like_limit: Optional[int] = None
sink_waits_iterative_swapped_gemm_like_limit: Optional[int] = None
reorder_iterative_unsafe_collectives_reorder: bool = False
sink_waits_iterative_unsafe_collectives_reorder: bool = False
bucket_all_gathers_fx: Literal["none", "all", "only_fsdp"] = "none"
# By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used
bucket_all_gathers_fx_bucket_size_determinator: Optional[Callable[[int], int]] = None
@ -399,6 +415,8 @@ bucket_reduce_scatters_fx_bucket_size_determinator: Optional[Callable[[int], int
None
)
estimate_runtime_benchmark: bool = False
# runtime estimation function for ops
# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle
estimate_op_runtime = "default"

View File

@ -0,0 +1,521 @@
import statistics
import time
from collections import defaultdict
from collections.abc import Sequence
from functools import reduce
from typing import Any, Callable, cast, Union
from sympy import Expr
import torch
import torch.distributed as c10d
import torch.utils._pytree as pytree
from torch._inductor.codecache import PyCodeCache
from torch.utils._mode_utils import no_dispatch
from . import ir
from .scheduler import BaseSchedulerNode, ExternKernelSchedulerNode, FusedSchedulerNode
from .utils import contains_collective, contains_wait
kernel_name_to_comm_op: dict[str, Callable[..., Any]] = {
"torch.ops._c10d_functional.all_gather_into_tensor.default": c10d.all_gather_into_tensor,
"torch.ops._c10d_functional.reduce_scatter_tensor.default": c10d.reduce_scatter_tensor,
"torch.ops._c10d_functional.all_gather_into_tensor_out.default": c10d.all_gather_into_tensor,
}
kernel_name_to_comp_op: dict[str, Callable[..., Any]] = {
"extern_kernels.mm": torch.ops.aten.mm,
"extern_kernels.bmm": torch.ops.aten.bmm,
"extern_kernels.addmm": torch.ops.aten.addmm,
}
OpType = Union[
torch._ops.OpOverload,
torch._ops.OpOverloadPacket,
torch._ops.HigherOrderOperator,
]
def _convert_str_to_op(full_name: str) -> OpType:
module_names = full_name.split(".")
target_kernel = torch
for module_name in module_names:
target_kernel = getattr(target_kernel, module_name)
assert isinstance(
target_kernel,
(
torch._ops.OpOverload,
torch._ops.OpOverloadPacket,
torch._ops.HigherOrderOperator,
),
)
return target_kernel
def _create_real_tensor(
size: Union[torch.Size, Sequence[Expr]],
dtype: torch.dtype,
device: Union[torch.device, None],
) -> torch.Tensor:
if dtype.is_floating_point:
out = torch.randn(size, dtype=dtype).to(device)
else:
out = torch.ones(size, dtype=dtype).to(device)
return out
def get_data_size(size):
return reduce(lambda x, y: x * y, size)
class CommPerfCache:
def __init__(self, threshold=3000):
self.cache = {}
self.threshold = threshold
self.ag_max_inp_size = -1
self.rs_max_out_size = -1
def _calculate_distance(self, size1, size2):
word_size1 = get_data_size(size1)
word_size2 = get_data_size(size2)
return abs(word_size1 - word_size2)
def _update_max_size(self):
for k in self.cache.keys():
if k[2] == "torch.ops._c10d_functional.all_gather_into_tensor.default":
self.ag_max_inp_size = max(
self.ag_max_inp_size, get_data_size(list(k[0]))
)
if k[2] == "torch.ops._c10d_functional.reduce_scatter_tensor.default":
self.rs_max_out_size = max(
self.rs_max_out_size, get_data_size(list(k[1]))
)
def add_comm_time(self, tensor_input_size, tensor_output_size, comm_func, value):
key = (tuple(tensor_input_size), tuple(tensor_output_size), comm_func)
self.cache[key] = value
if comm_func == "torch.ops._c10d_functional.all_gather_into_tensor.default":
self.ag_max_inp_size = max(
self.ag_max_inp_size, get_data_size(tensor_input_size)
)
if comm_func == "torch.ops._c10d_functional.reduce_scatter_tensor.default":
self.rs_max_out_size = max(
self.rs_max_out_size, get_data_size(tensor_output_size)
)
def get_comm_time(
self, tensor_input_size, tensor_output_size, comm_func, calibrated=False
):
key = (tuple(tensor_input_size), tuple(tensor_output_size), comm_func)
if key in self.cache:
return self.cache[key]
if calibrated:
threshold = float("inf")
else:
threshold = self.threshold
closest_key = None
closest_distance = float("inf")
for k in self.cache.keys():
if k[2] == comm_func:
input_distance = self._calculate_distance(tensor_input_size, k[0])
output_distance = self._calculate_distance(tensor_output_size, k[1])
total_distance = input_distance + output_distance
if (
input_distance <= threshold
and output_distance <= threshold
and total_distance < closest_distance
):
closest_distance = total_distance
closest_key = k
if closest_key:
return self.cache[closest_key]
return None
class CompPerfCache:
def __init__(self):
self.triton_cache = {}
self.extern_cache = {}
def add_triton_runtime(self, triton_code, runtime):
self.triton_cache[triton_code] = runtime
def add_extern_runtime(self, kernel_args, runtime):
self.extern_cache[kernel_args] = runtime
def get_runtime_by_triton(self, triton_code):
return self.triton_cache.get(triton_code, None)
def get_runtime_by_extern(self, kernel_args):
return self.extern_cache.get(kernel_args, None)
def estimate_runtime(
sched: "scheduler.Scheduler",
snodes: list["scheduler.BaseSchedulerNode"],
verbose: bool = False,
) -> dict[str, dict[str, float]]:
# The runtimes dict containts the estimated runtime of each node
# For each node, the key is the node name, and the value is a dict of {"COMM": comm_time, "COMP": comp_time}
# If the node is a collective node, the value is {"COMM": comm_time, "COMP": 0.}
# If the node is a compute node, the value is {"COMM": 0., "COMP": comp_time}
# If the node is a wait node, the value is {"COMM": 0., "COMP": 0.}
runtimes = {}
# Get the runtime of each rank
mults = defaultdict(list)
for _, snode in enumerate(snodes):
# if not contains_collective(snode):
# continue
try:
print(f"XXX ESTIMATE SNODE:{snode.debug_str_short()}")
from .comms import estimate_op_runtime
def_est = estimate_op_runtime(snode)
print(f"XXX DEFAULT_EST:{def_est}")
new_est = _estimate_op_runtime(sched, snode, verbose=verbose)
runtimes[snode.get_name()] = new_est
print(f"XXX NEW_EST:{new_est}")
new_est_f = new_est["COMM"] + new_est["COMP"]
mult = 1.0
if def_est != 0.0:
mult = new_est_f / def_est
key = "other"
if contains_collective(snode):
key = "collective"
elif isinstance(snode, ExternKernelSchedulerNode):
key = "extern"
if def_est != 0.0:
mults[key].append(mult)
print(f"XXX MULT[{key}]:{mult}")
except Exception as e:
import traceback
print(f"XXX ERROR_ESTIMATION {e} of snode:{snode.get_name()}")
traceback.print_exception(type(e), e, e.__traceback__)
for key, mults in mults.items():
print(f"XXX MULTS_AVG[{key}]:{sum(mults) / len(mults)}")
# If world_size is larger than 1, gather runtimes from each rank and sync the median runtime across ranks
world_size = c10d.distributed_c10d.get_world_size()
median_runtimes = runtimes
if world_size > 1:
gathered_runtimes: list[dict[str, dict[str, float]]] = [
{} for _ in range(world_size)
]
c10d.all_gather_object(
gathered_runtimes,
runtimes,
group=c10d.distributed_c10d._get_default_group(),
)
assert [len(gathered_runtime) > 0 for gathered_runtime in gathered_runtimes]
for key in list(runtimes.keys()):
comm_value = [
gathered_runtime[key]["COMM"] for gathered_runtime in gathered_runtimes
]
comp_value = [
gathered_runtime[key]["COMP"] for gathered_runtime in gathered_runtimes
]
median_runtimes[key] = {
"COMM": statistics.median(comm_value),
"COMP": statistics.median(comp_value),
}
return median_runtimes
def _estimate_op_runtime(
sched: "scheduler.Scheduler",
snode: "scheduler.BaseSchedulerNode",
verbose: bool = False,
) -> dict[str, float]:
runtime = {"COMM": 0.0, "COMP": 0.0}
if contains_collective(snode):
# benchmark communication node runtime
runtime["COMM"] = estimate_comm_time(sched, snode, verbose=verbose)
return runtime
elif contains_wait(snode):
# wait node
return runtime
runtime["COMP"] = estimate_comp_time(sched, snode, verbose=verbose)
return runtime
def estimate_comm_time(
sched: "scheduler.Scheduler",
snode: Union[tuple["ir.IRNode"], "scheduler.BaseSchedulerNode"],
estimate: bool = False,
verbose: bool = False,
comm_cache: "CommPerfCache" = None,
) -> float:
# TODO (ruisizhang123): add more types of collective communication.
# Currently, it only supports all_gather and reduce_scatter
# estimate set to True: return NCCL's estimated comm time (https://github.com/pytorch/pytorch/pull/149343)
# estimate set to False: run the collective communication and return the actual comm time
from .scheduler import BaseSchedulerNode
# for node with collective kernel estimation
if isinstance(snode, BaseSchedulerNode):
kernel = snode.node
py_kernel_name = (getattr(kernel, "python_kernel_name", ""),)
assert hasattr(kernel.inputs[0], "data")
if (
"torch.ops._c10d_functional.all_gather_into_tensor_out.default"
in py_kernel_name
):
input_layout = kernel.inputs[0].layout
output_layout = kernel.inputs[1].layout
tensor_input = _create_real_tensor(
input_layout.size, input_layout.dtype, input_layout.device
)
tensor_output = _create_real_tensor(
output_layout.size, output_layout.dtype, output_layout.device
)
else:
inputs = kernel.inputs[0]
output_layout = kernel.layout
tensor_input = _create_real_tensor(
inputs.data.get_size(),
inputs.data.get_dtype(),
inputs.data.get_device(),
)
tensor_output = _create_real_tensor(
output_layout.size, output_layout.dtype, output_layout.device
)
elif isinstance(snode, tuple):
node, inputs, outputs = snode
kernel = node.inputs[0]
tensor_input = _create_real_tensor(
inputs.layout.size, inputs.layout.dtype, inputs.layout.device
)
tensor_output = _create_real_tensor(
outputs.layout.size, outputs.layout.dtype, outputs.layout.device
)
else:
assert False, "NYI"
comm_time = benchmark_comm_func(
tensor_input,
tensor_output,
getattr(kernel, "python_kernel_name", ""),
comm_cache,
estimate,
verbose=verbose,
)
return comm_time
def benchmark_comm_func(
tensor_input, tensor_output, comm_func_name, comm_cache, estimate, verbose=False
):
rank = c10d.distributed_c10d.get_rank()
device = torch.device(f"cuda:{rank:d}")
process_group = c10d.distributed_c10d._get_default_group()
if (
comm_func_name == "torch.ops._c10d_functional.all_gather_into_tensor.default"
or comm_func_name
== "torch.ops._c10d_functional.all_gather_into_tensor_out.default"
):
input_args = {"input_tensor": tensor_input, "output_tensor": tensor_output}
elif comm_func_name == "torch.ops._c10d_functional.reduce_scatter_tensor.default":
input_args = {"input": tensor_input, "output": tensor_output}
if comm_cache is not None:
comm_time = comm_cache.get_comm_time(
tensor_input.size(), tensor_output.size(), comm_func_name
)
if comm_time is not None:
return comm_time
comm_func = kernel_name_to_comm_op.get(comm_func_name, None)
assert comm_func is not None, f"Unsupported comm op {comm_func}"
if estimate:
with c10d._time_estimator(group=process_group, device=device) as cm:
comm_func(**input_args)
comm_time = cm.estimated_time
else:
torch.cuda.synchronize()
comm_func(**input_args)
nruns = 2
comm_time_ms = 0
for _ in range(nruns):
c10d.barrier()
torch.cuda.synchronize()
start_evt = torch.cuda.Event(enable_timing=True)
end_evt = torch.cuda.Event(enable_timing=True)
start_evt.record()
comm_func(**input_args)
end_evt.record()
end_evt.synchronize()
current_run_time = start_evt.elapsed_time(end_evt)
comm_time_ms += current_run_time
comm_time_ns = comm_time_ms / nruns * 1e6
if verbose:
print(
f"[COMM Node estimate:{estimate}]",
getattr(comm_func_name, "python_kernel_name", ""),
"time",
comm_time_ns,
)
if comm_cache is not None:
comm_cache.add_comm_time(
tensor_input.size(), tensor_output.size(), comm_func_name, comm_time_ns
)
del tensor_input, tensor_output
return comm_time_ns
def estimate_comp_time(
sched: "scheduler.Scheduler",
snode: "scheduler.BaseSchedulerNode",
verbose: bool = False,
comp_cache: "CompPerfCache" = None,
) -> float:
# Estimate the runtime of a compute node
# FusedSchedulerNode & BaseSchedulerNode: get the generated triton code and use `do_bench` mode to obtain runtime
# ExternKernelSchedulerNode: get python kernel and run the kernel to obtain runtime
device = cast(torch.device, snode.get_device())
if isinstance(snode, FusedSchedulerNode):
node_list = snode.snodes
elif isinstance(snode, ExternKernelSchedulerNode):
time = benchmark_extern_node(snode.node, comp_cache)
if verbose and time != 0:
print("[COMP Node] EXTERN", "time", time)
return time
elif isinstance(snode, BaseSchedulerNode):
node_list = [snode]
else:
raise ValueError(f"Unsupported snode type {type(snode)}")
# this part code is from triton's bench code:
# https://github.com/pytorch/pytorch/blob/85111cd165f108ffabb4a90083d59d7a867ebd9f/torch/_inductor/codegen/triton.py#L4234
src_code = sched.generate_kernel_code_from_nodes(node_list, benchmark_kernel=True)
if comp_cache is not None:
time = comp_cache.get_runtime_by_triton(src_code)
if time is not None:
return time
module = PyCodeCache.load(src_code)
time_ms, _ = sched.benchmark_codegened_module(module=module, device=device)
time_ns = time_ms * 1e6
if comp_cache is not None:
comp_cache.add_triton_runtime(src_code, time_ns)
if verbose and time_ns != 0:
print("[COMP Node] BASE/FUSE", "time", time_ns)
return time_ns
def benchmark_extern_node(
node: ir._NodeOrNodes, comp_cache: "CompPerfCache" = None
) -> float:
if isinstance(node, ir.MultiOutput):
return 0
python_kernel_name = getattr(node, "python_kernel_name", "")
if python_kernel_name.startswith("extern_kernels"):
func = kernel_name_to_comp_op.get(python_kernel_name, None)
elif python_kernel_name.startswith("torch.ops.aten"):
func = _convert_str_to_op(python_kernel_name)
else:
func = None
if func is None:
return 0
else:
if isinstance(node, ir.FallbackKernel):
args = node.export_extern_kernel_node()
ordered_kwargs_length = len(list(node.ordered_kwargs_for_cpp_kernel))
if ordered_kwargs_length > 0:
args, ordered_kwargs = (
args[: -1 * ordered_kwargs_length],
args[-1 * ordered_kwargs_length :],
)
ordered_kwargs = dict(
zip(node.ordered_kwargs_for_cpp_kernel, ordered_kwargs)
)
node.kwargs.update(ordered_kwargs)
elif isinstance(node, ir.ExternKernel):
args = node.inputs
args = node.fill_non_provided_args(
[*args, *node.constant_args], node.kwargs
)
else:
raise ValueError(f"Unsupported node type {type(node)}")
flat_args, args_property = pytree.tree_flatten((args, node.kwargs))
if comp_cache is not None:
flat_args_info = [
input.get_size()
if isinstance(input, ir.IRNode)
and not isinstance(input, ir.GeneratorState)
else input
for input in flat_args
]
flat_args_info = (
tuple(a) if isinstance(a, list) else a for a in flat_args_info
)
kernel_flat_args = (python_kernel_name,) + tuple(flat_args_info)
op_time = comp_cache.get_runtime_by_extern(kernel_flat_args)
if op_time is not None:
return op_time
flat_args = [
ir.ir_node_to_tensor(input, guard_shape=False)
if isinstance(input, ir.IRNode) and not isinstance(input, ir.GeneratorState)
else input
for input in flat_args
]
# this part code is from https://fburl.com/3xpyoq93
with no_dispatch():
def to_real_tensor(e: Any) -> Any:
if not isinstance(e, torch.Tensor):
return e
out = _create_real_tensor(e.size(), e.dtype, e.device)
if e.is_sparse:
out._coalesced_(e.is_coalesced())
return out
flat_args = [to_real_tensor(a) for a in flat_args]
args, kwargs = pytree.tree_unflatten(flat_args, args_property)
func(*args, **kwargs)
num_iters = 3
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
cpu_start = time.time()
start_event.record(torch.cuda.current_stream())
for _ in range(num_iters):
func(*args, **kwargs)
end_event.record(torch.cuda.current_stream())
cpu_end = time.time()
torch.cuda.synchronize()
cpu_time = cpu_end - cpu_start
total_op_time = start_event.elapsed_time(end_event) - cpu_time
mean_op_time_ms = total_op_time / num_iters
del flat_args
mean_op_time_ns = mean_op_time_ms * 1e6
if comp_cache is not None:
comp_cache.add_extern_runtime(kernel_flat_args, mean_op_time_ns)
return mean_op_time_ns

View File

@ -1,3 +1,4 @@
import collections
import logging
from collections import defaultdict
from typing import Any, Callable, Optional
@ -42,6 +43,7 @@ def bucket_all_gather(
ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx)
if len(ag_buckets) == 0:
return
merge_all_gather(gm, ag_buckets)
@ -86,6 +88,42 @@ def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool:
return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]) # type: ignore[arg-type]
def collect_node_descendants(
graph: torch.fx.Graph,
) -> dict[torch.fx.Node, OrderedSet[torch.fx.Node]]:
"""
Collects the descendants of each node in the graph.
Args:
graph (torch.fx.Graph): The graph to collect descendants from.
Returns:
dict[torch.fx.Node, OrderedSet[torch.fx.Node]]: A dictionary mapping each node to its descendants.
"""
node_descendants: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = (
collections.defaultdict(OrderedSet)
)
outdegree = collections.defaultdict(int)
queue = []
for node in graph.nodes:
n_outdegree = len(node.users)
if n_outdegree == 0:
queue.append(node)
else:
outdegree[node] = len(node.users)
while queue:
node = queue.pop()
for input_node in node.all_input_nodes:
node_descendants[input_node] |= node_descendants[node]
node_descendants[input_node].add(node)
outdegree[input_node] -= 1
if outdegree[input_node] == 0:
queue.append(input_node)
return node_descendants
def greedy_bucket_collective_by_mb(
gm: torch.fx.GraphModule,
bucket_cap_mb_by_bucket_idx: Callable[[int], float],
@ -93,59 +131,38 @@ def greedy_bucket_collective_by_mb(
node_group_key: Callable[[torch.fx.Node], Any],
filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> list[list[torch.fx.Node]]:
"""
Bucketing adjacent collectives with equal node_group_key.
We can not bucket non adjacent collectives,
as this will effectively change the order of collectives.
Reordering can lead to different order on different ranks.
"""
g = gm.graph
found_candidates = False
for node in g.nodes:
if filter_node(node):
found_candidates = True
break
if not found_candidates:
if not gm.graph.find_nodes(
op="call_function", target=torch.ops._c10d_functional.wait_tensor.default
):
return []
nodes_successors: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = defaultdict(
OrderedSet
)
nodes_groups: list[list[torch.fx.Node]] = []
cur_group: list[torch.fx.Node] = []
cur_group_key = None
g = gm.graph
nodes_groups: dict[Any, list[torch.fx.Node]] = defaultdict(list)
# TODO: pearce kelly algorithm for detecting cycles
node_descendents = collect_node_descendants(gm.graph)
for node in g.nodes:
for n, successors in nodes_successors.items():
if any(arg in successors for arg in node.args):
successors.add(n)
if is_wait_tensor(node) and filter_node(node.args[0]):
if (filter_wait_node is None) or filter_wait_node(node):
coll_node = node.args[0]
group_key = node_group_key(coll_node)
if group_key == cur_group_key:
cur_group.append(coll_node)
else:
if len(cur_group) > 1:
nodes_groups.append(cur_group)
cur_group = [coll_node]
cur_group_key = group_key
if len(cur_group) > 1:
nodes_groups.append(cur_group)
nodes_groups[group_key].append(coll_node)
buckets: list[list[torch.fx.Node]] = []
for nodes in nodes_groups:
for nodes in nodes_groups.values():
cur_bucket: list[torch.fx.Node] = []
cur_bucket_successors: OrderedSet[torch.fx.Node] = OrderedSet()
cur_bucket_descendents: OrderedSet[torch.fx.Node] = OrderedSet()
cur_bucket_size_bytes: int = 0
cur_bucket_id: int = 0
bucket_size_bytes = int(
bucket_cap_mb_by_bucket_idx(cur_bucket_id) * 1024 * 1024
)
for node in nodes:
if node in cur_bucket_successors:
# We cannot bucket successors with the node
if node in cur_bucket_descendents:
# if there is a path from node to the current bucket, we cannot horizontally fuse (bucket)
continue
assert "val" in node.meta
n_val = node.meta["val"]
@ -160,10 +177,10 @@ def greedy_bucket_collective_by_mb(
cur_bucket = []
cur_bucket_size_bytes = 0
cur_bucket_id += 1
cur_bucket_successors = OrderedSet()
cur_bucket_descendents = OrderedSet()
cur_bucket_size_bytes += size_bytes
cur_bucket.append(node)
cur_bucket_successors |= nodes_successors[node]
cur_bucket_descendents |= node_descendents[node]
if len(cur_bucket) > 1:
buckets.append(cur_bucket)
return buckets
@ -180,7 +197,7 @@ def bucket_all_gather_by_mb(
Args:
gm (torch.fx.GraphModule): GraphModule where to bucket all_gathers.
bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
bucket_cap_mb_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow
to specify different sizes of the buckets at the start,
as first all_gather is usually exposed. Interface of bucket_cap_mb_by_bucket_idx
@ -218,14 +235,14 @@ def bucket_reduce_scatter_by_mb(
Args:
gm (torch.fx.GraphModule): GraphModule where to bucket reduce_scatters.
bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
bucket_cap_mb_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
in megabytes by bucket idx. The idea of `bucket_cap_mb_by_bucket_idx` is to allow
to specify different sizes of the buckets.
filter_wait_node (Optional[Callable[[torch.fx.Node], bool]]): If specified,
only reduce_scatter nodes with wait_node that satisfy `filter_wait_node` will be bucketed.
Returns:
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes.
list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes.
"""
def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
@ -259,6 +276,8 @@ def reduce_scatter_merge_fn_to_trace(
new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten()
# TODO - either use torch.cat or make sure inductor foreach codegen
# fires more reliably
new_rs_out = torch.ops.c10d_functional.wait_tensor(
torch.ops._c10d_functional.reduce_scatter_tensor.default(
new_rs_in, reduce_op, group_size, group_name
@ -347,7 +366,13 @@ def _trace(fn, inps) -> torch.fx.GraphModule: # type: ignore[no-untyped-def]
fake_mode = detect_fake_mode(inps)
assert fake_mode is not None
with fake_mode, enable_python_dispatcher():
return make_fx(fn)(*inps)
out = make_fx(fn)(*inps)
for node in out.graph.find_nodes(
op="call_function", target=torch.ops.aten.detach.default
):
node.replace_all_uses_with(node.args[0])
out.graph.erase_node(node)
return out
def _insert_fn_trace_before_node( # type: ignore[no-untyped-def]
@ -488,8 +513,6 @@ def merge_all_gather(
)
n_buckets = len(ag_buckets)
ag_node_to_pre_nodes = defaultdict(list)
ag_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
ag_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
for bucket_idx, ag_bucket in enumerate(ag_buckets):
@ -508,13 +531,6 @@ def merge_all_gather(
and ag_node.meta["val"].dtype == dtype
)
ag_node_in = ag_node.args[0]
if (
ag_node_in.op == "call_function" # type: ignore[union-attr]
and ag_node_in.target == torch.ops.prims.convert_element_type.default # type: ignore[union-attr]
and len(ag_node_in.users) == 1 # type: ignore[union-attr]
):
ag_node_to_pre_nodes[ag_node].append(ag_node_in)
ag_node_in = ag_node_in.args[0] # type: ignore[union-attr]
ag_ins[bucket_idx].append(ag_node_in) # type: ignore[union-attr, arg-type]
ag_waits[bucket_idx].append(wait_node)
@ -560,5 +576,3 @@ def merge_all_gather(
for ag_n, wait_n in zip(ag_buckets[bucket_idx], _ag_waits):
g.erase_node(wait_n)
g.erase_node(ag_n)
for n in reversed(ag_node_to_pre_nodes[ag_n]):
g.erase_node(n) # type: ignore[arg-type]

View File

@ -4,7 +4,7 @@ import collections
import dataclasses
import heapq
import logging
from typing import Callable, TYPE_CHECKING, TypedDict, Union
from typing import Callable, Optional, TYPE_CHECKING, TypedDict, Union
from torch._environment import is_fbcode
from torch._utils_internal import signpost_event
@ -76,7 +76,7 @@ def get_freeable_input_buf(
Create and keep track of all input buffers that can be freed during the program
Returns:
A dictionary containing all freeble input buffers, keyed by their names.
A dictionary containing all freeable input buffers, keyed by their names.
"""
def _dep_size_hint(dep: Dep) -> int:
@ -303,7 +303,11 @@ def compute_memory_timeline(
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
) -> tuple[list[BufferInfo], dict[BaseSchedulerNode, int]]:
) -> tuple[
list[BufferInfo],
dict[BaseSchedulerNode, int],
dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode],
]:
"""
Compute buffer allocation and deallocation sizes and map their
lifetime to the node schedule
@ -317,15 +321,33 @@ def compute_memory_timeline(
# get buffers' size and liveliness information
buf_info_list: list[BufferInfo] = []
buf_to_snode_last_use: dict[
Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode
] = {}
def _get_end_step_and_snode(
buf: Union[FreeableInputBuffer, SchedulerBuffer],
) -> tuple[int, Optional[BaseSchedulerNode]]:
max_step: int = -1
max_step_snode: Optional[BaseSchedulerNode] = None
succ_nodes = buf.mpi_buffer.succ_nodes
if succ_nodes:
for succ_node in succ_nodes:
step = node_to_step[succ_node]
if step > max_step:
max_step = step
max_step_snode = succ_node
assert max_step_snode is not None
return max_step, max_step_snode
# 1. for freeable input buffers
for buf_name, input_buf in name_to_freeable_input_buf.items():
end_step = (
len(nodes) - 1
if buf_name in graph_outputs
else max(
node_to_step[succ_node] for succ_node in input_buf.mpi_buffer.succ_nodes
)
)
end_step = -1
if buf_name not in graph_outputs:
end_step, end_step_snode = _get_end_step_and_snode(input_buf)
assert end_step_snode is not None
buf_to_snode_last_use[input_buf] = end_step_snode
buf_info_list.append(
BufferInfo(
input_buf,
@ -342,17 +364,17 @@ def compute_memory_timeline(
# note: it is possible for a non-graph-output sched_buf to have no succ_nodes and
# to be only used by its defining op (e.g., due to fusion when all consumers of
# the buffer are fused with its defining op). In such cases, end_step is step.
end_step = (
len(nodes) - 1
if sched_buf.get_name() in graph_outputs
else max(
[
node_to_step[succ_node]
for succ_node in sched_buf.mpi_buffer.succ_nodes
],
default=step,
)
)
buf_name = sched_buf.get_name()
end_step = -1
if buf_name not in graph_outputs:
end_step, end_step_snode = _get_end_step_and_snode(sched_buf)
if end_step == -1:
end_step = step
buf_to_snode_last_use[sched_buf] = node
else:
assert end_step_snode is not None
buf_to_snode_last_use[sched_buf] = end_step_snode
buf_info_list.append(
BufferInfo(
sched_buf,
@ -363,7 +385,7 @@ def compute_memory_timeline(
)
)
return buf_info_list, node_to_step
return buf_info_list, node_to_step, buf_to_snode_last_use
def estimate_peak_memory(
@ -373,35 +395,84 @@ def estimate_peak_memory(
) -> tuple[int, list[int]]:
"""
Given a list of nodes in their execution order, estimate the peak memory, by
keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers.
keeping track of the liveness of SchedulerBuffers and FreeableInputBuffers.
Returns:
int: peak memory
List[int]: memory usage at each node (or each step).
"""
# Use estimate_peak_memory_allocfree to keep one impl.
peak_memory, snodes_curr_memory, snodes_allocfree, buf_to_snode_last_use = (
estimate_peak_memory_allocfree(nodes, name_to_freeable_input_buf, graph_outputs)
)
return peak_memory, [(curr_mem[0] + curr_mem[1]) for curr_mem in snodes_curr_memory]
buf_info_list, _ = compute_memory_timeline(
@dataclasses.dataclass
class SNodeMemory:
size_alloc: int
size_free: int
def estimate_peak_memory_allocfree(
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
) -> tuple[
int,
list[tuple[int, int]],
dict[BaseSchedulerNode, SNodeMemory],
dict[Union[FreeableInputBuffer, SchedulerBuffer], BaseSchedulerNode],
]:
"""
Alternative version of estimate_peak_memory, that respects the fact,
that every SchedulerNode has multiple phases:
1. alloc ( outputs )
2. run_kernel
3. dealloc last_use buffers
estimate_peak_memory collapses memory into one value: size_alloc - size_free
While peak memory happens after alloc.
Duplicating the code to not migrate all callsites at once,
In future usages of estimate_peak_memory will migrate to this version.
"""
buf_info_list, _, buf_to_snode_last_use = compute_memory_timeline(
nodes, name_to_freeable_input_buf, graph_outputs
)
# incremental memory changes at each step
memory = [0 for _ in range(len(nodes) + 1)]
step_idx_allocfree = [SNodeMemory(0, 0) for _ in range(len(nodes))]
# for each buffer, update memory when created and when freed
for buf_info in buf_info_list:
memory[buf_info.start_step] += buf_info.size_alloc
memory[buf_info.end_step + 1] -= buf_info.size_free
step_idx_allocfree[buf_info.start_step].size_alloc += buf_info.size_alloc
if buf_info.end_step != -1:
step_idx_allocfree[buf_info.end_step].size_free += buf_info.size_free
snodes_allocfree = {}
for i, node in enumerate(nodes):
snodes_allocfree[node] = step_idx_allocfree[i]
# get peak memory by compute the cumulative memories
max_memory = 0
cur_memory = 0
memories_at_nodes = []
for t in range(len(nodes) + 1):
cur_memory += memory[t]
memories_at_nodes.append(cur_memory)
snodes_curr_memory = []
for t in range(len(nodes)):
alloc = step_idx_allocfree[t].size_alloc
free = step_idx_allocfree[t].size_free
cur_memory += alloc
post_alloc = cur_memory
max_memory = max(max_memory, cur_memory)
cur_memory -= free
post_free = cur_memory
snodes_curr_memory.append((post_alloc, post_free))
return (max_memory, memories_at_nodes)
return (
max_memory,
snodes_curr_memory,
snodes_allocfree,
buf_to_snode_last_use,
)
def topological_sort_lpmf(
@ -417,7 +488,7 @@ def topological_sort_lpmf(
Buffer memory optimization for video codec application modeled in Simulink
https://www.cs.york.ac.uk/rts/docs/DAC-1964-2006/PAPERS/2006/DAC06/PDFFILES/P0689.PDF
The algorithm maintain the max memory so far.
The algorithm maintains the max memory so far.
At every iteration, for each scheduleable node, it computes:
- how much memory needs to be allocated for the output buffers of this node;
- how much memory can be freed as a result of executing this node.

View File

@ -2159,7 +2159,20 @@ class Scheduler:
OrderedSet(V.graph.graph_inputs.keys()),
OrderedSet(V.graph.get_output_names()),
)
if config.estimate_runtime_benchmark:
from .estimator import estimate_runtime
verbose = True
estimate_runtime(self, self.nodes, verbose)
if config.reorder_for_compute_comm_overlap:
if not config.reorder_for_peak_memory:
from .memory import assign_memory_planning_info_for_scheduler_buffers
assign_memory_planning_info_for_scheduler_buffers(
self.nodes, self.name_to_buf
)
from torch._logging import trace_structured
trace_structured(
@ -2556,7 +2569,7 @@ class Scheduler:
)
graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
buf_info_list, _ = compute_memory_timeline(
buf_info_list, _, _ = compute_memory_timeline(
self.nodes,
name_to_freeable_input_buf,
graph_outputs,