mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[RELAND] Add __torch_function__ benchmarks (#36138)
Summary: Re-land of https://github.com/pytorch/pytorch/issues/35530 and https://github.com/pytorch/pytorch/issues/34645 Pull Request resolved: https://github.com/pytorch/pytorch/pull/36138 Differential Revision: D20893770 Pulled By: ezyang fbshipit-source-id: 75ab688a086f5fb87412a853df5246c0c39704ca
This commit is contained in:
committed by
Facebook GitHub Bot
parent
3aeb2b1562
commit
7c825bad10
@ -218,6 +218,18 @@ test_custom_script_ops() {
|
||||
fi
|
||||
}
|
||||
|
||||
test_torch_function_benchmark() {
|
||||
echo "Testing __torch_function__ benchmarks"
|
||||
pushd benchmarks/overrides_benchmark
|
||||
python bench.py -n 1 -m 2
|
||||
python pyspybench.py Tensor -n 1
|
||||
python pyspybench.py SubTensor -n 1
|
||||
python pyspybench.py WithTorchFunction -n 1
|
||||
python pyspybench.py SubWithTorchFunction -n 1
|
||||
popd
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_xla() {
|
||||
export XLA_USE_XRT=1 XRT_DEVICE_MAP="CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0"
|
||||
# Issue #30717: randomize the port of XLA/gRPC workers is listening on to reduce flaky tests.
|
||||
@ -286,6 +298,7 @@ elif [[ "${BUILD_ENVIRONMENT}" == *-test2 || "${JOB_BASE_NAME}" == *-test2 ]]; t
|
||||
test_aten
|
||||
test_libtorch
|
||||
test_custom_script_ops
|
||||
test_torch_function_benchmark
|
||||
elif [[ "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then
|
||||
test_bazel
|
||||
else
|
||||
@ -295,4 +308,5 @@ else
|
||||
test_aten
|
||||
test_libtorch
|
||||
test_custom_script_ops
|
||||
test_torch_function_benchmark
|
||||
fi
|
||||
|
40
benchmarks/overrides_benchmark/README.md
Normal file
40
benchmarks/overrides_benchmark/README.md
Normal file
@ -0,0 +1,40 @@
|
||||
# `__torch_function__` micro-benchmarks
|
||||
|
||||
This benchmark suite provides a systemic way to measure the performance of `__torch_function__` overhead.
|
||||
|
||||
## Getting started
|
||||
### Initial Setup
|
||||
Install `py-spy` by doing:
|
||||
|
||||
```bash
|
||||
pip install py-spy
|
||||
```
|
||||
|
||||
Note that more extensive documentation on using `py-spy` is available in `CONTRIBUTING.md`.
|
||||
|
||||
### Running the benchmark
|
||||
Run one of the following commands in the terminal, with the working directory being `${PYTORCH_CLONE_DIR}/benchmarks/overrides_benchmark`:
|
||||
|
||||
```bash
|
||||
# Benchmark all the cases
|
||||
python bench.py
|
||||
|
||||
# Flame graph pertaining to each case.
|
||||
py-spy record -o tensor.svg --native -- python pyspybench.py Tensor
|
||||
py-spy record -o subtensor.svg --native -- python pyspybench.py SubTensor
|
||||
py-spy record -o overridden.svg --native -- python pyspybench.py WithTorchFunction
|
||||
py-spy record -o suboverridden.svg --native -- python pyspybench.py SubWithTorchFunction
|
||||
```
|
||||
|
||||
Here is a brief overview of what the results should look like, if run correctly:
|
||||
|
||||
* Overhead for `torch` functions when run on `torch.Tensor` objects is on the order of 2 μs.
|
||||
* `__torch_function__` should add zero overhead for `torch.Tensor` inputs, a small overhead for subclasses of `torch.Tensor`, and a couple of microseconds for `Tensor`-likes with `__torch_function__`.
|
||||
* Changing the dispatching mechanism may result in changes that are on the order of 100 ns, which are hard to detect due to noise, but important.
|
||||
|
||||
## Reporting benchmark results
|
||||
When modifying any of the machinery around `__torch_function__`, run the benchmark for both the feature branch and the point it diverges from `master`. For each of these:
|
||||
|
||||
* Run `bench.py`, and include the output in your result.
|
||||
* For each case where `bench.py` shows a regression, run the commands described above, prefixing the output SVG filename (the input to the `-o` switch) with `base-` or `branch-` depending on the commit you are running the benchmark on.
|
||||
* For each SVG, open it in the browser, take a screenshot and include it in your result. Also include a ZIP file with all SVGs thus produced included.
|
67
benchmarks/overrides_benchmark/bench.py
Normal file
67
benchmarks/overrides_benchmark/bench.py
Normal file
@ -0,0 +1,67 @@
|
||||
import torch
|
||||
import time
|
||||
import argparse
|
||||
|
||||
from common import SubTensor, WithTorchFunction, SubWithTorchFunction
|
||||
|
||||
NUM_REPEATS = 1000
|
||||
NUM_REPEAT_OF_REPEATS = 1000
|
||||
|
||||
|
||||
def bench(t1, t2):
|
||||
bench_times = []
|
||||
for _ in range(NUM_REPEAT_OF_REPEATS):
|
||||
time_start = time.time()
|
||||
for _ in range(NUM_REPEATS):
|
||||
torch.add(t1, t2)
|
||||
bench_times.append(time.time() - time_start)
|
||||
|
||||
bench_time = float(torch.min(torch.Tensor(bench_times))) / 1000
|
||||
bench_std = float(torch.std(torch.Tensor(bench_times))) / 1000
|
||||
|
||||
return bench_time, bench_std
|
||||
|
||||
|
||||
def main():
|
||||
global NUM_REPEATS
|
||||
global NUM_REPEAT_OF_REPEATS
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run the __torch_function__ benchmarks."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nreps",
|
||||
"-n",
|
||||
type=int,
|
||||
default=NUM_REPEATS,
|
||||
help="The number of repeats for one measurement.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nrepreps",
|
||||
"-m",
|
||||
type=int,
|
||||
default=NUM_REPEAT_OF_REPEATS,
|
||||
help="The number of measurements.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
NUM_REPEATS = args.nreps
|
||||
NUM_REPEAT_OF_REPEATS = args.nrepreps
|
||||
|
||||
types = torch.Tensor, SubTensor, WithTorchFunction, SubWithTorchFunction
|
||||
|
||||
for t in types:
|
||||
tensor_1 = t(1)
|
||||
tensor_2 = t(2)
|
||||
|
||||
bench_min, bench_std = bench(tensor_1, tensor_2)
|
||||
print(
|
||||
"Type {0} had a minimum time of {1} us"
|
||||
" and a standard deviation of {2} us.".format(
|
||||
t.__name__, (10 ** 6 * bench_min), (10 ** 6) * bench_std
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
31
benchmarks/overrides_benchmark/common.py
Normal file
31
benchmarks/overrides_benchmark/common.py
Normal file
@ -0,0 +1,31 @@
|
||||
import torch
|
||||
|
||||
NUM_REPEATS = 1000
|
||||
NUM_REPEAT_OF_REPEATS = 1000
|
||||
|
||||
|
||||
class SubTensor(torch.Tensor):
|
||||
pass
|
||||
|
||||
|
||||
class WithTorchFunction:
|
||||
def __init__(self, data, requires_grad=False):
|
||||
if isinstance(data, torch.Tensor):
|
||||
self._tensor = data
|
||||
return
|
||||
|
||||
self._tensor = torch.Tensor(data, requires_grad)
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
return WithTorchFunction(args[0]._tensor + args[1]._tensor)
|
||||
|
||||
|
||||
class SubWithTorchFunction(torch.Tensor):
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
return args[0] + args[1]
|
28
benchmarks/overrides_benchmark/pyspybench.py
Normal file
28
benchmarks/overrides_benchmark/pyspybench.py
Normal file
@ -0,0 +1,28 @@
|
||||
import torch
|
||||
import argparse
|
||||
from common import SubTensor, WithTorchFunction, SubWithTorchFunction # noqa: F401
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
NUM_REPEATS = 1000000
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run the torch.add for a given class a given number of times."
|
||||
)
|
||||
parser.add_argument(
|
||||
"tensor_class", metavar="TensorClass", type=str, help="The class to benchmark."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nreps", "-n", type=int, default=NUM_REPEATS, help="The number of repeats."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
TensorClass = globals()[args.tensor_class]
|
||||
NUM_REPEATS = args.nreps
|
||||
|
||||
t1 = TensorClass(1)
|
||||
t2 = TensorClass(2)
|
||||
|
||||
for _ in range(NUM_REPEATS):
|
||||
torch.add(t1, t2)
|
@ -13,6 +13,9 @@ https://github.com/pytorch/pytorch/issues/24015 and
|
||||
https://www.numpy.org/neps/nep-0018-array-function-protocol.html
|
||||
)
|
||||
|
||||
If changing this file in a way that can affect ``__torch_function__`` overhead,
|
||||
please report the benchmarks in ``benchmarks/overrides_benchmark``. See the
|
||||
instructions in the ``README.md`` in that directory.
|
||||
"""
|
||||
|
||||
import __future__
|
||||
|
@ -215,6 +215,10 @@ auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyOb
|
||||
* precedence.
|
||||
*
|
||||
* 'obj' is an object to check for a __torch_function__ implementation
|
||||
*
|
||||
* If changing this file in a way that can affect the __torch_function__
|
||||
* overhead, please report the benchmarks in 'benchmarks/overrides_benchmark'.
|
||||
* See the instructions in the 'README.md' in that directory.
|
||||
*
|
||||
*/
|
||||
|
||||
|
Reference in New Issue
Block a user