Additional operators in operator benchmark (#145625)

The list of added operators:
add_, addcmul, arange, baddbmm…, bmm, clamp, div, div_, gelu, index_add, logical_and, mul_, sub_, topk, where

This pull request is the same as a previous one: https://github.com/pytorch/pytorch/pull/145121 which inadvertently got deleted while merging.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145625
Approved by: https://github.com/jeffdaily
This commit is contained in:
Arash Pakbin
2025-01-26 19:20:02 +00:00
committed by PyTorch MergeBot
parent 6a4fb4b615
commit f3ddc08ddc
12 changed files with 656 additions and 40 deletions

View File

@ -1,9 +1,12 @@
from pt import ( # noqa: F401
add_test,
ao_sparsifier_test,
arange_test,
as_strided_test,
batchnorm_test,
binary_inplace_test,
binary_test,
bmm_test,
cat_test,
channel_shuffle_test,
chunk_test,
@ -15,18 +18,25 @@ from pt import ( # noqa: F401
groupnorm_test,
hardsigmoid_test,
hardswish_test,
index_add__test,
index_select_test,
instancenorm_test,
interpolate_test,
layernorm_test,
linear_test,
matmul_test,
mm_test,
nan_to_num_test,
pool_test,
remainder_test,
softmax_test,
split_test,
stack_test,
sum_test,
tensor_to_test,
ternary_test,
topk_test,
where_test,
)
import operator_benchmark as op_bench

View File

@ -0,0 +1,48 @@
import operator_benchmark as op_bench
import torch
"""Microbenchmarks for arange operator"""
# Configs for PT stack operator
configs_short = op_bench.config_list(
attr_names=["start", "end", "step"],
attrs=[
[0, 1000, 2.5],
[-1024, 2048, 1],
],
cross_product_configs={"device": ["cpu"], "dtype": [torch.float]},
tags=["short"],
)
configs_long = op_bench.cross_product_configs(
start=[-1024, 8],
end=[16, 2048],
step=[8, 0.1],
device=["cpu", "cuda"],
dtype=[torch.float, torch.bfloat16],
tags=["long"],
)
class ArangeBenchmark(op_bench.TorchBenchmarkBase):
def init(self, start, end, step, dtype, device):
self.inputs = {
"start": start,
"end": end,
"step": step,
"dtype": dtype,
"device": device,
}
self.set_module_name("arange")
def forward(self, start, end, step, dtype, device):
return torch.arange(start=start, end=end, step=step, dtype=dtype, device=device)
op_bench.generate_pt_test(configs_short + configs_long, ArangeBenchmark)
if __name__ == "__main__":
op_bench.benchmark_runner.main()

View File

@ -0,0 +1,140 @@
import operator_benchmark as op_bench
import torch
"""Microbenchmarks for inplace binary operators."""
def add_(in1, in2):
return in1.add_(in2)
def sub_(in1, in2):
return in1.sub_(in2)
def div_(in1, in2):
return in1.div_(in2)
def mul_(in1, in2):
return in1.mul_(in2)
def copy_(in1, in2):
return in1.copy_(in2)
######
# Benchmark ops performance for inplace add + sub + mul + copy
######
binary_ops_list = op_bench.op_list(
attr_names=["op_name", "op_func"],
attrs=[
["add_", add_],
["sub_", sub_],
# ["div_", div_ ], # done separately below because of data type
["mul_", mul_],
["copy_", copy_],
],
)
binary_short_configs = op_bench.config_list(
attr_names=["M", "N", "K"],
attrs=[
[1, 1, 1],
[64, 64, 64],
[64, 64, 128],
],
cross_product_configs={
"device": ["cpu", "cuda"],
"dtype_one": [torch.int32],
"dtype_two": [torch.int32],
},
tags=["short"],
)
binary_long_configs = op_bench.cross_product_configs(
M=[8, 128],
N=[32, 64],
K=[256, 512],
device=["cpu", "cuda"],
dtype_one=[torch.int8, torch.int32],
dtype_two=[torch.int8, torch.int32],
tags=["long"],
)
class InpBinaryOpBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, K, device, dtype_one, dtype_two, op_func):
self.inputs = {
"input_one": torch.randn(M, N, K, device=device).to(dtype=dtype_one),
"input_two": torch.randn(M, N, K, device=device).to(dtype=dtype_two),
}
self.op_func = op_func
def forward(self, input_one, input_two):
return self.op_func(input_one, input_two)
op_bench.generate_pt_tests_from_op_list(
binary_ops_list, binary_short_configs + binary_long_configs, InpBinaryOpBenchmark
)
######
# Benchmark ops performance for inplace div
######
# Performing division inplace benchmarks separately, as data needs to be float
binary_ops_list = op_bench.op_list(
attr_names=["op_name", "op_func"],
attrs=[
["div_", div_],
],
)
binary_short_configs = op_bench.config_list(
attr_names=["M", "N", "K"],
attrs=[
[1, 1, 1],
[64, 64, 64],
[64, 64, 128],
],
cross_product_configs={
"device": ["cpu", "cuda"],
"dtype_one": [torch.float],
"dtype_two": [torch.float],
},
tags=["short"],
)
binary_long_configs = op_bench.cross_product_configs(
M=[8, 128],
N=[32, 64],
K=[256, 512],
device=["cpu", "cuda"],
dtype_one=[torch.float, torch.float],
dtype_two=[torch.float, torch.float],
tags=["long"],
)
class InpBinaryOpBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, K, device, dtype_one, dtype_two, op_func):
self.inputs = {
"input_one": torch.randn(M, N, K, device=device).to(dtype=dtype_one),
"input_two": torch.randn(M, N, K, device=device).to(dtype=dtype_two),
}
self.op_func = op_func
def forward(self, input_one, input_two):
return self.op_func(input_one, input_two)
op_bench.generate_pt_tests_from_op_list(
binary_ops_list, binary_short_configs + binary_long_configs, InpBinaryOpBenchmark
)
if __name__ == "__main__":
op_bench.benchmark_runner.main()

