PyTorch ThroughputBenchmark (#20766)

Summary:
This is useful for measuring inference performance of your
models. This is a very basic benchmark for now. We don't support
batching on the benchmark side, no inter and intra op parallelizm is
supported yet, just caller based parallelizm.

Main phylosophy here is that user should be able to provide inputs
from python and just stack them within the benchmark. API should be
exactly the same as passing inputs to module.forward.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20766

Test Plan: Added a new unit test

Differential Revision: D15435461

Pulled By: salexspb

fbshipit-source-id: db08829dc3f4398bb1d8aa16cc4a58b6c72f16c6
This commit is contained in:
Alexander Sidorov
2019-06-23 12:49:30 -07:00
committed by Facebook Github Bot
parent c0f96aaf01
commit 9b45237618
12 changed files with 674 additions and 0 deletions

View File

@ -0,0 +1,79 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
from torch.utils import ThroughputBenchmark
from torch.testing import assert_allclose
from common_utils import run_tests, TestCase
class TwoLayerNet(torch.jit.ScriptModule):
def __init__(self, D_in, H, D_out):
super(TwoLayerNet, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(2 * H, D_out)
@torch.jit.script_method
def forward(self, x1, x2):
h1_relu = self.linear1(x1).clamp(min=0)
h2_relu = self.linear1(x2).clamp(min=0)
cat = torch.cat((h1_relu, h2_relu), 1)
y_pred = self.linear2(cat)
return y_pred
class TwoLayerNetModule(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super(TwoLayerNetModule, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(2 * H, D_out)
def forward(self, x1, x2):
h1_relu = self.linear1(x1).clamp(min=0)
h2_relu = self.linear1(x2).clamp(min=0)
cat = torch.cat((h1_relu, h2_relu), 1)
y_pred = self.linear2(cat)
return y_pred
class TestThroughputBenchmark(TestCase):
def linear_test(self, Module):
D_in = 10
H = 5
D_out = 15
B = 8
NUM_INPUTS = 2
module = Module(D_in, H, D_out)
inputs = []
for i in range(NUM_INPUTS):
inputs.append([torch.randn(B, D_in), torch.randn(B, D_in)])
bench = ThroughputBenchmark(module)
for input in inputs:
# can do both args and kwargs here
bench.add_input(input[0], x2=input[1])
for i in range(NUM_INPUTS):
# or just unpack the list of inputs
module_result = module(*inputs[i])
bench_result = bench.run_once(*inputs[i])
assert_allclose(bench_result, module_result)
stats = bench.benchmark(
num_calling_threads=4,
num_warmup_iters=100,
num_iters=1000,
)
print("Avg latency (ms): {}".format(stats.latency_avg_ms))
print("Number of iterations: {}".format(stats.num_iters))
def test_script_module(self):
self.linear_test(TwoLayerNet)
def test_module(self):
self.linear_test(TwoLayerNetModule)
if __name__ == '__main__':
run_tests()