Files
pytorch/torch/distributed/_tools/__init__.py
Sanket Purandare a77145ae2f Selective Activation Checkpointing (SAC) Estimator for estimating memory and recomputation time trade-offs. (#135208)
This PR adds a Selective Activation Checkpointing (SAC) Estimator, built on top of the `Runtime Estimator`, for estimating memory and recomputation time trade-offs.
It provides a `TorchDispatchMode` based context manager that estimates the memory and runtime trade-offs of functions or `torch.nn.Modules` for SAC, using the `Runtime Estimator` #134243  under the hood to support two estimation modes: 'operator-level-benchmark' and 'operator-level-cost-model' (roofline model). The SAC Estimator provides detailed statistics and metadata information for operators of each module, including greedy order for selecting operators to be recomputed/checkpointed and per-module trade-off graphs. This estimator is designed to be used under FakeTensorMode and currently supports estimation of compute time and memory usage."

It's inspired from: [XFormers SAC](https://github.com/facebookresearch/xformers/blob/main/xformers/checkpoint.py) by @fmassa

End-to-end example:

```
import torch
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed._tools.sac_estimator import SACEstimator
from torch.testing._internal.distributed._tensor.common_dtensor import (
    ModelArgs,
    Transformer,
)

if __name__ == "__main__":
    dev = torch.cuda.current_device()
    vocab_size = 8192
    bsz, seq_len = 8, 1024
    model_args = ModelArgs(
        n_layers=4,
        n_heads=12,
        vocab_size=vocab_size,
        max_seq_len=seq_len,
        dim=768,
        dropout_p=0.1,
    )
    with FakeTensorMode():
        with torch.device(dev):
            model = Transformer(model_args)
        inp = torch.randint(
            0, model_args.vocab_size, (bsz, model_args.max_seq_len), device=dev
        )

        sace = SACEstimator()
        with sace(estimate_mode_type='operator-level-cost-model'):
            loss = model(inp).sum()
        loss.backward()
        sace.pwlf_sac_tradeoff_curve(n_segments=2, save_tradeoff_graphs=True)
        sace.display_modulewise_sac_stats(depth=4, print_tabular=True)
```

  Example AC Stats for one of the transformer layers:

![Screenshot 2024-10-11 at 10 09 13 PM](https://github.com/user-attachments/assets/1cf85564-4319-4732-bba1-89d505cda6ab)

Example AC Trade-off for one of the transformer layers:

![Screenshot 2024-10-11 at 10 09 58 PM](https://github.com/user-attachments/assets/5b2f343c-7e73-4c7d-bfea-3dcef2caa362)

Example AC Trade-Off graph one of the transformer layers:

![Transformer layers 3](https://github.com/user-attachments/assets/490d4b37-a916-4298-a14c-f78ffecbbde2)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135208
Approved by: https://github.com/weifengpy
2024-10-14 13:56:40 +00:00

13 lines
327 B
Python

from .fsdp2_mem_tracker import FSDPMemTracker
from .mem_tracker import MemTracker
from .memory_tracker import MemoryTracker
from .mod_tracker import ModTracker
from .runtime_estimator import RuntimeEstimator
from .sac_estimator import (
MSPS,
SACEstimator,
SACGreedyOrderMeta,
SACStats,
SACTradeOffStats,
)