View File

@ -11,6 +11,9 @@ binary_ops_bcast_list = op_bench.op_list(
attr_names=["op_name", "op_func"],
attrs=[
["add", torch.add],
["sub", torch.sub],
["div", torch.div],
["mul", torch.mul],
],
)
@ -45,16 +48,14 @@ op_bench.generate_pt_tests_from_op_list(
)
def copy(in1, in2):
return in1.copy_(in2)
# Benchmark ops performance without broadcast
binary_ops_list = op_bench.op_list(
attr_names=["op_name", "op_func"],
attrs=[
["add", torch.add],
["copy_", copy],
["sub", torch.sub],
["div", torch.div],
["mul", torch.mul],
],
)
@ -101,5 +102,104 @@ op_bench.generate_pt_tests_from_op_list(
)
######
# Benchmark ops performance for boolean dtype
######
# Benchmark ops performance with broadcast
binary_ops_bcast_list = op_bench.op_list(
attr_names=["op_name", "op_func"],
attrs=[["logical_and", torch.logical_and]],
)
# Configs with broadcast
binary_configs_broadcast = op_bench.config_list(
attr_names=["in_one", "in_two"],
attrs=[
[[64, 1, 64], [1, 64, 1]],
],
cross_product_configs={
"device": ["cpu"],
"dtype": [torch.bool],
},
tags=["short"],
)
class BinaryOpBcastBenchmark(op_bench.TorchBenchmarkBase):
def init(self, in_one, in_two, dtype, device, op_func):
self.inputs = {
"in_one": torch.bernoulli(0.5 * torch.ones(in_one, device=device)).to(
dtype=dtype
),
"in_two": torch.bernoulli(0.5 * torch.ones(in_two, device=device)).to(
dtype=dtype
),
}
self.op_func = op_func
def forward(self, in_one, in_two):
return self.op_func(in_one, in_two)
op_bench.generate_pt_tests_from_op_list(
binary_ops_bcast_list, binary_configs_broadcast, BinaryOpBcastBenchmark
)
# Benchmark ops performance without broadcast
binary_ops_list = op_bench.op_list(
attr_names=["op_name", "op_func"],
attrs=[["logical_and", torch.logical_and]],
)
binary_short_configs = op_bench.config_list(
attr_names=["M", "N", "K"],
attrs=[
[1, 1, 1],
[64, 64, 64],
[64, 64, 128],
],
cross_product_configs={
"device": ["cpu", "cuda"],
"dtype_one": [torch.bool],
"dtype_two": [torch.bool],
},
tags=["short"],
)
binary_long_configs = op_bench.cross_product_configs(
M=[8, 128],
N=[32, 64],
K=[256, 512],
device=["cpu", "cuda"],
dtype_one=[torch.bool, torch.bool],
dtype_two=[torch.bool, torch.bool],
tags=["long"],
)
class BinaryOpBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, K, device, dtype_one, dtype_two, op_func):
self.inputs = {
"input_one": torch.bernoulli(0.5 * torch.ones(M, N, K, device=device)).to(
dtype=dtype_one
),
"input_two": torch.bernoulli(0.5 * torch.ones(M, N, K, device=device)).to(
dtype=dtype_two
),
}
self.op_func = op_func
def forward(self, input_one, input_two):
return self.op_func(input_one, input_two)
op_bench.generate_pt_tests_from_op_list(
binary_ops_list, binary_short_configs + binary_long_configs, BinaryOpBenchmark
)
if __name__ == "__main__":
op_bench.benchmark_runner.main()

