mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactor memory estimator to use node storages, add test (#164783)
- Update the Memory Estimator to use node storages for analysis, which simplifies book keeping, as opposed to manually looking at operator schema. This will also allow me to reuse this component elsewhere. - Factor out into separate class, so that this same logic can be used in scheduling (node allocations / aliasing / uses) - Adds Tests for correctness - right now only on fwd/bwd by itself, not with both. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164783 Approved by: https://github.com/ruisizhang123 ghstack dependencies: #164738
This commit is contained in:
committed by
PyTorch MergeBot
parent
af4c29fea8
commit
aed5ed1076
173
test/inductor/test_mem_estimation.py
Normal file
173
test/inductor/test_mem_estimation.py
Normal file
@ -0,0 +1,173 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
import functools
|
||||
import weakref
|
||||
from collections import Counter
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch._inductor.fx_passes.memory_estimator import build_memory_profile
|
||||
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import IS_LINUX
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._pytree import tree_map_only
|
||||
from torch.utils.weak import WeakIdKeyDictionary
|
||||
|
||||
|
||||
def tensor_storage_id(tensor):
|
||||
return tensor._typed_storage()._cdata
|
||||
|
||||
|
||||
def device_filter(device):
|
||||
return device.type == "cuda"
|
||||
|
||||
|
||||
class FakeTensorMemoryProfilerMode(TorchDispatchMode):
|
||||
def __init__(self, device_filter: Optional[Callable[torch.device, bool]] = None):
|
||||
# counter of storage ids to live references
|
||||
self.storage_count: dict[int, int] = Counter()
|
||||
# live fake tensors
|
||||
self.live_tensors = WeakIdKeyDictionary()
|
||||
self.memory_use = 0
|
||||
self.max_memory = 0
|
||||
self.device_filter = device_filter
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs if kwargs is not None else {}
|
||||
rs = func(*args, **kwargs)
|
||||
tree_map_only(torch._subclasses.FakeTensor, self.increase_memory_use, rs)
|
||||
return rs
|
||||
|
||||
def increase_memory_use(self, tensor):
|
||||
# already accounted for
|
||||
if tensor in self.live_tensors:
|
||||
return
|
||||
|
||||
if self.device_filter is not None and not self.device_filter(tensor.device):
|
||||
return
|
||||
|
||||
self.live_tensors[tensor] = True
|
||||
nbytes = tensor.untyped_storage().nbytes()
|
||||
|
||||
storage_id = tensor_storage_id(tensor)
|
||||
|
||||
# new storage, add to memory
|
||||
if storage_id not in self.storage_count:
|
||||
self.change_memory(nbytes)
|
||||
|
||||
self.storage_count[storage_id] += 1
|
||||
|
||||
# when this tensor dies, we need to adjust memory
|
||||
weakref.finalize(
|
||||
tensor, functools.partial(self.tensor_cleanup, storage_id, nbytes)
|
||||
)
|
||||
|
||||
def tensor_cleanup(self, storage_id, nbytes):
|
||||
self.storage_count[storage_id] -= 1
|
||||
if self.storage_count[storage_id] == 0:
|
||||
del self.storage_count[storage_id]
|
||||
self.change_memory(-nbytes)
|
||||
|
||||
def change_memory(self, delta):
|
||||
self.memory_use += delta
|
||||
self.max_memory = max(self.memory_use, self.max_memory)
|
||||
|
||||
|
||||
class TestMemoryProfilingResNet(InductorTestCase):
|
||||
def test_simple_linear_layers(self):
|
||||
"""Test with a simple sequential model with explicit weights on CUDA."""
|
||||
|
||||
def create_inputs_and_weights():
|
||||
"""Create inputs and weights on CUDA."""
|
||||
x = torch.randn(32, 1000, device="cuda")
|
||||
w1 = torch.randn(500, 1000, device="cuda")
|
||||
w2 = torch.randn(100, 500, device="cuda")
|
||||
w3 = torch.randn(10, 100, device="cuda")
|
||||
return x, w1, w2, w3
|
||||
|
||||
def fn(x, w1, w2, w3):
|
||||
h1 = torch.nn.functional.linear(x, w1)
|
||||
h1 = torch.nn.functional.relu(h1)
|
||||
h2 = torch.nn.functional.linear(h1, w2)
|
||||
h2 = torch.nn.functional.relu(h2)
|
||||
out = torch.nn.functional.linear(h2, w3)
|
||||
return out
|
||||
|
||||
with FakeTensorMode():
|
||||
# Trace with make_fx
|
||||
x, w1, w2, w3 = create_inputs_and_weights()
|
||||
fx_graph = make_fx(fn)(x, w1, w2, w3)
|
||||
|
||||
# Static analysis
|
||||
def is_releasable(node):
|
||||
return node.op not in ("placeholder", "get_attr")
|
||||
|
||||
fx_memory_profile = build_memory_profile(fx_graph.graph, is_releasable)
|
||||
fx_peak = max(fx_memory_profile)
|
||||
|
||||
# Runtime profiling
|
||||
profiler = FakeTensorMemoryProfilerMode()
|
||||
|
||||
with profiler:
|
||||
x_runtime, w1_runtime, w2_runtime, w3_runtime = (
|
||||
create_inputs_and_weights()
|
||||
)
|
||||
result = fn(x_runtime, w1_runtime, w2_runtime, w3_runtime)
|
||||
del result
|
||||
|
||||
runtime_peak = profiler.max_memory
|
||||
|
||||
self.assertEqual(fx_peak, runtime_peak)
|
||||
|
||||
def test_conv_network(self):
|
||||
"""Test with a convolutional network."""
|
||||
|
||||
def create_inputs_and_weights():
|
||||
"""Create inputs and weights on CUDA."""
|
||||
x = torch.randn(8, 3, 224, 224, device="cuda")
|
||||
conv1_weight = torch.randn(64, 3, 3, 3, device="cuda")
|
||||
conv2_weight = torch.randn(128, 64, 3, 3, device="cuda")
|
||||
linear_weight = torch.randn(10, 128 * 56 * 56, device="cuda")
|
||||
return x, conv1_weight, conv2_weight, linear_weight
|
||||
|
||||
def fn(x, conv1_weight, conv2_weight, linear_weight):
|
||||
h = torch.nn.functional.conv2d(x, conv1_weight, padding=1)
|
||||
h = torch.nn.functional.relu(h)
|
||||
h = torch.nn.functional.max_pool2d(h, 2)
|
||||
h = torch.nn.functional.conv2d(h, conv2_weight, padding=1)
|
||||
h = torch.nn.functional.relu(h)
|
||||
h = torch.nn.functional.max_pool2d(h, 2)
|
||||
h = torch.flatten(h, 1)
|
||||
out = torch.nn.functional.linear(h, linear_weight)
|
||||
return out
|
||||
|
||||
with FakeTensorMode():
|
||||
# Trace with make_fx
|
||||
x, conv1_weight, conv2_weight, linear_weight = create_inputs_and_weights()
|
||||
fx_graph = make_fx(fn)(x, conv1_weight, conv2_weight, linear_weight)
|
||||
|
||||
def is_releasable(node):
|
||||
return node.op not in ("placeholder", "get_attr")
|
||||
|
||||
fx_memory_profile = build_memory_profile(fx_graph.graph, is_releasable)
|
||||
fx_peak = max(fx_memory_profile)
|
||||
|
||||
# Runtime profiling
|
||||
profiler = FakeTensorMemoryProfilerMode()
|
||||
|
||||
with profiler:
|
||||
x_runtime, conv1_w, conv2_w, linear_w = create_inputs_and_weights()
|
||||
result = fn(x_runtime, conv1_w, conv2_w, linear_w)
|
||||
del result
|
||||
|
||||
runtime_peak = profiler.max_memory
|
||||
|
||||
self.assertEqual(fx_peak, runtime_peak)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if IS_LINUX and HAS_CUDA_AND_TRITON:
|
||||
run_tests(needs="filelock")
|
@ -1,19 +1,160 @@
|
||||
import itertools
|
||||
import logging
|
||||
import operator
|
||||
from typing import Any, Callable
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
from torch._functorch.partitioners import _size_of, get_default_op_list
|
||||
from torch.fx.experimental.symbolic_shapes import hint_int
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StorageKey:
|
||||
storage: torch.UntypedStorage
|
||||
device: torch.device
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.storage._cdata
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, StorageKey):
|
||||
return False
|
||||
return (
|
||||
self.storage._cdata == other.storage._cdata and self.device == other.device
|
||||
)
|
||||
|
||||
|
||||
class GraphAliasTracker:
|
||||
"""
|
||||
Tracks storage allocation and usage relationships in an FX graph.
|
||||
|
||||
Differentiates between:
|
||||
- Fresh allocations: nodes that allocate new storage (not views/aliases)
|
||||
- Uses: nodes that use a storage as input
|
||||
"""
|
||||
|
||||
def __init__(self, nodes: list[fx.Node]):
|
||||
# Map from node to the fresh storages it allocates (not views/aliases)
|
||||
self.node_to_fresh_allocations: dict[fx.Node, OrderedSet[StorageKey]] = {}
|
||||
|
||||
# Map from storage to the node that originally allocated it
|
||||
self.storage_to_allocator: dict[StorageKey, fx.Node] = {}
|
||||
|
||||
# Map from node to all storages it uses as inputs
|
||||
self.node_to_storage_uses: dict[fx.Node, OrderedSet[StorageKey]] = {}
|
||||
|
||||
# Map from storage to all nodes that use it
|
||||
self.storage_to_uses: dict[StorageKey, OrderedSet[fx.Node]] = defaultdict(
|
||||
OrderedSet
|
||||
)
|
||||
|
||||
# Map from storage to the last node that uses it
|
||||
self.storage_to_last_user: dict[StorageKey, fx.Node] = {}
|
||||
|
||||
# Map from node to storages that have their last use at that node
|
||||
self.node_to_storages_last_used: dict[fx.Node, OrderedSet[StorageKey]] = (
|
||||
defaultdict(OrderedSet)
|
||||
)
|
||||
|
||||
# Track all output storages for each node (for building usage graph)
|
||||
self.node_to_output_storages: dict[fx.Node, OrderedSet[StorageKey]] = {}
|
||||
|
||||
# First pass: build storage allocations and track uses
|
||||
for node in nodes:
|
||||
# Get output storages
|
||||
output_storages = self._get_output_storages(node)
|
||||
self.node_to_output_storages[node] = output_storages
|
||||
|
||||
# Track fresh allocations
|
||||
fresh_allocations: OrderedSet[StorageKey] = OrderedSet()
|
||||
for storage_key in output_storages:
|
||||
if storage_key not in self.storage_to_allocator:
|
||||
self.storage_to_allocator[storage_key] = node
|
||||
fresh_allocations.add(storage_key)
|
||||
self.node_to_fresh_allocations[node] = fresh_allocations
|
||||
|
||||
# Track input storage uses (safe because inputs were already processed)
|
||||
input_storages = self._get_input_storages(node)
|
||||
self.node_to_storage_uses[node] = input_storages
|
||||
for storage_key in input_storages:
|
||||
self.storage_to_uses[storage_key].add(node)
|
||||
|
||||
# Second pass: find last users (iterate in reverse)
|
||||
for node in reversed(nodes):
|
||||
input_storages = self.node_to_storage_uses[node]
|
||||
for storage_key in input_storages:
|
||||
if storage_key not in self.storage_to_last_user:
|
||||
self.storage_to_last_user[storage_key] = node
|
||||
self.node_to_storages_last_used[node].add(storage_key)
|
||||
|
||||
@staticmethod
|
||||
def _get_output_storages(node: fx.Node) -> OrderedSet[StorageKey]:
|
||||
"""
|
||||
Get all storages from a node's outputs.
|
||||
|
||||
Uses pytree to handle arbitrary nested structures.
|
||||
"""
|
||||
val = node.meta.get("val")
|
||||
if val is None:
|
||||
return OrderedSet()
|
||||
|
||||
storages: OrderedSet[StorageKey] = OrderedSet()
|
||||
|
||||
def collect_storage(tensor: torch._subclasses.FakeTensor) -> None:
|
||||
storages.add(StorageKey(tensor.untyped_storage(), tensor.device))
|
||||
|
||||
# Use tree_map_only to handle FakeTensors in nested structures
|
||||
tree_map_only(torch._subclasses.FakeTensor, collect_storage, val)
|
||||
|
||||
return storages
|
||||
|
||||
def _get_input_storages(self, node: fx.Node) -> OrderedSet[StorageKey]:
|
||||
"""
|
||||
Get all storages from a node's inputs.
|
||||
"""
|
||||
input_storages: OrderedSet[StorageKey] = OrderedSet()
|
||||
|
||||
for input_node in node.all_input_nodes:
|
||||
input_storages.update(self.node_to_output_storages[input_node])
|
||||
|
||||
return input_storages
|
||||
|
||||
def get_fresh_allocations(self, node: fx.Node) -> OrderedSet[StorageKey]:
|
||||
"""Get all fresh storage allocations by this node (not views/aliases)."""
|
||||
return self.node_to_fresh_allocations[node]
|
||||
|
||||
def get_storage_uses(self, node: fx.Node) -> OrderedSet[StorageKey]:
|
||||
"""Get all storages that this node uses as inputs."""
|
||||
return self.node_to_storage_uses[node]
|
||||
|
||||
def get_storages_last_used(
|
||||
self,
|
||||
node: fx.Node,
|
||||
) -> OrderedSet[StorageKey]:
|
||||
"""
|
||||
Get storages whose last use is at this node.
|
||||
"""
|
||||
return self.node_to_storages_last_used[node]
|
||||
|
||||
|
||||
def _size_of_default(num_bytes: Union[int, torch.SymInt]) -> int:
|
||||
return hint_int(num_bytes, fallback=torch._inductor.config.unbacked_symint_fallback)
|
||||
|
||||
|
||||
def device_filter(device: torch.device) -> bool:
|
||||
return device.type != "cpu"
|
||||
|
||||
|
||||
def build_memory_profile(
|
||||
graph: fx.Graph,
|
||||
size_of: Callable[[fx.Node], int],
|
||||
is_releasable: Callable[[fx.Node], bool],
|
||||
size_of: Optional[Callable[[Union[int, torch.SymInt]], int]] = None,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Function to estimate the memory profile of an input FX graph.
|
||||
@ -21,11 +162,11 @@ def build_memory_profile(
|
||||
Args:
|
||||
- graph (fx.Graph): The input FX graph for which the memory profile
|
||||
is to be estimated.
|
||||
- size_of (Callable[[fx.Node], int]): A function that returns
|
||||
the size of a given node.
|
||||
- is_releasable (Callable[[fx.Node], bool]): A function that
|
||||
determines if a node's memory can be released (e.g. primal nodes
|
||||
cannot be released).
|
||||
- size_of (Callable[[Union[int, torch.SymInt]], int]): A function that converts
|
||||
byte counts (possibly symbolic) to concrete integers.
|
||||
|
||||
Returns:
|
||||
- List[int]: A list representing the memory profile over the execution
|
||||
@ -33,239 +174,49 @@ def build_memory_profile(
|
||||
a particular point in the execution.
|
||||
"""
|
||||
|
||||
size_of = size_of or _size_of_default
|
||||
nodes = list(graph.nodes)
|
||||
op_types = get_default_op_list()
|
||||
alias_info = GraphAliasTracker(nodes)
|
||||
|
||||
class AliasInfo:
|
||||
"""
|
||||
Class for storing and accessing alias information of a FX graph.
|
||||
# Build memory profile
|
||||
current_memory = 0
|
||||
|
||||
Attributes:
|
||||
- view_to_source: Maps view nodes to their source nodes
|
||||
- getitem_to_source: Maps getitem nodes to (source_node, key) tuples
|
||||
- source_to_getitems: Maps source nodes to dictionaries of
|
||||
{key: getitem_node, "unclaimed": None}
|
||||
- source_to_unclaimed_size: Maps source nodes to their storage size
|
||||
unclaimed by any getitem_nodes
|
||||
"""
|
||||
for node in itertools.chain(
|
||||
graph.find_nodes(op="placeholder"), graph.find_nodes(op="get_attr")
|
||||
):
|
||||
for storage_key in alias_info.get_fresh_allocations(node):
|
||||
if device_filter(storage_key.device):
|
||||
current_memory += size_of(storage_key.storage.nbytes())
|
||||
|
||||
def __init__(self, nodes: list[fx.Node]):
|
||||
"""
|
||||
Initialize the AliasInfo class with a list of FX graph nodes.
|
||||
memory_profile = [current_memory]
|
||||
|
||||
Args:
|
||||
- nodes (list[fx.Node]): A list of nodes from an FX graph,
|
||||
ordered in execution order.
|
||||
|
||||
The constructor analyzes the relationships between nodes in the FX graph
|
||||
to populate alias information. It identifies two types of alias nodes:
|
||||
getitem and view. For each view, it maps it to its source. For each
|
||||
getitem, it maps it to its source and key. It also populates mappings
|
||||
for source nodes to their getitems and calculates unclaimed storage sizes.
|
||||
|
||||
"""
|
||||
# For each view, we map it to its source.
|
||||
# Note that we treat getitems of a view (e.g. aten.split) as views.
|
||||
self.view_to_source: dict[fx.Node, fx.Node] = {}
|
||||
|
||||
# For each remaining getitem, we map it to its source and key.
|
||||
self.getitem_to_source: dict[fx.Node, tuple[fx.Node, Any]] = {}
|
||||
|
||||
# For each none-view source_node of getitems, we map it to a dictionary
|
||||
# in the form of {key: getitem_node, ..., "unclaimed": None}, where
|
||||
# "unclaimed" is a dummy key that represents all elements in the
|
||||
# source_node that is not claimed by any getitems.
|
||||
self.source_to_getitems: dict[fx.Node, dict[Any, fx.Node | None]] = {}
|
||||
|
||||
# For each none-view source_node of getitems with at least one unclaimed
|
||||
# elements, we map it to its unclaimed storage size.
|
||||
self.source_to_unclaimed_size: dict[fx.Node, int] = {}
|
||||
|
||||
for node in nodes:
|
||||
is_view = op_types.is_view(node)
|
||||
is_getitem = node.target is operator.getitem
|
||||
if not (is_view or is_getitem):
|
||||
continue
|
||||
assert not (is_view and is_getitem)
|
||||
assert node.args and isinstance(node.args[0], fx.Node)
|
||||
source = node.args[0]
|
||||
if is_view:
|
||||
assert not isinstance(source.meta["val"], list | tuple | dict)
|
||||
if source in self.view_to_source:
|
||||
source = self.view_to_source[source]
|
||||
self.view_to_source[node] = source
|
||||
if is_getitem:
|
||||
assert isinstance(source.meta["val"], list | tuple | dict)
|
||||
# Source of getitem can be a view (e.g. aten.split).
|
||||
if source in self.view_to_source:
|
||||
if source in self.view_to_source:
|
||||
source = self.view_to_source[source]
|
||||
# In this case, the getitem node should be treated
|
||||
# the same way as a regular view.
|
||||
self.view_to_source[node] = source
|
||||
continue
|
||||
# Source of getitem cannot be a getitem.
|
||||
assert source not in self.getitem_to_source
|
||||
|
||||
# There must be a second argument that specifies the key.
|
||||
assert len(node.args) >= 2
|
||||
key = node.args[1]
|
||||
self.getitem_to_source[node] = (source, key)
|
||||
|
||||
# Populate source_to_getitems.
|
||||
if source not in self.source_to_getitems:
|
||||
self.source_to_getitems[source] = {"unclaimed": None}
|
||||
assert key not in self.source_to_getitems[source]
|
||||
self.source_to_getitems[source][key] = node # type: ignore[index]
|
||||
|
||||
for source, getitem_map in self.source_to_getitems.items():
|
||||
unclaimed_source_size = size_of(source)
|
||||
for key, getitem_node in getitem_map.items():
|
||||
if key != "unclaimed" and getitem_node is not None:
|
||||
unclaimed_source_size -= size_of(getitem_node)
|
||||
assert unclaimed_source_size >= 0
|
||||
if unclaimed_source_size > 0:
|
||||
self.source_to_unclaimed_size[source] = unclaimed_source_size
|
||||
|
||||
def is_view(self, node: fx.Node) -> bool:
|
||||
return node in self.view_to_source
|
||||
|
||||
def is_getitem(self, node: fx.Node) -> bool:
|
||||
return node in self.getitem_to_source
|
||||
|
||||
def get_source(self, node: fx.Node) -> fx.Node | tuple[fx.Node, Any]:
|
||||
if self.is_view(node):
|
||||
return self.view_to_source[node]
|
||||
if self.is_getitem(node):
|
||||
return self.getitem_to_source[node]
|
||||
return node
|
||||
|
||||
def is_source_of_getitems(self, node: fx.Node) -> bool:
|
||||
return node in self.source_to_getitems
|
||||
|
||||
def get_storage_keys(self, source_node: fx.Node) -> list[Any]:
|
||||
assert source_node in self.source_to_getitems
|
||||
return list(self.source_to_getitems[source_node].keys())
|
||||
|
||||
def get_unclaimed_storage_size(self, source_node: fx.Node) -> int:
|
||||
return self.source_to_unclaimed_size.get(source_node, 0)
|
||||
|
||||
def get_getitem_by_key(self, source: fx.Node, key: Any) -> fx.Node | None:
|
||||
assert source in self.source_to_getitems
|
||||
assert key in self.source_to_getitems[source]
|
||||
return self.source_to_getitems[source][key]
|
||||
|
||||
def _get_last_usage(
|
||||
nodes: list[fx.Node], alias_info: AliasInfo
|
||||
) -> dict[fx.Node, list[tuple[fx.Node, Any]]]:
|
||||
"""
|
||||
Determine the last usage point of each storage. This information is used to
|
||||
identify when storages can be safely released.
|
||||
|
||||
Args:
|
||||
- nodes (list[fx.Node]): A list of nodes from the FX graph, ordered
|
||||
in execution order.
|
||||
- alias_info (AliasInfo): An instance of AliasInfo containing aliasing
|
||||
relationships between nodes in the graph.
|
||||
|
||||
Returns:
|
||||
- Dict[fx.Node, list[tuple[fx.Node, Optional[Any]]]]: A mapping
|
||||
from each node to a list of storages (represented as tuples of source node
|
||||
and key) that are last used by that node. This helps in identifying which
|
||||
storages can be released after the node's execution.
|
||||
|
||||
"""
|
||||
storage_to_last_user: dict[tuple[fx.Node, Any], fx.Node] = {}
|
||||
node_to_last_used_storages: dict[fx.Node, list[tuple[fx.Node, Any]]] = {}
|
||||
|
||||
def register_last_uses(use: fx.Node, user: fx.Node) -> None:
|
||||
keys: list[Any] = []
|
||||
if alias_info.is_view(use):
|
||||
# When use is a view (or getitem of a view),
|
||||
# user is essentially using the storage allocated at the
|
||||
# creation of the source of use.
|
||||
use = alias_info.get_source(use) # type: ignore[assignment]
|
||||
|
||||
if alias_info.is_source_of_getitems(use): # type: ignore[arg-type]
|
||||
# When use is a source of getitems, user is using all separate
|
||||
# storages of use.
|
||||
keys.extend(alias_info.get_storage_keys(use)) # type: ignore[arg-type]
|
||||
elif alias_info.is_getitem(use): # type: ignore[arg-type]
|
||||
# When use is a getitem, user is essentially using a separate
|
||||
# storage of the source of use specified by key.
|
||||
use, key = alias_info.get_source(use) # type: ignore[assignment,misc]
|
||||
keys.append(key)
|
||||
else:
|
||||
keys.append(None)
|
||||
|
||||
assert keys
|
||||
|
||||
for key in keys:
|
||||
if (use, key) not in storage_to_last_user: # type: ignore[comparison-overlap]
|
||||
storage_to_last_user[(use, key)] = user # type: ignore[index]
|
||||
node_to_last_used_storages.setdefault(user, []).append((use, key)) # type: ignore[arg-type]
|
||||
|
||||
for node in reversed(nodes):
|
||||
fx.node.map_arg(node.args, lambda n: register_last_uses(n, node))
|
||||
fx.node.map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||
|
||||
return node_to_last_used_storages
|
||||
|
||||
alias_info = AliasInfo(nodes)
|
||||
node_to_last_used_storages = _get_last_usage(nodes, alias_info)
|
||||
|
||||
# Initialize memory profile
|
||||
memory_profile = [0]
|
||||
|
||||
# Process the graph
|
||||
for node in nodes:
|
||||
if node.op == "placeholder":
|
||||
out_mem = size_of(node)
|
||||
memory_profile[0] += out_mem
|
||||
elif node.op == "output":
|
||||
pass
|
||||
elif (
|
||||
node.op == "call_function"
|
||||
or node.op == "call_module"
|
||||
or node.op == "call_method"
|
||||
):
|
||||
# Aliases don't allocate new memory
|
||||
if alias_info.is_view(node) or alias_info.is_getitem(node):
|
||||
memory_profile.append(memory_profile[-1])
|
||||
else:
|
||||
out_mem = size_of(node)
|
||||
memory_profile.append(memory_profile[-1] + out_mem)
|
||||
if node.op in ("placeholder", "get_attr", "output"):
|
||||
continue
|
||||
|
||||
# Process storages that are no longer needed after this operation
|
||||
storages_to_release = [
|
||||
(use, key)
|
||||
for use, key in node_to_last_used_storages.get(node, [])
|
||||
if is_releasable(use)
|
||||
]
|
||||
freed_memory = 0
|
||||
for node_to_release, key in storages_to_release:
|
||||
released_memory_size = 0
|
||||
if key is None:
|
||||
released_memory_size = size_of(node_to_release)
|
||||
elif key == "unclaimed":
|
||||
released_memory_size = alias_info.get_unclaimed_storage_size(
|
||||
node_to_release
|
||||
)
|
||||
else:
|
||||
getitem_node = alias_info.get_getitem_by_key(node_to_release, key)
|
||||
if getitem_node is not None:
|
||||
released_memory_size = size_of(getitem_node)
|
||||
freed_memory += released_memory_size
|
||||
# Process allocations
|
||||
for storage_key in alias_info.get_fresh_allocations(node):
|
||||
if device_filter(storage_key.device):
|
||||
current_memory += size_of(storage_key.storage.nbytes())
|
||||
|
||||
memory_profile.append(current_memory)
|
||||
|
||||
# Process deallocations
|
||||
for storage_key in alias_info.get_storages_last_used(node):
|
||||
allocator = alias_info.storage_to_allocator[storage_key]
|
||||
if is_releasable(allocator):
|
||||
if device_filter(storage_key.device):
|
||||
current_memory -= size_of(storage_key.storage.nbytes())
|
||||
|
||||
memory_profile.append(current_memory)
|
||||
|
||||
assert freed_memory >= 0
|
||||
memory_profile.append(memory_profile[-1] - freed_memory)
|
||||
return memory_profile
|
||||
|
||||
|
||||
def get_fwd_bwd_interactions(
|
||||
fwd_graph: fx.Graph,
|
||||
bwd_graph: fx.Graph,
|
||||
size_of: Callable[[fx.Node], int],
|
||||
size_of: Optional[Callable[[Union[int, torch.SymInt]], int]] = None,
|
||||
) -> tuple[int, OrderedSet[str]]:
|
||||
"""
|
||||
Analyze the interactions between the forward (fwd) and backward (bwd) graphs
|
||||
@ -274,77 +225,61 @@ def get_fwd_bwd_interactions(
|
||||
Args:
|
||||
- fwd_graph (fx.Graph): The forward graph representing the forward pass.
|
||||
- bwd_graph (fx.Graph): The backward graph representing the backward pass.
|
||||
- size_of (Callable[[fx.Node], int]): A function that returns the size
|
||||
of a given node.
|
||||
- size_of (Callable[[Union[int, torch.SymInt]], int]): A function that converts
|
||||
byte counts (possibly symbolic) to concrete integers.
|
||||
|
||||
Returns:
|
||||
- tuple[int, Set[fx.Node]]: A tuple containing:
|
||||
- tuple[int, OrderedSet[str]]: A tuple containing:
|
||||
1. The baseline memory usage during the backward pass, accounting for
|
||||
nodes that persist from the forward pass (i.e., in fwd output but
|
||||
storages that persist from the forward pass (i.e., in fwd output but
|
||||
not in bwd input).
|
||||
2. A set of nodes whose storage cannot be released during the bwd pass.
|
||||
These include nodes that are views of primals or in bwd input
|
||||
2. A set of node names whose storage cannot be released during the bwd pass.
|
||||
These include nodes that use storage from primals or are in bwd input
|
||||
but not in fwd output.
|
||||
"""
|
||||
|
||||
def get_nodes_in_output(graph: fx.Graph) -> OrderedSet[fx.Node]:
|
||||
"""
|
||||
Get the nodes in the output of a graph.
|
||||
size_of = size_of or _size_of_default
|
||||
|
||||
Args:
|
||||
- graph (fx.Graph): The input graph.
|
||||
# Build alias info for forward graph
|
||||
fwd_nodes = list(fwd_graph.nodes)
|
||||
fwd_alias_info = GraphAliasTracker(fwd_nodes)
|
||||
|
||||
Returns:
|
||||
- list[fx.Node]: A list of nodes in the output of the graph.
|
||||
"""
|
||||
output_node = list(graph.nodes)[-1]
|
||||
assert output_node.op == "output"
|
||||
nodes_in_output: OrderedSet[fx.Node] = OrderedSet()
|
||||
# Identify storages allocated by primal placeholder nodes
|
||||
primal_storages: OrderedSet[StorageKey] = OrderedSet()
|
||||
for node in fwd_graph.find_nodes(op="placeholder"):
|
||||
if node.name.startswith("primals"):
|
||||
primal_storages.update(fwd_alias_info.get_fresh_allocations(node))
|
||||
|
||||
def add_node(node: fx.Node) -> None:
|
||||
nodes_in_output.add(node)
|
||||
# Get storages in forward output
|
||||
fwd_output_node = next(iter(reversed(fwd_graph.nodes)))[-1]
|
||||
assert fwd_output_node.op == "output"
|
||||
fwd_output_storages = fwd_alias_info.get_storage_uses(fwd_output_node)
|
||||
|
||||
# Using map_arg since output_node.args[0] can be of different types
|
||||
# e.g. tuple, list, dict, fx.Node, etc.
|
||||
fx.node.map_arg(output_node.args[0], lambda n: add_node(n))
|
||||
return nodes_in_output
|
||||
|
||||
op_types = get_default_op_list()
|
||||
|
||||
bwd_baseline_memory = 0
|
||||
# placeholder nodes besides primals of the bwd_graph that should also
|
||||
# not be deleted during memory profile estimation of the bwd_graph
|
||||
# Node names that should not be deleted during memory profile estimation of bwd_graph
|
||||
do_not_delete: OrderedSet[str] = OrderedSet()
|
||||
|
||||
fwd_outputs = {}
|
||||
for node in get_nodes_in_output(fwd_graph):
|
||||
is_view_of_primal = False
|
||||
if op_types.is_view(node):
|
||||
source = node.args[0]
|
||||
if isinstance(source, fx.Node) and source.name.startswith("primals"):
|
||||
is_view_of_primal = True
|
||||
fwd_outputs[node.name] = (size_of(node), is_view_of_primal)
|
||||
bwd_inputs: OrderedSet[str] = OrderedSet()
|
||||
for node in bwd_graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
bwd_inputs.add(node.name)
|
||||
if node.name.startswith("view"):
|
||||
# if node is a view, then it has to be in fwd_outputs
|
||||
assert node.name in fwd_outputs
|
||||
_, is_view_of_primal = fwd_outputs[node.name]
|
||||
if is_view_of_primal:
|
||||
# Add node to do_not_delete because it is a view of a primal
|
||||
do_not_delete.add(node.name)
|
||||
# Collect all storages in backward inputs and identify nodes to not delete
|
||||
bwd_input_storages: OrderedSet[StorageKey] = OrderedSet()
|
||||
for node in bwd_graph.find_nodes(op="placeholder"):
|
||||
node_storages = GraphAliasTracker._get_output_storages(node)
|
||||
bwd_input_storages.update(node_storages)
|
||||
|
||||
# if node is not in fwd_outputs, then add it to do_not_delete
|
||||
if node.name not in fwd_outputs:
|
||||
do_not_delete.add(node.name)
|
||||
# Check if this node uses primal storage
|
||||
if node_storages & primal_storages:
|
||||
do_not_delete.add(node.name)
|
||||
|
||||
# nodes that are in fwd_outputs but not in bwd_inputs take memory storage
|
||||
# throughout the bwd pass
|
||||
for name, (size, _) in fwd_outputs.items():
|
||||
if name not in bwd_inputs:
|
||||
bwd_baseline_memory += size
|
||||
# Check if this node's storages are not in forward outputs
|
||||
# (meaning it's an external input to backward pass)
|
||||
if not (node_storages & fwd_output_storages):
|
||||
do_not_delete.add(node.name)
|
||||
|
||||
# Calculate baseline memory: storages in fwd output but not in bwd input
|
||||
# These storages persist throughout the backward pass
|
||||
baseline_storages = fwd_output_storages - bwd_input_storages
|
||||
bwd_baseline_memory = 0
|
||||
for storage_key in baseline_storages:
|
||||
if storage_key.device.type != "cpu":
|
||||
bwd_baseline_memory += size_of(storage_key.storage.nbytes())
|
||||
|
||||
return bwd_baseline_memory, do_not_delete
|
||||
|
||||
@ -353,24 +288,15 @@ def get_peak_memory(
|
||||
fwd_graph: fx.Graph,
|
||||
bwd_graph: fx.Graph,
|
||||
) -> int:
|
||||
def _safe_size_of(n: fx.Node) -> int:
|
||||
try:
|
||||
return _size_of(n)
|
||||
except Exception:
|
||||
log.warning("Failed size_of(%s). Returning 0 instead.", n)
|
||||
return 0
|
||||
|
||||
def _is_releasable(n: fx.Node) -> bool:
|
||||
# Storages of primals cannot be released during fwd or bwd pass.
|
||||
return not n.name.startswith("primals")
|
||||
|
||||
fwd_peak_memory = max(
|
||||
build_memory_profile(fwd_graph, _safe_size_of, _is_releasable)
|
||||
)
|
||||
fwd_peak_memory = max(build_memory_profile(fwd_graph, _is_releasable))
|
||||
|
||||
# tmp change
|
||||
bwd_baseline_memory, bwd_do_not_delete = get_fwd_bwd_interactions(
|
||||
fwd_graph, bwd_graph, _safe_size_of
|
||||
fwd_graph,
|
||||
bwd_graph,
|
||||
)
|
||||
|
||||
def _is_bwd_releasable(n: fx.Node) -> bool:
|
||||
@ -379,7 +305,7 @@ def get_peak_memory(
|
||||
return _is_releasable(n) and n.name not in bwd_do_not_delete
|
||||
|
||||
bwd_peak_memory = bwd_baseline_memory + max(
|
||||
build_memory_profile(bwd_graph, _safe_size_of, _is_bwd_releasable)
|
||||
build_memory_profile(bwd_graph, _is_bwd_releasable)
|
||||
)
|
||||
return max(
|
||||
fwd_peak_memory,
|
||||
|
Reference in New Issue
Block a user