add PT batchnorm op to the benchmark suite (#21201)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21201

as title

Reviewed By: hl475

Differential Revision: D15482581

fbshipit-source-id: d93713a35be41e76d077df419cb24585f69d72eb
This commit is contained in:
Mingzhe Li
2019-05-31 19:31:43 -07:00
committed by Facebook Github Bot
parent ed1078bde3
commit 00b3e69211
2 changed files with 42 additions and 38 deletions

View File

@ -1,38 +0,0 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import time
import numpy
import torch
import torch.nn.functional as F
def benchmark_batch_norm(data_shape):
C = data_shape[1]
x = torch.rand(data_shape)
mean = torch.rand(C)
var = torch.rand(C)
weight = torch.rand(C)
bias = torch.rand(C)
NITER = 10000
input_size = numpy.prod(data_shape)
total_size = 2 * input_size + 4 * C
for i in range(-10, NITER):
if i == 0:
s = time.time()
F.batch_norm(x, mean, var, weight, bias)
elapsed_sec = (time.time() - s) / NITER
print(
"batch_norm: data shape: %s, bandwidth: %.2f GB/s"
% (data_shape, (total_size * 4) / elapsed_sec / 1e9)
)
def main():
data_shapes = [[1, 256, 3136], [1, 2 ** 16, 1], [128, 2048, 1]]
for data_shape in data_shapes:
benchmark_batch_norm(data_shape)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,42 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import operator_benchmark as op_bench
import torch
import torch.nn.functional as F
"""Microbenchmarks for batchnorm operator."""
configs = op_bench.config_list(
attrs=[
[1, 256, 3136],
[1, 2 ** 16, 1],
[128, 2048, 1],
],
attr_names=["M", "N", "K"],
tags=["short"]
)
class BatchNormBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, K):
self.input_one = torch.rand(M, N, K)
self.mean = torch.rand(N)
self.var = torch.rand(N)
self.weight = torch.rand(N)
self.bias = torch.rand(N)
self.set_module_name("batchnorm")
def forward(self):
return F.batch_norm(self.input_one, self.mean, self.var, self.weight, self.bias)
op_bench.generate_pt_test(configs, BatchNormBenchmark)
if __name__ == "__main__":
op_bench.benchmark_runner.main()