mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
This is follow-up of #164695 to apply ruff SIM rules to more files. Most changes are about simplifying dict.get because None is already the default value. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165031 Approved by: https://github.com/mlazos
962 lines
41 KiB
Python
962 lines
41 KiB
Python
import math
|
|
import os
|
|
import sys
|
|
from collections import OrderedDict
|
|
from dataclasses import astuple, dataclass
|
|
from typing import Any, NamedTuple, Optional
|
|
from typing_extensions import Self
|
|
|
|
import torch
|
|
from torch import nan, nn, UntypedStorage
|
|
from torch._guards import active_fake_mode
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
from torch.distributed._tools.common_utils import get_untyped_storages
|
|
from torch.distributed._tools.mod_tracker import ModTracker
|
|
from torch.distributed._tools.runtime_estimator import RuntimeEstimator
|
|
from torch.testing._internal.composite_compliance import (
|
|
is_inplace,
|
|
is_inplace_view_fn,
|
|
is_view_fn,
|
|
)
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
from torch.utils._pytree import tree_flatten
|
|
from torch.utils.checkpoint import SAC_IGNORED_OPS
|
|
|
|
|
|
__all__ = ["SACEstimator", "SACStats", "MSPS", "SACTradeOffStats", "SACGreedyOrderMeta"]
|
|
aten = torch.ops.aten
|
|
|
|
_ADDITIONAL_IGNORED_OPS = {
|
|
aten.lift_fresh.default, # type: ignore[attr-defined]
|
|
torch.ops.profiler._record_function_exit._RecordFunction, # type: ignore[attr-defined]
|
|
aten.clone.default, # type: ignore[attr-defined] # seems needed for torch.compile
|
|
}
|
|
OPS_TO_ALWAYS_SKIP = SAC_IGNORED_OPS | _ADDITIONAL_IGNORED_OPS
|
|
# This value is hard-coded here:
|
|
# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117
|
|
_PYTORCH_MIN_ALLOCATE = (
|
|
2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1
|
|
)
|
|
|
|
|
|
def _display_stats_tabular(headers: list[str], table_data: list[list[Any]]) -> None:
|
|
try:
|
|
from tabulate import tabulate
|
|
except ImportError as err:
|
|
raise ImportError("Please install tabulate.") from err
|
|
|
|
# Use tabulate to print the table
|
|
print(tabulate(table_data, headers=headers, tablefmt="rst"))
|
|
|
|
|
|
# Based on:
|
|
# https://github.com/facebookresearch/xformers/blob/main/xformers/checkpoint.py#L71
|
|
@dataclass
|
|
class _SACMetadata:
|
|
"""
|
|
Stores metadata for a single operator for SAC.
|
|
|
|
Attributes:
|
|
func (Any): The operator function.
|
|
time_taken (float): The time taken by the operator.
|
|
memory_used (float): The memory used by the operator.
|
|
curr_idx (int): The current operator index.
|
|
output_ids (Tuple[int, ...]): The storage IDs of the operator's outputs.
|
|
inplace_info (Tuple[int, ...]): Tuple of self and parent operator for in-place operator.
|
|
is_view_like (bool): Whether the operator is view-like.
|
|
is_rand_op (bool): Whether the operator is a random operator.
|
|
"""
|
|
|
|
func: Any
|
|
time_taken: float
|
|
memory_used: float
|
|
curr_idx: int
|
|
output_ids: tuple[int, ...]
|
|
inplace_info: tuple[int, ...]
|
|
is_view_like: bool
|
|
is_rand_op: bool
|
|
|
|
|
|
@dataclass
|
|
class _SACModMetadata:
|
|
"""
|
|
Stores metadata for a module for SAC.
|
|
|
|
Attributes:
|
|
start_idx (int): The starting index of the module's operators.
|
|
force_store_random (bool): Whether to force store random operators in the module.
|
|
sac_metadata (List[_SACMetadata]): List of metadata for each operator in the module.
|
|
"""
|
|
|
|
start_idx: int
|
|
force_store_random: bool
|
|
sac_metadata: list[_SACMetadata]
|
|
|
|
|
|
@dataclass
|
|
class SACStats:
|
|
"""
|
|
A class for storing Activation Checkpointing statistics corresponding to a module.
|
|
|
|
Attributes:
|
|
func_names (List[str]): List of operator names.
|
|
runtimes (List[float]): List of operator runtimes in millliseconds.
|
|
memory (List[int]): List of operator memory usage in bytes.
|
|
view_like_ops (List[int]): Indices of view-like operators.
|
|
rand_ops (List[int]): Indices of random operators.
|
|
saved_autograd_ops (List[int]): Indices of operator results saved by autograd engine.
|
|
inplace_ops (List[Tuple[int, int]]): Tuple of indices of op and its first parent for Inplace operators.
|
|
force_store_random (bool): Whether to force store random operator results.
|
|
"""
|
|
|
|
func_names: list[str]
|
|
runtimes: list[float]
|
|
memory: list[int]
|
|
view_like_ops: list[int]
|
|
rand_ops: list[int]
|
|
saved_autograd_ops: list[int]
|
|
inplace_ops: list[tuple[int, int]]
|
|
force_store_random: bool
|
|
|
|
|
|
class MSPS(NamedTuple):
|
|
"""
|
|
Represents Memory and Runtime Statistics for an operator/operator group.
|
|
|
|
Attributes:
|
|
func_names (set[str]): Set of operator/operator group names.
|
|
op_idx (int): Operator index (group head index in case of operator groups).
|
|
memory (int): Memory usage in bytes.
|
|
runtime (float): Runtime in milliseconds.
|
|
msps (float): Memory per second calculated as memory/runtime.
|
|
"""
|
|
|
|
func_names: set[str]
|
|
op_idx: int
|
|
memory: int
|
|
runtime: float
|
|
msps: float
|
|
|
|
|
|
@dataclass
|
|
class SACTradeOffStats:
|
|
"""
|
|
Stores statistics for activation-checkpointing trade-off.
|
|
|
|
Attributes:
|
|
n_segments (int): Number of piecewise linear segments fitted to the trade-off curve.
|
|
slopes (List[float]): Slopes of the pieces of linear segments fitted to the trade-off curve.
|
|
intercepts (List[float]): Intercepts of the of the pieces of linear segments fitted to the trade-off curve.
|
|
fit_breaks (List[float]): Breakpoints of the of the pieces of linear segments fitted to the trade-off curve.
|
|
tradeoff_curve (OrderedDict[float, float]): Trade-off curve data of memory discarded vs recomputation time.
|
|
sac_memory (int): Total memory of operations available for activation checkpointing in bytes.
|
|
sac_runtime (float): Total runtime of operations available for activation checkpointing in milliseconds.
|
|
"""
|
|
|
|
n_segments: int
|
|
slopes: list[float]
|
|
intercepts: list[float]
|
|
fit_breaks: list[float]
|
|
tradeoff_curve: OrderedDict[float, float]
|
|
sac_memory: int
|
|
sac_runtime: float
|
|
|
|
|
|
@dataclass
|
|
class SACGreedyOrderMeta:
|
|
"""
|
|
Stores metadata for Greedy-order SAC.
|
|
|
|
Attributes:
|
|
recomputed_ops (set[int]): Set of operator indices to be recomputed.
|
|
stored_ops (set[int]): Set of operator indices to be stored.
|
|
inplace_op_groups (dict[int, set[int]]): Dictionary of inplace operator groups from group-head to operators.
|
|
random_ops_group (dict[int, set[int]]): Dictionary of random op group head to random ops.
|
|
msps_meta (list[MSPS]): List of Memory and Runtime Statistics for operators.
|
|
"""
|
|
|
|
recomputed_ops: set[int]
|
|
stored_ops: set[int]
|
|
inplace_op_groups: dict[int, set[int]]
|
|
random_ops_group: dict[int, set[int]]
|
|
msps_meta: list[MSPS]
|
|
|
|
|
|
class SACEstimator(TorchDispatchMode):
|
|
"""
|
|
Estimates the memory and recomputation time trade-offs for applying Selective Activation Checkpointing (SAC).
|
|
|
|
This class provides a ``TorchDispatchMode`` based context manager that can be used to estimate the memory and
|
|
runtime trade-offs of functions or ``torch.nn.Module``s for Selective Activation Checkpointing (SAC). It provides
|
|
detailed statistics and metadata information for operators of each module and provides a greedy order for selecting
|
|
the operators to be recomputed/checkpointed. It also constructs the per-module trade-off graph of discarded memory
|
|
vs recomputation time for the obtained greedy order. Using ``RuntimeEstimator`` under the hood, it supports two
|
|
estimation modes, `operator-level-benchmark` and (`operator-level-cost-model` (roofline model).
|
|
|
|
Attributes:
|
|
sac_mod_stats (Dict[str, SACStats]): Dictionary from module FQN (fully qualified name) to ``SACStats``.
|
|
sac_mod_tradeoff_stats (Dict[str, SACTradeOffStats]): Dictionary from module FQN to ``SACTradeOffStats``.
|
|
sac_mod_greedy_order_meta (Dict[str, SACGreedyOrderMeta]): Dictionary from module FQN to ``SACGreedyOrderMeta``.
|
|
|
|
Note:
|
|
1) This class is designed to be used under ``FakeTensorMode``.
|
|
2) Currently, it only supports estimation of compute time and memory usage, and does not consider communication.
|
|
|
|
Example usage:
|
|
|
|
.. code-block:: python
|
|
|
|
sac_estimator = SACEstimator()
|
|
with FakeTensorMode():
|
|
module = ...
|
|
inp = ...
|
|
with sac_estimator("operator-level-cost-model"):
|
|
output = module(inp)
|
|
sac_estimator.display_modulewise_sac_stats(depth=4, print_tabular=True)
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self.sac_mod_stats: dict[str, SACStats] = {}
|
|
self.sac_mod_tradeoff_stats: dict[str, SACTradeOffStats] = {}
|
|
self.sac_mod_greedy_order_meta: dict[str, SACGreedyOrderMeta] = {}
|
|
self._mod_tracker = ModTracker()
|
|
self._sac_metadata: list[_SACMetadata] = []
|
|
self._sac_mod_metadata: dict[str, _SACModMetadata] = {}
|
|
self._leaf_modules: set[str] = set()
|
|
self._saved_tensor_hook_ctx = torch.autograd.graph.saved_tensors_hooks(
|
|
self._pack_hook, lambda x: x
|
|
)
|
|
self._saved_tensor_ids: set[int] = set()
|
|
self._estimate_runtime = RuntimeEstimator._roofline_estimate
|
|
|
|
def _pack_hook(self, x: torch.Tensor) -> torch.Tensor:
|
|
# Hook function to track underlying storage IDs of tensors
|
|
# Updates the _saved_tensor_ids set with the IDs of the tensor's storages
|
|
# Used in conjunction with torch.autograd.graph.saved_tensors_hooks
|
|
untyped_storages = get_untyped_storages(x)
|
|
storage_ids = (hash(st) for st in untyped_storages)
|
|
self._saved_tensor_ids.update(storage_ids)
|
|
return x
|
|
|
|
def _pre_fw_hook(self, mod: nn.Module, inputs: Any) -> None:
|
|
# Pre-forward hook function to prepare module metadata
|
|
# Tracks module FQN, force store random flag, and ``SACModMetadata``
|
|
# Initializes metadata for non-leaf modules, marks leaf modules
|
|
mod_fqn = self._mod_tracker.get_known_fqn(mod)
|
|
assert mod_fqn is not None
|
|
num_children = sum(1 for _ in mod.children())
|
|
if num_children > 0:
|
|
force_store_random = self._get_force_store_random(inputs)
|
|
self._sac_mod_metadata[mod_fqn] = _SACModMetadata(
|
|
start_idx=len(self._sac_metadata),
|
|
force_store_random=force_store_random,
|
|
sac_metadata=[],
|
|
)
|
|
else:
|
|
self._leaf_modules.add(mod_fqn)
|
|
|
|
def _post_fw_hook(self, mod: nn.Module, inputs: Any, outputs: Any) -> None:
|
|
# 1. Retrieves the module's FQN and checks if it's a leaf module
|
|
# 2. If not a leaf module, computes:
|
|
# - ``SACStats`` using the module's metadata and force store random flag
|
|
# - ``SACGreedyOrderMeta`` using the computed SAC statistics
|
|
mod_fqn = self._mod_tracker.get_known_fqn(mod)
|
|
assert mod_fqn is not None
|
|
if mod_fqn in self._leaf_modules:
|
|
return
|
|
else:
|
|
self.sac_mod_stats[mod_fqn] = self._get_sac_stats(
|
|
data=self._sac_mod_metadata[mod_fqn].sac_metadata,
|
|
force_store_random=self._sac_mod_metadata[mod_fqn].force_store_random,
|
|
)
|
|
self.sac_mod_greedy_order_meta[mod_fqn] = self._get_greedy_order_meta(
|
|
self.sac_mod_stats[mod_fqn]
|
|
)
|
|
|
|
def _get_force_store_random(self, inputs: Any) -> bool:
|
|
flat_inputs, _ = tree_flatten(inputs)
|
|
return all(not isinstance(x, torch.Tensor) for x in flat_inputs)
|
|
|
|
def _get_sac_stats(
|
|
self, data: list[_SACMetadata], force_store_random: bool
|
|
) -> SACStats:
|
|
# 1. Ignore the operations that should be skipped by SAC such as aten.detach.default because autograd
|
|
# inserts those during backward and it breaks the fwd-bwd alignment
|
|
filtered_data = [x for x in data if x.func not in OPS_TO_ALWAYS_SKIP]
|
|
|
|
(
|
|
ops,
|
|
runtimes_,
|
|
memory_,
|
|
new_ids,
|
|
output_ids,
|
|
inplace_ops_,
|
|
view_like_ops_,
|
|
rand_ops_,
|
|
) = zip(*[astuple(x) for x in filtered_data], strict=True)
|
|
|
|
# 2. Extract the metadata information
|
|
runtimes = list(runtimes_)
|
|
memory = list(memory_)
|
|
func_names = [op._overloadpacket.__name__ for op in ops]
|
|
view_like_ops = [i for i, x in enumerate(view_like_ops_) if x]
|
|
rand_ops = [i for i, x in enumerate(rand_ops_) if x]
|
|
saved_autograd_ops = [
|
|
i
|
|
for i, out_ids in enumerate(output_ids)
|
|
if set(out_ids).issubset(self._saved_tensor_ids)
|
|
]
|
|
|
|
# 3. Remap the inplace indices as we have removed OPS_TO_ALWAYS_SKIP
|
|
# FIXME @sanketpurandare: Fix this by changing the parent of the inplace-op
|
|
# to itself if the original parent is in OPS_TO_ALWAYS_SKIP.
|
|
try:
|
|
inplace_ops = [tuple(map(new_ids.index, x)) for x in inplace_ops_ if x]
|
|
except ValueError as err:
|
|
raise ValueError(
|
|
f"The remapping of inplace ops failed since one of the inplace op parents"
|
|
f" must have been present in {OPS_TO_ALWAYS_SKIP}"
|
|
) from err
|
|
|
|
# 4. The last operation is always stored as the output of the checkpoint
|
|
# block, so we can avoid recomputing it. We set the memory to zero
|
|
# instead of adding a new constraint because we want both the 0 and 1
|
|
# endpoints for memory_budget to be valid
|
|
# FIXME @sanketpurandare: this heuristic for finding the last non-view non-inplace op
|
|
# might not always be correct, which would yield suboptimal policies
|
|
last_op = len(ops) - 1
|
|
skip_ops_ = set(view_like_ops) | set({x[0] for x in inplace_ops})
|
|
reversed_skip_ops = sorted(skip_ops_, reverse=True)
|
|
for op in reversed_skip_ops:
|
|
if op == last_op:
|
|
last_op -= 1
|
|
|
|
memory[last_op] = 0
|
|
|
|
# 5. Create a single ``SACStats`` object for the entire block of ``_SACMetadata``.
|
|
return SACStats(
|
|
func_names=func_names,
|
|
runtimes=runtimes,
|
|
memory=memory,
|
|
view_like_ops=view_like_ops,
|
|
rand_ops=rand_ops,
|
|
saved_autograd_ops=saved_autograd_ops,
|
|
inplace_ops=inplace_ops, # type: ignore[arg-type]
|
|
force_store_random=force_store_random,
|
|
)
|
|
|
|
def _get_inplace_metadata(
|
|
self, func: Any, out_storages: set[UntypedStorage]
|
|
) -> tuple[int, tuple[int, ...], dict[str, tuple[int, ...]]]:
|
|
# 1. Get the current index of the metadata obtained so far
|
|
curr_idx = len(self._sac_metadata)
|
|
# 2. Get the set of active modules that are not leaf
|
|
active_mod_fqns: set[str] = {
|
|
par for par in self._mod_tracker.parents if par not in self._leaf_modules
|
|
}
|
|
# 3. Output ids are the identifies of the storage objects corresponding to the tensors
|
|
output_ids = tuple(hash(st) for st in out_storages)
|
|
# 4. If the function is not inplace, return
|
|
if not is_inplace(func):
|
|
return curr_idx, output_ids, dict.fromkeys(active_mod_fqns, ())
|
|
|
|
op_idx = curr_idx
|
|
# 5. Initialize the parent op ids of the inplace op for each of the active modules
|
|
mod_op_parent_idxs: dict[str, int] = dict.fromkeys(active_mod_fqns, -1)
|
|
for i, d in enumerate(self._sac_metadata):
|
|
# 6. Find the first occurrence of a tensor corresponding to each module that
|
|
# shares the same storage as the current tensor
|
|
past_output_ids = d.output_ids
|
|
if set(output_ids).issubset(set(past_output_ids)):
|
|
for mod_fqn, op_parent_idx in mod_op_parent_idxs.items():
|
|
if op_parent_idx == -1:
|
|
if acm_stats := self._sac_mod_metadata.get(mod_fqn, None):
|
|
if i >= acm_stats.start_idx:
|
|
mod_op_parent_idxs[mod_fqn] = i
|
|
else:
|
|
assert mod_fqn == "Global"
|
|
mod_op_parent_idxs[mod_fqn] = i
|
|
# 7. If no parent tensor is found, then it's probably an inplace op on the arguments
|
|
# so one can just store the current-op idx as parent idx
|
|
for mod_fqn, op_parent_idx in mod_op_parent_idxs.items():
|
|
if op_parent_idx < 0:
|
|
mod_op_parent_idxs[mod_fqn] = op_idx
|
|
mod_inplace_info = {
|
|
mod_fqn: (op_idx, mod_op_parent_idxs[mod_fqn])
|
|
for mod_fqn in active_mod_fqns
|
|
}
|
|
return curr_idx, output_ids, mod_inplace_info # type: ignore[return-value]
|
|
|
|
def __torch_dispatch__( # type: ignore[no-untyped-def]
|
|
self, func, types, args=..., kwargs=None
|
|
):
|
|
# 1. Get the runtime estimate
|
|
out, op_time = self._estimate_runtime(func, args, kwargs)
|
|
flat_outs, _ = tree_flatten(out)
|
|
out_storages_cuda: set[UntypedStorage] = set()
|
|
out_storages_cpu: set[UntypedStorage] = set()
|
|
cuda_devices: set[torch.device] = set()
|
|
for o in flat_outs:
|
|
if isinstance(o, torch.Tensor):
|
|
if o.device.type == "cuda":
|
|
out_storages_cuda.update(get_untyped_storages(o))
|
|
cuda_devices.add(o.device)
|
|
else:
|
|
out_storages_cpu.update(get_untyped_storages(o))
|
|
|
|
# Check if there's more than 1 CUDA device
|
|
assert len(cuda_devices) <= 1, (
|
|
f"{func.__name__}'s output has more than 1 CUDA devices {cuda_devices}"
|
|
)
|
|
|
|
# 2. Get the memory consumed by output
|
|
nbytes_cuda = sum(
|
|
math.ceil(st.nbytes() / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE
|
|
for st in out_storages_cuda
|
|
)
|
|
nbytes_cpu = sum(st.nbytes() for st in out_storages_cpu)
|
|
nbytes = nbytes_cuda + nbytes_cpu
|
|
# 3. Get the current operator index, output storage identifiers and inplace metadata
|
|
out_storages = out_storages_cuda | out_storages_cpu
|
|
curr_idx, output_ids, mod_inplace_info = self._get_inplace_metadata(
|
|
func, out_storages
|
|
)
|
|
# 4. Determine if the function is in-place, random-op or a view-like
|
|
is_view_like = is_view_fn(func) or is_inplace_view_fn(func)
|
|
is_rand_op = torch.Tag.nondeterministic_seeded in func.tags
|
|
if is_view_like:
|
|
nbytes = 0
|
|
# sdpa has non-deterministic seed, but might be deterministic
|
|
# if no dropout is applied
|
|
if func.overloadpacket.__name__ == "_scaled_dot_product_flash_attention":
|
|
# pyrefly: ignore # missing-attribute
|
|
is_rand_op = kwargs.get("dropout_p", 0) != 0
|
|
# 5. Create metadata information per active non-leaf module
|
|
for mod_fqn in self._mod_tracker.parents:
|
|
if mod_fqn in self._leaf_modules:
|
|
continue
|
|
acm = _SACMetadata(
|
|
func=func,
|
|
time_taken=op_time,
|
|
memory_used=nbytes,
|
|
curr_idx=curr_idx,
|
|
output_ids=output_ids,
|
|
inplace_info=mod_inplace_info[mod_fqn],
|
|
is_view_like=is_view_like,
|
|
is_rand_op=is_rand_op,
|
|
)
|
|
if acm_stats := self._sac_mod_metadata.get(mod_fqn, None):
|
|
acm_stats.sac_metadata.append(acm)
|
|
else:
|
|
assert mod_fqn == "Global", (
|
|
f"Module {mod_fqn} not found in AC Mod Stats"
|
|
)
|
|
self._sac_metadata.append(acm)
|
|
|
|
return out
|
|
|
|
def _get_greedy_order_meta(self, sac_stats: SACStats) -> SACGreedyOrderMeta:
|
|
# An inplace-op group is a set of inplace-ops that operate on the same underlying tensor storage.
|
|
# 1. inplace_op_groups: A dictionary from the top-most parent of inplace-ops to the inplace-ops in the group
|
|
# The top-most op can itself be an inplace-op or can be a non-inplace op.
|
|
# 2. inplace_op_to_group_head: A dictionary that maps all the inplace-ops to their respective group heads.
|
|
inplace_op_groups: dict[int, set[int]] = {}
|
|
inplace_op_to_group_head: dict[int, int] = dict(sac_stats.inplace_ops)
|
|
|
|
# Initialize inplace_op_groups using inplace_op_to_group_head
|
|
for op_idx, group_head_idx in inplace_op_to_group_head.items():
|
|
op_group = inplace_op_groups.setdefault(group_head_idx, {group_head_idx})
|
|
op_group.add(op_idx)
|
|
|
|
# Like inplace ops, all of the random ops in the function/module should all be either recomputed or saved
|
|
# as a group. This is because, they affect the ranom seed generator. If force_store_random is set True,
|
|
# all of the random ops will be stored by default. For easy of manageability, we store the top-most random op
|
|
# as the leader of the random_ops_group.
|
|
random_ops_group: dict[int, set[int]] = {}
|
|
random_group_head_idx = min(sac_stats.rand_ops, default=-1)
|
|
has_rand_ops = bool(sac_stats.rand_ops)
|
|
if has_rand_ops:
|
|
random_ops_group[random_group_head_idx] = set(sac_stats.rand_ops)
|
|
|
|
# 1. Random ops are stored if force_store_random is set
|
|
# 2. View-like ops are recomputed by default
|
|
# 3. For inplace_op_groups:
|
|
# a) If the head of this group is an inplace op, then we have to store the entire group.
|
|
# b) If any op in the group is random and force_store_random is set, then entire group will be stored.
|
|
# c) If none of ops in the group are random and the head of the group is not an in-place op, then
|
|
# this group can be considered for recomputation in its entirety
|
|
stored_ops: set[int] = set()
|
|
recomputed_ops: set[int] = set()
|
|
# Case 1:
|
|
if has_rand_ops and sac_stats.force_store_random:
|
|
stored_ops.add(random_group_head_idx)
|
|
# Case 2:
|
|
recomputed_ops.update(set(sac_stats.view_like_ops))
|
|
|
|
for group_head_idx, op_group in inplace_op_groups.items():
|
|
# Case 3a:
|
|
if group_head_idx in inplace_op_to_group_head:
|
|
stored_ops.add(group_head_idx)
|
|
# Case 3b:
|
|
if (
|
|
sac_stats.force_store_random & len(op_group & set(sac_stats.rand_ops))
|
|
> 0
|
|
):
|
|
stored_ops.add(group_head_idx)
|
|
|
|
# The potential recompute candidates are populated as:
|
|
recompute_candidates: set[int] = set()
|
|
# 1) The random group head if it is not stored
|
|
if has_rand_ops and random_group_head_idx not in stored_ops:
|
|
recompute_candidates.add(random_group_head_idx)
|
|
# 2) The in-place op group heads that are not stored
|
|
recompute_candidates.update(set(inplace_op_groups.keys()) - stored_ops)
|
|
# 3) The non-inplace and non-random ops that are neither stored nor recomputed by default
|
|
recompute_candidates.update(
|
|
set(range(len(sac_stats.memory)))
|
|
- recomputed_ops
|
|
- stored_ops
|
|
- set(inplace_op_to_group_head.keys())
|
|
- set(sac_stats.rand_ops)
|
|
)
|
|
|
|
# We define msps for a recomp candidate as the ratio of memory/runtime aka memory savings per second
|
|
msps_meta: list[MSPS] = []
|
|
for cand_idx in recompute_candidates:
|
|
op_indices = {cand_idx}
|
|
if cand_idx in inplace_op_groups:
|
|
op_indices.update(inplace_op_groups[cand_idx])
|
|
if has_rand_ops and cand_idx == random_group_head_idx:
|
|
op_indices.update(sac_stats.rand_ops)
|
|
|
|
mem = sum(sac_stats.memory[op_idx] for op_idx in op_indices)
|
|
runtime = sum(sac_stats.runtimes[op_idx] for op_idx in op_indices)
|
|
func_names = {sac_stats.func_names[op_idx] for op_idx in op_indices}
|
|
msps = (mem / runtime) if runtime > 0 else sys.float_info.max
|
|
msps_meta.append(MSPS(func_names, cand_idx, mem, runtime, msps))
|
|
# We choose candidates to be recomputed based on increasing msps
|
|
msps_meta.sort(key=lambda x: x.msps, reverse=True)
|
|
return SACGreedyOrderMeta(
|
|
recomputed_ops, stored_ops, inplace_op_groups, random_ops_group, msps_meta
|
|
)
|
|
|
|
def _get_sac_tradeoff_pwlf_stats(
|
|
self,
|
|
sac_stats: SACStats,
|
|
greedy_order_meta: SACGreedyOrderMeta,
|
|
n_segments: int = 2,
|
|
save_tradeoff_graph: bool = False,
|
|
filename: str = "ac_tradeoff",
|
|
) -> SACTradeOffStats:
|
|
try:
|
|
import numpy as np # type: ignore[import-not-found]
|
|
import pwlf # type: ignore[import-untyped, import-not-found]
|
|
except ImportError as err:
|
|
raise ImportError("Please install pwlf and numpy package.") from err
|
|
|
|
stored_ops, recomputed_ops, inplace_op_groups, random_ops_group, msps_meta = (
|
|
greedy_order_meta.stored_ops,
|
|
greedy_order_meta.recomputed_ops,
|
|
greedy_order_meta.inplace_op_groups,
|
|
greedy_order_meta.random_ops_group,
|
|
greedy_order_meta.msps_meta,
|
|
)
|
|
# 1. Initialize the discarded memory and recomputation runtime to sum of already chosen recomputed_ops
|
|
recomp_indices: set[int] = set()
|
|
for r_idx in recomputed_ops:
|
|
recomp_indices.add(r_idx)
|
|
if r_idx in inplace_op_groups:
|
|
recomp_indices.update(inplace_op_groups[r_idx])
|
|
if r_idx in random_ops_group:
|
|
recomp_indices.update(random_ops_group[r_idx])
|
|
|
|
discarded_mem = sum(sac_stats.memory[op_idx] for op_idx in recomp_indices)
|
|
recomp_runtime = sum(sac_stats.runtimes[op_idx] for op_idx in recomp_indices)
|
|
# 2. Initialize the max recomputation time and total recomputation memory
|
|
sac_runtime = sum(sac_stats.runtimes)
|
|
sac_memory = sum(sac_stats.memory)
|
|
# 3. Tradeoff curve stores the KV pair of the discarded memory to total memory and,
|
|
# recomputation time to total runtime incurred.
|
|
delta = 1e-2
|
|
tradeoff_curve = OrderedDict()
|
|
# 4. Initialize the trade-off curve with the stats of of already chosen recomputed_ops
|
|
tradeoff_curve[(discarded_mem / sac_memory) + delta] = (
|
|
recomp_runtime / sac_runtime
|
|
)
|
|
# 5. Update the trade-off curve with memory and runtime stats of SAC candidates in the
|
|
# greedy order of their ``MSPS``.
|
|
for cand in msps_meta:
|
|
discarded_mem += cand.memory
|
|
recomp_runtime += cand.runtime
|
|
tradeoff_curve[(discarded_mem / sac_memory) + delta] = (
|
|
recomp_runtime / sac_runtime
|
|
)
|
|
# 6. Finally, we add the memory and recomputation time of the always stored ops.
|
|
stored_indices: set[int] = set()
|
|
for s_idx in stored_ops:
|
|
stored_indices.add(s_idx)
|
|
if s_idx in inplace_op_groups:
|
|
stored_indices.update(inplace_op_groups[s_idx])
|
|
if s_idx in random_ops_group:
|
|
stored_indices.update(random_ops_group[s_idx])
|
|
discarded_mem += sum(sac_stats.memory[op_idx] for op_idx in stored_indices)
|
|
recomp_runtime += sum(sac_stats.runtimes[op_idx] for op_idx in stored_indices)
|
|
tradeoff_curve[(discarded_mem / sac_memory) + delta] = (
|
|
recomp_runtime / sac_runtime
|
|
)
|
|
x_ = list(tradeoff_curve.keys())
|
|
y_ = list(tradeoff_curve.values())
|
|
# 7. We shift the y values to left and x values to right to upperbound the trade-off function
|
|
# TODO: Write a better explanation why this needs to be done
|
|
x = x_[: len(x_) - 1]
|
|
y = y_[1:]
|
|
tradeoff_pwlf = pwlf.PiecewiseLinFit(x, y)
|
|
# 8. Fit a piecewise linear function with the specified number of segments to the trade-off curve.
|
|
n_segments = max(min(len(x) - 2, n_segments), 1)
|
|
tradeoff_pwlf.fit(n_segments=n_segments)
|
|
|
|
# save prediction graph
|
|
def save_prediction_graph(
|
|
pwlf_: pwlf.PiecewiseLinFit, x: list[float], y: list[float], filename: str
|
|
) -> None:
|
|
try:
|
|
import matplotlib.pyplot as plt # type: ignore[import-not-found]
|
|
import numpy as np # type: ignore[import-not-found]
|
|
except ImportError as err:
|
|
raise ImportError(
|
|
"Install matplotlib and numpy using pip: pip install matplotlib numpy"
|
|
) from err
|
|
# predict for the determined points
|
|
xHat = np.linspace(min(x), max(x), num=10000)
|
|
yHat = pwlf_.predict(xHat)
|
|
|
|
# plot the results
|
|
plt.figure()
|
|
plt.plot(x, y, "o", label="Shifted")
|
|
plt.plot(xHat, yHat, "-", label="Predicted")
|
|
plt.plot(x_, y_, "x", label="Original")
|
|
plt.ylabel("Recomp time / Total recomp time")
|
|
plt.xlabel("Memory discarded / Total memory")
|
|
plt.legend()
|
|
plt.title(f"{filename}")
|
|
plt.suptitle(
|
|
f"Total Memory = {sac_memory} B Total Runtime = {sac_runtime:.4f} ms",
|
|
fontsize=10,
|
|
)
|
|
folder_name = "tradeoff_graphs"
|
|
if not os.path.exists(folder_name):
|
|
os.makedirs(folder_name)
|
|
# Save the plots in the folder
|
|
plt.savefig(os.path.join(folder_name, f"{filename}.png"))
|
|
|
|
if save_tradeoff_graph:
|
|
save_prediction_graph(tradeoff_pwlf, x, y, filename)
|
|
# 9. Obtain the slopes, intercepts and breakpoints of the fitted piecewise linear functions
|
|
slopes = tradeoff_pwlf.calc_slopes().tolist()
|
|
assert isinstance(tradeoff_pwlf.intercepts, np.ndarray) and isinstance(
|
|
tradeoff_pwlf.fit_breaks, np.ndarray
|
|
)
|
|
intercepts = tradeoff_pwlf.intercepts.tolist()
|
|
fit_breaks = tradeoff_pwlf.fit_breaks.tolist()
|
|
return SACTradeOffStats(
|
|
n_segments=n_segments,
|
|
slopes=slopes,
|
|
intercepts=intercepts, # type: ignore[arg-type]
|
|
fit_breaks=fit_breaks, # type: ignore[arg-type]
|
|
tradeoff_curve=tradeoff_curve,
|
|
sac_memory=sac_memory,
|
|
sac_runtime=sac_runtime,
|
|
)
|
|
|
|
def display_sac_stats(
|
|
self, sac_stats: SACStats, print_tabular: bool = False
|
|
) -> None:
|
|
"""
|
|
Displays the SAC statistics.
|
|
|
|
Args:
|
|
sac_stats (SACStats): The SAC statistics to display.
|
|
print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False.
|
|
|
|
Prints:
|
|
1. Total Memory: The total memory usage in bytes.
|
|
2. Total Runtime: The total runtime in milliseconds.
|
|
3. Store Random: A flag indicating whether to force store random operator results.
|
|
|
|
Followed by a table with the following columns:
|
|
1. Op Idx: The operator index.
|
|
2. Op Name: The operator name.
|
|
3. Runtimes (ms): The operator runtime in milliseconds.
|
|
4. Memory (B): The operator memory usage in bytes.
|
|
5. View-like: A flag indicating whether the operator is view-like.
|
|
6. Random: A flag indicating whether the operator is random.
|
|
7. Saved Autograd: A flag indicating whether the operator's result is saved by autograd engine.
|
|
8. In-place: The index of the operator's first parent, or None if not in-place.
|
|
|
|
If print_tabular is True, the table is printed in a tabular format.
|
|
Otherwise, the table is printed in a plain text format.
|
|
"""
|
|
print(
|
|
f"Total Memory: {sum(sac_stats.memory)} B Total Runtime: {sum(sac_stats.runtimes)} ms"
|
|
f" Store Random: {sac_stats.force_store_random}"
|
|
)
|
|
table_data = []
|
|
op_parent = dict(sac_stats.inplace_ops)
|
|
for i, fn_name in enumerate(sac_stats.func_names):
|
|
row = [
|
|
str(i),
|
|
fn_name,
|
|
f"{sac_stats.runtimes[i]:.4f}",
|
|
str(sac_stats.memory[i]),
|
|
str(i in sac_stats.view_like_ops),
|
|
str(i in sac_stats.rand_ops),
|
|
str(i in sac_stats.saved_autograd_ops),
|
|
str(op_parent.get(i)),
|
|
]
|
|
table_data.append(row)
|
|
# Define headers
|
|
headers = [
|
|
"Op Idx",
|
|
"Op Name",
|
|
"Runtimes(ms)",
|
|
"Memory (B)",
|
|
"View-like",
|
|
"Random",
|
|
"Saved Autograd",
|
|
"In-place",
|
|
]
|
|
if print_tabular:
|
|
_display_stats_tabular(headers, table_data)
|
|
else:
|
|
max_widths = [0 for _ in range(len(headers))]
|
|
table_data.insert(0, headers)
|
|
for row in table_data:
|
|
for i, elem in enumerate(row):
|
|
max_widths[i] = max(max_widths[i], len(elem))
|
|
for row in table_data:
|
|
print(
|
|
"\t".join(
|
|
[f"{elem:<{max_widths[i]}}" for i, elem in enumerate(row)]
|
|
)
|
|
)
|
|
|
|
def display_sac_tradeoff_stats(
|
|
self,
|
|
greedy_order_meta: SACGreedyOrderMeta,
|
|
sac_stats: SACStats,
|
|
print_tabular: bool = False,
|
|
) -> None:
|
|
"""
|
|
Displays the SAC trade-off statistics.
|
|
|
|
Args:
|
|
greedy_order_meta (SACGreedyOrderMeta): The SAC greedy order metadata.
|
|
sac_stats (SACStats): The SAC statistics.
|
|
print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False.
|
|
|
|
Prints:
|
|
A table with the following columns:
|
|
1. Op Id(s): The operator index(es).
|
|
2. Op Name(s): The operator name(s).
|
|
3. Discarded Mem (%): The percentage of discarded memory.
|
|
4. Discarded Mem (B): The discarded memory in bytes.
|
|
5. Recomp time (%): The percentage of recomputed time.
|
|
6. Recomp time (ms): The recomputed time in milliseconds.
|
|
7. MSPS: The memory per second.
|
|
8. Always Stored: A flag indicating whether the operator is always stored.
|
|
9. Always Recomputed: A flag indicating whether the operator is always recomputed.
|
|
|
|
If print_tabular is True, the table is printed in a tabular format.
|
|
Otherwise, the table is printed in a plain text format.
|
|
"""
|
|
table_data = []
|
|
total_memory, total_runtime = sum(sac_stats.memory), sum(sac_stats.runtimes)
|
|
discarded_mem: int = 0
|
|
recomp_runtime: float = 0.0
|
|
|
|
def append_row(
|
|
op_indices: set[int],
|
|
func_names: set[str],
|
|
msps: Optional[float] = None,
|
|
stored: Optional[bool] = False,
|
|
recomputed: Optional[bool] = False,
|
|
) -> None:
|
|
row = [
|
|
str(op_indices),
|
|
str(func_names),
|
|
f"{discarded_mem / total_memory:.4f}",
|
|
str(discarded_mem),
|
|
f"{recomp_runtime / total_runtime:.4f}",
|
|
str(recomp_runtime),
|
|
f"{msps:.2e}" if msps is not None else str(nan),
|
|
str(stored),
|
|
str(recomputed),
|
|
]
|
|
table_data.append(row)
|
|
|
|
stored_ops, recomputed_ops, inplace_op_groups, random_ops_group, msps_meta = (
|
|
greedy_order_meta.stored_ops,
|
|
greedy_order_meta.recomputed_ops,
|
|
greedy_order_meta.inplace_op_groups,
|
|
greedy_order_meta.random_ops_group,
|
|
greedy_order_meta.msps_meta,
|
|
)
|
|
|
|
for op_idx in recomputed_ops:
|
|
op_indices: set[int] = {op_idx}
|
|
if op_idx in inplace_op_groups:
|
|
op_indices.update(inplace_op_groups[op_idx])
|
|
if op_idx in random_ops_group:
|
|
op_indices.update(random_ops_group[op_idx])
|
|
discarded_mem += sum(sac_stats.memory[i] for i in op_indices)
|
|
recomp_runtime += sum(sac_stats.runtimes[i] for i in op_indices)
|
|
func_names = {sac_stats.func_names[i] for i in op_indices}
|
|
append_row(op_indices, func_names, recomputed=True)
|
|
|
|
for cand in msps_meta:
|
|
discarded_mem += cand.memory
|
|
recomp_runtime += cand.runtime
|
|
op_indices = {cand.op_idx}
|
|
if cand.op_idx in inplace_op_groups:
|
|
op_indices.update(inplace_op_groups[cand.op_idx])
|
|
if cand.op_idx in random_ops_group:
|
|
op_indices.update(random_ops_group[cand.op_idx])
|
|
append_row(op_indices, cand.func_names, msps=cand.msps)
|
|
|
|
for op_idx in stored_ops:
|
|
op_indices = {op_idx}
|
|
if op_idx in inplace_op_groups:
|
|
op_indices.update(inplace_op_groups[op_idx])
|
|
if op_idx in random_ops_group:
|
|
op_indices.update(random_ops_group[op_idx])
|
|
discarded_mem += sum(sac_stats.memory[i] for i in op_indices)
|
|
recomp_runtime += sum(sac_stats.runtimes[i] for i in op_indices)
|
|
func_names = {sac_stats.func_names[i] for i in op_indices}
|
|
append_row(op_indices, func_names, stored=True)
|
|
|
|
headers = [
|
|
"Op Id(s)",
|
|
"Op Name(s)",
|
|
"Discarded Mem (%)",
|
|
"Discarded Mem (B)",
|
|
"Recomp time (%)",
|
|
"Recomp time (ms)",
|
|
"MSPS",
|
|
"Always Stored",
|
|
"Always Recomputed",
|
|
]
|
|
if print_tabular:
|
|
_display_stats_tabular(headers, table_data)
|
|
else:
|
|
max_widths = [0 for _ in range(len(headers))]
|
|
table_data.insert(0, headers)
|
|
for row in table_data:
|
|
for i, elem in enumerate(row):
|
|
max_widths[i] = max(max_widths[i], len(elem))
|
|
for row in table_data:
|
|
print(
|
|
"\t".join(
|
|
[f"{elem:<{max_widths[i]}}" for i, elem in enumerate(row)]
|
|
)
|
|
)
|
|
|
|
def pwlf_sac_tradeoff_curve(
|
|
self,
|
|
n_segments: int = 2,
|
|
save_tradeoff_graphs: bool = False,
|
|
) -> None:
|
|
"""
|
|
Fits a piecewise linear function with the specified sumber of segments to the SAC trade-off curve of
|
|
discarded memory vs recomputation time.
|
|
|
|
Args:
|
|
n_segments (int, optional): The number of segments to be used for fitting the piecewise linear function to
|
|
the trade-off curve. Defaults to 2.
|
|
save_tradeoff_graphs (bool, optional): Whether to save the trade-off graphs to file. Defaults to False.
|
|
|
|
If save_tradeoff_graphs is True, the trade-off graphs are saved to file using the module FQN as the filename.
|
|
"""
|
|
for mod_fqn, sac_stats in self.sac_mod_stats.items():
|
|
self.sac_mod_tradeoff_stats[mod_fqn] = self._get_sac_tradeoff_pwlf_stats(
|
|
sac_stats=sac_stats,
|
|
greedy_order_meta=self.sac_mod_greedy_order_meta[mod_fqn],
|
|
n_segments=n_segments,
|
|
save_tradeoff_graph=save_tradeoff_graphs,
|
|
filename=mod_fqn,
|
|
)
|
|
|
|
def display_modulewise_sac_stats(
|
|
self, depth: int = 2, print_tabular: bool = False
|
|
) -> None:
|
|
"""
|
|
Displays the SAC and trade-off statistics for each module.
|
|
|
|
Args:
|
|
depth (int, optional): The maximum depth of modules to display. Defaults to 2.
|
|
print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False.
|
|
|
|
Prints:
|
|
For each module with depth less than or equal to the specified depth:
|
|
1. The SAC statistics for the module (using display_sac_stats).
|
|
2. The SAC trade-off statistics for the module (using display_sac_tradeoff_stats).
|
|
|
|
If print_tabular is True, the statistics are printed in a tabular format.
|
|
Otherwise, the statistics are printed in a plain text format.
|
|
"""
|
|
for mod_fqn, sac_stats in self.sac_mod_stats.items():
|
|
mod_depth = mod_fqn.count(".") + 1
|
|
if mod_depth > depth:
|
|
continue
|
|
print(f"Module: {mod_fqn}")
|
|
self.display_sac_stats(sac_stats, print_tabular)
|
|
print(f"AC Trade-off for Module: {mod_fqn} MSPS = Memory/Runtime")
|
|
self.display_sac_tradeoff_stats(
|
|
self.sac_mod_greedy_order_meta[mod_fqn], sac_stats, print_tabular
|
|
)
|
|
|
|
def __call__(self, estimate_mode_type: str) -> Self:
|
|
"""
|
|
Sets the estimate mode type.
|
|
|
|
Currently supported modes:
|
|
- "operator-level-benchmark": Estimates runtime using operator benchmarking.
|
|
- "operator-level-cost-model": Estimates runtime using roofline cost model.
|
|
|
|
Args:
|
|
estimate_mode_type (str): The type of estimate mode to use.
|
|
|
|
Returns:
|
|
SACEstimator: The SAC estimator instance.
|
|
|
|
Raises:
|
|
NotImplementedError: If the estimate mode type is not supported.
|
|
"""
|
|
if estimate_mode_type == "operator-level-benchmark":
|
|
self._estimate_runtime = RuntimeEstimator._benchmark_estimate
|
|
elif estimate_mode_type == "operator-level-cost-model":
|
|
self._estimate_runtime = RuntimeEstimator._roofline_estimate
|
|
else:
|
|
raise NotImplementedError(
|
|
f"estimate_mode_type {estimate_mode_type} not supported"
|
|
)
|
|
return self
|
|
|
|
def __enter__(self) -> Self: # type: ignore[no-untyped-def]
|
|
fake_mode = active_fake_mode()
|
|
assert isinstance(fake_mode, FakeTensorMode), (
|
|
"SAC Estimator should be called in FakeTensorMode"
|
|
)
|
|
RuntimeEstimator.fake_mode = fake_mode
|
|
self._mod_tracker.register_user_hooks(
|
|
pre_fw_hook=self._pre_fw_hook,
|
|
post_fw_hook=self._post_fw_hook,
|
|
)
|
|
self._mod_tracker.__enter__()
|
|
self._saved_tensor_hook_ctx.__enter__()
|
|
return super().__enter__()
|
|
|
|
def __exit__(self, *args: Any) -> None: # type: ignore[no-untyped-def]
|
|
self._saved_tensor_hook_ctx.__exit__()
|
|
self._mod_tracker.__exit__(*args)
|
|
super().__exit__(*args)
|