Commit Graph

6 Commits

Author SHA1 Message Date
995df34b19 [BE][PYFMT] migrate PYFMT for torch.{distributed,distributions} to ruff format (#144547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547
Approved by: https://github.com/kwen2501
2025-02-28 07:35:56 +00:00
00ffeca1b1 PEP585 update - torch/distributed (#145164)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164
Approved by: https://github.com/bobrenjc93
2025-01-21 04:23:29 +00:00
6374332d33 Revert "PEP585 update - torch/distributed (#145164)"
This reverts commit 6cb186e279bc179a6bb63f0226e24ab42a07b394.

Reverted https://github.com/pytorch/pytorch/pull/145164 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is failing an inductor test ([comment](https://github.com/pytorch/pytorch/pull/145164#issuecomment-2602875679))
2025-01-20 16:46:46 +00:00
6cb186e279 PEP585 update - torch/distributed (#145164)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145164
Approved by: https://github.com/bobrenjc93
2025-01-20 00:19:01 +00:00
08be9ec312 Migrate from Tuple -> tuple in torch/distributed (#144258)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144258
Approved by: https://github.com/aorenste
2025-01-10 08:34:54 +00:00
de35d3062f Runtime Estimator for estimating GPU compute time (#134243)
This PR adds a basic Runtime Estimator for single-device models.
It estimates the GPU runtime in milliseconds using various estimation methods under the ``FakeTensorMode``.
It provides a ``TorchDispatchMode`` based context manager that can estimate the eager runtime of PyTorch functions. It supports two estimation modes, benchmarking (`operator-level-benchmark`) and roofline cost modeling (`operator-level-cost-model`).
For modules executed under this context manager, it agggregates the forward and backward operation runtimes and records their execution orders.

```
import torch
from torch import nn, optim
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed._tools.runtime_estimator import RuntimeEstimator
from torch.testing._internal.distributed._tensor.common_dtensor import (
    ModelArgs,
    Transformer,
)

if __name__ == "__main__":
    def _train_step(
        model: nn.Module,
        optimizer: optim.Optimizer,
        inp: torch.Tensor,
    ):
        out = model(inp)
        loss = out.sum()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    dev = torch.cuda.current_device()
    vocab_size = 8192
    bsz, seq_len = 32, 1024
    model_args = ModelArgs(
        n_layers=4,
        n_heads=12,
        vocab_size=vocab_size,
        max_seq_len=seq_len,
        dim=768,
        dropout_p=0.1,
    )
    runtime_estimator = RuntimeEstimator()

    with FakeTensorMode():
        with torch.device(dev):
            model = Transformer(model_args)
        optimizer = optim.Adam(model.parameters(), lr=1e-2, foreach=True)
        inp = torch.randint(0, model_args.vocab_size, (bsz, model_args.max_seq_len), device=dev)
        with runtime_estimator("operator-level-benchmark"):
            _train_step(model, optimizer, inp)
        with runtime_estimator("operator-level-cost-model"):
            _train_step(model, optimizer, inp)

    # Actual model runtime
    with torch.device(dev):
        model = Transformer(model_args)
    optimizer = optim.Adam(model.parameters(), lr=1e-2, foreach=True)
    inp = torch.randint(0, model_args.vocab_size, (bsz, model_args.max_seq_len), device=dev)
    warmup_iters, actual_iters = 2, 5
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    for _ in range(warmup_iters):
        _train_step(model, optimizer, inp)
    start_event.record()
    for _ in range(actual_iters):
        _train_step(model, optimizer, inp)
    end_event.record()
    torch.cuda.synchronize()
    measured_time = start_event.elapsed_time(end_event) / actual_iters
    print(f"Actual total_time: {measured_time:.3f} ms")
  ```

<img width="506" alt="Screenshot 2024-08-26 at 11 27 15 PM" src="https://github.com/user-attachments/assets/04d243c9-21a6-4389-8c20-80958980788c">

@weifengpy @xuanzhang816 @gnadathur

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134243
Approved by: https://github.com/weifengpy
2024-08-28 20:06:54 +00:00