mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126 Approved by: https://github.com/kit1980
138 lines
4.1 KiB
Python
138 lines
4.1 KiB
Python
import numpy as np
|
|
|
|
import torch
|
|
|
|
from . import benchmark
|
|
|
|
|
|
class Concat2D2InputBench(benchmark.Benchmark):
|
|
def __init__(self, mode, device, dtype, I1_D1, I1_D2, I2_D1, I2_D2, concat_dim):
|
|
super().__init__(mode, device, dtype)
|
|
self.I1_D1 = I1_D1
|
|
self.I1_D2 = I1_D2
|
|
self.I2_D1 = I2_D1
|
|
self.I2_D2 = I2_D2
|
|
self.concat_dim = concat_dim
|
|
self.input1 = self.randn(
|
|
[I1_D1, I1_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.input2 = self.randn(
|
|
[I2_D1, I2_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.inputs = [self.input1, self.input2]
|
|
|
|
def forward(self, input1, input2):
|
|
x1 = self.add(input1, 0.00001)
|
|
x2 = self.add(input2, 0.00001)
|
|
y = self.cat((x1, x2), dim=self.concat_dim)
|
|
return y
|
|
|
|
def reference(self):
|
|
return np.concatenate(
|
|
(self.numpy(self.input1), self.numpy(self.input2)),
|
|
axis=self.concat_dim,
|
|
)
|
|
|
|
def config(self):
|
|
return [self.I1_D1, self.I1_D2, self.I2_D1, self.I2_D2, self.concat_dim]
|
|
|
|
@staticmethod
|
|
def module():
|
|
return "concat2d2input"
|
|
|
|
def memory_workload(self):
|
|
if self.mode == "fwd":
|
|
sol_count = 1 + 1
|
|
algorithmic_count = 3 + 1
|
|
else:
|
|
sol_count = (1 + 1) + (1 + 1)
|
|
algorithmic_count = (3 + 1) + (3 + 1)
|
|
|
|
buffer_size = self.I1_D1 * self.I1_D2 + self.I2_D1 * self.I2_D2
|
|
return {
|
|
"sol": buffer_size * sol_count,
|
|
"algorithmic": buffer_size * algorithmic_count,
|
|
}
|
|
|
|
@staticmethod
|
|
def default_configs():
|
|
return [
|
|
[1, 160, 1, 14, 1],
|
|
[1, 580, 1, 174, 1],
|
|
[20, 160, 20, 14, 1],
|
|
[20, 580, 20, 174, 1],
|
|
[8, 512, 8, 512, 1],
|
|
[1 << 13, 1060, 1 << 13, 1040, 1],
|
|
[1 << 13, 2000, 1 << 13, 1074, 1],
|
|
[1 << 15, 1060, 1 << 15, 2670, 1],
|
|
[1 << 15, 5120, 1 << 15, 2512, 1],
|
|
]
|
|
|
|
|
|
benchmark.register_benchmark_class(Concat2D2InputBench)
|
|
|
|
|
|
class ConcatGraphOptBench(benchmark.Benchmark):
|
|
def __init__(self, mode, device, dtype, I1_D1, I1_D2, I2_D1, I2_D2, concat_dim):
|
|
super().__init__(mode, device, dtype)
|
|
self.I1_D1 = I1_D1
|
|
self.I1_D2 = I1_D2
|
|
self.I2_D1 = I2_D1
|
|
self.I2_D2 = I2_D2
|
|
self.concat_dim = concat_dim
|
|
self.input1 = self.randn(
|
|
[I1_D1, I1_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.input2 = self.randn(
|
|
[I2_D1, I2_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
|
|
)
|
|
self.inputs = [self.input1, self.input2]
|
|
torch._C._jit_override_can_fuse_on_cpu(True)
|
|
torch._C._jit_cat_wo_conditionals(True)
|
|
|
|
def forward(self, input1, input2):
|
|
x1 = self.add(input1, 0.00001)
|
|
x2 = self.add(input2, 0.00001)
|
|
y = self.cat((x1, x2), dim=self.concat_dim)
|
|
z = self.relu(y)
|
|
return z
|
|
|
|
def reference(self):
|
|
return np.concatenate(
|
|
(self.numpy(self.input1), self.numpy(self.input2)),
|
|
axis=self.concat_dim,
|
|
)
|
|
|
|
def config(self):
|
|
return [self.I1_D1, self.I1_D2, self.I2_D1, self.I2_D2, self.concat_dim]
|
|
|
|
@staticmethod
|
|
def module():
|
|
return "concatGraphOpt"
|
|
|
|
def memory_workload(self):
|
|
if self.mode == "fwd":
|
|
sol_count = 1 + 1
|
|
algorithmic_count = 3 + 1
|
|
else:
|
|
sol_count = (1 + 1) + (1 + 1)
|
|
algorithmic_count = (3 + 1) + (3 + 1)
|
|
|
|
buffer_size = self.I1_D1 * self.I1_D2 + self.I2_D1 * self.I2_D2
|
|
return {
|
|
"sol": buffer_size * sol_count,
|
|
"algorithmic": buffer_size * algorithmic_count,
|
|
}
|
|
|
|
@staticmethod
|
|
def default_configs():
|
|
return [
|
|
[1 << 13, 1060, 1 << 13, 1040, 1],
|
|
[1 << 13, 2000, 1 << 13, 1074, 1],
|
|
[1 << 15, 1060, 1 << 15, 2670, 1],
|
|
[1 << 15, 5120, 1 << 15, 2512, 1],
|
|
]
|
|
|
|
|
|
benchmark.register_benchmark_class(ConcatGraphOptBench)
|