mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PyTorch ThroughputBenchmark (#20766)
Summary: This is useful for measuring inference performance of your models. This is a very basic benchmark for now. We don't support batching on the benchmark side, no inter and intra op parallelizm is supported yet, just caller based parallelizm. Main phylosophy here is that user should be able to provide inputs from python and just stack them within the benchmark. API should be exactly the same as passing inputs to module.forward. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20766 Test Plan: Added a new unit test Differential Revision: D15435461 Pulled By: salexspb fbshipit-source-id: db08829dc3f4398bb1d8aa16cc4a58b6c72f16c6
This commit is contained in:
committed by
Facebook Github Bot
parent
c0f96aaf01
commit
9b45237618
79
test/test_throughput_benchmark.py
Normal file
79
test/test_throughput_benchmark.py
Normal file
@ -0,0 +1,79 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import torch
|
||||
from torch.utils import ThroughputBenchmark
|
||||
from torch.testing import assert_allclose
|
||||
|
||||
from common_utils import run_tests, TestCase
|
||||
|
||||
class TwoLayerNet(torch.jit.ScriptModule):
|
||||
def __init__(self, D_in, H, D_out):
|
||||
super(TwoLayerNet, self).__init__()
|
||||
self.linear1 = torch.nn.Linear(D_in, H)
|
||||
self.linear2 = torch.nn.Linear(2 * H, D_out)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x1, x2):
|
||||
h1_relu = self.linear1(x1).clamp(min=0)
|
||||
h2_relu = self.linear1(x2).clamp(min=0)
|
||||
cat = torch.cat((h1_relu, h2_relu), 1)
|
||||
y_pred = self.linear2(cat)
|
||||
return y_pred
|
||||
|
||||
class TwoLayerNetModule(torch.nn.Module):
|
||||
def __init__(self, D_in, H, D_out):
|
||||
super(TwoLayerNetModule, self).__init__()
|
||||
self.linear1 = torch.nn.Linear(D_in, H)
|
||||
self.linear2 = torch.nn.Linear(2 * H, D_out)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
h1_relu = self.linear1(x1).clamp(min=0)
|
||||
h2_relu = self.linear1(x2).clamp(min=0)
|
||||
cat = torch.cat((h1_relu, h2_relu), 1)
|
||||
y_pred = self.linear2(cat)
|
||||
return y_pred
|
||||
|
||||
class TestThroughputBenchmark(TestCase):
|
||||
def linear_test(self, Module):
|
||||
D_in = 10
|
||||
H = 5
|
||||
D_out = 15
|
||||
B = 8
|
||||
NUM_INPUTS = 2
|
||||
|
||||
module = Module(D_in, H, D_out)
|
||||
|
||||
inputs = []
|
||||
|
||||
for i in range(NUM_INPUTS):
|
||||
inputs.append([torch.randn(B, D_in), torch.randn(B, D_in)])
|
||||
bench = ThroughputBenchmark(module)
|
||||
|
||||
for input in inputs:
|
||||
# can do both args and kwargs here
|
||||
bench.add_input(input[0], x2=input[1])
|
||||
|
||||
for i in range(NUM_INPUTS):
|
||||
# or just unpack the list of inputs
|
||||
module_result = module(*inputs[i])
|
||||
bench_result = bench.run_once(*inputs[i])
|
||||
assert_allclose(bench_result, module_result)
|
||||
|
||||
stats = bench.benchmark(
|
||||
num_calling_threads=4,
|
||||
num_warmup_iters=100,
|
||||
num_iters=1000,
|
||||
)
|
||||
|
||||
print("Avg latency (ms): {}".format(stats.latency_avg_ms))
|
||||
print("Number of iterations: {}".format(stats.num_iters))
|
||||
|
||||
|
||||
def test_script_module(self):
|
||||
self.linear_test(TwoLayerNet)
|
||||
|
||||
def test_module(self):
|
||||
self.linear_test(TwoLayerNetModule)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
Reference in New Issue
Block a user