View File

@ -3,43 +3,86 @@ import operator_benchmark as op_bench
import torch
"""Microbenchmarks for add_ operator. Supports both Caffe2/PyTorch."""
"""Microbenchmarks for batched operators."""
class BmmBenchmark(op_bench.TorchBenchmarkBase):
def init(self, B, M, N, K, device, op):
self.inputs = {
"batch1": torch.rand(
(B, M, K), device=device, requires_grad=self.auto_set()
),
"batch2": torch.rand(
(
B,
K,
N,
),
device=device,
requires_grad=self.auto_set(),
),
}
self.set_module_name(f"bmm (actual op={op}")
self.op = torch.bmm if op == "bmm" else torch.matmul
def forward(self, batch1, batch2):
return self.op(batch1, batch2)
bmm_configs = op_bench.cross_product_configs(
B=[2, 100],
M=[8, 256],
N=[256, 16],
K=[16, 32],
device=["cpu"],
tags=["short"],
op=["bmm", "matmul"],
# binary ops (two inputs in shape of batches)
batched_binary_ops = op_bench.op_list(
attr_names=["op_name", "op_func"],
attrs=[
["bmm", torch.bmm],
],
)
op_bench.generate_pt_test(bmm_configs, BmmBenchmark)
batched_binary_configs_short = op_bench.config_list(
attr_names=["B", "M", "N", "K"],
attrs=[
[2, 1, 8, 2],
[128, 64, 32, 64],
],
cross_product_configs={
"device": ["cpu"],
"dtype": [torch.float, torch.bfloat16],
},
tags=["short"],
)
batched_binary_configs_long = op_bench.cross_product_configs(
B=[1, 128],
M=[8, 128],
N=[32, 64],
K=[4, 256],
device=["cpu", "cuda"],
dtype=[torch.float, torch.bfloat16],
tags=["long"],
)
class BatchedBinaryOpBenchmark(op_bench.TorchBenchmarkBase):
def init(self, B, M, N, K, device, dtype, op_func):
self.inputs = {
"batch1": torch.rand((B, M, N), device=device).to(dtype=dtype),
"batch2": torch.rand((B, N, K), device=device).to(dtype=dtype),
}
self.op_func = op_func
def forward(self, batch1, batch2):
return self.op_func(batch1, batch2)
op_bench.generate_pt_tests_from_op_list(
batched_binary_ops,
batched_binary_configs_short + batched_binary_configs_long,
BatchedBinaryOpBenchmark,
)
# batched ternary ops
batched_ternary_ops = op_bench.op_list(
attr_names=["op_name", "op_func"],
attrs=[["baddbmm", torch.baddbmm]],
)
class BatchedTernaryOpBenchmark(op_bench.TorchBenchmarkBase):
def init(self, B, M, N, K, device, dtype, op_func):
self.inputs = {
"input_": torch.rand((B, M, K), device=device).to(dtype=dtype),
"batch1": torch.rand((B, M, N), device=device).to(dtype=dtype),
"batch2": torch.rand((B, N, K), device=device).to(dtype=dtype),
}
self.op_func = op_func
def forward(self, input_, batch1, batch2):
return self.op_func(input_, batch1, batch2)
op_bench.generate_pt_tests_from_op_list(
batched_ternary_ops,
batched_binary_configs_short + batched_binary_configs_long,
BatchedTernaryOpBenchmark,
)
# TODO: does it automatically register new scripts?
if __name__ == "__main__":
op_bench.benchmark_runner.main()

View File

@ -0,0 +1,62 @@
import numpy
import operator_benchmark as op_bench
import torch
"""Microbenchmarks for index_add_ operator."""
configs_short = op_bench.config_list(
attr_names=["M", "N", "K", "dim"],
attrs=[[8, 32, 1, 0], [256, 512, 1, 1], [512, 512, 1, 2]],
cross_product_configs={"device": ["cpu"], "dtype": [torch.float]},
tags=["short"],
)
configs_long = op_bench.cross_product_configs(
M=[1, 128, 1024],
N=[2, 256, 512],
K=[1, 2, 8],
dim=[0, 1, 2],
device=["cpu", "cuda"],
dtype=[torch.float],
tags=["long"],
)
class IndexAddBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, K, dim, dtype, device):
# creating the original tensor
tensor = torch.rand(M, N, K, dtype=dtype, device=device)
# creating index
index_max_len = tensor.shape[dim]
index_len = numpy.random.randint(1, index_max_len + 1)
index = torch.tensor(
numpy.random.choice(index_max_len, index_len, replace=False), device=device
)
src_dims = [M, N, K]
src_dims[dim] = index_len
source = torch.rand(*src_dims, dtype=dtype, device=device)
self.inputs = {
"tensor": tensor,
"dim": dim,
"index": index,
"source": source,
}
self.set_module_name("index_add_")
def forward(self, tensor, dim, index, source):
return tensor.index_add_(dim, index, source)
op_bench.generate_pt_test(configs_short + configs_long, IndexAddBenchmark)
if __name__ == "__main__":
op_bench.benchmark_runner.main()

