[doc] Rewrite benchmarks/dynamo/README.md (#115485)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115485
Approved by: https://github.com/yanboliang
This commit is contained in:
Jason Ansel
2023-12-09 12:12:33 -08:00
committed by PyTorch MergeBot
parent 8ddc549c0f
commit 4490d4692b

View File

@ -1,52 +1,91 @@
# Torchdynamo Benchmarks
# `torch.compile()` Benchmarking
## What We Benchmark
TorchDynamo provides a benchmark harness that takes care of uniformly benchmarking different models. It interleaves runs of eager and dynamo to avoid machine noise/variability issues, and reports results based on medians along with P-values.
This directory contains benchmarking code for TorchDynamo and many
backends including TorchInductor. It includes three main benchmark suites:
The runner integrates with models from TorchBenchmark, HuggingFace and TIMM suites and covers both training and inference.
- [TorchBenchmark](https://github.com/pytorch/benchmark): A diverse set of models, initially seeded from
highly cited research models as ranked by [Papers With Code](https://paperswithcode.com). See [torchbench
installation](https://github.com/pytorch/benchmark#installation) and `torchbench.py` for the low-level runner.
[Makefile](Makefile) also contains the commands needed to setup TorchBenchmark to match the versions used in
PyTorch CI.
The infrastructure allows us to specify a loss function. For torchbench models, we use .sum().backward() call in place of the native loss function. For TIMM models, we use a CrossEntropy loss. And HF models contain a loss function inside the model itself, so we don't need any special loss computation handling.
- Models from [HuggingFace](https://github.com/huggingface/transformers): Primarily transformer models, with
representative models chosen for each category available. The low-level runner (`huggingface.py`) automatically
downloads and installs the needed dependencies on first run.
Training benchmarks approximate training by running the model forward, computing loss, running backward, and then the optimizer (SGD). Note: the optimizer is currently not compiled by Torchdynamo.
- Models from [TIMM](https://github.com/huggingface/pytorch-image-models): Primarily vision models, with representative
models chosen for each category available. The low-level runner (`timm_models.py`) automatically downloads and
installs the needed dependencies on first run.
Inference benchmarks and Training benchmarks measure correctness by comparing dynamo and eager model outputs given fixed inputs and seeds.
## Setup
## GPU Performance Dashboard
### Machine
We run benchmarks on AWS machines (p4d.24xlarge) using 8xNVidia A100 40GB cards. We suggest using Cuda 11.6 for consistency.
Daily results from the benchmarks here are available in the [TorchInductor
Performance Dashboard](https://hud.pytorch.org/benchmark/compilers),
currently run on an NVIDIA A100 GPU.
### Benchmarks
Make sure to carefully follow the [torchbench installation](https://github.com/pytorch/benchmark#installation) instructions, taking care to build the auxiliary libraries (torchvision, torchtext) from a matching version to your pytorch version.
The [inductor-perf-test-nightly.yml](https://github.com/pytorch/pytorch/actions/workflows/inductor-perf-test-nightly.yml)
workflow generates the data in the performance dashboard. If you have the needed permissions, you can benchmark
your own branch on the PyTorch GitHub repo by:
1) Select "Run workflow" in the top right of the [workflow](https://github.com/pytorch/pytorch/actions/workflows/inductor-perf-test-nightly.yml)
2) Select your branch you want to benchmark
3) Choose the options (such as training vs inference)
4) Click "Run workflow"
5) Wait for the job to complete (4 to 12 hours depending on backlog)
6) Go to the [dashboard](https://hud.pytorch.org/benchmark/compilers)
7) Select your branch and commit at the top of the dashboard
For HF and TIMM models, the scripts already install the transformers and timm package respectively on the first run.
The dashboard compares two commits a "Base Commit" and a "New Commit".
An entry such as `2.38x → 2.41x` means that the performance improved
from `2.38x` in the base to `2.41x` in the new commit. All performance
results are normalized to eager mode PyTorch (`1x`), and higher is better.
## Runbook
### Basic Usage
There are a lot of flags in the benchmark runner, and it can be confusing to know which settings to use or what machine to run it on. In order to support apples-to-apples comparison, we have provided the following 'standard' settings in `runner.py`. This script is a wrapper over the common benchmarking infrastructure and simplifies the flags. We will continually update `runner.py` with the latest and most relevant compilers for training and inference. It also provides some graph utilities to visualize and compare results. Some of the example commands are:
## CPU Performance Dashboard
**Inference Commands**
* Inference compilers on torchbench models - `python benchmarks/dynamo/runner.py --suites=torchbench --inference --dtypes=float16`
* Inductor Inference compiler on torchbench models - `python benchmarks/dynamo/runner.py --suites=torchbench --inference --dtypes=float16 --compilers=inductor`
The [TorchInductor CPU Performance
Dashboard](https://github.com/pytorch/pytorch/issues/93531) is tracked
on a GitHub issue and updated periodically.
**Training Commands**
* Training compilers on TIMM models - `python benchmarks/dynamo/runner.py --suites=timm_models --training --dtypes=float32 --output-dir=timm_logs`
* AOTAutograd Training compiler on TIMM models - `python benchmarks/dynamo/runner.py --suites=timm_models --training --dtypes=float32 --compilers=aot_nvfuser --output-dir=timm_logs`
* Inductor Training compiler on TIMM models - `python benchmarks/dynamo/runner.py --suites=timm_models --training --dtypes=float32 --compilers=inductor --output-dir=timm_logs`
## Running Locally
Running runner.py generates a file named `run.sh`. This file contains the actual commands that invoke the common benchmarking infrastructure with the appropriate flags. Which brings us to the advanced usage.
Raw commands used to generate the data for
the performance dashboards can be found
[here](https://github.com/pytorch/pytorch/blob/641ec2115f300a3e3b39c75f6a32ee3f64afcf30/.ci/pytorch/test.sh#L343-L418).
### Advanced Usage
To summarize there are three scripts to run each set of benchmarks:
- `./benchmarks/dynamo/torchbench.py ...`
- `./benchmarks/dynamo/huggingface.py ...`
- `./benchmarks/dynamo/timm_models.py ...`
One could directly call `torchbench.py`, `huggingface.py` or `timm_models.py` with the necessary flags. There are a lot of flags in the benchmarks runner. Some of the examples are as follows. These are subject to change.
Each of these scripts takes the same set of arguments. The ones used by dashboards are:
- `--accuracy` or `--performance`: selects between checking correctness and measuring speedup (both are run for dashboard).
- `--training` or `--inference`: selects between measuring training or inference (both are run for dashboard).
- `--device=cuda` or `--device=cpu`: selects device to measure.
- `--amp`, `--bfloat16`, `--float16`, `--float32`: selects precision to use `--amp` is used for training and `--bfloat16` for inference.
- `--cold-start-latency`: disables caching to accurately measure compile times.
- `--backend=inductor`: selects TorchInductor as the compiler backend to measure. Many more are available, see `--help`.
- `--output=<filename>.csv`: where to write results to.
- `--dynamic-shapes --dynamic-batch-only`: used when the `dynamic` config is enabled.
- `--disable-cudagraphs`: used by configurations without cudagraphs enabled (default).
- `--freezing`: enable additional inference-only optimizations.
- `--cpp-wrapper`: enable C++ wrapper code to lower overheads.
- `TORCHINDUCTOR_MAX_AUTOTUNE=1` (environment variable): used to measure max-autotune mode, which is run weekly due to longer compile times.
- `--export-aot-inductor`: benchmarks ahead-of-time compilation mode.
- `--total-partitions` and `--partition-id`: used to parallel benchmarking across different machines.
**Inference Commands**
* TorchScript (with TorchDynamo capture) NVFuser Inference - `python benchmarks/dynamo/torchbench.py -dcuda -n100 --speedup-dynamo-ts --performance`
* TorchInductor CUDA Graphs Inference - `python benchmarks/dynamo/torchbench.py -dcuda --float32 -n50 --inductor --performance`
For debugging you can run just a single benchmark by adding the `--only=<NAME>` flag.
**Training Commands**
* TorchScript (with TorchDynamo capture) NVFuser Training - `python benchmarks/dynamo/torchbench.py --float32 -dcuda --training --nvfuser --speedup-dynamo-ts --performance`
* TorchInductor CUDA Graphs Training - `python benchmarks/dynamo/torchbench.py --float32 -dcuda --training --inductor --performance`
A complete list of options can be seen by running each of the runners with the `--help` flag.
Above commands are for torchbench models. You can simply replace `torchbench.py` with `huggingface.py` for HF models, and `timm_model.py` for TIMM models.
As an example, the commands to run first line of the dashboard (performance only) would be:
```
./benchmarks/dynamo/torchbench.py --performance --training --amp --backend=inductor --output=torchbench_training.csv
./benchmarks/dynamo/torchbench.py --performance --inference --bfloat16 --backend=inductor --output=torchbench_inference.csv
./benchmarks/dynamo/huggingface.py --performance --training --amp --backend=inductor --output=huggingface_training.csv
./benchmarks/dynamo/huggingface.py --performance --inference --bfloat16 --backend=inductor --output=huggingface_inference.csv
./benchmarks/dynamo/timm_models.py --performance --training --amp --backend=inductor --output=timm_models_training.csv
./benchmarks/dynamo/timm_models.py --performance --inference --bfloat16 --backend=inductor --output=timm_models_inference.csv
```