Files
pytorch/benchmarks/operator_benchmark/benchmark_test_generator.py
Ilia Cherniavskii eecf52b444 Fix in benchmark_test_generator (#20237)
Summary:
Add missing import
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20237

Differential Revision: D15245957

Pulled By: ilia-cher

fbshipit-source-id: 0f71aa08eb9ecac32002a1644838d06ab9faa37c
2019-05-07 17:03:25 -07:00

93 lines
3.1 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from operator_benchmark.benchmark_caffe2 import Caffe2OperatorTestCase
from operator_benchmark.benchmark_pytorch import PyTorchOperatorTestCase
from operator_benchmark.benchmark_utils import * # noqa
import torch
def generate_test(configs, map_config, ops, OperatorTestCase):
"""
This function is used to create PyTorch/Caffe2 operators based on configs.
configs usually include both long_config and short_config and they will be
mapped to input_shapes and args which are ready to be digested by an operator.
OperatorTestCase is used to create an operator with inputs/outputs and args.
"""
for config in configs:
for case in config:
shapes = {}
for item in case:
if 'mode' in item:
run_mode = item['mode']
continue
shapes.update(item)
assert run_mode is not None, "Missing mode in configs"
for op in ops:
shapes_args = map_config(test_name=op[0], **shapes)
if shapes_args is not None:
OperatorTestCase(
test_name=op[0],
op_type=op[1],
input_shapes=shapes_args[0],
op_args=shapes_args[1],
run_mode=run_mode)
def generate_pt_test(configs, pt_map_func, pt_ops):
"""
This function creates PyTorch operators which will be benchmarked.
"""
generate_test(configs, pt_map_func, pt_ops, PyTorchOperatorTestCase)
def generate_c2_test(configs, c2_map_func, c2_ops):
"""
This function creates Caffe2 operators which will be benchmarked.
"""
generate_test(configs, c2_map_func, c2_ops, Caffe2OperatorTestCase)
def map_c2_config_add(test_name, M, N, K):
input_one = (M, N, K)
input_two = (M, N, K)
input_shapes = [input_one, input_two]
args = {}
return (input_shapes, args)
map_pt_config_add = map_c2_config_add
def map_c2_config_matmul(test_name, M, N, K, trans_a, trans_b, contig, dtype):
if not contig or dtype != torch.float32:
return None
input_one = (N, M) if trans_a else (M, N)
input_two = (K, N) if trans_b else (N, K)
input_shapes = [input_one, input_two]
args = {'trans_a': trans_a, 'trans_b': trans_b}
return (input_shapes, args)
def map_pt_config_matmul(test_name, M, N, K, trans_a, trans_b, contig, dtype):
if trans_a or trans_b:
return None
input_shapes = [(M, N), (N, K)]
args = {'contig': contig, 'dtype': dtype}
return (input_shapes, args)
def map_pt_config_intraop(test_name, N, M, contig, dtype):
if test_name in ['bitor', 'cbitor']:
if dtype.is_floating_point:
return None
if test_name in ['tanh', 'sigmoid', 'sumall']:
if not dtype.is_floating_point:
return None
input_shapes = [(N, M), (N, M)]
args = {'contig': contig, 'dtype': dtype}
return (input_shapes, args)