Add Memory Estimation Tracker (#165059)

Add Memory Tracker utility, which will track live memory given alternate ordering of nodes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165059
Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev
ghstack dependencies: #164738, #164783, #164944, #164945
This commit is contained in:
eellison
2025-10-15 09:33:33 -07:00
committed by PyTorch MergeBot
parent 8c4b528403
commit 2b71b62045
3 changed files with 340 additions and 28 deletions

View File

@ -6,12 +6,13 @@ from collections import Counter
from typing import Callable, Optional
import torch
from torch._inductor.fx_passes.memory_estimator import build_memory_profile
from torch._inductor.fx_passes.memory_estimator import (
build_memory_profile,
MemoryTracker,
)
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map_only
from torch.utils.weak import WeakIdKeyDictionary
@ -168,6 +169,180 @@ class TestMemoryProfilingResNet(InductorTestCase):
self.assertEqual(fx_peak, runtime_peak)
class TestMemoryTracker(InductorTestCase):
def test_memory_tracker_original_order(self):
"""Test that MemoryTracker works correctly with original scheduling order and matches runtime profiling."""
def create_inputs_and_weights():
"""Create inputs and weights on CUDA."""
x = torch.randn(32, 100, device="cuda")
w1 = torch.randn(100, 50, device="cuda")
w2 = torch.randn(50, 10, device="cuda")
return x, w1, w2
def fn(x, w1, w2):
# Create a simple function that allocates intermediate tensors
h1 = torch.matmul(x, w1) # Allocates h1
h2 = torch.relu(h1) # h1 can be freed, h2 allocated
out = torch.matmul(h2, w2) # h2 can be freed, out allocated
return out
with FakeTensorMode():
# Create inputs
x, w1, w2 = create_inputs_and_weights()
# Trace the function
fx_graph = make_fx(fn)(x, w1, w2)
# Test MemoryTracker with original order
memory_tracker = MemoryTracker(fx_graph.graph, device_filter=device_filter)
# Schedule nodes in original order
compute_nodes = [
node
for node in fx_graph.graph.nodes
if node.op not in ("placeholder", "get_attr", "output")
]
for node in compute_nodes:
memory_tracker.schedule_node(node)
memory_tracker_peak = memory_tracker.get_current_memory_bytes()
# Compare with runtime profiling using FakeTensorMemoryProfilerMode
profiler = FakeTensorMemoryProfilerMode(device_filter=device_filter)
with profiler:
x_runtime, w1_runtime, w2_runtime = create_inputs_and_weights()
result = fn(x_runtime, w1_runtime, w2_runtime)
del result
runtime_peak = profiler.max_memory
# Verify both approaches track meaningful memory usage
self.assertGreater(
memory_tracker_peak, 0, "MemoryTracker should track memory usage"
)
self.assertGreater(
runtime_peak, 0, "Runtime profiler should track memory usage"
)
def test_memory_tracker_different_scheduling(self):
"""Test that different scheduling orders produce different memory usage patterns."""
def foo(primals_1):
zeros = torch.zeros_like(primals_1) # Create zeros tensor
add_result = zeros + 1 # Use zeros (first use)
sum_result = zeros.sum() # Use zeros (second use)
cpu = torch.zeros([20], device="cpu")
cpu_2 = cpu + 1
return add_result, sum_result, cpu_2
with FakeTensorMode():
# Create input
primals_1 = torch.randn(1000, 1000, device="cuda")
# Trace the function
fx_graph = make_fx(foo)(primals_1)
# Get compute nodes (excluding placeholders, get_attr, output)
compute_nodes = [
node
for node in fx_graph.graph.nodes
if node.op not in ("placeholder", "get_attr", "output")
]
if len(compute_nodes) < 3:
self.skipTest(
f"Need at least 3 compute nodes, got {len(compute_nodes)}"
)
# Test original order: zeros_like, add, sum
# zeros gets freed after sum (last use of zeros)
memory_tracker1 = MemoryTracker(fx_graph.graph, device_filter=device_filter)
memory_profile1 = []
initial_mem = memory_tracker1.get_current_memory_bytes()
for node in compute_nodes:
memory_tracker1.schedule_node(node)
memory_profile1.append(memory_tracker1.get_current_memory_bytes())
# use of primals should not deallocate
self.assertEqual(memory_profile1[0], initial_mem * 2)
# Test different order: zeros_like, sum, add
# zeros gets freed after add (last use of zeros in new order)
memory_tracker2 = MemoryTracker(fx_graph.graph, device_filter=device_filter)
memory_profile2 = []
# Alternative schedule: change which operation is the last use of zeros
# Original: zeros_like, add, sum (zeros freed after sum)
# Alternative: zeros_like, sum, add (zeros freed after add)
assert len(compute_nodes) == 5, (
f"Expected 3 compute nodes, got {len(compute_nodes)}"
)
reordered_nodes = [
compute_nodes[0], # zeros_like: zeros = torch.zeros_like(primals_1)
compute_nodes[2], # sum: sum_result = zeros.sum() (zeros still alive)
compute_nodes[
1
], # add: add_result = zeros + 1 (last use, zeros freed here)
compute_nodes[3], # cpu = torch.zeros([20], device="cpu")
compute_nodes[4], # cpu_2 = cpu + 1
]
for node in reordered_nodes:
memory_tracker2.schedule_node(node)
memory_profile2.append(memory_tracker2.get_current_memory_bytes())
# Compare peak memories
peak1 = max(memory_profile1)
peak2 = max(memory_profile2)
# Both should end with the same final memory (all intermediate tensors freed)
self.assertEqual(memory_profile1[-1], memory_profile2[-1])
# The profiles should be different, showing different memory patterns
self.assertNotEqual(
memory_profile1,
memory_profile2,
"Different scheduling should produce different memory profiles",
)
# The different scheduling should produce different peak memory!
# Original: zeros + add_result both alive → higher peak
# Reordered: zeros freed before add_result created → lower peak
self.assertGreater(
peak1, peak2, "Original order should have higher peak memory"
)
# Specifically, original has both zeros and add_result alive simultaneously
self.assertGreater(
memory_profile1[1],
memory_profile2[1],
"Original order keeps more tensors alive simultaneously",
)
# The reordered version should have lower intermediate memory usage
self.assertLess(
peak2,
peak1,
"Reordered schedule reduces peak memory through better deallocation timing",
)
# Verify the MemoryTracker correctly tracks different scheduling
# The first tracker should match since we tested accuracy against FakeTensorMemoryProfilerMode
self.assertLessEqual(
abs(memory_tracker1.peak_memory - peak1),
8,
"First tracker peak should match profile peak",
)
# The key test: profiles show different peaks due to different deallocation timing
self.assertNotEqual(
peak1, peak2, "Different scheduling produces different peak memory"
)
if __name__ == "__main__":
if IS_LINUX and HAS_CUDA_AND_TRITON:
run_tests(needs="filelock")
run_tests(needs="filelock")

View File

@ -14,14 +14,6 @@ from torch.utils._pytree import tree_map_only
log = logging.getLogger(__name__)
def _is_wait_tensor(node: fx.Node) -> bool:
"""Check if a node is a wait_tensor operation."""
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.wait_tensor.default
)
@dataclass(frozen=True)
class StorageKey:
storage: torch.UntypedStorage
@ -125,23 +117,12 @@ class GraphAliasTracker:
def _get_input_storages(self, node: fx.Node) -> OrderedSet[StorageKey]:
"""
Get all storages from a node's inputs.
For wait_tensor operations, this includes both the direct inputs (the collective handle)
and all inputs from the corresponding collective start operation, since the wait
is what actually allows those inputs to be freed.
"""
input_storages: OrderedSet[StorageKey] = OrderedSet()
for input_node in node.all_input_nodes:
input_storages.update(self.node_to_output_storages[input_node])
# Handle collective start/wait pairs: wait_tensor should also "use" all inputs
# from the collective start operation, since it's the wait that releases them
if _is_wait_tensor(node):
collective_start = node.args[0]
assert isinstance(collective_start, fx.Node)
input_storages.update(self.node_to_storage_uses[collective_start])
return input_storages
def get_fresh_allocations(self, node: fx.Node) -> OrderedSet[StorageKey]:
@ -303,14 +284,15 @@ def get_fwd_bwd_interactions(
return bwd_baseline_memory, do_not_delete
def _is_releasable(n: fx.Node) -> bool:
# Storages of primals cannot be released during fwd or bwd pass.
return not n.name.startswith("primals")
def get_peak_memory(
fwd_graph: fx.Graph,
bwd_graph: fx.Graph,
) -> int:
def _is_releasable(n: fx.Node) -> bool:
# Storages of primals cannot be released during fwd or bwd pass.
return not n.name.startswith("primals")
fwd_peak_memory = max(build_memory_profile(fwd_graph, _is_releasable))
bwd_baseline_memory, bwd_do_not_delete = get_fwd_bwd_interactions(
@ -330,3 +312,143 @@ def get_peak_memory(
fwd_peak_memory,
bwd_peak_memory,
)
class MemoryTracker:
"""
Tracks memory usage for alternative scheduling orders of an FX graph.
This class enables tracking memory usage as nodes are scheduled in a different
order than the original graph.
"""
def __init__(
self,
graph: fx.Graph,
is_releasable: Optional[Callable[[fx.Node], bool]] = None,
device_filter: Optional[Callable[[torch.device], bool]] = None,
):
"""
Initialize memory tracker for alternative scheduling of the given graph.
Args:
graph: FX graph to track memory for under alternative scheduling
is_releaseable: do we consider this input to the graph to release memory
upon final use, or is allocated for the duration of the graph ?
by default, we assume all nodes but those that start with "primals" to be releasable
device_filter: Function to determine which devices to track (default: non-CPU)
"""
self.graph = graph
self.nodes = list(graph.nodes)
self.device_filter = device_filter or (lambda device: device.type != "cpu")
self.scheduled: OrderedSet[fx.Node] = OrderedSet()
# Memory tracking using GraphAliasTracker
self.alias_tracker = GraphAliasTracker(self.nodes)
self.current_live_storages: OrderedSet[StorageKey] = OrderedSet()
self.current_memory_bytes = 0
self.is_releasable = _is_releasable if is_releasable is None else is_releasable
# Initialize live storages with placeholders and get_attr nodes
for node in self.nodes:
if node.op in ("placeholder", "get_attr"):
fresh_allocations = self.alias_tracker.get_fresh_allocations(node)
for storage_key in fresh_allocations:
if self.device_filter(storage_key.device):
self.current_live_storages.add(storage_key)
self.current_memory_bytes += self._get_storage_size(storage_key)
self.peak_memory = self.current_memory_bytes
log.debug(
"Memory tracker initialized with initial memory: %d MB",
self.current_memory_bytes // (1024 * 1024),
)
def schedule_node(self, node: fx.Node) -> None:
"""
Schedule a node and update memory tracking for the new scheduling order.
Args:
node: The node being scheduled (potentially out of original order)
"""
assert node not in self.scheduled, "should not schedule node twice"
self.scheduled.add(node)
self._update_memory_for_node(node)
def get_current_memory_bytes(self) -> int:
"""Get current live memory in bytes under the current scheduling."""
return self.current_memory_bytes
def _get_storage_size(self, storage_key: StorageKey) -> int:
"""Get the size of a storage in bytes, handling symbolic shapes."""
size_bytes = storage_key.storage.nbytes()
return hint_int(
size_bytes, fallback=torch._inductor.config.unbacked_symint_fallback
)
def _get_storages_freed_by_node(self, node: fx.Node) -> OrderedSet[StorageKey]:
"""Get storages that would be freed if we schedule this node."""
freed_storages: OrderedSet[StorageKey] = OrderedSet()
input_storages = self.alias_tracker.get_storage_uses(node)
for storage_key in input_storages:
if not self.device_filter(storage_key.device):
continue
# Invariant: if a node uses a storage, it must be live
assert storage_key in self.current_live_storages, (
"all input storages should be currently allocated"
)
if not self.is_releasable(
self.alias_tracker.storage_to_allocator[storage_key]
):
continue
all_uses = self.alias_tracker.storage_to_uses[storage_key]
# If no more unscheduled uses remain, the storage can be freed
if all(u in self.scheduled for u in all_uses):
freed_storages.add(storage_key)
return freed_storages
def _update_memory_for_node(self, node: fx.Node) -> None:
"""Update memory tracking when a node is scheduled."""
if node.op in ("placeholder", "get_attr", "output"):
return
# Add fresh allocations
fresh_allocations = self.alias_tracker.get_fresh_allocations(node)
alloc_bytes = 0
for storage_key in fresh_allocations:
if (
self.device_filter(storage_key.device)
and storage_key not in self.current_live_storages
):
size = self._get_storage_size(storage_key)
self.current_live_storages.add(storage_key)
self.current_memory_bytes += size
alloc_bytes += size
self.peak_memory = max(self.current_memory_bytes, self.peak_memory)
# Remove storages that are no longer used
storages_to_free = self._get_storages_freed_by_node(node)
freed_bytes = 0
for storage_key in storages_to_free:
if storage_key in self.current_live_storages:
size = self._get_storage_size(storage_key)
self.current_live_storages.remove(storage_key)
self.current_memory_bytes -= size
freed_bytes += size
log.debug(
"Scheduled %s: memory change %d allocs, %d frees, current memory: %d MB",
node.name,
len(fresh_allocations),
len(storages_to_free),
self.current_memory_bytes // (1024 * 1024),
)

View File

@ -12,6 +12,11 @@ import torch
import torch.fx as fx
from torch._dynamo.utils import counters, dynamo_timed
from torch._inductor.fx_passes.bucketing import is_wait_tensor
from torch._inductor.fx_passes.memory_estimator import (
_is_releasable,
build_memory_profile,
MemoryTracker,
)
from torch.utils._mode_utils import no_dispatch
from torch.utils._ordered_set import OrderedSet
@ -217,6 +222,12 @@ class OverlapScheduler:
self.collective_info: dict[fx.Node, CollectiveInfo] = {}
self.unscheduled_collectives: OrderedSet[fx.Node] = OrderedSet()
# Memory tracking using abstracted MemoryTracker
self.original_peak_memory = max(
build_memory_profile(self.graph, _is_releasable)
)
self.memory_tracker = MemoryTracker(self.graph)
self.wait_to_start: dict[fx.Node, fx.Node] = {}
self._identify_collectives()
@ -422,6 +433,7 @@ class OverlapScheduler:
assert node not in self.scheduled
assert all(n in self.scheduled for n in node.all_input_nodes)
self.scheduled.add(node)
self.memory_tracker.schedule_node(node)
for user in node.users:
self.in_degree[user] -= 1
@ -661,6 +673,9 @@ class OverlapScheduler:
potentially_hidden_collectives
)
counters["inductor"]["overlap_original_mem"] = self.original_peak_memory
counters["inductor"]["rescheduled_mem"] = self.memory_tracker.peak_memory
log.info(
"Overlap scheduling: total exposed %s, total bad exposed %s, total potentially hidden %s",
len(exposed),