Files
pytorch/torch/distributed/_tools/sac_estimator.py
Aaron Gokaslan 3555ebb63d [BE]: Update ruff to 0.11.8 (#153249)
Fixes a ton of false negatives throughout the codebase. RUFF also properly validates NOQA comments now and most of the changes are fixing typos there or removing filewide flake8 suppressions that were also silencing ruff issues.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153249
Approved by: https://github.com/cyyever, https://github.com/albanD, https://github.com/seemethere
2025-05-12 18:30:52 +00:00

961 lines
41 KiB
Python

import math
import os
import sys
from collections import OrderedDict
from dataclasses import astuple, dataclass
from typing import Any, NamedTuple, Optional
from typing_extensions import Self
import torch
from torch import nan, nn, UntypedStorage
from torch._guards import active_fake_mode
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed._tools.common_utils import get_untyped_storages
from torch.distributed._tools.mod_tracker import ModTracker
from torch.distributed._tools.runtime_estimator import RuntimeEstimator
from torch.testing._internal.composite_compliance import (
is_inplace,
is_inplace_view_fn,
is_view_fn,
)
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten
from torch.utils.checkpoint import SAC_IGNORED_OPS
__all__ = ["SACEstimator", "SACStats", "MSPS", "SACTradeOffStats", "SACGreedyOrderMeta"]
aten = torch.ops.aten
_ADDITIONAL_IGNORED_OPS = {
aten.lift_fresh.default, # type: ignore[attr-defined]
torch.ops.profiler._record_function_exit._RecordFunction, # type: ignore[attr-defined]
aten.clone.default, # type: ignore[attr-defined] # seems needed for torch.compile
}
OPS_TO_ALWAYS_SKIP = SAC_IGNORED_OPS | _ADDITIONAL_IGNORED_OPS
# This value is hard-coded here:
# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117
_PYTORCH_MIN_ALLOCATE = (
2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1
)
def _display_stats_tabular(headers: list[str], table_data: list[list[Any]]) -> None:
try:
from tabulate import tabulate
except ImportError as err:
raise ImportError("Please install tabulate.") from err
# Use tabulate to print the table
print(tabulate(table_data, headers=headers, tablefmt="rst"))
# Based on:
# https://github.com/facebookresearch/xformers/blob/main/xformers/checkpoint.py#L71
@dataclass
class _SACMetadata:
"""
Stores metadata for a single operator for SAC.
Attributes:
func (Any): The operator function.
time_taken (float): The time taken by the operator.
memory_used (float): The memory used by the operator.
curr_idx (int): The current operator index.
output_ids (Tuple[int, ...]): The storage IDs of the operator's outputs.
inplace_info (Tuple[int, ...]): Tuple of self and parent operator for in-place operator.
is_view_like (bool): Whether the operator is view-like.
is_rand_op (bool): Whether the operator is a random operator.
"""
func: Any
time_taken: float
memory_used: float
curr_idx: int
output_ids: tuple[int, ...]
inplace_info: tuple[int, ...]
is_view_like: bool
is_rand_op: bool
@dataclass
class _SACModMetadata:
"""
Stores metadata for a module for SAC.
Attributes:
start_idx (int): The starting index of the module's operators.
force_store_random (bool): Whether to force store random operators in the module.
sac_metadata (List[_SACMetadata]): List of metadata for each operator in the module.
"""
start_idx: int
force_store_random: bool
sac_metadata: list[_SACMetadata]
@dataclass
class SACStats:
"""
A class for storing Activation Checkpointing statistics corresponding to a module.
Attributes:
func_names (List[str]): List of operator names.
runtimes (List[float]): List of operator runtimes in millliseconds.
memory (List[int]): List of operator memory usage in bytes.
view_like_ops (List[int]): Indices of view-like operators.
rand_ops (List[int]): Indices of random operators.
saved_autograd_ops (List[int]): Indices of operator results saved by autograd engine.
inplace_ops (List[Tuple[int, int]]): Tuple of indices of op and its first parent for Inplace operators.
force_store_random (bool): Whether to force store random operator results.
"""
func_names: list[str]
runtimes: list[float]
memory: list[int]
view_like_ops: list[int]
rand_ops: list[int]
saved_autograd_ops: list[int]
inplace_ops: list[tuple[int, int]]
force_store_random: bool
class MSPS(NamedTuple):
"""
Represents Memory and Runtime Statistics for an operator/operator group.
Attributes:
func_names (set[str]): Set of operator/operator group names.
op_idx (int): Operator index (group head index incase of operator groups).
memory (int): Memory usage in bytes.
runtime (float): Runtime in milliseconds.
msps (float): Memory per second calculated as memory/runtime.
"""
func_names: set[str]
op_idx: int
memory: int
runtime: float
msps: float
@dataclass
class SACTradeOffStats:
"""
Stores statistics for activation-checkpointing trade-off.
Attributes:
n_segments (int): Number of piecewise linear segments fitted to the trade-off curve.
slopes (List[float]): Slopes of the pieces of linear segments fitted to the trade-off curve.
intercepts (List[float]): Intercepts of the of the pieces of linear segments fitted to the trade-off curve.
fit_breaks (List[float]): Breakpoints of the of the pieces of linear segments fitted to the trade-off curve.
tradeoff_curve (OrderedDict[float, float]): Trade-off curve data of memory discarded vs recomputation time.
sac_memory (int): Total memory of operations available for activation checkpointing in bytes.
sac_runtime (float): Total runtime of operations available for activation checkpointing in milliseconds.
"""
n_segments: int
slopes: list[float]
intercepts: list[float]
fit_breaks: list[float]
tradeoff_curve: OrderedDict[float, float]
sac_memory: int
sac_runtime: float
@dataclass
class SACGreedyOrderMeta:
"""
Stores metadata for Greedy-order SAC.
Attributes:
recomputed_ops (set[int]): Set of operator indices to be recomputed.
stored_ops (set[int]): Set of operator indices to be stored.
inplace_op_groups (dict[int, set[int]]): Dictionary of inplace operator groups from group-head to operators.
random_ops_group (dict[int, set[int]]): Dictionary of random op group head to random ops.
msps_meta (list[MSPS]): List of Memory and Runtime Statistics for operators.
"""
recomputed_ops: set[int]
stored_ops: set[int]
inplace_op_groups: dict[int, set[int]]
random_ops_group: dict[int, set[int]]
msps_meta: list[MSPS]
class SACEstimator(TorchDispatchMode):
"""
Estimates the memory and recomputation time trade-offs for applying Selective Activation Checkpointing (SAC).
This class provides a ``TorchDispatchMode`` based context manager that can be used to estimate the memory and
runtime trade-offs of functions or ``torch.nn.Module``s for Selective Activation Checkpointing (SAC). It provides
detailed statistics and metadata information for operators of each module and provides a greedy order for selecting
the operators to be recomputed/checkpointed. It also constructs the per-module trade-off graph of discarded memory
vs recomputation time for the obtained greedy order. Using ``RuntimeEstimator`` under the hood, it supports two
estimation modes, `operator-level-benchmark` and (`operator-level-cost-model` (roofline model).
Attributes:
sac_mod_stats (Dict[str, SACStats]): Dictionary from module FQN (fuly qualified name) to ``SACStats``.
sac_mod_tradeoff_stats (Dict[str, SACTradeOffStats]): Dictionary from module FQN to ``SACTradeOffStats``.
sac_mod_greedy_order_meta (Dict[str, SACGreedyOrderMeta]): Dictionary from module FQN to ``SACGreedyOrderMeta``.
Note:
1) This class is designed to be used under ``FakeTensorMode``.
2) Currently, it only supports estimation of compute time and memory usage, and does not consider communication.
Example usage:
.. code-block:: python
sac_estimator = SACEstimator()
with FakeTensorMode():
module = ...
inp = ...
with sac_estimator("operator-level-cost-model"):
output = module(inp)
sac_estimator.display_modulewise_sac_stats(depth=4, print_tabular=True)
"""
def __init__(self) -> None:
self.sac_mod_stats: dict[str, SACStats] = {}
self.sac_mod_tradeoff_stats: dict[str, SACTradeOffStats] = {}
self.sac_mod_greedy_order_meta: dict[str, SACGreedyOrderMeta] = {}
self._mod_tracker = ModTracker()
self._sac_metadata: list[_SACMetadata] = []
self._sac_mod_metadata: dict[str, _SACModMetadata] = {}
self._leaf_modules: set[str] = set()
self._saved_tensor_hook_ctx = torch.autograd.graph.saved_tensors_hooks(
self._pack_hook, lambda x: x
)
self._saved_tensor_ids: set[int] = set()
self._estimate_runtime = RuntimeEstimator._roofline_estimate
def _pack_hook(self, x: torch.Tensor) -> torch.Tensor:
# Hook function to track underlying storage IDs of tensors
# Updates the _saved_tensor_ids set with the IDs of the tensor's storages
# Used in conjunction with torch.autograd.graph.saved_tensors_hooks
untyped_storages = get_untyped_storages(x)
storage_ids = (hash(st) for st in untyped_storages)
self._saved_tensor_ids.update(storage_ids)
return x
def _pre_fw_hook(self, mod: nn.Module, inputs: Any) -> None:
# Pre-forward hook function to prepare module metadata
# Tracks module FQN, force store random flag, and ``SACModMetadata``
# Initializes metadata for non-leaf modules, marks leaf modules
mod_fqn = self._mod_tracker.get_known_fqn(mod)
assert mod_fqn is not None
num_children = sum(1 for _ in mod.children())
if num_children > 0:
force_store_random = self._get_force_store_random(inputs)
self._sac_mod_metadata[mod_fqn] = _SACModMetadata(
start_idx=len(self._sac_metadata),
force_store_random=force_store_random,
sac_metadata=[],
)
else:
self._leaf_modules.add(mod_fqn)
def _post_fw_hook(self, mod: nn.Module, inputs: Any, outputs: Any) -> None:
# 1. Retrieves the module's FQN and checks if it's a leaf module
# 2. If not a leaf module, computes:
# - ``SACStats`` using the module's metadata and force store random flag
# - ``SACGreedyOrderMeta`` using the computed SAC statistics
mod_fqn = self._mod_tracker.get_known_fqn(mod)
assert mod_fqn is not None
if mod_fqn in self._leaf_modules:
return
else:
self.sac_mod_stats[mod_fqn] = self._get_sac_stats(
data=self._sac_mod_metadata[mod_fqn].sac_metadata,
force_store_random=self._sac_mod_metadata[mod_fqn].force_store_random,
)
self.sac_mod_greedy_order_meta[mod_fqn] = self._get_greedy_order_meta(
self.sac_mod_stats[mod_fqn]
)
def _get_force_store_random(self, inputs: Any) -> bool:
flat_inputs, _ = tree_flatten(inputs)
return all(not isinstance(x, torch.Tensor) for x in flat_inputs)
def _get_sac_stats(
self, data: list[_SACMetadata], force_store_random: bool
) -> SACStats:
# 1. Ignore the operations that should be skipped by SAC such as aten.detach.default because autograd
# inserts those during backward and it breaks the fwd-bwd alignment
filtered_data = [x for x in data if x.func not in OPS_TO_ALWAYS_SKIP]
(
ops,
runtimes_,
memory_,
new_ids,
output_ids,
inplace_ops_,
view_like_ops_,
rand_ops_,
) = zip(*[astuple(x) for x in filtered_data], strict=True)
# 2. Extract the metadata information
runtimes = list(runtimes_)
memory = list(memory_)
func_names = [op._overloadpacket.__name__ for op in ops]
view_like_ops = [i for i, x in enumerate(view_like_ops_) if x]
rand_ops = [i for i, x in enumerate(rand_ops_) if x]
saved_autograd_ops = [
i
for i, out_ids in enumerate(output_ids)
if set(out_ids).issubset(self._saved_tensor_ids)
]
# 3. Remap the inplace indices as we have removed OPS_TO_ALWAYS_SKIP
# FIXME @sanketpurandare: Fix this by changing the parent of the inplace-op
# to itself if the original parent is in OPS_TO_ALWAYS_SKIP.
try:
inplace_ops = [tuple(map(new_ids.index, x)) for x in inplace_ops_ if x]
except ValueError as err:
raise ValueError(
f"The remapping of inplace ops failed since one of the inplace op parents"
f" must have been present in {OPS_TO_ALWAYS_SKIP}"
) from err
# 4. The last operation is always stored as the output of the checkpoint
# block, so we can avoid recomputing it. We set the memory to zero
# instead of adding a new constraint because we want both the 0 and 1
# endpoints for memory_budget to be valid
# FIXME @sanketpurandare: this heuristic for finding the last non-view non-inplace op
# might not always be correct, which would yield suboptimal policies
last_op = len(ops) - 1
skip_ops_ = set(view_like_ops) | set({x[0] for x in inplace_ops})
reversed_skip_ops = sorted(skip_ops_, reverse=True)
for op in reversed_skip_ops:
if op == last_op:
last_op -= 1
memory[last_op] = 0
# 5. Create a single ``SACStats`` object for the entire block of ``_SACMetadata``.
return SACStats(
func_names=func_names,
runtimes=runtimes,
memory=memory,
view_like_ops=view_like_ops,
rand_ops=rand_ops,
saved_autograd_ops=saved_autograd_ops,
inplace_ops=inplace_ops, # type: ignore[arg-type]
force_store_random=force_store_random,
)
def _get_inplace_metadata(
self, func: Any, out_storages: set[UntypedStorage]
) -> tuple[int, tuple[int, ...], dict[str, tuple[int, ...]]]:
# 1. Get the current index of the metadata obtained so far
curr_idx = len(self._sac_metadata)
# 2. Get the set of active modules that are not leaf
active_mod_fqns: set[str] = {
par for par in self._mod_tracker.parents if par not in self._leaf_modules
}
# 3. Output ids are the identifies of the storage objects corresponding to the tensors
output_ids = tuple(hash(st) for st in out_storages)
# 4. If the function is not inplace, return
if not is_inplace(func):
return curr_idx, output_ids, dict.fromkeys(active_mod_fqns, ())
op_idx = curr_idx
# 5. Initialize the parent op ids of the inplace op for each of the active modules
mod_op_parent_idxs: dict[str, int] = dict.fromkeys(active_mod_fqns, -1)
for i, d in enumerate(self._sac_metadata):
# 6. Find the first occurence of a tensor corresponding to each module that
# shares the same storage as the current tensor
past_output_ids = d.output_ids
if set(output_ids).issubset(set(past_output_ids)):
for mod_fqn, op_parent_idx in mod_op_parent_idxs.items():
if op_parent_idx == -1:
if acm_stats := self._sac_mod_metadata.get(mod_fqn, None):
if i >= acm_stats.start_idx:
mod_op_parent_idxs[mod_fqn] = i
else:
assert mod_fqn == "Global"
mod_op_parent_idxs[mod_fqn] = i
# 7. If no parent tensor is found, then it's probably an inplace op on the arguments
# so one can just store the current-op idx as parent idx
for mod_fqn, op_parent_idx in mod_op_parent_idxs.items():
if op_parent_idx < 0:
mod_op_parent_idxs[mod_fqn] = op_idx
mod_inplace_info = {
mod_fqn: (op_idx, mod_op_parent_idxs[mod_fqn])
for mod_fqn in active_mod_fqns
}
return curr_idx, output_ids, mod_inplace_info # type: ignore[return-value]
def __torch_dispatch__( # type: ignore[no-untyped-def]
self, func, types, args=..., kwargs=None
):
# 1. Get the runtime estimate
out, op_time = self._estimate_runtime(func, args, kwargs)
flat_outs, _ = tree_flatten(out)
out_storages_cuda: set[UntypedStorage] = set()
out_storages_cpu: set[UntypedStorage] = set()
cuda_devices: set[torch.device] = set()
for o in flat_outs:
if isinstance(o, torch.Tensor):
if o.device.type == "cuda":
out_storages_cuda.update(get_untyped_storages(o))
cuda_devices.add(o.device)
else:
out_storages_cpu.update(get_untyped_storages(o))
# Check if there's more than 1 CUDA device
assert len(cuda_devices) <= 1, (
f"{func.__name__}'s output has more than 1 CUDA devices {cuda_devices}"
)
# 2. Get the memory consumed by output
nbytes_cuda = sum(
math.ceil(st.nbytes() / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE
for st in out_storages_cuda
)
nbytes_cpu = sum(st.nbytes() for st in out_storages_cpu)
nbytes = nbytes_cuda + nbytes_cpu
# 3. Get the current operator index, output storage identifiers and inplace metadata
out_storages = out_storages_cuda | out_storages_cpu
curr_idx, output_ids, mod_inplace_info = self._get_inplace_metadata(
func, out_storages
)
# 4. Determine if the function is in-place, random-op or a view-like
is_view_like = is_view_fn(func) or is_inplace_view_fn(func)
is_rand_op = torch.Tag.nondeterministic_seeded in func.tags
if is_view_like:
nbytes = 0
# sdpa has non-deterministic seed, but might be deterministic
# if no dropout is applied
if func.overloadpacket.__name__ == "_scaled_dot_product_flash_attention":
is_rand_op = kwargs.get("dropout_p", 0) != 0
# 5. Create metadata information per active non-leaf module
for mod_fqn in self._mod_tracker.parents:
if mod_fqn in self._leaf_modules:
continue
acm = _SACMetadata(
func=func,
time_taken=op_time,
memory_used=nbytes,
curr_idx=curr_idx,
output_ids=output_ids,
inplace_info=mod_inplace_info[mod_fqn],
is_view_like=is_view_like,
is_rand_op=is_rand_op,
)
if acm_stats := self._sac_mod_metadata.get(mod_fqn, None):
acm_stats.sac_metadata.append(acm)
else:
assert mod_fqn == "Global", (
f"Module {mod_fqn} not found in AC Mod Stats"
)
self._sac_metadata.append(acm)
return out
def _get_greedy_order_meta(self, sac_stats: SACStats) -> SACGreedyOrderMeta:
# An inplace-op group is a set of inplace-ops that operate on the same underlying tensor storage.
# 1. inplace_op_groups: A dictionary from the top-most parent of inplace-ops to the inplace-ops in the group
# The top-most op can itself be an inplace-op or can be a non-inplace op.
# 2. inplace_op_to_group_head: A dictionary that maps all the inplace-ops to their respective group heads.
inplace_op_groups: dict[int, set[int]] = {}
inplace_op_to_group_head: dict[int, int] = dict(sac_stats.inplace_ops)
# Initialize inplace_op_groups using inplace_op_to_group_head
for op_idx, group_head_idx in inplace_op_to_group_head.items():
op_group = inplace_op_groups.setdefault(group_head_idx, {group_head_idx})
op_group.add(op_idx)
# Like inplace ops, all of the random ops in the function/module should all be either recomputed or saved
# as a group. This is because, they affect the ranom seed generator. If force_store_random is set True,
# all of the random ops will be stored by default. For easy of manageability, we store the top-most random op
# as the leader of the random_ops_group.
random_ops_group: dict[int, set[int]] = {}
random_group_head_idx = min(sac_stats.rand_ops, default=-1)
has_rand_ops = bool(sac_stats.rand_ops)
if has_rand_ops:
random_ops_group[random_group_head_idx] = set(sac_stats.rand_ops)
# 1. Random ops are stored if force_store_random is set
# 2. View-like ops are recomputed by default
# 3. For inplace_op_groups:
# a) If the head of this group is an inplace op, then we have to store the entire group.
# b) If any op in the group is random and force_store_random is set, then entire group will be stored.
# c) If none of ops in the group are random and the head of the group is not an in-place op, then
# this group can be considered for recomputation in its entireity
stored_ops: set[int] = set()
recomputed_ops: set[int] = set()
# Case 1:
if has_rand_ops and sac_stats.force_store_random:
stored_ops.add(random_group_head_idx)
# Case 2:
recomputed_ops.update(set(sac_stats.view_like_ops))
for group_head_idx, op_group in inplace_op_groups.items():
# Case 3a:
if group_head_idx in inplace_op_to_group_head:
stored_ops.add(group_head_idx)
# Case 3b:
if (
sac_stats.force_store_random & len(op_group & set(sac_stats.rand_ops))
> 0
):
stored_ops.add(group_head_idx)
# The potential recompute candidates are populated as:
recompute_candidates: set[int] = set()
# 1) The random group head if it is not stored
if has_rand_ops and random_group_head_idx not in stored_ops:
recompute_candidates.add(random_group_head_idx)
# 2) The in-place op group heads that are not stored
recompute_candidates.update(set(inplace_op_groups.keys()) - stored_ops)
# 3) The non-inplace and non-random ops that are neither stored nor recomputed by default
recompute_candidates.update(
set(range(len(sac_stats.memory)))
- recomputed_ops
- stored_ops
- set(inplace_op_to_group_head.keys())
- set(sac_stats.rand_ops)
)
# We define msps for a recomp candidate as the ratio of memory/runtime aka memory savings per second
msps_meta: list[MSPS] = []
for cand_idx in recompute_candidates:
op_indices = {cand_idx}
if cand_idx in inplace_op_groups:
op_indices.update(inplace_op_groups[cand_idx])
if has_rand_ops and cand_idx == random_group_head_idx:
op_indices.update(sac_stats.rand_ops)
mem = sum(sac_stats.memory[op_idx] for op_idx in op_indices)
runtime = sum(sac_stats.runtimes[op_idx] for op_idx in op_indices)
func_names = {sac_stats.func_names[op_idx] for op_idx in op_indices}
msps = (mem / runtime) if runtime > 0 else sys.float_info.max
msps_meta.append(MSPS(func_names, cand_idx, mem, runtime, msps))
# We choose canidates to be recomputed based on increasing msps
msps_meta.sort(key=lambda x: x.msps, reverse=True)
return SACGreedyOrderMeta(
recomputed_ops, stored_ops, inplace_op_groups, random_ops_group, msps_meta
)
def _get_sac_tradeoff_pwlf_stats(
self,
sac_stats: SACStats,
greedy_order_meta: SACGreedyOrderMeta,
n_segments: int = 2,
save_tradeoff_graph: bool = False,
filename: str = "ac_tradeoff",
) -> SACTradeOffStats:
try:
import numpy as np # type: ignore[import-not-found]
import pwlf # type: ignore[import-untyped, import-not-found]
except ImportError as err:
raise ImportError("Please install pwlf and numpy package.") from err
stored_ops, recomputed_ops, inplace_op_groups, random_ops_group, msps_meta = (
greedy_order_meta.stored_ops,
greedy_order_meta.recomputed_ops,
greedy_order_meta.inplace_op_groups,
greedy_order_meta.random_ops_group,
greedy_order_meta.msps_meta,
)
# 1. Intitialize the discarded memory and recomputation runtime to sum of already chosen recomputed_ops
recomp_indices: set[int] = set()
for r_idx in recomputed_ops:
recomp_indices.add(r_idx)
if r_idx in inplace_op_groups:
recomp_indices.update(inplace_op_groups[r_idx])
if r_idx in random_ops_group:
recomp_indices.update(random_ops_group[r_idx])
discarded_mem = sum(sac_stats.memory[op_idx] for op_idx in recomp_indices)
recomp_runtime = sum(sac_stats.runtimes[op_idx] for op_idx in recomp_indices)
# 2. Initialize the max recomputation time and total recomputation memory
sac_runtime = sum(sac_stats.runtimes)
sac_memory = sum(sac_stats.memory)
# 3. Tradeoff curve stores the KV pair of the dicarded memory to total memory and,
# recomputation time to total runtime incurred.
delta = 1e-2
tradeoff_curve = OrderedDict()
# 4. Initialize the trade-off curve with the stats of of already chosen recomputed_ops
tradeoff_curve[(discarded_mem / sac_memory) + delta] = (
recomp_runtime / sac_runtime
)
# 5. Update the trade-off curve with memory and runtime stats of SAC candidates in the
# greedy order of their ``MSPS``.
for cand in msps_meta:
discarded_mem += cand.memory
recomp_runtime += cand.runtime
tradeoff_curve[(discarded_mem / sac_memory) + delta] = (
recomp_runtime / sac_runtime
)
# 6. Finally, we add the memory and recomputation time of the always stored ops.
stored_indices: set[int] = set()
for s_idx in stored_ops:
stored_indices.add(s_idx)
if s_idx in inplace_op_groups:
stored_indices.update(inplace_op_groups[s_idx])
if s_idx in random_ops_group:
stored_indices.update(random_ops_group[s_idx])
discarded_mem += sum(sac_stats.memory[op_idx] for op_idx in stored_indices)
recomp_runtime += sum(sac_stats.runtimes[op_idx] for op_idx in stored_indices)
tradeoff_curve[(discarded_mem / sac_memory) + delta] = (
recomp_runtime / sac_runtime
)
x_ = list(tradeoff_curve.keys())
y_ = list(tradeoff_curve.values())
# 7. We shift the y values to left and x values to right to upperbound the trade-off function
# TODO: Write a better explanation why this needs to be done
x = x_[: len(x_) - 1]
y = y_[1:]
tradeoff_pwlf = pwlf.PiecewiseLinFit(x, y)
# 8. Fit a piecewise linear function with the specified number of segments to the trade-off curve.
n_segments = max(min(len(x) - 2, n_segments), 1)
tradeoff_pwlf.fit(n_segments=n_segments)
# save prediction graph
def save_prediction_graph(
pwlf_: pwlf.PiecewiseLinFit, x: list[float], y: list[float], filename: str
) -> None:
try:
import matplotlib.pyplot as plt # type: ignore[import-not-found]
import numpy as np # type: ignore[import-not-found]
except ImportError as err:
raise ImportError(
"Install matplotlib and numpy using pip: pip install matplotlib numpy"
) from err
# predict for the determined points
xHat = np.linspace(min(x), max(x), num=10000)
yHat = pwlf_.predict(xHat)
# plot the results
plt.figure()
plt.plot(x, y, "o", label="Shifted")
plt.plot(xHat, yHat, "-", label="Predicted")
plt.plot(x_, y_, "x", label="Original")
plt.ylabel("Recomp time / Total recomp time")
plt.xlabel("Memory discarded / Total memory")
plt.legend()
plt.title(f"{filename}")
plt.suptitle(
f"Total Memory = {sac_memory} B Total Runtime = {sac_runtime:.4f} ms",
fontsize=10,
)
folder_name = "tradeoff_graphs"
if not os.path.exists(folder_name):
os.makedirs(folder_name)
# Save the plots in the folder
plt.savefig(os.path.join(folder_name, f"{filename}.png"))
if save_tradeoff_graph:
save_prediction_graph(tradeoff_pwlf, x, y, filename)
# 9. Obtain the slopes, intercepts and breakpoints of the fitted piecewise linear functions
slopes = tradeoff_pwlf.calc_slopes().tolist()
assert isinstance(tradeoff_pwlf.intercepts, np.ndarray) and isinstance(
tradeoff_pwlf.fit_breaks, np.ndarray
)
intercepts = tradeoff_pwlf.intercepts.tolist()
fit_breaks = tradeoff_pwlf.fit_breaks.tolist()
return SACTradeOffStats(
n_segments=n_segments,
slopes=slopes,
intercepts=intercepts, # type: ignore[arg-type]
fit_breaks=fit_breaks, # type: ignore[arg-type]
tradeoff_curve=tradeoff_curve,
sac_memory=sac_memory,
sac_runtime=sac_runtime,
)
def display_sac_stats(
self, sac_stats: SACStats, print_tabular: bool = False
) -> None:
"""
Displays the SAC statistics.
Args:
sac_stats (SACStats): The SAC statistics to display.
print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False.
Prints:
1. Total Memory: The total memory usage in bytes.
2. Total Runtime: The total runtime in milliseconds.
3. Store Random: A flag indicating whether to force store random operator results.
Followed by a table with the following columns:
1. Op Idx: The operator index.
2. Op Name: The operator name.
3. Runtimes (ms): The operator runtime in milliseconds.
4. Memory (B): The operator memory usage in bytes.
5. View-like: A flag indicating whether the operator is view-like.
6. Random: A flag indicating whether the operator is random.
7. Saved Autograd: A flag indicating whether the operator's result is saved by autograd engine.
8. In-place: The index of the operator's first parent, or None if not in-place.
If print_tabular is True, the table is printed in a tabular format.
Otherwise, the table is printed in a plain text format.
"""
print(
f"Total Memory: {sum(sac_stats.memory)} B Total Runtime: {sum(sac_stats.runtimes)} ms"
f" Store Random: {sac_stats.force_store_random}"
)
table_data = []
op_parent = dict(sac_stats.inplace_ops)
for i, fn_name in enumerate(sac_stats.func_names):
row = [
str(i),
fn_name,
f"{sac_stats.runtimes[i]:.4f}",
str(sac_stats.memory[i]),
str(i in sac_stats.view_like_ops),
str(i in sac_stats.rand_ops),
str(i in sac_stats.saved_autograd_ops),
str(op_parent.get(i, None)),
]
table_data.append(row)
# Define headers
headers = [
"Op Idx",
"Op Name",
"Runtimes(ms)",
"Memory (B)",
"View-like",
"Random",
"Saved Autograd",
"In-place",
]
if print_tabular:
_display_stats_tabular(headers, table_data)
else:
max_widths = [0 for _ in range(len(headers))]
table_data.insert(0, headers)
for row in table_data:
for i, elem in enumerate(row):
max_widths[i] = max(max_widths[i], len(elem))
for row in table_data:
print(
"\t".join(
[f"{elem:<{max_widths[i]}}" for i, elem in enumerate(row)]
)
)
def display_sac_tradeoff_stats(
self,
greedy_order_meta: SACGreedyOrderMeta,
sac_stats: SACStats,
print_tabular: bool = False,
) -> None:
"""
Displays the SAC trade-off statistics.
Args:
greedy_order_meta (SACGreedyOrderMeta): The SAC greedy order metadata.
sac_stats (SACStats): The SAC statistics.
print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False.
Prints:
A table with the following columns:
1. Op Id(s): The operator index(es).
2. Op Name(s): The operator name(s).
3. Discarded Mem (%): The percentage of discarded memory.
4. Discarded Mem (B): The discarded memory in bytes.
5. Recomp time (%): The percentage of recomputed time.
6. Recomp time (ms): The recomputed time in milliseconds.
7. MSPS: The memory per second.
8. Always Stored: A flag indicating whether the operator is always stored.
9. Always Recomputed: A flag indicating whether the operator is always recomputed.
If print_tabular is True, the table is printed in a tabular format.
Otherwise, the table is printed in a plain text format.
"""
table_data = []
total_memory, total_runtime = sum(sac_stats.memory), sum(sac_stats.runtimes)
discarded_mem: int = 0
recomp_runtime: float = 0.0
def append_row(
op_indices: set[int],
func_names: set[str],
msps: Optional[float] = None,
stored: Optional[bool] = False,
recomputed: Optional[bool] = False,
) -> None:
row = [
str(op_indices),
str(func_names),
f"{discarded_mem / total_memory:.4f}",
str(discarded_mem),
f"{recomp_runtime / total_runtime:.4f}",
str(recomp_runtime),
f"{msps:.2e}" if msps is not None else str(nan),
str(stored),
str(recomputed),
]
table_data.append(row)
stored_ops, recomputed_ops, inplace_op_groups, random_ops_group, msps_meta = (
greedy_order_meta.stored_ops,
greedy_order_meta.recomputed_ops,
greedy_order_meta.inplace_op_groups,
greedy_order_meta.random_ops_group,
greedy_order_meta.msps_meta,
)
for op_idx in recomputed_ops:
op_indices: set[int] = {op_idx}
if op_idx in inplace_op_groups:
op_indices.update(inplace_op_groups[op_idx])
if op_idx in random_ops_group:
op_indices.update(random_ops_group[op_idx])
discarded_mem += sum(sac_stats.memory[i] for i in op_indices)
recomp_runtime += sum(sac_stats.runtimes[i] for i in op_indices)
func_names = {sac_stats.func_names[i] for i in op_indices}
append_row(op_indices, func_names, recomputed=True)
for cand in msps_meta:
discarded_mem += cand.memory
recomp_runtime += cand.runtime
op_indices = {cand.op_idx}
if cand.op_idx in inplace_op_groups:
op_indices.update(inplace_op_groups[cand.op_idx])
if cand.op_idx in random_ops_group:
op_indices.update(random_ops_group[cand.op_idx])
append_row(op_indices, cand.func_names, msps=cand.msps)
for op_idx in stored_ops:
op_indices = {op_idx}
if op_idx in inplace_op_groups:
op_indices.update(inplace_op_groups[op_idx])
if op_idx in random_ops_group:
op_indices.update(random_ops_group[op_idx])
discarded_mem += sum(sac_stats.memory[i] for i in op_indices)
recomp_runtime += sum(sac_stats.runtimes[i] for i in op_indices)
func_names = {sac_stats.func_names[i] for i in op_indices}
append_row(op_indices, func_names, stored=True)
headers = [
"Op Id(s)",
"Op Name(s)",
"Discarded Mem (%)",
"Discarded Mem (B)",
"Recomp time (%)",
"Recomp time (ms)",
"MSPS",
"Always Stored",
"Always Recomputed",
]
if print_tabular:
_display_stats_tabular(headers, table_data)
else:
max_widths = [0 for _ in range(len(headers))]
table_data.insert(0, headers)
for row in table_data:
for i, elem in enumerate(row):
max_widths[i] = max(max_widths[i], len(elem))
for row in table_data:
print(
"\t".join(
[f"{elem:<{max_widths[i]}}" for i, elem in enumerate(row)]
)
)
def pwlf_sac_tradeoff_curve(
self,
n_segments: int = 2,
save_tradeoff_graphs: bool = False,
) -> None:
"""
Fits a piecewise linear function with the specified sumber of segments to the SAC trade-off curve of
discarded memory vs recomputation time.
Args:
n_segments (int, optional): The number of segments to be used for fitting the piecewise linear function to
the trade-off curve. Defaults to 2.
save_tradeoff_graphs (bool, optional): Whether to save the trade-off graphs to file. Defaults to False.
If save_tradeoff_graphs is True, the trade-off graphs are saved to file using the module FQN as the filename.
"""
for mod_fqn, sac_stats in self.sac_mod_stats.items():
self.sac_mod_tradeoff_stats[mod_fqn] = self._get_sac_tradeoff_pwlf_stats(
sac_stats=sac_stats,
greedy_order_meta=self.sac_mod_greedy_order_meta[mod_fqn],
n_segments=n_segments,
save_tradeoff_graph=save_tradeoff_graphs,
filename=mod_fqn,
)
def display_modulewise_sac_stats(
self, depth: int = 2, print_tabular: bool = False
) -> None:
"""
Displays the SAC and trade-off statistics for each module.
Args:
depth (int, optional): The maximum depth of modules to display. Defaults to 2.
print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False.
Prints:
For each module with depth less than or equal to the specified depth:
1. The SAC statistics for the module (using display_sac_stats).
2. The SAC trade-off statistics for the module (using display_sac_tradeoff_stats).
If print_tabular is True, the statistics are printed in a tabular format.
Otherwise, the statistics are printed in a plain text format.
"""
for mod_fqn, sac_stats in self.sac_mod_stats.items():
mod_depth = mod_fqn.count(".") + 1
if mod_depth > depth:
continue
print(f"Module: {mod_fqn}")
self.display_sac_stats(sac_stats, print_tabular)
print(f"AC Trade-off for Module: {mod_fqn} MSPS = Memory/Runtime")
self.display_sac_tradeoff_stats(
self.sac_mod_greedy_order_meta[mod_fqn], sac_stats, print_tabular
)
def __call__(self, estimate_mode_type: str) -> Self:
"""
Sets the estimate mode type.
Currently supported modes:
- "operator-level-benchmark": Estimates runtime using operator benchmarking.
- "operator-level-cost-model": Estimates runtime using roofline cost model.
Args:
estimate_mode_type (str): The type of estimate mode to use.
Returns:
SACEstimator: The SAC estimator instance.
Raises:
NotImplementedError: If the estimate mode type is not supported.
"""
if estimate_mode_type == "operator-level-benchmark":
self._estimate_runtime = RuntimeEstimator._benchmark_estimate
elif estimate_mode_type == "operator-level-cost-model":
self._estimate_runtime = RuntimeEstimator._roofline_estimate
else:
raise NotImplementedError(
f"estimate_mode_type {estimate_mode_type} not supported"
)
return self
def __enter__(self) -> Self: # type: ignore[no-untyped-def]
fake_mode = active_fake_mode()
assert isinstance(fake_mode, FakeTensorMode), (
"SAC Estimator should be called in FakeTensorMode"
)
RuntimeEstimator.fake_mode = fake_mode
self._mod_tracker.register_user_hooks(
pre_fw_hook=self._pre_fw_hook,
post_fw_hook=self._post_fw_hook,
)
self._mod_tracker.__enter__()
self._saved_tensor_hook_ctx.__enter__()
return super().__enter__()
def __exit__(self, *args: Any) -> None: # type: ignore[no-untyped-def]
self._saved_tensor_hook_ctx.__exit__()
self._mod_tracker.__exit__(*args)
super().__exit__(*args)