mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Memory Estimation Tracker (#165059)
Add Memory Tracker utility, which will track live memory given alternate ordering of nodes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165059 Approved by: https://github.com/ezyang, https://github.com/IvanKobzarev ghstack dependencies: #164738, #164783, #164944, #164945
This commit is contained in:
committed by
PyTorch MergeBot
parent
8c4b528403
commit
2b71b62045
@ -6,12 +6,13 @@ from collections import Counter
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch._inductor.fx_passes.memory_estimator import build_memory_profile
|
||||
from torch._inductor.fx_passes.memory_estimator import (
|
||||
build_memory_profile,
|
||||
MemoryTracker,
|
||||
)
|
||||
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import IS_LINUX
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._pytree import tree_map_only
|
||||
from torch.utils.weak import WeakIdKeyDictionary
|
||||
@ -168,6 +169,180 @@ class TestMemoryProfilingResNet(InductorTestCase):
|
||||
self.assertEqual(fx_peak, runtime_peak)
|
||||
|
||||
|
||||
class TestMemoryTracker(InductorTestCase):
|
||||
def test_memory_tracker_original_order(self):
|
||||
"""Test that MemoryTracker works correctly with original scheduling order and matches runtime profiling."""
|
||||
|
||||
def create_inputs_and_weights():
|
||||
"""Create inputs and weights on CUDA."""
|
||||
x = torch.randn(32, 100, device="cuda")
|
||||
w1 = torch.randn(100, 50, device="cuda")
|
||||
w2 = torch.randn(50, 10, device="cuda")
|
||||
return x, w1, w2
|
||||
|
||||
def fn(x, w1, w2):
|
||||
# Create a simple function that allocates intermediate tensors
|
||||
h1 = torch.matmul(x, w1) # Allocates h1
|
||||
h2 = torch.relu(h1) # h1 can be freed, h2 allocated
|
||||
out = torch.matmul(h2, w2) # h2 can be freed, out allocated
|
||||
return out
|
||||
|
||||
with FakeTensorMode():
|
||||
# Create inputs
|
||||
x, w1, w2 = create_inputs_and_weights()
|
||||
|
||||
# Trace the function
|
||||
fx_graph = make_fx(fn)(x, w1, w2)
|
||||
|
||||
# Test MemoryTracker with original order
|
||||
memory_tracker = MemoryTracker(fx_graph.graph, device_filter=device_filter)
|
||||
|
||||
# Schedule nodes in original order
|
||||
compute_nodes = [
|
||||
node
|
||||
for node in fx_graph.graph.nodes
|
||||
if node.op not in ("placeholder", "get_attr", "output")
|
||||
]
|
||||
|
||||
for node in compute_nodes:
|
||||
memory_tracker.schedule_node(node)
|
||||
|
||||
memory_tracker_peak = memory_tracker.get_current_memory_bytes()
|
||||
|
||||
# Compare with runtime profiling using FakeTensorMemoryProfilerMode
|
||||
profiler = FakeTensorMemoryProfilerMode(device_filter=device_filter)
|
||||
|
||||
with profiler:
|
||||
x_runtime, w1_runtime, w2_runtime = create_inputs_and_weights()
|
||||
result = fn(x_runtime, w1_runtime, w2_runtime)
|
||||
del result
|
||||
|
||||
runtime_peak = profiler.max_memory
|
||||
|
||||
# Verify both approaches track meaningful memory usage
|
||||
self.assertGreater(
|
||||
memory_tracker_peak, 0, "MemoryTracker should track memory usage"
|
||||
)
|
||||
self.assertGreater(
|
||||
runtime_peak, 0, "Runtime profiler should track memory usage"
|
||||
)
|
||||
|
||||
def test_memory_tracker_different_scheduling(self):
|
||||
"""Test that different scheduling orders produce different memory usage patterns."""
|
||||
|
||||
def foo(primals_1):
|
||||
zeros = torch.zeros_like(primals_1) # Create zeros tensor
|
||||
add_result = zeros + 1 # Use zeros (first use)
|
||||
sum_result = zeros.sum() # Use zeros (second use)
|
||||
cpu = torch.zeros([20], device="cpu")
|
||||
cpu_2 = cpu + 1
|
||||
return add_result, sum_result, cpu_2
|
||||
|
||||
with FakeTensorMode():
|
||||
# Create input
|
||||
primals_1 = torch.randn(1000, 1000, device="cuda")
|
||||
|
||||
# Trace the function
|
||||
fx_graph = make_fx(foo)(primals_1)
|
||||
|
||||
# Get compute nodes (excluding placeholders, get_attr, output)
|
||||
compute_nodes = [
|
||||
node
|
||||
for node in fx_graph.graph.nodes
|
||||
if node.op not in ("placeholder", "get_attr", "output")
|
||||
]
|
||||
|
||||
if len(compute_nodes) < 3:
|
||||
self.skipTest(
|
||||
f"Need at least 3 compute nodes, got {len(compute_nodes)}"
|
||||
)
|
||||
|
||||
# Test original order: zeros_like, add, sum
|
||||
# zeros gets freed after sum (last use of zeros)
|
||||
memory_tracker1 = MemoryTracker(fx_graph.graph, device_filter=device_filter)
|
||||
memory_profile1 = []
|
||||
initial_mem = memory_tracker1.get_current_memory_bytes()
|
||||
|
||||
for node in compute_nodes:
|
||||
memory_tracker1.schedule_node(node)
|
||||
memory_profile1.append(memory_tracker1.get_current_memory_bytes())
|
||||
|
||||
# use of primals should not deallocate
|
||||
self.assertEqual(memory_profile1[0], initial_mem * 2)
|
||||
|
||||
# Test different order: zeros_like, sum, add
|
||||
# zeros gets freed after add (last use of zeros in new order)
|
||||
memory_tracker2 = MemoryTracker(fx_graph.graph, device_filter=device_filter)
|
||||
memory_profile2 = []
|
||||
|
||||
# Alternative schedule: change which operation is the last use of zeros
|
||||
# Original: zeros_like, add, sum (zeros freed after sum)
|
||||
# Alternative: zeros_like, sum, add (zeros freed after add)
|
||||
assert len(compute_nodes) == 5, (
|
||||
f"Expected 3 compute nodes, got {len(compute_nodes)}"
|
||||
)
|
||||
reordered_nodes = [
|
||||
compute_nodes[0], # zeros_like: zeros = torch.zeros_like(primals_1)
|
||||
compute_nodes[2], # sum: sum_result = zeros.sum() (zeros still alive)
|
||||
compute_nodes[
|
||||
1
|
||||
], # add: add_result = zeros + 1 (last use, zeros freed here)
|
||||
compute_nodes[3], # cpu = torch.zeros([20], device="cpu")
|
||||
compute_nodes[4], # cpu_2 = cpu + 1
|
||||
]
|
||||
|
||||
for node in reordered_nodes:
|
||||
memory_tracker2.schedule_node(node)
|
||||
memory_profile2.append(memory_tracker2.get_current_memory_bytes())
|
||||
|
||||
# Compare peak memories
|
||||
peak1 = max(memory_profile1)
|
||||
peak2 = max(memory_profile2)
|
||||
|
||||
# Both should end with the same final memory (all intermediate tensors freed)
|
||||
self.assertEqual(memory_profile1[-1], memory_profile2[-1])
|
||||
|
||||
# The profiles should be different, showing different memory patterns
|
||||
self.assertNotEqual(
|
||||
memory_profile1,
|
||||
memory_profile2,
|
||||
"Different scheduling should produce different memory profiles",
|
||||
)
|
||||
|
||||
# The different scheduling should produce different peak memory!
|
||||
# Original: zeros + add_result both alive → higher peak
|
||||
# Reordered: zeros freed before add_result created → lower peak
|
||||
self.assertGreater(
|
||||
peak1, peak2, "Original order should have higher peak memory"
|
||||
)
|
||||
|
||||
# Specifically, original has both zeros and add_result alive simultaneously
|
||||
self.assertGreater(
|
||||
memory_profile1[1],
|
||||
memory_profile2[1],
|
||||
"Original order keeps more tensors alive simultaneously",
|
||||
)
|
||||
|
||||
# The reordered version should have lower intermediate memory usage
|
||||
self.assertLess(
|
||||
peak2,
|
||||
peak1,
|
||||
"Reordered schedule reduces peak memory through better deallocation timing",
|
||||
)
|
||||
|
||||
# Verify the MemoryTracker correctly tracks different scheduling
|
||||
# The first tracker should match since we tested accuracy against FakeTensorMemoryProfilerMode
|
||||
self.assertLessEqual(
|
||||
abs(memory_tracker1.peak_memory - peak1),
|
||||
8,
|
||||
"First tracker peak should match profile peak",
|
||||
)
|
||||
|
||||
# The key test: profiles show different peaks due to different deallocation timing
|
||||
self.assertNotEqual(
|
||||
peak1, peak2, "Different scheduling produces different peak memory"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if IS_LINUX and HAS_CUDA_AND_TRITON:
|
||||
run_tests(needs="filelock")
|
||||
run_tests(needs="filelock")
|
||||
|
@ -14,14 +14,6 @@ from torch.utils._pytree import tree_map_only
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_wait_tensor(node: fx.Node) -> bool:
|
||||
"""Check if a node is a wait_tensor operation."""
|
||||
return (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops._c10d_functional.wait_tensor.default
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StorageKey:
|
||||
storage: torch.UntypedStorage
|
||||
@ -125,23 +117,12 @@ class GraphAliasTracker:
|
||||
def _get_input_storages(self, node: fx.Node) -> OrderedSet[StorageKey]:
|
||||
"""
|
||||
Get all storages from a node's inputs.
|
||||
|
||||
For wait_tensor operations, this includes both the direct inputs (the collective handle)
|
||||
and all inputs from the corresponding collective start operation, since the wait
|
||||
is what actually allows those inputs to be freed.
|
||||
"""
|
||||
input_storages: OrderedSet[StorageKey] = OrderedSet()
|
||||
|
||||
for input_node in node.all_input_nodes:
|
||||
input_storages.update(self.node_to_output_storages[input_node])
|
||||
|
||||
# Handle collective start/wait pairs: wait_tensor should also "use" all inputs
|
||||
# from the collective start operation, since it's the wait that releases them
|
||||
if _is_wait_tensor(node):
|
||||
collective_start = node.args[0]
|
||||
assert isinstance(collective_start, fx.Node)
|
||||
input_storages.update(self.node_to_storage_uses[collective_start])
|
||||
|
||||
return input_storages
|
||||
|
||||
def get_fresh_allocations(self, node: fx.Node) -> OrderedSet[StorageKey]:
|
||||
@ -303,14 +284,15 @@ def get_fwd_bwd_interactions(
|
||||
return bwd_baseline_memory, do_not_delete
|
||||
|
||||
|
||||
def _is_releasable(n: fx.Node) -> bool:
|
||||
# Storages of primals cannot be released during fwd or bwd pass.
|
||||
return not n.name.startswith("primals")
|
||||
|
||||
|
||||
def get_peak_memory(
|
||||
fwd_graph: fx.Graph,
|
||||
bwd_graph: fx.Graph,
|
||||
) -> int:
|
||||
def _is_releasable(n: fx.Node) -> bool:
|
||||
# Storages of primals cannot be released during fwd or bwd pass.
|
||||
return not n.name.startswith("primals")
|
||||
|
||||
fwd_peak_memory = max(build_memory_profile(fwd_graph, _is_releasable))
|
||||
|
||||
bwd_baseline_memory, bwd_do_not_delete = get_fwd_bwd_interactions(
|
||||
@ -330,3 +312,143 @@ def get_peak_memory(
|
||||
fwd_peak_memory,
|
||||
bwd_peak_memory,
|
||||
)
|
||||
|
||||
|
||||
class MemoryTracker:
|
||||
"""
|
||||
Tracks memory usage for alternative scheduling orders of an FX graph.
|
||||
|
||||
This class enables tracking memory usage as nodes are scheduled in a different
|
||||
order than the original graph.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: fx.Graph,
|
||||
is_releasable: Optional[Callable[[fx.Node], bool]] = None,
|
||||
device_filter: Optional[Callable[[torch.device], bool]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize memory tracker for alternative scheduling of the given graph.
|
||||
|
||||
Args:
|
||||
graph: FX graph to track memory for under alternative scheduling
|
||||
is_releaseable: do we consider this input to the graph to release memory
|
||||
upon final use, or is allocated for the duration of the graph ?
|
||||
by default, we assume all nodes but those that start with "primals" to be releasable
|
||||
device_filter: Function to determine which devices to track (default: non-CPU)
|
||||
"""
|
||||
|
||||
self.graph = graph
|
||||
self.nodes = list(graph.nodes)
|
||||
self.device_filter = device_filter or (lambda device: device.type != "cpu")
|
||||
self.scheduled: OrderedSet[fx.Node] = OrderedSet()
|
||||
|
||||
# Memory tracking using GraphAliasTracker
|
||||
self.alias_tracker = GraphAliasTracker(self.nodes)
|
||||
self.current_live_storages: OrderedSet[StorageKey] = OrderedSet()
|
||||
self.current_memory_bytes = 0
|
||||
self.is_releasable = _is_releasable if is_releasable is None else is_releasable
|
||||
|
||||
# Initialize live storages with placeholders and get_attr nodes
|
||||
for node in self.nodes:
|
||||
if node.op in ("placeholder", "get_attr"):
|
||||
fresh_allocations = self.alias_tracker.get_fresh_allocations(node)
|
||||
for storage_key in fresh_allocations:
|
||||
if self.device_filter(storage_key.device):
|
||||
self.current_live_storages.add(storage_key)
|
||||
self.current_memory_bytes += self._get_storage_size(storage_key)
|
||||
|
||||
self.peak_memory = self.current_memory_bytes
|
||||
|
||||
log.debug(
|
||||
"Memory tracker initialized with initial memory: %d MB",
|
||||
self.current_memory_bytes // (1024 * 1024),
|
||||
)
|
||||
|
||||
def schedule_node(self, node: fx.Node) -> None:
|
||||
"""
|
||||
Schedule a node and update memory tracking for the new scheduling order.
|
||||
|
||||
Args:
|
||||
node: The node being scheduled (potentially out of original order)
|
||||
"""
|
||||
assert node not in self.scheduled, "should not schedule node twice"
|
||||
self.scheduled.add(node)
|
||||
self._update_memory_for_node(node)
|
||||
|
||||
def get_current_memory_bytes(self) -> int:
|
||||
"""Get current live memory in bytes under the current scheduling."""
|
||||
return self.current_memory_bytes
|
||||
|
||||
def _get_storage_size(self, storage_key: StorageKey) -> int:
|
||||
"""Get the size of a storage in bytes, handling symbolic shapes."""
|
||||
size_bytes = storage_key.storage.nbytes()
|
||||
return hint_int(
|
||||
size_bytes, fallback=torch._inductor.config.unbacked_symint_fallback
|
||||
)
|
||||
|
||||
def _get_storages_freed_by_node(self, node: fx.Node) -> OrderedSet[StorageKey]:
|
||||
"""Get storages that would be freed if we schedule this node."""
|
||||
freed_storages: OrderedSet[StorageKey] = OrderedSet()
|
||||
|
||||
input_storages = self.alias_tracker.get_storage_uses(node)
|
||||
for storage_key in input_storages:
|
||||
if not self.device_filter(storage_key.device):
|
||||
continue
|
||||
|
||||
# Invariant: if a node uses a storage, it must be live
|
||||
assert storage_key in self.current_live_storages, (
|
||||
"all input storages should be currently allocated"
|
||||
)
|
||||
|
||||
if not self.is_releasable(
|
||||
self.alias_tracker.storage_to_allocator[storage_key]
|
||||
):
|
||||
continue
|
||||
|
||||
all_uses = self.alias_tracker.storage_to_uses[storage_key]
|
||||
|
||||
# If no more unscheduled uses remain, the storage can be freed
|
||||
if all(u in self.scheduled for u in all_uses):
|
||||
freed_storages.add(storage_key)
|
||||
|
||||
return freed_storages
|
||||
|
||||
def _update_memory_for_node(self, node: fx.Node) -> None:
|
||||
"""Update memory tracking when a node is scheduled."""
|
||||
if node.op in ("placeholder", "get_attr", "output"):
|
||||
return
|
||||
|
||||
# Add fresh allocations
|
||||
fresh_allocations = self.alias_tracker.get_fresh_allocations(node)
|
||||
alloc_bytes = 0
|
||||
for storage_key in fresh_allocations:
|
||||
if (
|
||||
self.device_filter(storage_key.device)
|
||||
and storage_key not in self.current_live_storages
|
||||
):
|
||||
size = self._get_storage_size(storage_key)
|
||||
self.current_live_storages.add(storage_key)
|
||||
self.current_memory_bytes += size
|
||||
alloc_bytes += size
|
||||
|
||||
self.peak_memory = max(self.current_memory_bytes, self.peak_memory)
|
||||
|
||||
# Remove storages that are no longer used
|
||||
storages_to_free = self._get_storages_freed_by_node(node)
|
||||
freed_bytes = 0
|
||||
for storage_key in storages_to_free:
|
||||
if storage_key in self.current_live_storages:
|
||||
size = self._get_storage_size(storage_key)
|
||||
self.current_live_storages.remove(storage_key)
|
||||
self.current_memory_bytes -= size
|
||||
freed_bytes += size
|
||||
|
||||
log.debug(
|
||||
"Scheduled %s: memory change %d allocs, %d frees, current memory: %d MB",
|
||||
node.name,
|
||||
len(fresh_allocations),
|
||||
len(storages_to_free),
|
||||
self.current_memory_bytes // (1024 * 1024),
|
||||
)
|
||||
|
@ -12,6 +12,11 @@ import torch
|
||||
import torch.fx as fx
|
||||
from torch._dynamo.utils import counters, dynamo_timed
|
||||
from torch._inductor.fx_passes.bucketing import is_wait_tensor
|
||||
from torch._inductor.fx_passes.memory_estimator import (
|
||||
_is_releasable,
|
||||
build_memory_profile,
|
||||
MemoryTracker,
|
||||
)
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
@ -217,6 +222,12 @@ class OverlapScheduler:
|
||||
self.collective_info: dict[fx.Node, CollectiveInfo] = {}
|
||||
self.unscheduled_collectives: OrderedSet[fx.Node] = OrderedSet()
|
||||
|
||||
# Memory tracking using abstracted MemoryTracker
|
||||
self.original_peak_memory = max(
|
||||
build_memory_profile(self.graph, _is_releasable)
|
||||
)
|
||||
self.memory_tracker = MemoryTracker(self.graph)
|
||||
|
||||
self.wait_to_start: dict[fx.Node, fx.Node] = {}
|
||||
self._identify_collectives()
|
||||
|
||||
@ -422,6 +433,7 @@ class OverlapScheduler:
|
||||
assert node not in self.scheduled
|
||||
assert all(n in self.scheduled for n in node.all_input_nodes)
|
||||
self.scheduled.add(node)
|
||||
self.memory_tracker.schedule_node(node)
|
||||
|
||||
for user in node.users:
|
||||
self.in_degree[user] -= 1
|
||||
@ -661,6 +673,9 @@ class OverlapScheduler:
|
||||
potentially_hidden_collectives
|
||||
)
|
||||
|
||||
counters["inductor"]["overlap_original_mem"] = self.original_peak_memory
|
||||
counters["inductor"]["rescheduled_mem"] = self.memory_tracker.peak_memory
|
||||
|
||||
log.info(
|
||||
"Overlap scheduling: total exposed %s, total bad exposed %s, total potentially hidden %s",
|
||||
len(exposed),
|
||||
|
Reference in New Issue
Block a user