Files
pytorch/benchmarks/operator_benchmark/benchmark_pytorch.py
Ilia Cherniavskii 19e6886576 Intra-op parallel microbenchmarks for PT (#19997)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19997
ghimport-source-id: 420d4a68a1ef879beee2734adba8abb575e0b0ab

Differential Revision: D15231375

Pulled By: ilia-cher

fbshipit-source-id: ce7248ea2ebb54d25c9d831c6e3f23f3534557dd
2019-05-06 20:21:45 -07:00

47 lines
1.6 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from operator_benchmark import benchmark_core
import torch
"""PyTorch performance microbenchmarks.
This module contains PyTorch-specific functionalities for performance
microbenchmarks.
"""
def PyTorchOperatorTestCase(test_name, op_type, input_shapes, op_args, run_mode):
"""Benchmark Tester function for Pytorch framework.
"""
inputs = []
is_contig = 'contig' not in op_args or op_args['contig']
dtype = op_args['dtype'] if 'dtype' in op_args else torch.float32
for shape in input_shapes:
tensor_shape = list(shape)
if not is_contig:
tensor_shape = [s * 2 for s in tensor_shape]
if dtype in [torch.float32, torch.float64]: # skip float16
input = torch.rand(tensor_shape, dtype=dtype)
elif not dtype.is_floating_point:
input = torch.randint(low=0, high=100, size=tensor_shape, dtype=dtype)
else:
input = torch.ones(tensor_shape, dtype=dtype)
if not is_contig:
slices = []
for dim in tensor_shape:
slices.append(slice(0, dim, 2))
input = input[slices]
assert list(input.size()) == list(shape)
assert not input.is_contiguous()
inputs.append(input)
def benchmark_func(num_runs):
op_type(*(inputs + [num_runs]))
benchmark_core.add_benchmark_tester("PyTorch", test_name, input_shapes, op_args, run_mode, benchmark_func)