mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
The list of added operators: add_, addcmul, arange, baddbmm…, bmm, clamp, div, div_, gelu, index_add, logical_and, mul_, sub_, topk, where This pull request is the same as a previous one: https://github.com/pytorch/pytorch/pull/145121 which inadvertently got deleted while merging. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145625 Approved by: https://github.com/jeffdaily
124 lines
2.9 KiB
Python
124 lines
2.9 KiB
Python
import operator_benchmark as op_bench
|
|
|
|
import torch
|
|
|
|
|
|
"""
|
|
Microbenchmarks for batch matrix mult with einsum and torch.bmm.
|
|
"""
|
|
|
|
batch_mm_configs_short = op_bench.config_list(
|
|
attr_names=["B", "M", "N", "K"],
|
|
attrs=[
|
|
[4, 5, 3, 2],
|
|
[32, 25, 20, 30],
|
|
[128, 100, 120, 110],
|
|
],
|
|
cross_product_configs={
|
|
"device": ["cpu", "cuda"],
|
|
},
|
|
tags=["short"],
|
|
)
|
|
|
|
batch_mm_configs_long = op_bench.config_list(
|
|
attr_names=["B", "M", "N", "K"],
|
|
attrs=[
|
|
[128, 256, 128, 256],
|
|
[512, 1024, 1024, 512],
|
|
],
|
|
cross_product_configs={
|
|
"device": ["cpu", "cuda"],
|
|
},
|
|
tags=["long"],
|
|
)
|
|
|
|
batch_mm_op_list = op_bench.op_list(
|
|
attr_names=["op_name", "op_func"],
|
|
attrs=[
|
|
["einsum_bmm", torch.einsum],
|
|
# ["bmm", torch.bmm],
|
|
],
|
|
)
|
|
|
|
|
|
class BatchMatrixMultBenchmark(op_bench.TorchBenchmarkBase):
|
|
def init(self, B, M, N, K, device, op_func):
|
|
self.inputs = {
|
|
"input_one": torch.rand(B, M, N, device=device),
|
|
"input_two": torch.rand(B, N, K, device=device),
|
|
}
|
|
self.op_func = op_func
|
|
|
|
def forward(self, input_one, input_two):
|
|
if self.op_func.__name__ == "einsum":
|
|
return torch.einsum("bij,bjk->bik", input_one, input_two)
|
|
else:
|
|
return torch.bmm(input_one, input_two)
|
|
|
|
|
|
"""
|
|
Microbenchmarks for element-wise matrix mult with einsum and torch.mul.
|
|
"""
|
|
|
|
batch_elementwise_configs_short = op_bench.config_list(
|
|
attr_names=["B", "M", "N"],
|
|
attrs=[
|
|
[4, 5, 3],
|
|
[32, 25, 20],
|
|
[100, 90, 110],
|
|
],
|
|
cross_product_configs={
|
|
"device": ["cpu", "cuda"],
|
|
},
|
|
tags=["short"],
|
|
)
|
|
|
|
|
|
batch_elementwise_configs_long = op_bench.cross_product_configs(
|
|
B=[128, 512, 1024],
|
|
M=[128, 512, 1024],
|
|
N=[128, 512, 1024],
|
|
device=["cpu", "cuda"],
|
|
tags=["long"],
|
|
)
|
|
|
|
batch_elementwise_op_list = op_bench.op_list(
|
|
attr_names=["op_name", "op_func"],
|
|
attrs=[
|
|
["einsum_elementwise", torch.einsum],
|
|
["mul", torch.mul],
|
|
],
|
|
)
|
|
|
|
|
|
class BatchElementWiseBenchmark(op_bench.TorchBenchmarkBase):
|
|
def init(self, B, M, N, device, op_func):
|
|
self.inputs = {
|
|
"input_one": torch.rand(B, M, N, device=device),
|
|
"input_two": torch.rand(B, M, N, device=device),
|
|
}
|
|
self.op_func = op_func
|
|
|
|
def forward(self, input_one, input_two):
|
|
if self.op_func.__name__ == "einsum":
|
|
return torch.einsum("bij,bij->bij", input_one, input_two)
|
|
else:
|
|
return torch.mul(input_one, input_two)
|
|
|
|
|
|
op_bench.generate_pt_tests_from_op_list(
|
|
batch_mm_op_list,
|
|
batch_mm_configs_short + batch_mm_configs_long,
|
|
BatchMatrixMultBenchmark,
|
|
)
|
|
|
|
op_bench.generate_pt_tests_from_op_list(
|
|
batch_elementwise_op_list,
|
|
batch_elementwise_configs_short + batch_elementwise_configs_long,
|
|
BatchElementWiseBenchmark,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
op_bench.benchmark_runner.main()
|