mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
**Problem:** Fusion can accumulate large amount of reads, which leads to significant increase in peak memory utilization. Imagine we have the following code snippet ``` total = torch.rand(N, N) for _ in range(r): x = torch.rand(N, N) total = total + x ``` The default execution is memory efficient as only two tensors of size N-by-N is in memory at any given time. However, with fusion, the additions are fused into a single operation and the execution becomes something like: ``` x_1 = torch.rand(N, N) x_2 = torch.rand(N, N) ... x_r = torch.rand(N, N) total = x_1 + x_2 + ... + x_r ``` Though this is run-time efficient, in the case of large `N` and/or large `r`, this is not memory efficient. [internal only] see [post](https://fb.workplace.com/groups/1075192433118967/permalink/1703374333634104/) for additional details **Solution:** Our proposed solution is to ban fusions in case where a large amount of reads are accumulated. This is in addition to some existing logics during torch compile. * During lowering (i.e., `ir.py`), the config `realize_acc_reads_threshold`, which is default to be 8, controls _the number of_ buffers can be accumulated for a single operator. However, this is oblivious to the size of the buffers. Hence, we additionally introduce a config `realize_acc_reads_size_threshold` to control _the amount of buffers_ in size that can be accumulated. * During scheduling (i.e., `scheduler.py`), additional fusion will be performed and thus we also need to capture such pattern there. The decisions are implemented under `choices.py`. **Results:** For a small example similar to be one in the test case (but with larger `N` and higher number of loop repeats), the memory snapshot before and after are shown below. Note the snapshot on the right is zoomed out so that the y-axis of the two snapshots match. <img width="1328" alt="image" src="https://github.com/user-attachments/assets/670b5961-8454-4379-ae0f-62d4e7946c64" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/157563 Approved by: https://github.com/jansel, https://github.com/mlazos
PyTorch Benchmarks
This folder contains scripts that produce reproducible timings of various PyTorch features.
It also provides mechanisms to compare PyTorch with other frameworks.
Setup environment
Make sure you're on a machine with CUDA, torchvision, and pytorch installed. Install in the following order:
# Install torchvision. It comes with the pytorch stable release binary
python -m pip install torch torchvision
# Install the latest pytorch master from source.
# It should supersede the installation from the release binary.
cd $PYTORCH_HOME
python -m pip install --no-build-isolation -v -e .
# Check the pytorch installation version
python -c "import torch; print(torch.__version__)"
Benchmark List
Please refer to each subfolder to discover each benchmark suite. Links are provided where descriptions exist: