mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136359 Approved by: https://github.com/albanD
250 lines
7.4 KiB
Python
250 lines
7.4 KiB
Python
import itertools
|
|
import operator
|
|
|
|
import numpy as np
|
|
import scipy.special
|
|
|
|
import torch
|
|
|
|
from . import benchmark
|
|
|
|
|
|
# A template class for elementwise operations.
|
|
# A derived class will override the class instance to customize its behavior.
|
|
class ElementBench(benchmark.Benchmark):
|
|
# List of customization class variables.
|
|
op_str = None
|
|
binary_op_pt_func = None
|
|
binary_op_np_func = None
|
|
unary_op_pt_func = None
|
|
unary_op_np_func = None
|
|
split_input = True
|
|
|
|
def __init__(self, mode, device, dtype, N):
|
|
super().__init__(mode, device, dtype)
|
|
self.N = N
|
|
self.d1 = self.rand(
|
|
[N], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.d2 = self.rand(
|
|
[N], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.d3 = self.rand(
|
|
[N], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.d4 = self.rand(
|
|
[N], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.inputs = [self.d1, self.d2, self.d3, self.d4]
|
|
self.deterministic = "rand" not in self.op_str
|
|
|
|
def _eval(self, d1, d2, d3, d4, binary_op, unary_op):
|
|
if not binary_op:
|
|
|
|
def binary_op(x, y):
|
|
return x + y
|
|
|
|
if not unary_op:
|
|
|
|
def unary_op(x):
|
|
return x
|
|
|
|
if self.split_input:
|
|
d1 = unary_op(d1)
|
|
d2 = unary_op(d2)
|
|
d3 = unary_op(d3)
|
|
d4 = unary_op(d4)
|
|
else:
|
|
d2 = unary_op(d1 + 0.001)
|
|
d3 = unary_op(d1 + 0.002)
|
|
d4 = unary_op(d1 + 0.003)
|
|
d1 = unary_op(d1)
|
|
a = binary_op(d1, d2)
|
|
b = binary_op(d3, d4)
|
|
c = a + b
|
|
return c
|
|
|
|
def forward(self, d1, d2, d3, d4):
|
|
binary_op = self.__class__.binary_op_pt_func
|
|
unary_op = self.__class__.unary_op_pt_func
|
|
return self._eval(d1, d2, d3, d4, binary_op, unary_op)
|
|
|
|
def reference(self):
|
|
binary_op = self.__class__.binary_op_np_func
|
|
unary_op = self.__class__.unary_op_np_func
|
|
[d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]]
|
|
return self._eval(d1, d2, d3, d4, binary_op, unary_op)
|
|
|
|
def config(self):
|
|
return [self.N]
|
|
|
|
@classmethod
|
|
def module(cls):
|
|
return "element_" + cls.op_str
|
|
|
|
def memory_workload(self):
|
|
input_count = len(self.inputs)
|
|
if self.mode == "fwd":
|
|
if self.split_input:
|
|
sol_count = input_count + 1
|
|
algorithmic_count = input_count + 1
|
|
else:
|
|
sol_count = 1 + 1
|
|
algorithmic_count = 1 + 1
|
|
if "rand" in self.op_str:
|
|
sol_count = 1
|
|
algorithmic_count = 1
|
|
else:
|
|
if self.split_input:
|
|
sol_count = (input_count + 1) + (1 + input_count)
|
|
algorithmic_count = (input_count + 1) + ((2 + 1) * input_count)
|
|
else:
|
|
sol_count = 1 + 1
|
|
algorithmic_count = 1 + 1
|
|
if "rand" in self.op_str:
|
|
sol_count = 1
|
|
algorithmic_count = 1
|
|
|
|
buffer_size = self.N
|
|
return {
|
|
"sol": buffer_size * sol_count,
|
|
"algorithmic": buffer_size * algorithmic_count,
|
|
}
|
|
|
|
@staticmethod
|
|
def default_configs():
|
|
return [[1 << 25]]
|
|
|
|
|
|
def register_element_ops():
|
|
binary_op_list = [
|
|
["mul", operator.mul],
|
|
["add", operator.add],
|
|
["sub", operator.sub],
|
|
["div", lambda a, b: a / (b + 1e-4)],
|
|
[
|
|
"pow",
|
|
torch.pow,
|
|
np.power,
|
|
], # no fuson triggered
|
|
["max", torch.max, np.maximum],
|
|
["min", torch.min, np.minimum],
|
|
]
|
|
|
|
unary_op_list = [
|
|
["erf", torch.erf, scipy.special.erf],
|
|
["exp", torch.exp, np.exp],
|
|
["sin", torch.sin, np.sin],
|
|
["cos", torch.cos, np.cos],
|
|
["rand_like", torch.rand_like, lambda x: np.random.rand(*x.shape)],
|
|
]
|
|
|
|
for split_input, binary_op in itertools.product([True, False], binary_op_list):
|
|
# Make a copy of ElementBench
|
|
if len(binary_op) == 2:
|
|
[op_str, op_pt_func] = binary_op
|
|
op_np_func = op_pt_func
|
|
elif len(binary_op) == 3:
|
|
[op_str, op_pt_func, op_np_func] = binary_op
|
|
split_str = "split" if split_input else "shared"
|
|
op_str = split_str + "_" + op_str
|
|
bm_cls = type("ElementBench_" + op_str, (ElementBench,), {})
|
|
bm_cls.op_str = op_str
|
|
bm_cls.binary_op_pt_func = op_pt_func
|
|
bm_cls.binary_op_np_func = op_np_func
|
|
bm_cls.split_input = split_input
|
|
benchmark.register_benchmark_class(bm_cls)
|
|
|
|
for split_input, unary_op in itertools.product([True, False], unary_op_list):
|
|
# Make a copy of ElementBench
|
|
if len(unary_op) == 2:
|
|
[op_str, op_pt_func] = unary_op
|
|
op_np_func = op_pt_func
|
|
elif len(unary_op) == 3:
|
|
[op_str, op_pt_func, op_np_func] = unary_op
|
|
split_str = "split" if split_input else "shared"
|
|
op_str = split_str + "_" + op_str
|
|
bm_cls = type("ElementBench_" + op_str, (ElementBench,), {})
|
|
bm_cls.op_str = op_str
|
|
bm_cls.unary_op_pt_func = op_pt_func
|
|
bm_cls.unary_op_np_func = op_np_func
|
|
bm_cls.split_input = split_input
|
|
benchmark.register_benchmark_class(bm_cls)
|
|
|
|
|
|
# benchmark.register_benchmark_class(ElementMulBench)
|
|
register_element_ops()
|
|
|
|
|
|
class SimpleElementBench(benchmark.Benchmark):
|
|
def __init__(self, mode, device, dtype, N):
|
|
super().__init__(mode, device, dtype)
|
|
self.N = N
|
|
self.data = self.rand(
|
|
[N], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.inputs = [self.data]
|
|
|
|
def forward(self, data):
|
|
a = data + 0.001
|
|
b = a + 0.002
|
|
return b
|
|
|
|
def reference(self):
|
|
binary_op = self.__class__.binary_op_np_func
|
|
unary_op = self.__class__.unary_op_np_func
|
|
[d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]]
|
|
return self._eval(d1, d2, d3, d4, binary_op, unary_op)
|
|
|
|
def config(self):
|
|
return [self.N]
|
|
|
|
@staticmethod
|
|
def input_iterable():
|
|
return True
|
|
|
|
@classmethod
|
|
def module(cls):
|
|
return "simple_element"
|
|
|
|
def memory_workload(self):
|
|
if self.mode == "fwd":
|
|
sol_count = 2
|
|
algorithmic_count = 2
|
|
else:
|
|
sol_count = 2
|
|
algorithmic_count = 2
|
|
|
|
buffer_size = self.N
|
|
return {
|
|
"sol": buffer_size * sol_count,
|
|
"algorithmic": buffer_size * algorithmic_count,
|
|
}
|
|
|
|
@staticmethod
|
|
def default_configs():
|
|
return [[1 << 25]]
|
|
|
|
|
|
benchmark.register_benchmark_class(SimpleElementBench)
|
|
|
|
|
|
class DynamicSimpleElementBench(benchmark.DynamicShape, SimpleElementBench):
|
|
def __init__(self, mode, device, dtype, N):
|
|
benchmark.DynamicShape.__init__(self)
|
|
SimpleElementBench.__init__(self, mode, device, dtype, N)
|
|
|
|
@classmethod
|
|
def module(cls):
|
|
return "simple_dynamic_element"
|
|
|
|
def instantiate_input(self):
|
|
(N,) = self.rand_shape([self.N])
|
|
data = self.rand(
|
|
[N], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.inputs = [data]
|
|
|
|
|
|
benchmark.register_benchmark_class(DynamicSimpleElementBench)
|