mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add PT batchnorm op to the benchmark suite (#21201)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/21201 as title Reviewed By: hl475 Differential Revision: D15482581 fbshipit-source-id: d93713a35be41e76d077df419cb24585f69d72eb
This commit is contained in:
committed by
Facebook Github Bot
parent
ed1078bde3
commit
00b3e69211
@ -1,38 +0,0 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import time
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def benchmark_batch_norm(data_shape):
|
||||
C = data_shape[1]
|
||||
x = torch.rand(data_shape)
|
||||
mean = torch.rand(C)
|
||||
var = torch.rand(C)
|
||||
weight = torch.rand(C)
|
||||
bias = torch.rand(C)
|
||||
NITER = 10000
|
||||
input_size = numpy.prod(data_shape)
|
||||
total_size = 2 * input_size + 4 * C
|
||||
for i in range(-10, NITER):
|
||||
if i == 0:
|
||||
s = time.time()
|
||||
F.batch_norm(x, mean, var, weight, bias)
|
||||
elapsed_sec = (time.time() - s) / NITER
|
||||
print(
|
||||
"batch_norm: data shape: %s, bandwidth: %.2f GB/s"
|
||||
% (data_shape, (total_size * 4) / elapsed_sec / 1e9)
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
data_shapes = [[1, 256, 3136], [1, 2 ** 16, 1], [128, 2048, 1]]
|
||||
for data_shape in data_shapes:
|
||||
benchmark_batch_norm(data_shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
42
benchmarks/operator_benchmark/ops/pt/batchnorm_test.py
Normal file
42
benchmarks/operator_benchmark/ops/pt/batchnorm_test.py
Normal file
@ -0,0 +1,42 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
"""Microbenchmarks for batchnorm operator."""
|
||||
|
||||
configs = op_bench.config_list(
|
||||
attrs=[
|
||||
[1, 256, 3136],
|
||||
[1, 2 ** 16, 1],
|
||||
[128, 2048, 1],
|
||||
],
|
||||
attr_names=["M", "N", "K"],
|
||||
tags=["short"]
|
||||
)
|
||||
|
||||
|
||||
class BatchNormBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, K):
|
||||
self.input_one = torch.rand(M, N, K)
|
||||
self.mean = torch.rand(N)
|
||||
self.var = torch.rand(N)
|
||||
self.weight = torch.rand(N)
|
||||
self.bias = torch.rand(N)
|
||||
self.set_module_name("batchnorm")
|
||||
|
||||
def forward(self):
|
||||
return F.batch_norm(self.input_one, self.mean, self.var, self.weight, self.bias)
|
||||
|
||||
|
||||
op_bench.generate_pt_test(configs, BatchNormBenchmark)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
Reference in New Issue
Block a user