mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105429 Approved by: https://github.com/malfet
107 lines
2.9 KiB
Python
107 lines
2.9 KiB
Python
from . import benchmark
|
|
|
|
|
|
class ConvImplBench(benchmark.Benchmark):
|
|
def __init__(self, case, mode, device, dtype, kernel_size, N, iC, H, W, oC):
|
|
super().__init__(mode, device, dtype)
|
|
self.case = case
|
|
self.kernel_size = kernel_size
|
|
self.N = N
|
|
self.iC = iC
|
|
self.H = H
|
|
self.W = W
|
|
self.oC = oC
|
|
self.data = self.rand(
|
|
[N, iC, H, W], device=device, requires_grad=self.requires_grad
|
|
)
|
|
if case == "conv":
|
|
self.groups = 1
|
|
elif case == "depthwise_conv":
|
|
self.groups = iC
|
|
else:
|
|
raise ValueError(f"invalid case: {case}")
|
|
|
|
self.conv = self.conv2d_layer(iC, oC, kernel_size, groups=self.groups)
|
|
if device != "cpu":
|
|
self.to_device(self.conv, device)
|
|
|
|
def forward(self):
|
|
y = self.conv(self.data)
|
|
return y
|
|
|
|
def config(self):
|
|
return [self.kernel_size, self.N, self.iC, self.H, self.W, self.oC]
|
|
|
|
def memory_workload(self):
|
|
if self.mode == "fwd":
|
|
sol_count = {"i": 1, "o": 1, "k": 1}
|
|
algorithmic_count = {"i": 1, "o": 1, "k": 1}
|
|
else:
|
|
sol_count = {"i": 1 + 1, "o": 1 + 1, "k": 1 + 1}
|
|
algorithmic_count = {"i": 1 + (1 + 1), "o": 1 + (1 + 1), "k": 1 + (1 + 1)}
|
|
|
|
buffer_size = {
|
|
"i": self.N * self.iC * self.H * self.W,
|
|
"o": self.N * self.oC * self.H * self.W,
|
|
"k": self.oC
|
|
* (self.iC / self.groups)
|
|
* self.kernel_size
|
|
* self.kernel_size,
|
|
}
|
|
sol_size = 0
|
|
algorithmic_size = 0
|
|
for key in sol_count:
|
|
sol_size += buffer_size[key] * sol_count[key]
|
|
algorithmic_size += buffer_size[key] * algorithmic_count[key]
|
|
return {"sol": sol_size, "algorithmic": algorithmic_size}
|
|
|
|
def compute_workload(self):
|
|
if self.mode == "fwd":
|
|
count = 1
|
|
elif self.mode == "both":
|
|
count = 1 + (1 + 1)
|
|
else:
|
|
raise ValueError(f"invalid mode: {self.mode}")
|
|
|
|
op_count = (
|
|
self.N
|
|
* self.iC
|
|
/ self.groups
|
|
* self.oC
|
|
* self.kernel_size
|
|
* self.kernel_size
|
|
* self.H
|
|
* self.W
|
|
)
|
|
op_count *= 2
|
|
|
|
return op_count * count
|
|
|
|
@staticmethod
|
|
def default_configs():
|
|
return [
|
|
[3, 64, 32, 128, 128, 64],
|
|
]
|
|
|
|
|
|
class ConvBench(ConvImplBench):
|
|
def __init__(self, *args):
|
|
super().__init__("conv", *args)
|
|
|
|
@staticmethod
|
|
def module():
|
|
return "conv"
|
|
|
|
|
|
class DepthwiseConvBench(ConvImplBench):
|
|
def __init__(self, *args):
|
|
super().__init__("depthwise_conv", *args)
|
|
|
|
@staticmethod
|
|
def module():
|
|
return "depthwise_conv"
|
|
|
|
|
|
benchmark.register_benchmark_class(ConvBench)
|
|
benchmark.register_benchmark_class(DepthwiseConvBench)
|