mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126 Approved by: https://github.com/kit1980
321 lines
9.4 KiB
Python
321 lines
9.4 KiB
Python
import itertools
|
|
import operator
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
from . import benchmark
|
|
|
|
|
|
class BroadcastMulBench(benchmark.Benchmark):
|
|
def __init__(self, mode, device, dtype, case, M, N, K):
|
|
super().__init__(mode, device, dtype)
|
|
self.case = case
|
|
self.M = M
|
|
self.N = N
|
|
self.K = K
|
|
|
|
if case == "row":
|
|
self.d1 = self.rand(
|
|
[M, N, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.d2 = self.rand(
|
|
[M, 1, K], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
elif case == "mid":
|
|
self.d1 = self.rand(
|
|
[M, N, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.d2 = self.rand(
|
|
[1, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
elif case == "col":
|
|
self.d1 = self.rand(
|
|
[M, 1, K], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.d2 = self.rand(
|
|
[1, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
else:
|
|
raise ValueError(f"invalid case: {case}")
|
|
|
|
self.inputs = [self.d1, self.d2]
|
|
|
|
def forward(self, d1, d2):
|
|
y = d1 + d2
|
|
return y
|
|
|
|
def reference(self):
|
|
return self.numpy(self.d1) + self.numpy(self.d2)
|
|
|
|
def config(self):
|
|
return [self.M, self.N, self.K]
|
|
|
|
@staticmethod
|
|
def default_configs():
|
|
return [[128, 256, 128]]
|
|
|
|
def memory_workload(self):
|
|
if self.mode == "fwd":
|
|
sol_count = 1
|
|
algorithmic_count = 1
|
|
else:
|
|
sol_count = (1) + (1)
|
|
algorithmic_count = 1 + (1 + 1)
|
|
|
|
buffer_size = self.M * self.N * self.K
|
|
return {
|
|
"sol": buffer_size * sol_count,
|
|
"algorithmic": buffer_size * algorithmic_count,
|
|
}
|
|
|
|
|
|
class BroadcastRowBench(BroadcastMulBench):
|
|
def __init__(self, mode, device, dtype, M, N, K):
|
|
super().__init__(mode, device, dtype, "row", M, N, K)
|
|
|
|
@staticmethod
|
|
def module():
|
|
return "broadcast_row"
|
|
|
|
|
|
class BroadcastMidBench(BroadcastMulBench):
|
|
def __init__(self, mode, device, dtype, M, N, K):
|
|
super().__init__(mode, device, dtype, "mid", M, N, K)
|
|
|
|
@staticmethod
|
|
def module():
|
|
return "broadcast_mid"
|
|
|
|
|
|
class BroadcastColBench(BroadcastMulBench):
|
|
def __init__(self, mode, device, dtype, M, N, K):
|
|
super().__init__(mode, device, dtype, "col", M, N, K)
|
|
|
|
@staticmethod
|
|
def module():
|
|
return "broadcast_col"
|
|
|
|
|
|
class BroadcastThreeArgs(benchmark.Benchmark):
|
|
def __init__(self, mode, device, dtype, M, N, K, L):
|
|
super().__init__(mode, device, dtype)
|
|
self.M = M
|
|
self.N = N
|
|
self.K = K
|
|
self.L = L
|
|
|
|
self.d1 = self.rand(
|
|
[M, N], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.d2 = self.rand(
|
|
[K, M, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.d3 = self.rand(
|
|
[L, K, 1, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
|
|
self.inputs = [self.d1, self.d2, self.d3]
|
|
|
|
def forward(self, d1, d2, d3):
|
|
y = d1 + d2 + d3
|
|
return y
|
|
|
|
def reference(self):
|
|
return self.numpy(self.d1) + self.numpy(self.d2) + self.numpy(self.d3)
|
|
|
|
def config(self):
|
|
return [self.M, self.N, self.K, self.L]
|
|
|
|
@staticmethod
|
|
def default_configs():
|
|
return [[32, 16, 64, 128]]
|
|
|
|
def memory_workload(self):
|
|
if self.mode == "fwd":
|
|
sol_count = 1
|
|
algorithmic_count = 1
|
|
else:
|
|
sol_count = (1) + (1)
|
|
algorithmic_count = 1 + (1 + 1 + 1)
|
|
|
|
buffer_size = self.M * self.N * self.K * self.L * 4
|
|
return {
|
|
"sol": buffer_size * sol_count,
|
|
"algorithmic": buffer_size * algorithmic_count,
|
|
}
|
|
|
|
@staticmethod
|
|
def module():
|
|
return "broadcast_3args"
|
|
|
|
|
|
# benchmark.register_benchmark_class(BroadcastRowBench)
|
|
# benchmark.register_benchmark_class(BroadcastMidBench)
|
|
# benchmark.register_benchmark_class(BroadcastColBench)
|
|
# benchmark.register_benchmark_class(BroadcastThreeArgs)
|
|
|
|
|
|
# TODO: merge this with elementwise bench
|
|
# A template class for elementwise operations.
|
|
# A derived class will override the class instance to customize its behavior.
|
|
class BroadcastBench(benchmark.Benchmark):
|
|
# List of customization class variables.
|
|
op_str = None
|
|
binary_op_pt_func = None
|
|
binary_op_np_func = None
|
|
unary_op_pt_func = None
|
|
unary_op_np_func = None
|
|
split_input = True
|
|
|
|
def __init__(self, mode, device, dtype, M, N, K):
|
|
super().__init__(mode, device, dtype)
|
|
self.M = M
|
|
self.N = N
|
|
self.K = K
|
|
self.d1 = self.rand(
|
|
[M, N], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.d2 = self.rand(
|
|
[K, 1, N], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.d3 = self.rand(
|
|
[M, N], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.d4 = self.rand(
|
|
[K, M, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.inputs = [self.d1, self.d2, self.d3, self.d4]
|
|
|
|
def _eval(self, d1, d2, d3, d4, binary_op, unary_op):
|
|
if not binary_op:
|
|
|
|
def binary_op(x, y):
|
|
return x + y
|
|
|
|
if not unary_op:
|
|
|
|
def unary_op(x):
|
|
return x
|
|
|
|
if self.split_input:
|
|
d1 = unary_op(d1)
|
|
d2 = unary_op(d2)
|
|
d3 = unary_op(d3)
|
|
d4 = unary_op(d4)
|
|
else:
|
|
d1, d2, d3, d4 = (
|
|
unary_op(d1),
|
|
unary_op(d2),
|
|
unary_op(d1 + 0.001),
|
|
unary_op(d4),
|
|
)
|
|
a = binary_op(d1, d2)
|
|
b = binary_op(d3, d4)
|
|
c = a + b
|
|
return c
|
|
|
|
def forward(self, d1, d2, d3, d4):
|
|
binary_op = self.__class__.binary_op_pt_func
|
|
unary_op = self.__class__.unary_op_pt_func
|
|
return self._eval(d1, d2, d3, d4, binary_op, unary_op)
|
|
|
|
def reference(self):
|
|
binary_op = self.__class__.binary_op_np_func
|
|
unary_op = self.__class__.unary_op_np_func
|
|
[d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]]
|
|
return self._eval(d1, d2, d3, d4, binary_op, unary_op)
|
|
|
|
def config(self):
|
|
return [self.M, self.N, self.K]
|
|
|
|
@classmethod
|
|
def module(cls):
|
|
return "broadcast_" + cls.op_str
|
|
|
|
def memory_workload(self):
|
|
input_count = len(self.inputs)
|
|
if self.mode == "fwd":
|
|
if self.split_input:
|
|
sol_count = 1
|
|
algorithmic_count = 1
|
|
else:
|
|
sol_count = 1
|
|
algorithmic_count = 1
|
|
else:
|
|
if self.split_input:
|
|
sol_count = 1
|
|
algorithmic_count = input_count
|
|
else:
|
|
sol_count = 1
|
|
algorithmic_count = input_count
|
|
|
|
buffer_size = self.M * self.N * self.K * 4
|
|
return {
|
|
"sol": buffer_size * sol_count,
|
|
"algorithmic": buffer_size * algorithmic_count,
|
|
}
|
|
|
|
@staticmethod
|
|
def default_configs():
|
|
return [[1 << 8, 1 << 7, 1 << 9]]
|
|
|
|
|
|
def register_broadcast_ops():
|
|
binary_op_list = [
|
|
["mul", operator.mul],
|
|
["add", operator.add],
|
|
["sub", operator.sub],
|
|
["div", lambda a, b: a / (b + 1e-4)],
|
|
[
|
|
"pow",
|
|
torch.pow,
|
|
np.power,
|
|
], # no fuson triggered
|
|
["max", torch.max, np.maximum],
|
|
["min", torch.min, np.minimum],
|
|
]
|
|
|
|
unary_op_list = [
|
|
["erf", torch.erf, np.erf],
|
|
["exp", torch.exp, np.exp],
|
|
["sin", torch.sin, np.sin],
|
|
["cos", torch.cos, np.cos],
|
|
]
|
|
|
|
for split_input, binary_op in itertools.product([True, False], binary_op_list):
|
|
# Make a copy of BroadcastBench
|
|
if len(binary_op) == 2:
|
|
[op_str, op_pt_func] = binary_op
|
|
op_np_func = op_pt_func
|
|
elif len(binary_op) == 3:
|
|
[op_str, op_pt_func, op_np_func] = binary_op
|
|
split_str = "split" if split_input else "shared"
|
|
op_str = split_str + "_" + op_str
|
|
bm_cls = type("BroadcastBench_" + op_str, (BroadcastBench,), {})
|
|
bm_cls.op_str = op_str
|
|
bm_cls.binary_op_pt_func = op_pt_func
|
|
bm_cls.binary_op_np_func = op_np_func
|
|
bm_cls.split_input = split_input
|
|
benchmark.register_benchmark_class(bm_cls)
|
|
|
|
for split_input, unary_op in itertools.product([True, False], unary_op_list):
|
|
# Make a copy of BroadcastBench
|
|
if len(unary_op) == 2:
|
|
[op_str, op_pt_func] = unary_op
|
|
op_np_func = op_pt_func
|
|
elif len(unary_op) == 3:
|
|
[op_str, op_pt_func, op_np_func] = unary_op
|
|
split_str = "split" if split_input else "shared"
|
|
op_str = split_str + "_" + op_str
|
|
bm_cls = type("BroadcastBench_" + op_str, (BroadcastBench,), {})
|
|
bm_cls.op_str = op_str
|
|
bm_cls.unary_op_pt_func = op_pt_func
|
|
bm_cls.unary_op_np_func = op_np_func
|
|
bm_cls.split_input = split_input
|
|
benchmark.register_benchmark_class(bm_cls)
|
|
|
|
|
|
register_broadcast_ops()
|