Compare commits

...

1 Commits

Author SHA1 Message Date
5e1120505a [WIP] Misc changes, sending to Elias 2025-11-07 15:46:16 +00:00
2 changed files with 55 additions and 10 deletions

View File

@ -176,6 +176,7 @@ class OverlapPreservingBucketer:
head = None
prev_event = None
position = 0
hiding_nodes = OrderedSet()
for node in self.scheduled:
node_type = None
@ -183,11 +184,13 @@ class OverlapPreservingBucketer:
# Determine if this node is relevant for this PG
if node in self.collective_info and get_group_name(node) == pg:
node_type = "starts"
if hn := self.collective_info[node].hiding_node:
hiding_nodes.add(hn)
elif is_wait_tensor(node):
wait_input = node.args[0]
if isinstance(wait_input, fx.Node) and get_group_name(wait_input) == pg:
node_type = "waits"
elif is_compute_node(node):
elif is_compute_node(node) or node in hiding_nodes:
node_type = "compute"
if node_type is None:
@ -205,7 +208,7 @@ class OverlapPreservingBucketer:
prev_event = event
position += 1
# from IPython import embed; embed(); exit()
return head
def _populate_node_to_event(self, pg: str) -> None:

View File

@ -4,9 +4,9 @@ import itertools
import logging
import sys
from collections import Counter, defaultdict
from collections.abc import Callable, Iterable
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any
from typing import Any, Callable
import torch
import torch.fx as fx
@ -44,12 +44,13 @@ def get_group_name(n: fx.Node) -> str:
def get_custom_estimation(
n: fx.Node,
override_size: int | None = None,
custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None,
) -> float | None:
if custom_runtime_estimation is None:
return None
return custom_runtime_estimation(n)
return custom_runtime_estimation(n, override_size)
def estimate_collective_time(
@ -58,7 +59,7 @@ def estimate_collective_time(
custom_runtime_estimation: Callable[[fx.Node], float | None] | None = None,
) -> float:
"""Estimate the runtime of a collective operation, optionally with an overridden size."""
if (est := get_custom_estimation(n, custom_runtime_estimation)) is not None:
if (est := get_custom_estimation(n, override_size, custom_runtime_estimation)) is not None:
return est
return torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node(
@ -68,10 +69,13 @@ def estimate_collective_time(
def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int:
size = 0
for node in fx_node.all_input_nodes:
if (t := node.meta.get("val")) is not None:
# todo - symbolic
size += t.numel() * t.element_size()
# for node in fx_node.all_input_nodes:
# if (t := node.meta.get("val")) is not None:
# # todo - symbolic
# size += t.numel() * t.element_size()
n = fx_node.args[0]
t = n.meta.get("val")
size = t.numel() * t.element_size()
return size
@ -419,6 +423,8 @@ class OverlapScheduler:
self._handle_collective_start(node)
elif is_wait_tensor(node):
self._handle_wait(node)
elif node.op == "placeholder":
self._schedule(node)
else:
self._handle_other(node)
@ -458,6 +464,28 @@ class OverlapScheduler:
preserve_node_ordering(self.graph, additional_deps)
def _handle_other(self, node: fx.Node) -> None:
# compute_time = benchmark_node(node, self.custom_runtime_estimation)
compute_time = self.custom_runtime_estimation(node)
available_compute = compute_time * self.compute_overlap_multipler
# if str(node) == "convert_element_type_3":
# from IPython import embed; embed(); exit()
# TODO: separate overlap time per process group
# First reduce exposed time of in-flight collectives
for info in self.in_flight.values():
if info.exposed_time_ms == 0:
continue
overlap_amount = min(info.exposed_time_ms, available_compute)
info.exposed_time_ms -= overlap_amount
available_compute -= overlap_amount
if info.exposed_time_ms == 0:
info.hiding_node = node
elif available_compute == 0:
break
# Then, look for unscheduled collectives we can overlap
if available_compute:
self._schedule_collectives_for_overlap(node, available_compute)
self._schedule(node)
def _schedule(self, node: fx.Node) -> None:
@ -559,6 +587,10 @@ class OverlapScheduler:
assert node in self.wait_to_start
coll_start = self.wait_to_start[node]
assert coll_start in self.in_flight
# if str(node) == "wait_tensor_44":
# from IPython import embed; embed(); exit()
# if self.collective_info[coll_start].is_exposed:
# print(f"{coll_start=} is exposed with {self.collective_info[coll_start]}")
# Scheduling a wait of a collective also forces the wait
# of every node enqueued prior to the collective on the
@ -644,10 +676,14 @@ class OverlapScheduler:
)
for collective in possible_collectives:
# if str(compute_node) == 'convert_element_type_3':
# print(f"{available_compute_time=}, {collective=}")
if available_compute_time == 0:
break
info = self.collective_info[collective]
# if str(compute_node) == 'convert_element_type_3':
# print(f" 1 - {info=}")
# Skip if compute depends on collective or vice versa
if (
@ -681,10 +717,16 @@ class OverlapScheduler:
# Schedule path to this collective
self._schedule_path_to_collective(path, compute_node)
# if str(compute_node) == 'convert_element_type_3':
# print(f" 2 - {info=}")
self._handle_collective_start(collective)
# if str(compute_node) == 'convert_element_type_3':
# print(f" 3 - {info=}")
# Update the exposed time for this newly scheduled collective
# after scheduling, which will account for latency reduction of bucketing
# if str(compute_node) == 'convert_element_type_3':
# print(f" 4 - {info=}")
overlap_amount = min(available_compute_time, info.exposed_time_ms)
info.exposed_time_ms -= overlap_amount
if info.exposed_time_ms == 0: