mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19997 ghimport-source-id: 420d4a68a1ef879beee2734adba8abb575e0b0ab Differential Revision: D15231375 Pulled By: ilia-cher fbshipit-source-id: ce7248ea2ebb54d25c9d831c6e3f23f3534557dd
47 lines
1.6 KiB
Python
47 lines
1.6 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from operator_benchmark import benchmark_core
|
|
|
|
import torch
|
|
|
|
"""PyTorch performance microbenchmarks.
|
|
|
|
This module contains PyTorch-specific functionalities for performance
|
|
microbenchmarks.
|
|
"""
|
|
|
|
|
|
def PyTorchOperatorTestCase(test_name, op_type, input_shapes, op_args, run_mode):
|
|
"""Benchmark Tester function for Pytorch framework.
|
|
"""
|
|
inputs = []
|
|
is_contig = 'contig' not in op_args or op_args['contig']
|
|
dtype = op_args['dtype'] if 'dtype' in op_args else torch.float32
|
|
for shape in input_shapes:
|
|
tensor_shape = list(shape)
|
|
if not is_contig:
|
|
tensor_shape = [s * 2 for s in tensor_shape]
|
|
if dtype in [torch.float32, torch.float64]: # skip float16
|
|
input = torch.rand(tensor_shape, dtype=dtype)
|
|
elif not dtype.is_floating_point:
|
|
input = torch.randint(low=0, high=100, size=tensor_shape, dtype=dtype)
|
|
else:
|
|
input = torch.ones(tensor_shape, dtype=dtype)
|
|
|
|
if not is_contig:
|
|
slices = []
|
|
for dim in tensor_shape:
|
|
slices.append(slice(0, dim, 2))
|
|
input = input[slices]
|
|
assert list(input.size()) == list(shape)
|
|
assert not input.is_contiguous()
|
|
inputs.append(input)
|
|
|
|
def benchmark_func(num_runs):
|
|
op_type(*(inputs + [num_runs]))
|
|
|
|
benchmark_core.add_benchmark_tester("PyTorch", test_name, input_shapes, op_args, run_mode, benchmark_func)
|