View File

@ -36,7 +36,7 @@ batch_mm_op_list = op_bench.op_list(
attr_names=["op_name", "op_func"],
attrs=[
["einsum_bmm", torch.einsum],
["bmm", torch.bmm],
# ["bmm", torch.bmm],
],
)

View File

@ -0,0 +1,53 @@
import operator_benchmark as op_bench
import torch
"""Microbenchmarks for torch.mm."""
# Benchmark ops performance without broadcast
ops_list = op_bench.op_list(
attr_names=["op_name", "op_func"],
attrs=[["mm", torch.mm]],
)
mm_short_configs = op_bench.config_list(
attr_names=["M", "N", "K"],
attrs=[
[1, 1, 1],
[64, 64, 64],
[64, 64, 128],
],
cross_product_configs={"device": ["cpu"], "dtype": [torch.float]},
tags=["short"],
)
mm_long_configs = op_bench.cross_product_configs(
M=[8, 128],
N=[32, 64],
K=[256, 512],
device=["cpu", "cuda"],
dtype=[torch.float, torch.bfloat16],
tags=["long"],
)
class MmOpBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, K, device, dtype, op_func):
self.inputs = {
"input_one": torch.randn(M, N, device=device).to(dtype=dtype),
"input_two": torch.randn(N, K, device=device).to(dtype=dtype),
}
self.op_func = op_func
def forward(self, input_one, input_two):
return self.op_func(input_one, input_two)
op_bench.generate_pt_tests_from_op_list(
ops_list, mm_short_configs + mm_long_configs, MmOpBenchmark
)
if __name__ == "__main__":
op_bench.benchmark_runner.main()

View File

@ -0,0 +1,57 @@
import operator_benchmark as op_bench
import torch
"""Microbenchmarks for ternary operators."""
ternary_ops = op_bench.op_list(
attr_names=["op_name", "op_func"],
attrs=[
["addcmul", torch.addcmul],
["addcdiv", torch.addcdiv],
],
)
ternary_configs_short = op_bench.config_list(
attr_names=["M", "N"],
attrs=[
[1, 2],
[32, 64],
],
cross_product_configs={
"device": ["cpu"],
"dtype": [torch.float, torch.bfloat16],
},
tags=["short"],
)
ternary_configs_long = op_bench.cross_product_configs(
M=[8, 128],
N=[32, 64],
device=["cpu", "cuda"],
dtype=[torch.float, torch.bfloat16],
tags=["long"],
)
class TernaryOpBenchmark(op_bench.TorchBenchmarkBase):
def init(self, M, N, device, dtype, op_func):
self.inputs = {
"input_": torch.rand((M, N), device=device).to(dtype=dtype),
"tensor1": torch.rand((M, N), device=device).to(dtype=dtype),
"tensor2": torch.rand((M, N), device=device).to(dtype=dtype),
}
self.op_func = op_func
def forward(self, input_, tensor1, tensor2):
return self.op_func(input_, tensor1, tensor2)
op_bench.generate_pt_tests_from_op_list(
ternary_ops, ternary_configs_short + ternary_configs_long, TernaryOpBenchmark
)
if __name__ == "__main__":
op_bench.benchmark_runner.main()

