mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This PR adds a Selective Activation Checkpointing (SAC) Estimator, built on top of the `Runtime Estimator`, for estimating memory and recomputation time trade-offs. It provides a `TorchDispatchMode` based context manager that estimates the memory and runtime trade-offs of functions or `torch.nn.Modules` for SAC, using the `Runtime Estimator` #134243 under the hood to support two estimation modes: 'operator-level-benchmark' and 'operator-level-cost-model' (roofline model). The SAC Estimator provides detailed statistics and metadata information for operators of each module, including greedy order for selecting operators to be recomputed/checkpointed and per-module trade-off graphs. This estimator is designed to be used under FakeTensorMode and currently supports estimation of compute time and memory usage." It's inspired from: [XFormers SAC](https://github.com/facebookresearch/xformers/blob/main/xformers/checkpoint.py) by @fmassa End-to-end example: ``` import torch from torch._subclasses.fake_tensor import FakeTensorMode from torch.distributed._tools.sac_estimator import SACEstimator from torch.testing._internal.distributed._tensor.common_dtensor import ( ModelArgs, Transformer, ) if __name__ == "__main__": dev = torch.cuda.current_device() vocab_size = 8192 bsz, seq_len = 8, 1024 model_args = ModelArgs( n_layers=4, n_heads=12, vocab_size=vocab_size, max_seq_len=seq_len, dim=768, dropout_p=0.1, ) with FakeTensorMode(): with torch.device(dev): model = Transformer(model_args) inp = torch.randint( 0, model_args.vocab_size, (bsz, model_args.max_seq_len), device=dev ) sace = SACEstimator() with sace(estimate_mode_type='operator-level-cost-model'): loss = model(inp).sum() loss.backward() sace.pwlf_sac_tradeoff_curve(n_segments=2, save_tradeoff_graphs=True) sace.display_modulewise_sac_stats(depth=4, print_tabular=True) ``` Example AC Stats for one of the transformer layers:  Example AC Trade-off for one of the transformer layers:  Example AC Trade-Off graph one of the transformer layers:  Pull Request resolved: https://github.com/pytorch/pytorch/pull/135208 Approved by: https://github.com/weifengpy
13 lines
327 B
Python
13 lines
327 B
Python
from .fsdp2_mem_tracker import FSDPMemTracker
|
|
from .mem_tracker import MemTracker
|
|
from .memory_tracker import MemoryTracker
|
|
from .mod_tracker import ModTracker
|
|
from .runtime_estimator import RuntimeEstimator
|
|
from .sac_estimator import (
|
|
MSPS,
|
|
SACEstimator,
|
|
SACGreedyOrderMeta,
|
|
SACStats,
|
|
SACTradeOffStats,
|
|
)
|