mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Additional operators in operator benchmark (#145625)
The list of added operators: add_, addcmul, arange, baddbmm…, bmm, clamp, div, div_, gelu, index_add, logical_and, mul_, sub_, topk, where This pull request is the same as a previous one: https://github.com/pytorch/pytorch/pull/145121 which inadvertently got deleted while merging. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145625 Approved by: https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
6a4fb4b615
commit
f3ddc08ddc
@ -1,9 +1,12 @@
|
||||
from pt import ( # noqa: F401
|
||||
add_test,
|
||||
ao_sparsifier_test,
|
||||
arange_test,
|
||||
as_strided_test,
|
||||
batchnorm_test,
|
||||
binary_inplace_test,
|
||||
binary_test,
|
||||
bmm_test,
|
||||
cat_test,
|
||||
channel_shuffle_test,
|
||||
chunk_test,
|
||||
@ -15,18 +18,25 @@ from pt import ( # noqa: F401
|
||||
groupnorm_test,
|
||||
hardsigmoid_test,
|
||||
hardswish_test,
|
||||
index_add__test,
|
||||
index_select_test,
|
||||
instancenorm_test,
|
||||
interpolate_test,
|
||||
layernorm_test,
|
||||
linear_test,
|
||||
matmul_test,
|
||||
mm_test,
|
||||
nan_to_num_test,
|
||||
pool_test,
|
||||
remainder_test,
|
||||
softmax_test,
|
||||
split_test,
|
||||
stack_test,
|
||||
sum_test,
|
||||
tensor_to_test,
|
||||
ternary_test,
|
||||
topk_test,
|
||||
where_test,
|
||||
)
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
|
48
benchmarks/operator_benchmark/pt/arange_test.py
Normal file
48
benchmarks/operator_benchmark/pt/arange_test.py
Normal file
@ -0,0 +1,48 @@
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
"""Microbenchmarks for arange operator"""
|
||||
|
||||
# Configs for PT stack operator
|
||||
configs_short = op_bench.config_list(
|
||||
attr_names=["start", "end", "step"],
|
||||
attrs=[
|
||||
[0, 1000, 2.5],
|
||||
[-1024, 2048, 1],
|
||||
],
|
||||
cross_product_configs={"device": ["cpu"], "dtype": [torch.float]},
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
configs_long = op_bench.cross_product_configs(
|
||||
start=[-1024, 8],
|
||||
end=[16, 2048],
|
||||
step=[8, 0.1],
|
||||
device=["cpu", "cuda"],
|
||||
dtype=[torch.float, torch.bfloat16],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
class ArangeBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, start, end, step, dtype, device):
|
||||
self.inputs = {
|
||||
"start": start,
|
||||
"end": end,
|
||||
"step": step,
|
||||
"dtype": dtype,
|
||||
"device": device,
|
||||
}
|
||||
|
||||
self.set_module_name("arange")
|
||||
|
||||
def forward(self, start, end, step, dtype, device):
|
||||
return torch.arange(start=start, end=end, step=step, dtype=dtype, device=device)
|
||||
|
||||
|
||||
op_bench.generate_pt_test(configs_short + configs_long, ArangeBenchmark)
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
140
benchmarks/operator_benchmark/pt/binary_inplace_test.py
Normal file
140
benchmarks/operator_benchmark/pt/binary_inplace_test.py
Normal file
@ -0,0 +1,140 @@
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
"""Microbenchmarks for inplace binary operators."""
|
||||
|
||||
|
||||
def add_(in1, in2):
|
||||
return in1.add_(in2)
|
||||
|
||||
|
||||
def sub_(in1, in2):
|
||||
return in1.sub_(in2)
|
||||
|
||||
|
||||
def div_(in1, in2):
|
||||
return in1.div_(in2)
|
||||
|
||||
|
||||
def mul_(in1, in2):
|
||||
return in1.mul_(in2)
|
||||
|
||||
|
||||
def copy_(in1, in2):
|
||||
return in1.copy_(in2)
|
||||
|
||||
|
||||
######
|
||||
# Benchmark ops performance for inplace add + sub + mul + copy
|
||||
######
|
||||
binary_ops_list = op_bench.op_list(
|
||||
attr_names=["op_name", "op_func"],
|
||||
attrs=[
|
||||
["add_", add_],
|
||||
["sub_", sub_],
|
||||
# ["div_", div_ ], # done separately below because of data type
|
||||
["mul_", mul_],
|
||||
["copy_", copy_],
|
||||
],
|
||||
)
|
||||
|
||||
binary_short_configs = op_bench.config_list(
|
||||
attr_names=["M", "N", "K"],
|
||||
attrs=[
|
||||
[1, 1, 1],
|
||||
[64, 64, 64],
|
||||
[64, 64, 128],
|
||||
],
|
||||
cross_product_configs={
|
||||
"device": ["cpu", "cuda"],
|
||||
"dtype_one": [torch.int32],
|
||||
"dtype_two": [torch.int32],
|
||||
},
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
binary_long_configs = op_bench.cross_product_configs(
|
||||
M=[8, 128],
|
||||
N=[32, 64],
|
||||
K=[256, 512],
|
||||
device=["cpu", "cuda"],
|
||||
dtype_one=[torch.int8, torch.int32],
|
||||
dtype_two=[torch.int8, torch.int32],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
class InpBinaryOpBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, K, device, dtype_one, dtype_two, op_func):
|
||||
self.inputs = {
|
||||
"input_one": torch.randn(M, N, K, device=device).to(dtype=dtype_one),
|
||||
"input_two": torch.randn(M, N, K, device=device).to(dtype=dtype_two),
|
||||
}
|
||||
self.op_func = op_func
|
||||
|
||||
def forward(self, input_one, input_two):
|
||||
return self.op_func(input_one, input_two)
|
||||
|
||||
|
||||
op_bench.generate_pt_tests_from_op_list(
|
||||
binary_ops_list, binary_short_configs + binary_long_configs, InpBinaryOpBenchmark
|
||||
)
|
||||
|
||||
|
||||
######
|
||||
# Benchmark ops performance for inplace div
|
||||
######
|
||||
# Performing division inplace benchmarks separately, as data needs to be float
|
||||
binary_ops_list = op_bench.op_list(
|
||||
attr_names=["op_name", "op_func"],
|
||||
attrs=[
|
||||
["div_", div_],
|
||||
],
|
||||
)
|
||||
|
||||
binary_short_configs = op_bench.config_list(
|
||||
attr_names=["M", "N", "K"],
|
||||
attrs=[
|
||||
[1, 1, 1],
|
||||
[64, 64, 64],
|
||||
[64, 64, 128],
|
||||
],
|
||||
cross_product_configs={
|
||||
"device": ["cpu", "cuda"],
|
||||
"dtype_one": [torch.float],
|
||||
"dtype_two": [torch.float],
|
||||
},
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
binary_long_configs = op_bench.cross_product_configs(
|
||||
M=[8, 128],
|
||||
N=[32, 64],
|
||||
K=[256, 512],
|
||||
device=["cpu", "cuda"],
|
||||
dtype_one=[torch.float, torch.float],
|
||||
dtype_two=[torch.float, torch.float],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
class InpBinaryOpBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, K, device, dtype_one, dtype_two, op_func):
|
||||
self.inputs = {
|
||||
"input_one": torch.randn(M, N, K, device=device).to(dtype=dtype_one),
|
||||
"input_two": torch.randn(M, N, K, device=device).to(dtype=dtype_two),
|
||||
}
|
||||
self.op_func = op_func
|
||||
|
||||
def forward(self, input_one, input_two):
|
||||
return self.op_func(input_one, input_two)
|
||||
|
||||
|
||||
op_bench.generate_pt_tests_from_op_list(
|
||||
binary_ops_list, binary_short_configs + binary_long_configs, InpBinaryOpBenchmark
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
@ -11,6 +11,9 @@ binary_ops_bcast_list = op_bench.op_list(
|
||||
attr_names=["op_name", "op_func"],
|
||||
attrs=[
|
||||
["add", torch.add],
|
||||
["sub", torch.sub],
|
||||
["div", torch.div],
|
||||
["mul", torch.mul],
|
||||
],
|
||||
)
|
||||
|
||||
@ -45,16 +48,14 @@ op_bench.generate_pt_tests_from_op_list(
|
||||
)
|
||||
|
||||
|
||||
def copy(in1, in2):
|
||||
return in1.copy_(in2)
|
||||
|
||||
|
||||
# Benchmark ops performance without broadcast
|
||||
binary_ops_list = op_bench.op_list(
|
||||
attr_names=["op_name", "op_func"],
|
||||
attrs=[
|
||||
["add", torch.add],
|
||||
["copy_", copy],
|
||||
["sub", torch.sub],
|
||||
["div", torch.div],
|
||||
["mul", torch.mul],
|
||||
],
|
||||
)
|
||||
|
||||
@ -101,5 +102,104 @@ op_bench.generate_pt_tests_from_op_list(
|
||||
)
|
||||
|
||||
|
||||
######
|
||||
# Benchmark ops performance for boolean dtype
|
||||
######
|
||||
|
||||
|
||||
# Benchmark ops performance with broadcast
|
||||
binary_ops_bcast_list = op_bench.op_list(
|
||||
attr_names=["op_name", "op_func"],
|
||||
attrs=[["logical_and", torch.logical_and]],
|
||||
)
|
||||
|
||||
# Configs with broadcast
|
||||
binary_configs_broadcast = op_bench.config_list(
|
||||
attr_names=["in_one", "in_two"],
|
||||
attrs=[
|
||||
[[64, 1, 64], [1, 64, 1]],
|
||||
],
|
||||
cross_product_configs={
|
||||
"device": ["cpu"],
|
||||
"dtype": [torch.bool],
|
||||
},
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
|
||||
class BinaryOpBcastBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, in_one, in_two, dtype, device, op_func):
|
||||
self.inputs = {
|
||||
"in_one": torch.bernoulli(0.5 * torch.ones(in_one, device=device)).to(
|
||||
dtype=dtype
|
||||
),
|
||||
"in_two": torch.bernoulli(0.5 * torch.ones(in_two, device=device)).to(
|
||||
dtype=dtype
|
||||
),
|
||||
}
|
||||
self.op_func = op_func
|
||||
|
||||
def forward(self, in_one, in_two):
|
||||
return self.op_func(in_one, in_two)
|
||||
|
||||
|
||||
op_bench.generate_pt_tests_from_op_list(
|
||||
binary_ops_bcast_list, binary_configs_broadcast, BinaryOpBcastBenchmark
|
||||
)
|
||||
|
||||
|
||||
# Benchmark ops performance without broadcast
|
||||
binary_ops_list = op_bench.op_list(
|
||||
attr_names=["op_name", "op_func"],
|
||||
attrs=[["logical_and", torch.logical_and]],
|
||||
)
|
||||
|
||||
binary_short_configs = op_bench.config_list(
|
||||
attr_names=["M", "N", "K"],
|
||||
attrs=[
|
||||
[1, 1, 1],
|
||||
[64, 64, 64],
|
||||
[64, 64, 128],
|
||||
],
|
||||
cross_product_configs={
|
||||
"device": ["cpu", "cuda"],
|
||||
"dtype_one": [torch.bool],
|
||||
"dtype_two": [torch.bool],
|
||||
},
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
binary_long_configs = op_bench.cross_product_configs(
|
||||
M=[8, 128],
|
||||
N=[32, 64],
|
||||
K=[256, 512],
|
||||
device=["cpu", "cuda"],
|
||||
dtype_one=[torch.bool, torch.bool],
|
||||
dtype_two=[torch.bool, torch.bool],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
class BinaryOpBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, K, device, dtype_one, dtype_two, op_func):
|
||||
self.inputs = {
|
||||
"input_one": torch.bernoulli(0.5 * torch.ones(M, N, K, device=device)).to(
|
||||
dtype=dtype_one
|
||||
),
|
||||
"input_two": torch.bernoulli(0.5 * torch.ones(M, N, K, device=device)).to(
|
||||
dtype=dtype_two
|
||||
),
|
||||
}
|
||||
self.op_func = op_func
|
||||
|
||||
def forward(self, input_one, input_two):
|
||||
return self.op_func(input_one, input_two)
|
||||
|
||||
|
||||
op_bench.generate_pt_tests_from_op_list(
|
||||
binary_ops_list, binary_short_configs + binary_long_configs, BinaryOpBenchmark
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
||||
|
@ -3,43 +3,86 @@ import operator_benchmark as op_bench
|
||||
import torch
|
||||
|
||||
|
||||
"""Microbenchmarks for add_ operator. Supports both Caffe2/PyTorch."""
|
||||
"""Microbenchmarks for batched operators."""
|
||||
|
||||
|
||||
class BmmBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, B, M, N, K, device, op):
|
||||
self.inputs = {
|
||||
"batch1": torch.rand(
|
||||
(B, M, K), device=device, requires_grad=self.auto_set()
|
||||
),
|
||||
"batch2": torch.rand(
|
||||
(
|
||||
B,
|
||||
K,
|
||||
N,
|
||||
),
|
||||
device=device,
|
||||
requires_grad=self.auto_set(),
|
||||
),
|
||||
}
|
||||
self.set_module_name(f"bmm (actual op={op}")
|
||||
self.op = torch.bmm if op == "bmm" else torch.matmul
|
||||
|
||||
def forward(self, batch1, batch2):
|
||||
return self.op(batch1, batch2)
|
||||
|
||||
|
||||
bmm_configs = op_bench.cross_product_configs(
|
||||
B=[2, 100],
|
||||
M=[8, 256],
|
||||
N=[256, 16],
|
||||
K=[16, 32],
|
||||
device=["cpu"],
|
||||
tags=["short"],
|
||||
op=["bmm", "matmul"],
|
||||
# binary ops (two inputs in shape of batches)
|
||||
batched_binary_ops = op_bench.op_list(
|
||||
attr_names=["op_name", "op_func"],
|
||||
attrs=[
|
||||
["bmm", torch.bmm],
|
||||
],
|
||||
)
|
||||
|
||||
op_bench.generate_pt_test(bmm_configs, BmmBenchmark)
|
||||
batched_binary_configs_short = op_bench.config_list(
|
||||
attr_names=["B", "M", "N", "K"],
|
||||
attrs=[
|
||||
[2, 1, 8, 2],
|
||||
[128, 64, 32, 64],
|
||||
],
|
||||
cross_product_configs={
|
||||
"device": ["cpu"],
|
||||
"dtype": [torch.float, torch.bfloat16],
|
||||
},
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
batched_binary_configs_long = op_bench.cross_product_configs(
|
||||
B=[1, 128],
|
||||
M=[8, 128],
|
||||
N=[32, 64],
|
||||
K=[4, 256],
|
||||
device=["cpu", "cuda"],
|
||||
dtype=[torch.float, torch.bfloat16],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
class BatchedBinaryOpBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, B, M, N, K, device, dtype, op_func):
|
||||
self.inputs = {
|
||||
"batch1": torch.rand((B, M, N), device=device).to(dtype=dtype),
|
||||
"batch2": torch.rand((B, N, K), device=device).to(dtype=dtype),
|
||||
}
|
||||
self.op_func = op_func
|
||||
|
||||
def forward(self, batch1, batch2):
|
||||
return self.op_func(batch1, batch2)
|
||||
|
||||
|
||||
op_bench.generate_pt_tests_from_op_list(
|
||||
batched_binary_ops,
|
||||
batched_binary_configs_short + batched_binary_configs_long,
|
||||
BatchedBinaryOpBenchmark,
|
||||
)
|
||||
|
||||
|
||||
# batched ternary ops
|
||||
batched_ternary_ops = op_bench.op_list(
|
||||
attr_names=["op_name", "op_func"],
|
||||
attrs=[["baddbmm", torch.baddbmm]],
|
||||
)
|
||||
|
||||
|
||||
class BatchedTernaryOpBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, B, M, N, K, device, dtype, op_func):
|
||||
self.inputs = {
|
||||
"input_": torch.rand((B, M, K), device=device).to(dtype=dtype),
|
||||
"batch1": torch.rand((B, M, N), device=device).to(dtype=dtype),
|
||||
"batch2": torch.rand((B, N, K), device=device).to(dtype=dtype),
|
||||
}
|
||||
self.op_func = op_func
|
||||
|
||||
def forward(self, input_, batch1, batch2):
|
||||
return self.op_func(input_, batch1, batch2)
|
||||
|
||||
|
||||
op_bench.generate_pt_tests_from_op_list(
|
||||
batched_ternary_ops,
|
||||
batched_binary_configs_short + batched_binary_configs_long,
|
||||
BatchedTernaryOpBenchmark,
|
||||
)
|
||||
|
||||
# TODO: does it automatically register new scripts?
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
||||
|
62
benchmarks/operator_benchmark/pt/index_add__test.py
Normal file
62
benchmarks/operator_benchmark/pt/index_add__test.py
Normal file
@ -0,0 +1,62 @@
|
||||
import numpy
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
"""Microbenchmarks for index_add_ operator."""
|
||||
|
||||
|
||||
configs_short = op_bench.config_list(
|
||||
attr_names=["M", "N", "K", "dim"],
|
||||
attrs=[[8, 32, 1, 0], [256, 512, 1, 1], [512, 512, 1, 2]],
|
||||
cross_product_configs={"device": ["cpu"], "dtype": [torch.float]},
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
|
||||
configs_long = op_bench.cross_product_configs(
|
||||
M=[1, 128, 1024],
|
||||
N=[2, 256, 512],
|
||||
K=[1, 2, 8],
|
||||
dim=[0, 1, 2],
|
||||
device=["cpu", "cuda"],
|
||||
dtype=[torch.float],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
class IndexAddBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, K, dim, dtype, device):
|
||||
# creating the original tensor
|
||||
tensor = torch.rand(M, N, K, dtype=dtype, device=device)
|
||||
|
||||
# creating index
|
||||
index_max_len = tensor.shape[dim]
|
||||
index_len = numpy.random.randint(1, index_max_len + 1)
|
||||
index = torch.tensor(
|
||||
numpy.random.choice(index_max_len, index_len, replace=False), device=device
|
||||
)
|
||||
|
||||
src_dims = [M, N, K]
|
||||
src_dims[dim] = index_len
|
||||
source = torch.rand(*src_dims, dtype=dtype, device=device)
|
||||
|
||||
self.inputs = {
|
||||
"tensor": tensor,
|
||||
"dim": dim,
|
||||
"index": index,
|
||||
"source": source,
|
||||
}
|
||||
self.set_module_name("index_add_")
|
||||
|
||||
def forward(self, tensor, dim, index, source):
|
||||
return tensor.index_add_(dim, index, source)
|
||||
|
||||
|
||||
op_bench.generate_pt_test(configs_short + configs_long, IndexAddBenchmark)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
@ -36,7 +36,7 @@ batch_mm_op_list = op_bench.op_list(
|
||||
attr_names=["op_name", "op_func"],
|
||||
attrs=[
|
||||
["einsum_bmm", torch.einsum],
|
||||
["bmm", torch.bmm],
|
||||
# ["bmm", torch.bmm],
|
||||
],
|
||||
)
|
||||
|
||||
|
53
benchmarks/operator_benchmark/pt/mm_test.py
Normal file
53
benchmarks/operator_benchmark/pt/mm_test.py
Normal file
@ -0,0 +1,53 @@
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
"""Microbenchmarks for torch.mm."""
|
||||
|
||||
# Benchmark ops performance without broadcast
|
||||
ops_list = op_bench.op_list(
|
||||
attr_names=["op_name", "op_func"],
|
||||
attrs=[["mm", torch.mm]],
|
||||
)
|
||||
|
||||
mm_short_configs = op_bench.config_list(
|
||||
attr_names=["M", "N", "K"],
|
||||
attrs=[
|
||||
[1, 1, 1],
|
||||
[64, 64, 64],
|
||||
[64, 64, 128],
|
||||
],
|
||||
cross_product_configs={"device": ["cpu"], "dtype": [torch.float]},
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
mm_long_configs = op_bench.cross_product_configs(
|
||||
M=[8, 128],
|
||||
N=[32, 64],
|
||||
K=[256, 512],
|
||||
device=["cpu", "cuda"],
|
||||
dtype=[torch.float, torch.bfloat16],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
class MmOpBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, K, device, dtype, op_func):
|
||||
self.inputs = {
|
||||
"input_one": torch.randn(M, N, device=device).to(dtype=dtype),
|
||||
"input_two": torch.randn(N, K, device=device).to(dtype=dtype),
|
||||
}
|
||||
self.op_func = op_func
|
||||
|
||||
def forward(self, input_one, input_two):
|
||||
return self.op_func(input_one, input_two)
|
||||
|
||||
|
||||
op_bench.generate_pt_tests_from_op_list(
|
||||
ops_list, mm_short_configs + mm_long_configs, MmOpBenchmark
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
57
benchmarks/operator_benchmark/pt/ternary_test.py
Normal file
57
benchmarks/operator_benchmark/pt/ternary_test.py
Normal file
@ -0,0 +1,57 @@
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
"""Microbenchmarks for ternary operators."""
|
||||
|
||||
|
||||
ternary_ops = op_bench.op_list(
|
||||
attr_names=["op_name", "op_func"],
|
||||
attrs=[
|
||||
["addcmul", torch.addcmul],
|
||||
["addcdiv", torch.addcdiv],
|
||||
],
|
||||
)
|
||||
|
||||
ternary_configs_short = op_bench.config_list(
|
||||
attr_names=["M", "N"],
|
||||
attrs=[
|
||||
[1, 2],
|
||||
[32, 64],
|
||||
],
|
||||
cross_product_configs={
|
||||
"device": ["cpu"],
|
||||
"dtype": [torch.float, torch.bfloat16],
|
||||
},
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
ternary_configs_long = op_bench.cross_product_configs(
|
||||
M=[8, 128],
|
||||
N=[32, 64],
|
||||
device=["cpu", "cuda"],
|
||||
dtype=[torch.float, torch.bfloat16],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
class TernaryOpBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, M, N, device, dtype, op_func):
|
||||
self.inputs = {
|
||||
"input_": torch.rand((M, N), device=device).to(dtype=dtype),
|
||||
"tensor1": torch.rand((M, N), device=device).to(dtype=dtype),
|
||||
"tensor2": torch.rand((M, N), device=device).to(dtype=dtype),
|
||||
}
|
||||
self.op_func = op_func
|
||||
|
||||
def forward(self, input_, tensor1, tensor2):
|
||||
return self.op_func(input_, tensor1, tensor2)
|
||||
|
||||
|
||||
op_bench.generate_pt_tests_from_op_list(
|
||||
ternary_ops, ternary_configs_short + ternary_configs_long, TernaryOpBenchmark
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
46
benchmarks/operator_benchmark/pt/topk_test.py
Normal file
46
benchmarks/operator_benchmark/pt/topk_test.py
Normal file
@ -0,0 +1,46 @@
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
"""Microbenchmarks for topk operator"""
|
||||
|
||||
|
||||
topk_configs_short = op_bench.config_list(
|
||||
attr_names=["shape", "k", "dim"],
|
||||
attrs=[
|
||||
[(16, 4), 4, 1],
|
||||
[(1024 * 1024,), 16, 0],
|
||||
],
|
||||
cross_product_configs={"device": ["cpu"], "dtype": [torch.float]},
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
topk_configs_long = op_bench.cross_product_configs(
|
||||
shape=[(64, 2), (1024 * 1024,), (128,)],
|
||||
k=[1, 2, 4, 16, 32],
|
||||
dim=[0],
|
||||
device=["cpu", "cuda"],
|
||||
dtype=[torch.float, torch.bfloat16],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
class TopkBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, shape, k, dim, dtype, device):
|
||||
self.inputs = {
|
||||
"input": torch.randn(shape, device=device, dtype=dtype),
|
||||
"k": k,
|
||||
"dim": dim,
|
||||
}
|
||||
|
||||
self.set_module_name("topk")
|
||||
|
||||
def forward(self, input, k, dim):
|
||||
return torch.topk(input, k=k, dim=dim)
|
||||
|
||||
|
||||
op_bench.generate_pt_test(topk_configs_short + topk_configs_long, TopkBenchmark)
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
@ -72,6 +72,10 @@ def long_(input):
|
||||
return input.long()
|
||||
|
||||
|
||||
def clamp(input):
|
||||
return torch.clamp(input, min=0.25, max=0.75)
|
||||
|
||||
|
||||
unary_ops_list = op_bench.op_list(
|
||||
attr_names=["op_name", "op_func"],
|
||||
attrs=[
|
||||
@ -86,6 +90,7 @@ unary_ops_list = op_bench.op_list(
|
||||
["atan_", torch.atan_],
|
||||
["ceil", torch.ceil],
|
||||
["ceil_", torch.ceil_],
|
||||
["clamp", clamp],
|
||||
["clone", torch.clone],
|
||||
["cos", torch.cos],
|
||||
["cos_", torch.cos_],
|
||||
@ -104,6 +109,7 @@ unary_ops_list = op_bench.op_list(
|
||||
["floor_", torch.floor_],
|
||||
["frac", torch.frac],
|
||||
["frac_", torch.frac_],
|
||||
["gelu", torch.nn.functional.gelu],
|
||||
["hardshrink", torch.hardshrink],
|
||||
["lgamma", torch.lgamma],
|
||||
["log", torch.log],
|
||||
|
51
benchmarks/operator_benchmark/pt/where_test.py
Normal file
51
benchmarks/operator_benchmark/pt/where_test.py
Normal file
@ -0,0 +1,51 @@
|
||||
import operator_benchmark as op_bench
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
"""Microbenchmarks for where operator."""
|
||||
|
||||
|
||||
configs_short = op_bench.config_list(
|
||||
attr_names=["cond_shape", "input_shape", "other_shape"],
|
||||
attrs=[
|
||||
[(8, 16, 1), (1,), (1,)],
|
||||
[(8, 16, 1), (16, 1), (8, 16, 1)],
|
||||
[(8, 16, 1), (8, 1, 1), (1,)],
|
||||
],
|
||||
cross_product_configs={"device": ["cpu"], "dtype": [torch.float]},
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
|
||||
configs_long = op_bench.cross_product_configs(
|
||||
cond_shape=[(64, 16, 1), (64, 16, 8), (1024, 64, 16, 128)],
|
||||
input_shape=[(1,), (16, 1), (64, 16, 1)],
|
||||
other_shape=[(1,), (16, 1), (64, 16, 1)],
|
||||
device=["cpu", "cuda"],
|
||||
dtype=[torch.float],
|
||||
tags=["long"],
|
||||
)
|
||||
|
||||
|
||||
class WhereBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, cond_shape, input_shape, other_shape, dtype, device):
|
||||
def _create_tensor(shape):
|
||||
return torch.randn(*shape, dtype=dtype, device=device)
|
||||
|
||||
self.inputs = {
|
||||
"condition": _create_tensor(cond_shape) > 0,
|
||||
"input": _create_tensor(input_shape),
|
||||
"other": _create_tensor(other_shape),
|
||||
}
|
||||
self.set_module_name("where")
|
||||
|
||||
def forward(self, condition, input, other):
|
||||
return torch.where(condition, input, other)
|
||||
|
||||
|
||||
op_bench.generate_pt_test(configs_short + configs_long, WhereBenchmark)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
Reference in New Issue
Block a user