View File

@ -0,0 +1,46 @@
import operator_benchmark as op_bench
import torch
"""Microbenchmarks for topk operator"""
topk_configs_short = op_bench.config_list(
attr_names=["shape", "k", "dim"],
attrs=[
[(16, 4), 4, 1],
[(1024 * 1024,), 16, 0],
],
cross_product_configs={"device": ["cpu"], "dtype": [torch.float]},
tags=["short"],
)
topk_configs_long = op_bench.cross_product_configs(
shape=[(64, 2), (1024 * 1024,), (128,)],
k=[1, 2, 4, 16, 32],
dim=[0],
device=["cpu", "cuda"],
dtype=[torch.float, torch.bfloat16],
tags=["long"],
)
class TopkBenchmark(op_bench.TorchBenchmarkBase):
def init(self, shape, k, dim, dtype, device):
self.inputs = {
"input": torch.randn(shape, device=device, dtype=dtype),
"k": k,
"dim": dim,
}
self.set_module_name("topk")
def forward(self, input, k, dim):
return torch.topk(input, k=k, dim=dim)
op_bench.generate_pt_test(topk_configs_short + topk_configs_long, TopkBenchmark)
if __name__ == "__main__":
op_bench.benchmark_runner.main()

View File

@ -72,6 +72,10 @@ def long_(input):
return input.long()
def clamp(input):
return torch.clamp(input, min=0.25, max=0.75)
unary_ops_list = op_bench.op_list(
attr_names=["op_name", "op_func"],
attrs=[
@ -86,6 +90,7 @@ unary_ops_list = op_bench.op_list(
["atan_", torch.atan_],
["ceil", torch.ceil],
["ceil_", torch.ceil_],
["clamp", clamp],
["clone", torch.clone],
["cos", torch.cos],
["cos_", torch.cos_],
@ -104,6 +109,7 @@ unary_ops_list = op_bench.op_list(
["floor_", torch.floor_],
["frac", torch.frac],
["frac_", torch.frac_],
["gelu", torch.nn.functional.gelu],
["hardshrink", torch.hardshrink],
["lgamma", torch.lgamma],
["log", torch.log],

View File

@ -0,0 +1,51 @@
import operator_benchmark as op_bench
import torch
"""Microbenchmarks for where operator."""
configs_short = op_bench.config_list(
attr_names=["cond_shape", "input_shape", "other_shape"],
attrs=[
[(8, 16, 1), (1,), (1,)],
[(8, 16, 1), (16, 1), (8, 16, 1)],
[(8, 16, 1), (8, 1, 1), (1,)],
],
cross_product_configs={"device": ["cpu"], "dtype": [torch.float]},
tags=["short"],
)
configs_long = op_bench.cross_product_configs(
cond_shape=[(64, 16, 1), (64, 16, 8), (1024, 64, 16, 128)],
input_shape=[(1,), (16, 1), (64, 16, 1)],
other_shape=[(1,), (16, 1), (64, 16, 1)],
device=["cpu", "cuda"],
dtype=[torch.float],
tags=["long"],
)
class WhereBenchmark(op_bench.TorchBenchmarkBase):
def init(self, cond_shape, input_shape, other_shape, dtype, device):
def _create_tensor(shape):
return torch.randn(*shape, dtype=dtype, device=device)
self.inputs = {
"condition": _create_tensor(cond_shape) > 0,
"input": _create_tensor(input_shape),
"other": _create_tensor(other_shape),
}
self.set_module_name("where")
def forward(self, condition, input, other):
return torch.where(condition, input, other)
op_bench.generate_pt_test(configs_short + configs_long, WhereBenchmark)
if __name__ == "__main__":
op_bench.benchmark_runner.main()