Refactor memory estimator to use node storages, add test (#164783)

- Update the Memory Estimator to use node storages for analysis, which simplifies book keeping, as opposed to manually looking at operator schema. This will also allow me to reuse this component elsewhere.

- Factor out into separate class, so that this same logic can be used  in scheduling (node allocations / aliasing / uses)

- Adds Tests for correctness - right now only on fwd/bwd by itself, not with both.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164783
Approved by: https://github.com/ruisizhang123
ghstack dependencies: #164738
This commit is contained in:
eellison
2025-10-08 08:44:07 -07:00
committed by PyTorch MergeBot
parent af4c29fea8
commit aed5ed1076
2 changed files with 393 additions and 294 deletions

View File

@ -0,0 +1,173 @@
# Owner(s): ["module: inductor"]
import functools
import weakref
from collections import Counter
from typing import Callable, Optional
import torch
from torch._inductor.fx_passes.memory_estimator import build_memory_profile
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map_only
from torch.utils.weak import WeakIdKeyDictionary
def tensor_storage_id(tensor):
return tensor._typed_storage()._cdata
def device_filter(device):
return device.type == "cuda"
class FakeTensorMemoryProfilerMode(TorchDispatchMode):
def __init__(self, device_filter: Optional[Callable[torch.device, bool]] = None):
# counter of storage ids to live references
self.storage_count: dict[int, int] = Counter()
# live fake tensors
self.live_tensors = WeakIdKeyDictionary()
self.memory_use = 0
self.max_memory = 0
self.device_filter = device_filter
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs is not None else {}
rs = func(*args, **kwargs)
tree_map_only(torch._subclasses.FakeTensor, self.increase_memory_use, rs)
return rs
def increase_memory_use(self, tensor):
# already accounted for
if tensor in self.live_tensors:
return
if self.device_filter is not None and not self.device_filter(tensor.device):
return
self.live_tensors[tensor] = True
nbytes = tensor.untyped_storage().nbytes()
storage_id = tensor_storage_id(tensor)
# new storage, add to memory
if storage_id not in self.storage_count:
self.change_memory(nbytes)
self.storage_count[storage_id] += 1
# when this tensor dies, we need to adjust memory
weakref.finalize(
tensor, functools.partial(self.tensor_cleanup, storage_id, nbytes)
)
def tensor_cleanup(self, storage_id, nbytes):
self.storage_count[storage_id] -= 1
if self.storage_count[storage_id] == 0:
del self.storage_count[storage_id]
self.change_memory(-nbytes)
def change_memory(self, delta):
self.memory_use += delta
self.max_memory = max(self.memory_use, self.max_memory)
class TestMemoryProfilingResNet(InductorTestCase):
def test_simple_linear_layers(self):
"""Test with a simple sequential model with explicit weights on CUDA."""
def create_inputs_and_weights():
"""Create inputs and weights on CUDA."""
x = torch.randn(32, 1000, device="cuda")
w1 = torch.randn(500, 1000, device="cuda")
w2 = torch.randn(100, 500, device="cuda")
w3 = torch.randn(10, 100, device="cuda")
return x, w1, w2, w3
def fn(x, w1, w2, w3):
h1 = torch.nn.functional.linear(x, w1)
h1 = torch.nn.functional.relu(h1)
h2 = torch.nn.functional.linear(h1, w2)
h2 = torch.nn.functional.relu(h2)
out = torch.nn.functional.linear(h2, w3)
return out
with FakeTensorMode():
# Trace with make_fx
x, w1, w2, w3 = create_inputs_and_weights()
fx_graph = make_fx(fn)(x, w1, w2, w3)
# Static analysis
def is_releasable(node):
return node.op not in ("placeholder", "get_attr")
fx_memory_profile = build_memory_profile(fx_graph.graph, is_releasable)
fx_peak = max(fx_memory_profile)
# Runtime profiling
profiler = FakeTensorMemoryProfilerMode()
with profiler:
x_runtime, w1_runtime, w2_runtime, w3_runtime = (
create_inputs_and_weights()
)
result = fn(x_runtime, w1_runtime, w2_runtime, w3_runtime)
del result
runtime_peak = profiler.max_memory
self.assertEqual(fx_peak, runtime_peak)
def test_conv_network(self):
"""Test with a convolutional network."""
def create_inputs_and_weights():
"""Create inputs and weights on CUDA."""
x = torch.randn(8, 3, 224, 224, device="cuda")
conv1_weight = torch.randn(64, 3, 3, 3, device="cuda")
conv2_weight = torch.randn(128, 64, 3, 3, device="cuda")
linear_weight = torch.randn(10, 128 * 56 * 56, device="cuda")
return x, conv1_weight, conv2_weight, linear_weight
def fn(x, conv1_weight, conv2_weight, linear_weight):
h = torch.nn.functional.conv2d(x, conv1_weight, padding=1)
h = torch.nn.functional.relu(h)
h = torch.nn.functional.max_pool2d(h, 2)
h = torch.nn.functional.conv2d(h, conv2_weight, padding=1)
h = torch.nn.functional.relu(h)
h = torch.nn.functional.max_pool2d(h, 2)
h = torch.flatten(h, 1)
out = torch.nn.functional.linear(h, linear_weight)
return out
with FakeTensorMode():
# Trace with make_fx
x, conv1_weight, conv2_weight, linear_weight = create_inputs_and_weights()
fx_graph = make_fx(fn)(x, conv1_weight, conv2_weight, linear_weight)
def is_releasable(node):
return node.op not in ("placeholder", "get_attr")
fx_memory_profile = build_memory_profile(fx_graph.graph, is_releasable)
fx_peak = max(fx_memory_profile)
# Runtime profiling
profiler = FakeTensorMemoryProfilerMode()
with profiler:
x_runtime, conv1_w, conv2_w, linear_w = create_inputs_and_weights()
result = fn(x_runtime, conv1_w, conv2_w, linear_w)
del result
runtime_peak = profiler.max_memory
self.assertEqual(fx_peak, runtime_peak)
if __name__ == "__main__":
if IS_LINUX and HAS_CUDA_AND_TRITON:
run_tests(needs="filelock")

View File

@ -1,19 +1,160 @@
import itertools
import logging
import operator
from typing import Any, Callable
from collections import defaultdict
from dataclasses import dataclass
from typing import Callable, Optional, Union
import torch
import torch.fx as fx
from torch._functorch.partitioners import _size_of, get_default_op_list
from torch.fx.experimental.symbolic_shapes import hint_int
from torch.utils._ordered_set import OrderedSet
from torch.utils._pytree import tree_map_only
log = logging.getLogger(__name__)
@dataclass(frozen=True)
class StorageKey:
storage: torch.UntypedStorage
device: torch.device
def __hash__(self) -> int:
return self.storage._cdata
def __eq__(self, other: object) -> bool:
if not isinstance(other, StorageKey):
return False
return (
self.storage._cdata == other.storage._cdata and self.device == other.device
)
class GraphAliasTracker:
"""
Tracks storage allocation and usage relationships in an FX graph.
Differentiates between:
- Fresh allocations: nodes that allocate new storage (not views/aliases)
- Uses: nodes that use a storage as input
"""
def __init__(self, nodes: list[fx.Node]):
# Map from node to the fresh storages it allocates (not views/aliases)
self.node_to_fresh_allocations: dict[fx.Node, OrderedSet[StorageKey]] = {}
# Map from storage to the node that originally allocated it
self.storage_to_allocator: dict[StorageKey, fx.Node] = {}
# Map from node to all storages it uses as inputs
self.node_to_storage_uses: dict[fx.Node, OrderedSet[StorageKey]] = {}
# Map from storage to all nodes that use it
self.storage_to_uses: dict[StorageKey, OrderedSet[fx.Node]] = defaultdict(
OrderedSet
)
# Map from storage to the last node that uses it
self.storage_to_last_user: dict[StorageKey, fx.Node] = {}
# Map from node to storages that have their last use at that node
self.node_to_storages_last_used: dict[fx.Node, OrderedSet[StorageKey]] = (
defaultdict(OrderedSet)
)
# Track all output storages for each node (for building usage graph)
self.node_to_output_storages: dict[fx.Node, OrderedSet[StorageKey]] = {}
# First pass: build storage allocations and track uses
for node in nodes:
# Get output storages
output_storages = self._get_output_storages(node)
self.node_to_output_storages[node] = output_storages
# Track fresh allocations
fresh_allocations: OrderedSet[StorageKey] = OrderedSet()
for storage_key in output_storages:
if storage_key not in self.storage_to_allocator:
self.storage_to_allocator[storage_key] = node
fresh_allocations.add(storage_key)
self.node_to_fresh_allocations[node] = fresh_allocations
# Track input storage uses (safe because inputs were already processed)
input_storages = self._get_input_storages(node)
self.node_to_storage_uses[node] = input_storages
for storage_key in input_storages:
self.storage_to_uses[storage_key].add(node)
# Second pass: find last users (iterate in reverse)
for node in reversed(nodes):
input_storages = self.node_to_storage_uses[node]
for storage_key in input_storages:
if storage_key not in self.storage_to_last_user:
self.storage_to_last_user[storage_key] = node
self.node_to_storages_last_used[node].add(storage_key)
@staticmethod
def _get_output_storages(node: fx.Node) -> OrderedSet[StorageKey]:
"""
Get all storages from a node's outputs.
Uses pytree to handle arbitrary nested structures.
"""
val = node.meta.get("val")
if val is None:
return OrderedSet()
storages: OrderedSet[StorageKey] = OrderedSet()
def collect_storage(tensor: torch._subclasses.FakeTensor) -> None:
storages.add(StorageKey(tensor.untyped_storage(), tensor.device))
# Use tree_map_only to handle FakeTensors in nested structures
tree_map_only(torch._subclasses.FakeTensor, collect_storage, val)
return storages
def _get_input_storages(self, node: fx.Node) -> OrderedSet[StorageKey]:
"""
Get all storages from a node's inputs.
"""
input_storages: OrderedSet[StorageKey] = OrderedSet()
for input_node in node.all_input_nodes:
input_storages.update(self.node_to_output_storages[input_node])
return input_storages
def get_fresh_allocations(self, node: fx.Node) -> OrderedSet[StorageKey]:
"""Get all fresh storage allocations by this node (not views/aliases)."""
return self.node_to_fresh_allocations[node]
def get_storage_uses(self, node: fx.Node) -> OrderedSet[StorageKey]:
"""Get all storages that this node uses as inputs."""
return self.node_to_storage_uses[node]
def get_storages_last_used(
self,
node: fx.Node,
) -> OrderedSet[StorageKey]:
"""
Get storages whose last use is at this node.
"""
return self.node_to_storages_last_used[node]
def _size_of_default(num_bytes: Union[int, torch.SymInt]) -> int:
return hint_int(num_bytes, fallback=torch._inductor.config.unbacked_symint_fallback)
def device_filter(device: torch.device) -> bool:
return device.type != "cpu"
def build_memory_profile(
graph: fx.Graph,
size_of: Callable[[fx.Node], int],
is_releasable: Callable[[fx.Node], bool],
size_of: Optional[Callable[[Union[int, torch.SymInt]], int]] = None,
) -> list[int]:
"""
Function to estimate the memory profile of an input FX graph.
@ -21,11 +162,11 @@ def build_memory_profile(
Args:
- graph (fx.Graph): The input FX graph for which the memory profile
is to be estimated.
- size_of (Callable[[fx.Node], int]): A function that returns
the size of a given node.
- is_releasable (Callable[[fx.Node], bool]): A function that
determines if a node's memory can be released (e.g. primal nodes
cannot be released).
- size_of (Callable[[Union[int, torch.SymInt]], int]): A function that converts
byte counts (possibly symbolic) to concrete integers.
Returns:
- List[int]: A list representing the memory profile over the execution
@ -33,239 +174,49 @@ def build_memory_profile(
a particular point in the execution.
"""
size_of = size_of or _size_of_default
nodes = list(graph.nodes)
op_types = get_default_op_list()
alias_info = GraphAliasTracker(nodes)
class AliasInfo:
"""
Class for storing and accessing alias information of a FX graph.
# Build memory profile
current_memory = 0
Attributes:
- view_to_source: Maps view nodes to their source nodes
- getitem_to_source: Maps getitem nodes to (source_node, key) tuples
- source_to_getitems: Maps source nodes to dictionaries of
{key: getitem_node, "unclaimed": None}
- source_to_unclaimed_size: Maps source nodes to their storage size
unclaimed by any getitem_nodes
"""
for node in itertools.chain(
graph.find_nodes(op="placeholder"), graph.find_nodes(op="get_attr")
):
for storage_key in alias_info.get_fresh_allocations(node):
if device_filter(storage_key.device):
current_memory += size_of(storage_key.storage.nbytes())
def __init__(self, nodes: list[fx.Node]):
"""
Initialize the AliasInfo class with a list of FX graph nodes.
memory_profile = [current_memory]
Args:
- nodes (list[fx.Node]): A list of nodes from an FX graph,
ordered in execution order.
The constructor analyzes the relationships between nodes in the FX graph
to populate alias information. It identifies two types of alias nodes:
getitem and view. For each view, it maps it to its source. For each
getitem, it maps it to its source and key. It also populates mappings
for source nodes to their getitems and calculates unclaimed storage sizes.
"""
# For each view, we map it to its source.
# Note that we treat getitems of a view (e.g. aten.split) as views.
self.view_to_source: dict[fx.Node, fx.Node] = {}
# For each remaining getitem, we map it to its source and key.
self.getitem_to_source: dict[fx.Node, tuple[fx.Node, Any]] = {}
# For each none-view source_node of getitems, we map it to a dictionary
# in the form of {key: getitem_node, ..., "unclaimed": None}, where
# "unclaimed" is a dummy key that represents all elements in the
# source_node that is not claimed by any getitems.
self.source_to_getitems: dict[fx.Node, dict[Any, fx.Node | None]] = {}
# For each none-view source_node of getitems with at least one unclaimed
# elements, we map it to its unclaimed storage size.
self.source_to_unclaimed_size: dict[fx.Node, int] = {}
for node in nodes:
is_view = op_types.is_view(node)
is_getitem = node.target is operator.getitem
if not (is_view or is_getitem):
continue
assert not (is_view and is_getitem)
assert node.args and isinstance(node.args[0], fx.Node)
source = node.args[0]
if is_view:
assert not isinstance(source.meta["val"], list | tuple | dict)
if source in self.view_to_source:
source = self.view_to_source[source]
self.view_to_source[node] = source
if is_getitem:
assert isinstance(source.meta["val"], list | tuple | dict)
# Source of getitem can be a view (e.g. aten.split).
if source in self.view_to_source:
if source in self.view_to_source:
source = self.view_to_source[source]
# In this case, the getitem node should be treated
# the same way as a regular view.
self.view_to_source[node] = source
continue
# Source of getitem cannot be a getitem.
assert source not in self.getitem_to_source
# There must be a second argument that specifies the key.
assert len(node.args) >= 2
key = node.args[1]
self.getitem_to_source[node] = (source, key)
# Populate source_to_getitems.
if source not in self.source_to_getitems:
self.source_to_getitems[source] = {"unclaimed": None}
assert key not in self.source_to_getitems[source]
self.source_to_getitems[source][key] = node # type: ignore[index]
for source, getitem_map in self.source_to_getitems.items():
unclaimed_source_size = size_of(source)
for key, getitem_node in getitem_map.items():
if key != "unclaimed" and getitem_node is not None:
unclaimed_source_size -= size_of(getitem_node)
assert unclaimed_source_size >= 0
if unclaimed_source_size > 0:
self.source_to_unclaimed_size[source] = unclaimed_source_size
def is_view(self, node: fx.Node) -> bool:
return node in self.view_to_source
def is_getitem(self, node: fx.Node) -> bool:
return node in self.getitem_to_source
def get_source(self, node: fx.Node) -> fx.Node | tuple[fx.Node, Any]:
if self.is_view(node):
return self.view_to_source[node]
if self.is_getitem(node):
return self.getitem_to_source[node]
return node
def is_source_of_getitems(self, node: fx.Node) -> bool:
return node in self.source_to_getitems
def get_storage_keys(self, source_node: fx.Node) -> list[Any]:
assert source_node in self.source_to_getitems
return list(self.source_to_getitems[source_node].keys())
def get_unclaimed_storage_size(self, source_node: fx.Node) -> int:
return self.source_to_unclaimed_size.get(source_node, 0)
def get_getitem_by_key(self, source: fx.Node, key: Any) -> fx.Node | None:
assert source in self.source_to_getitems
assert key in self.source_to_getitems[source]
return self.source_to_getitems[source][key]
def _get_last_usage(
nodes: list[fx.Node], alias_info: AliasInfo
) -> dict[fx.Node, list[tuple[fx.Node, Any]]]:
"""
Determine the last usage point of each storage. This information is used to
identify when storages can be safely released.
Args:
- nodes (list[fx.Node]): A list of nodes from the FX graph, ordered
in execution order.
- alias_info (AliasInfo): An instance of AliasInfo containing aliasing
relationships between nodes in the graph.
Returns:
- Dict[fx.Node, list[tuple[fx.Node, Optional[Any]]]]: A mapping
from each node to a list of storages (represented as tuples of source node
and key) that are last used by that node. This helps in identifying which
storages can be released after the node's execution.
"""
storage_to_last_user: dict[tuple[fx.Node, Any], fx.Node] = {}
node_to_last_used_storages: dict[fx.Node, list[tuple[fx.Node, Any]]] = {}
def register_last_uses(use: fx.Node, user: fx.Node) -> None:
keys: list[Any] = []
if alias_info.is_view(use):
# When use is a view (or getitem of a view),
# user is essentially using the storage allocated at the
# creation of the source of use.
use = alias_info.get_source(use) # type: ignore[assignment]
if alias_info.is_source_of_getitems(use): # type: ignore[arg-type]
# When use is a source of getitems, user is using all separate
# storages of use.
keys.extend(alias_info.get_storage_keys(use)) # type: ignore[arg-type]
elif alias_info.is_getitem(use): # type: ignore[arg-type]
# When use is a getitem, user is essentially using a separate
# storage of the source of use specified by key.
use, key = alias_info.get_source(use) # type: ignore[assignment,misc]
keys.append(key)
else:
keys.append(None)
assert keys
for key in keys:
if (use, key) not in storage_to_last_user: # type: ignore[comparison-overlap]
storage_to_last_user[(use, key)] = user # type: ignore[index]
node_to_last_used_storages.setdefault(user, []).append((use, key)) # type: ignore[arg-type]
for node in reversed(nodes):
fx.node.map_arg(node.args, lambda n: register_last_uses(n, node))
fx.node.map_arg(node.kwargs, lambda n: register_last_uses(n, node))
return node_to_last_used_storages
alias_info = AliasInfo(nodes)
node_to_last_used_storages = _get_last_usage(nodes, alias_info)
# Initialize memory profile
memory_profile = [0]
# Process the graph
for node in nodes:
if node.op == "placeholder":
out_mem = size_of(node)
memory_profile[0] += out_mem
elif node.op == "output":
pass
elif (
node.op == "call_function"
or node.op == "call_module"
or node.op == "call_method"
):
# Aliases don't allocate new memory
if alias_info.is_view(node) or alias_info.is_getitem(node):
memory_profile.append(memory_profile[-1])
else:
out_mem = size_of(node)
memory_profile.append(memory_profile[-1] + out_mem)
if node.op in ("placeholder", "get_attr", "output"):
continue
# Process storages that are no longer needed after this operation
storages_to_release = [
(use, key)
for use, key in node_to_last_used_storages.get(node, [])
if is_releasable(use)
]
freed_memory = 0
for node_to_release, key in storages_to_release:
released_memory_size = 0
if key is None:
released_memory_size = size_of(node_to_release)
elif key == "unclaimed":
released_memory_size = alias_info.get_unclaimed_storage_size(
node_to_release
)
else:
getitem_node = alias_info.get_getitem_by_key(node_to_release, key)
if getitem_node is not None:
released_memory_size = size_of(getitem_node)
freed_memory += released_memory_size
# Process allocations
for storage_key in alias_info.get_fresh_allocations(node):
if device_filter(storage_key.device):
current_memory += size_of(storage_key.storage.nbytes())
memory_profile.append(current_memory)
# Process deallocations
for storage_key in alias_info.get_storages_last_used(node):
allocator = alias_info.storage_to_allocator[storage_key]
if is_releasable(allocator):
if device_filter(storage_key.device):
current_memory -= size_of(storage_key.storage.nbytes())
memory_profile.append(current_memory)
assert freed_memory >= 0
memory_profile.append(memory_profile[-1] - freed_memory)
return memory_profile
def get_fwd_bwd_interactions(
fwd_graph: fx.Graph,
bwd_graph: fx.Graph,
size_of: Callable[[fx.Node], int],
size_of: Optional[Callable[[Union[int, torch.SymInt]], int]] = None,
) -> tuple[int, OrderedSet[str]]:
"""
Analyze the interactions between the forward (fwd) and backward (bwd) graphs
@ -274,77 +225,61 @@ def get_fwd_bwd_interactions(
Args:
- fwd_graph (fx.Graph): The forward graph representing the forward pass.
- bwd_graph (fx.Graph): The backward graph representing the backward pass.
- size_of (Callable[[fx.Node], int]): A function that returns the size
of a given node.
- size_of (Callable[[Union[int, torch.SymInt]], int]): A function that converts
byte counts (possibly symbolic) to concrete integers.
Returns:
- tuple[int, Set[fx.Node]]: A tuple containing:
- tuple[int, OrderedSet[str]]: A tuple containing:
1. The baseline memory usage during the backward pass, accounting for
nodes that persist from the forward pass (i.e., in fwd output but
storages that persist from the forward pass (i.e., in fwd output but
not in bwd input).
2. A set of nodes whose storage cannot be released during the bwd pass.
These include nodes that are views of primals or in bwd input
2. A set of node names whose storage cannot be released during the bwd pass.
These include nodes that use storage from primals or are in bwd input
but not in fwd output.
"""
def get_nodes_in_output(graph: fx.Graph) -> OrderedSet[fx.Node]:
"""
Get the nodes in the output of a graph.
size_of = size_of or _size_of_default
Args:
- graph (fx.Graph): The input graph.
# Build alias info for forward graph
fwd_nodes = list(fwd_graph.nodes)
fwd_alias_info = GraphAliasTracker(fwd_nodes)
Returns:
- list[fx.Node]: A list of nodes in the output of the graph.
"""
output_node = list(graph.nodes)[-1]
assert output_node.op == "output"
nodes_in_output: OrderedSet[fx.Node] = OrderedSet()
# Identify storages allocated by primal placeholder nodes
primal_storages: OrderedSet[StorageKey] = OrderedSet()
for node in fwd_graph.find_nodes(op="placeholder"):
if node.name.startswith("primals"):
primal_storages.update(fwd_alias_info.get_fresh_allocations(node))
def add_node(node: fx.Node) -> None:
nodes_in_output.add(node)
# Get storages in forward output
fwd_output_node = next(iter(reversed(fwd_graph.nodes)))[-1]
assert fwd_output_node.op == "output"
fwd_output_storages = fwd_alias_info.get_storage_uses(fwd_output_node)
# Using map_arg since output_node.args[0] can be of different types
# e.g. tuple, list, dict, fx.Node, etc.
fx.node.map_arg(output_node.args[0], lambda n: add_node(n))
return nodes_in_output
op_types = get_default_op_list()
bwd_baseline_memory = 0
# placeholder nodes besides primals of the bwd_graph that should also
# not be deleted during memory profile estimation of the bwd_graph
# Node names that should not be deleted during memory profile estimation of bwd_graph
do_not_delete: OrderedSet[str] = OrderedSet()
fwd_outputs = {}
for node in get_nodes_in_output(fwd_graph):
is_view_of_primal = False
if op_types.is_view(node):
source = node.args[0]
if isinstance(source, fx.Node) and source.name.startswith("primals"):
is_view_of_primal = True
fwd_outputs[node.name] = (size_of(node), is_view_of_primal)
bwd_inputs: OrderedSet[str] = OrderedSet()
for node in bwd_graph.nodes:
if node.op == "placeholder":
bwd_inputs.add(node.name)
if node.name.startswith("view"):
# if node is a view, then it has to be in fwd_outputs
assert node.name in fwd_outputs
_, is_view_of_primal = fwd_outputs[node.name]
if is_view_of_primal:
# Add node to do_not_delete because it is a view of a primal
do_not_delete.add(node.name)
# Collect all storages in backward inputs and identify nodes to not delete
bwd_input_storages: OrderedSet[StorageKey] = OrderedSet()
for node in bwd_graph.find_nodes(op="placeholder"):
node_storages = GraphAliasTracker._get_output_storages(node)
bwd_input_storages.update(node_storages)
# if node is not in fwd_outputs, then add it to do_not_delete
if node.name not in fwd_outputs:
do_not_delete.add(node.name)
# Check if this node uses primal storage
if node_storages & primal_storages:
do_not_delete.add(node.name)
# nodes that are in fwd_outputs but not in bwd_inputs take memory storage
# throughout the bwd pass
for name, (size, _) in fwd_outputs.items():
if name not in bwd_inputs:
bwd_baseline_memory += size
# Check if this node's storages are not in forward outputs
# (meaning it's an external input to backward pass)
if not (node_storages & fwd_output_storages):
do_not_delete.add(node.name)
# Calculate baseline memory: storages in fwd output but not in bwd input
# These storages persist throughout the backward pass
baseline_storages = fwd_output_storages - bwd_input_storages
bwd_baseline_memory = 0
for storage_key in baseline_storages:
if storage_key.device.type != "cpu":
bwd_baseline_memory += size_of(storage_key.storage.nbytes())
return bwd_baseline_memory, do_not_delete
@ -353,24 +288,15 @@ def get_peak_memory(
fwd_graph: fx.Graph,
bwd_graph: fx.Graph,
) -> int:
def _safe_size_of(n: fx.Node) -> int:
try:
return _size_of(n)
except Exception:
log.warning("Failed size_of(%s). Returning 0 instead.", n)
return 0
def _is_releasable(n: fx.Node) -> bool:
# Storages of primals cannot be released during fwd or bwd pass.
return not n.name.startswith("primals")
fwd_peak_memory = max(
build_memory_profile(fwd_graph, _safe_size_of, _is_releasable)
)
fwd_peak_memory = max(build_memory_profile(fwd_graph, _is_releasable))
# tmp change
bwd_baseline_memory, bwd_do_not_delete = get_fwd_bwd_interactions(
fwd_graph, bwd_graph, _safe_size_of
fwd_graph,
bwd_graph,
)
def _is_bwd_releasable(n: fx.Node) -> bool:
@ -379,7 +305,7 @@ def get_peak_memory(
return _is_releasable(n) and n.name not in bwd_do_not_delete
bwd_peak_memory = bwd_baseline_memory + max(
build_memory_profile(bwd_graph, _safe_size_of, _is_bwd_releasable)
build_memory_profile(bwd_graph, _is_bwd_releasable)
)
return max(
fwd_peak_memory,