From f3ddc08ddc0b11b93f756871b8f239d180dd9af8 Mon Sep 17 00:00:00 2001 From: Arash Pakbin Date: Sun, 26 Jan 2025 19:20:02 +0000 Subject: [PATCH] Additional operators in operator benchmark (#145625) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The list of added operators: add_, addcmul, arange, baddbmm…, bmm, clamp, div, div_, gelu, index_add, logical_and, mul_, sub_, topk, where This pull request is the same as a previous one: https://github.com/pytorch/pytorch/pull/145121 which inadvertently got deleted while merging. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145625 Approved by: https://github.com/jeffdaily --- .../benchmark_all_other_test.py | 10 ++ .../operator_benchmark/pt/arange_test.py | 48 ++++++ .../pt/binary_inplace_test.py | 140 ++++++++++++++++++ .../operator_benchmark/pt/binary_test.py | 110 +++++++++++++- benchmarks/operator_benchmark/pt/bmm_test.py | 111 +++++++++----- .../operator_benchmark/pt/index_add__test.py | 62 ++++++++ .../operator_benchmark/pt/matrix_mult_test.py | 2 +- benchmarks/operator_benchmark/pt/mm_test.py | 53 +++++++ .../operator_benchmark/pt/ternary_test.py | 57 +++++++ benchmarks/operator_benchmark/pt/topk_test.py | 46 ++++++ .../operator_benchmark/pt/unary_test.py | 6 + .../operator_benchmark/pt/where_test.py | 51 +++++++ 12 files changed, 656 insertions(+), 40 deletions(-) create mode 100644 benchmarks/operator_benchmark/pt/arange_test.py create mode 100644 benchmarks/operator_benchmark/pt/binary_inplace_test.py create mode 100644 benchmarks/operator_benchmark/pt/index_add__test.py create mode 100644 benchmarks/operator_benchmark/pt/mm_test.py create mode 100644 benchmarks/operator_benchmark/pt/ternary_test.py create mode 100644 benchmarks/operator_benchmark/pt/topk_test.py create mode 100644 benchmarks/operator_benchmark/pt/where_test.py diff --git a/benchmarks/operator_benchmark/benchmark_all_other_test.py b/benchmarks/operator_benchmark/benchmark_all_other_test.py index 05022e8407f0..e368c281d9a4 100644 --- a/benchmarks/operator_benchmark/benchmark_all_other_test.py +++ b/benchmarks/operator_benchmark/benchmark_all_other_test.py @@ -1,9 +1,12 @@ from pt import ( # noqa: F401 add_test, ao_sparsifier_test, + arange_test, as_strided_test, batchnorm_test, + binary_inplace_test, binary_test, + bmm_test, cat_test, channel_shuffle_test, chunk_test, @@ -15,18 +18,25 @@ from pt import ( # noqa: F401 groupnorm_test, hardsigmoid_test, hardswish_test, + index_add__test, + index_select_test, instancenorm_test, interpolate_test, layernorm_test, linear_test, matmul_test, + mm_test, nan_to_num_test, pool_test, remainder_test, softmax_test, split_test, + stack_test, sum_test, tensor_to_test, + ternary_test, + topk_test, + where_test, ) import operator_benchmark as op_bench diff --git a/benchmarks/operator_benchmark/pt/arange_test.py b/benchmarks/operator_benchmark/pt/arange_test.py new file mode 100644 index 000000000000..c3d039cb56bd --- /dev/null +++ b/benchmarks/operator_benchmark/pt/arange_test.py @@ -0,0 +1,48 @@ +import operator_benchmark as op_bench + +import torch + + +"""Microbenchmarks for arange operator""" + +# Configs for PT stack operator +configs_short = op_bench.config_list( + attr_names=["start", "end", "step"], + attrs=[ + [0, 1000, 2.5], + [-1024, 2048, 1], + ], + cross_product_configs={"device": ["cpu"], "dtype": [torch.float]}, + tags=["short"], +) + +configs_long = op_bench.cross_product_configs( + start=[-1024, 8], + end=[16, 2048], + step=[8, 0.1], + device=["cpu", "cuda"], + dtype=[torch.float, torch.bfloat16], + tags=["long"], +) + + +class ArangeBenchmark(op_bench.TorchBenchmarkBase): + def init(self, start, end, step, dtype, device): + self.inputs = { + "start": start, + "end": end, + "step": step, + "dtype": dtype, + "device": device, + } + + self.set_module_name("arange") + + def forward(self, start, end, step, dtype, device): + return torch.arange(start=start, end=end, step=step, dtype=dtype, device=device) + + +op_bench.generate_pt_test(configs_short + configs_long, ArangeBenchmark) + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/binary_inplace_test.py b/benchmarks/operator_benchmark/pt/binary_inplace_test.py new file mode 100644 index 000000000000..ce5391045872 --- /dev/null +++ b/benchmarks/operator_benchmark/pt/binary_inplace_test.py @@ -0,0 +1,140 @@ +import operator_benchmark as op_bench + +import torch + + +"""Microbenchmarks for inplace binary operators.""" + + +def add_(in1, in2): + return in1.add_(in2) + + +def sub_(in1, in2): + return in1.sub_(in2) + + +def div_(in1, in2): + return in1.div_(in2) + + +def mul_(in1, in2): + return in1.mul_(in2) + + +def copy_(in1, in2): + return in1.copy_(in2) + + +###### +# Benchmark ops performance for inplace add + sub + mul + copy +###### +binary_ops_list = op_bench.op_list( + attr_names=["op_name", "op_func"], + attrs=[ + ["add_", add_], + ["sub_", sub_], + # ["div_", div_ ], # done separately below because of data type + ["mul_", mul_], + ["copy_", copy_], + ], +) + +binary_short_configs = op_bench.config_list( + attr_names=["M", "N", "K"], + attrs=[ + [1, 1, 1], + [64, 64, 64], + [64, 64, 128], + ], + cross_product_configs={ + "device": ["cpu", "cuda"], + "dtype_one": [torch.int32], + "dtype_two": [torch.int32], + }, + tags=["short"], +) + +binary_long_configs = op_bench.cross_product_configs( + M=[8, 128], + N=[32, 64], + K=[256, 512], + device=["cpu", "cuda"], + dtype_one=[torch.int8, torch.int32], + dtype_two=[torch.int8, torch.int32], + tags=["long"], +) + + +class InpBinaryOpBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, K, device, dtype_one, dtype_two, op_func): + self.inputs = { + "input_one": torch.randn(M, N, K, device=device).to(dtype=dtype_one), + "input_two": torch.randn(M, N, K, device=device).to(dtype=dtype_two), + } + self.op_func = op_func + + def forward(self, input_one, input_two): + return self.op_func(input_one, input_two) + + +op_bench.generate_pt_tests_from_op_list( + binary_ops_list, binary_short_configs + binary_long_configs, InpBinaryOpBenchmark +) + + +###### +# Benchmark ops performance for inplace div +###### +# Performing division inplace benchmarks separately, as data needs to be float +binary_ops_list = op_bench.op_list( + attr_names=["op_name", "op_func"], + attrs=[ + ["div_", div_], + ], +) + +binary_short_configs = op_bench.config_list( + attr_names=["M", "N", "K"], + attrs=[ + [1, 1, 1], + [64, 64, 64], + [64, 64, 128], + ], + cross_product_configs={ + "device": ["cpu", "cuda"], + "dtype_one": [torch.float], + "dtype_two": [torch.float], + }, + tags=["short"], +) + +binary_long_configs = op_bench.cross_product_configs( + M=[8, 128], + N=[32, 64], + K=[256, 512], + device=["cpu", "cuda"], + dtype_one=[torch.float, torch.float], + dtype_two=[torch.float, torch.float], + tags=["long"], +) + + +class InpBinaryOpBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, K, device, dtype_one, dtype_two, op_func): + self.inputs = { + "input_one": torch.randn(M, N, K, device=device).to(dtype=dtype_one), + "input_two": torch.randn(M, N, K, device=device).to(dtype=dtype_two), + } + self.op_func = op_func + + def forward(self, input_one, input_two): + return self.op_func(input_one, input_two) + + +op_bench.generate_pt_tests_from_op_list( + binary_ops_list, binary_short_configs + binary_long_configs, InpBinaryOpBenchmark +) + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/binary_test.py b/benchmarks/operator_benchmark/pt/binary_test.py index 4a4144a96ee8..60b1bba7933f 100644 --- a/benchmarks/operator_benchmark/pt/binary_test.py +++ b/benchmarks/operator_benchmark/pt/binary_test.py @@ -11,6 +11,9 @@ binary_ops_bcast_list = op_bench.op_list( attr_names=["op_name", "op_func"], attrs=[ ["add", torch.add], + ["sub", torch.sub], + ["div", torch.div], + ["mul", torch.mul], ], ) @@ -45,16 +48,14 @@ op_bench.generate_pt_tests_from_op_list( ) -def copy(in1, in2): - return in1.copy_(in2) - - # Benchmark ops performance without broadcast binary_ops_list = op_bench.op_list( attr_names=["op_name", "op_func"], attrs=[ ["add", torch.add], - ["copy_", copy], + ["sub", torch.sub], + ["div", torch.div], + ["mul", torch.mul], ], ) @@ -101,5 +102,104 @@ op_bench.generate_pt_tests_from_op_list( ) +###### +# Benchmark ops performance for boolean dtype +###### + + +# Benchmark ops performance with broadcast +binary_ops_bcast_list = op_bench.op_list( + attr_names=["op_name", "op_func"], + attrs=[["logical_and", torch.logical_and]], +) + +# Configs with broadcast +binary_configs_broadcast = op_bench.config_list( + attr_names=["in_one", "in_two"], + attrs=[ + [[64, 1, 64], [1, 64, 1]], + ], + cross_product_configs={ + "device": ["cpu"], + "dtype": [torch.bool], + }, + tags=["short"], +) + + +class BinaryOpBcastBenchmark(op_bench.TorchBenchmarkBase): + def init(self, in_one, in_two, dtype, device, op_func): + self.inputs = { + "in_one": torch.bernoulli(0.5 * torch.ones(in_one, device=device)).to( + dtype=dtype + ), + "in_two": torch.bernoulli(0.5 * torch.ones(in_two, device=device)).to( + dtype=dtype + ), + } + self.op_func = op_func + + def forward(self, in_one, in_two): + return self.op_func(in_one, in_two) + + +op_bench.generate_pt_tests_from_op_list( + binary_ops_bcast_list, binary_configs_broadcast, BinaryOpBcastBenchmark +) + + +# Benchmark ops performance without broadcast +binary_ops_list = op_bench.op_list( + attr_names=["op_name", "op_func"], + attrs=[["logical_and", torch.logical_and]], +) + +binary_short_configs = op_bench.config_list( + attr_names=["M", "N", "K"], + attrs=[ + [1, 1, 1], + [64, 64, 64], + [64, 64, 128], + ], + cross_product_configs={ + "device": ["cpu", "cuda"], + "dtype_one": [torch.bool], + "dtype_two": [torch.bool], + }, + tags=["short"], +) + +binary_long_configs = op_bench.cross_product_configs( + M=[8, 128], + N=[32, 64], + K=[256, 512], + device=["cpu", "cuda"], + dtype_one=[torch.bool, torch.bool], + dtype_two=[torch.bool, torch.bool], + tags=["long"], +) + + +class BinaryOpBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, K, device, dtype_one, dtype_two, op_func): + self.inputs = { + "input_one": torch.bernoulli(0.5 * torch.ones(M, N, K, device=device)).to( + dtype=dtype_one + ), + "input_two": torch.bernoulli(0.5 * torch.ones(M, N, K, device=device)).to( + dtype=dtype_two + ), + } + self.op_func = op_func + + def forward(self, input_one, input_two): + return self.op_func(input_one, input_two) + + +op_bench.generate_pt_tests_from_op_list( + binary_ops_list, binary_short_configs + binary_long_configs, BinaryOpBenchmark +) + + if __name__ == "__main__": op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/bmm_test.py b/benchmarks/operator_benchmark/pt/bmm_test.py index 8ff5d0b5e1b0..1c6d1f9aca55 100644 --- a/benchmarks/operator_benchmark/pt/bmm_test.py +++ b/benchmarks/operator_benchmark/pt/bmm_test.py @@ -3,43 +3,86 @@ import operator_benchmark as op_bench import torch -"""Microbenchmarks for add_ operator. Supports both Caffe2/PyTorch.""" +"""Microbenchmarks for batched operators.""" - -class BmmBenchmark(op_bench.TorchBenchmarkBase): - def init(self, B, M, N, K, device, op): - self.inputs = { - "batch1": torch.rand( - (B, M, K), device=device, requires_grad=self.auto_set() - ), - "batch2": torch.rand( - ( - B, - K, - N, - ), - device=device, - requires_grad=self.auto_set(), - ), - } - self.set_module_name(f"bmm (actual op={op}") - self.op = torch.bmm if op == "bmm" else torch.matmul - - def forward(self, batch1, batch2): - return self.op(batch1, batch2) - - -bmm_configs = op_bench.cross_product_configs( - B=[2, 100], - M=[8, 256], - N=[256, 16], - K=[16, 32], - device=["cpu"], - tags=["short"], - op=["bmm", "matmul"], +# binary ops (two inputs in shape of batches) +batched_binary_ops = op_bench.op_list( + attr_names=["op_name", "op_func"], + attrs=[ + ["bmm", torch.bmm], + ], ) -op_bench.generate_pt_test(bmm_configs, BmmBenchmark) +batched_binary_configs_short = op_bench.config_list( + attr_names=["B", "M", "N", "K"], + attrs=[ + [2, 1, 8, 2], + [128, 64, 32, 64], + ], + cross_product_configs={ + "device": ["cpu"], + "dtype": [torch.float, torch.bfloat16], + }, + tags=["short"], +) + +batched_binary_configs_long = op_bench.cross_product_configs( + B=[1, 128], + M=[8, 128], + N=[32, 64], + K=[4, 256], + device=["cpu", "cuda"], + dtype=[torch.float, torch.bfloat16], + tags=["long"], +) + + +class BatchedBinaryOpBenchmark(op_bench.TorchBenchmarkBase): + def init(self, B, M, N, K, device, dtype, op_func): + self.inputs = { + "batch1": torch.rand((B, M, N), device=device).to(dtype=dtype), + "batch2": torch.rand((B, N, K), device=device).to(dtype=dtype), + } + self.op_func = op_func + + def forward(self, batch1, batch2): + return self.op_func(batch1, batch2) + + +op_bench.generate_pt_tests_from_op_list( + batched_binary_ops, + batched_binary_configs_short + batched_binary_configs_long, + BatchedBinaryOpBenchmark, +) + + +# batched ternary ops +batched_ternary_ops = op_bench.op_list( + attr_names=["op_name", "op_func"], + attrs=[["baddbmm", torch.baddbmm]], +) + + +class BatchedTernaryOpBenchmark(op_bench.TorchBenchmarkBase): + def init(self, B, M, N, K, device, dtype, op_func): + self.inputs = { + "input_": torch.rand((B, M, K), device=device).to(dtype=dtype), + "batch1": torch.rand((B, M, N), device=device).to(dtype=dtype), + "batch2": torch.rand((B, N, K), device=device).to(dtype=dtype), + } + self.op_func = op_func + + def forward(self, input_, batch1, batch2): + return self.op_func(input_, batch1, batch2) + + +op_bench.generate_pt_tests_from_op_list( + batched_ternary_ops, + batched_binary_configs_short + batched_binary_configs_long, + BatchedTernaryOpBenchmark, +) + +# TODO: does it automatically register new scripts? if __name__ == "__main__": op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/index_add__test.py b/benchmarks/operator_benchmark/pt/index_add__test.py new file mode 100644 index 000000000000..d30de1975be6 --- /dev/null +++ b/benchmarks/operator_benchmark/pt/index_add__test.py @@ -0,0 +1,62 @@ +import numpy + +import operator_benchmark as op_bench + +import torch + + +"""Microbenchmarks for index_add_ operator.""" + + +configs_short = op_bench.config_list( + attr_names=["M", "N", "K", "dim"], + attrs=[[8, 32, 1, 0], [256, 512, 1, 1], [512, 512, 1, 2]], + cross_product_configs={"device": ["cpu"], "dtype": [torch.float]}, + tags=["short"], +) + + +configs_long = op_bench.cross_product_configs( + M=[1, 128, 1024], + N=[2, 256, 512], + K=[1, 2, 8], + dim=[0, 1, 2], + device=["cpu", "cuda"], + dtype=[torch.float], + tags=["long"], +) + + +class IndexAddBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, K, dim, dtype, device): + # creating the original tensor + tensor = torch.rand(M, N, K, dtype=dtype, device=device) + + # creating index + index_max_len = tensor.shape[dim] + index_len = numpy.random.randint(1, index_max_len + 1) + index = torch.tensor( + numpy.random.choice(index_max_len, index_len, replace=False), device=device + ) + + src_dims = [M, N, K] + src_dims[dim] = index_len + source = torch.rand(*src_dims, dtype=dtype, device=device) + + self.inputs = { + "tensor": tensor, + "dim": dim, + "index": index, + "source": source, + } + self.set_module_name("index_add_") + + def forward(self, tensor, dim, index, source): + return tensor.index_add_(dim, index, source) + + +op_bench.generate_pt_test(configs_short + configs_long, IndexAddBenchmark) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/matrix_mult_test.py b/benchmarks/operator_benchmark/pt/matrix_mult_test.py index c905b5661927..48e5ca66806f 100644 --- a/benchmarks/operator_benchmark/pt/matrix_mult_test.py +++ b/benchmarks/operator_benchmark/pt/matrix_mult_test.py @@ -36,7 +36,7 @@ batch_mm_op_list = op_bench.op_list( attr_names=["op_name", "op_func"], attrs=[ ["einsum_bmm", torch.einsum], - ["bmm", torch.bmm], + # ["bmm", torch.bmm], ], ) diff --git a/benchmarks/operator_benchmark/pt/mm_test.py b/benchmarks/operator_benchmark/pt/mm_test.py new file mode 100644 index 000000000000..bf2a2651e8fb --- /dev/null +++ b/benchmarks/operator_benchmark/pt/mm_test.py @@ -0,0 +1,53 @@ +import operator_benchmark as op_bench + +import torch + + +"""Microbenchmarks for torch.mm.""" + +# Benchmark ops performance without broadcast +ops_list = op_bench.op_list( + attr_names=["op_name", "op_func"], + attrs=[["mm", torch.mm]], +) + +mm_short_configs = op_bench.config_list( + attr_names=["M", "N", "K"], + attrs=[ + [1, 1, 1], + [64, 64, 64], + [64, 64, 128], + ], + cross_product_configs={"device": ["cpu"], "dtype": [torch.float]}, + tags=["short"], +) + +mm_long_configs = op_bench.cross_product_configs( + M=[8, 128], + N=[32, 64], + K=[256, 512], + device=["cpu", "cuda"], + dtype=[torch.float, torch.bfloat16], + tags=["long"], +) + + +class MmOpBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, K, device, dtype, op_func): + self.inputs = { + "input_one": torch.randn(M, N, device=device).to(dtype=dtype), + "input_two": torch.randn(N, K, device=device).to(dtype=dtype), + } + self.op_func = op_func + + def forward(self, input_one, input_two): + return self.op_func(input_one, input_two) + + +op_bench.generate_pt_tests_from_op_list( + ops_list, mm_short_configs + mm_long_configs, MmOpBenchmark +) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/ternary_test.py b/benchmarks/operator_benchmark/pt/ternary_test.py new file mode 100644 index 000000000000..23c3c77d04ad --- /dev/null +++ b/benchmarks/operator_benchmark/pt/ternary_test.py @@ -0,0 +1,57 @@ +import operator_benchmark as op_bench + +import torch + + +"""Microbenchmarks for ternary operators.""" + + +ternary_ops = op_bench.op_list( + attr_names=["op_name", "op_func"], + attrs=[ + ["addcmul", torch.addcmul], + ["addcdiv", torch.addcdiv], + ], +) + +ternary_configs_short = op_bench.config_list( + attr_names=["M", "N"], + attrs=[ + [1, 2], + [32, 64], + ], + cross_product_configs={ + "device": ["cpu"], + "dtype": [torch.float, torch.bfloat16], + }, + tags=["short"], +) + +ternary_configs_long = op_bench.cross_product_configs( + M=[8, 128], + N=[32, 64], + device=["cpu", "cuda"], + dtype=[torch.float, torch.bfloat16], + tags=["long"], +) + + +class TernaryOpBenchmark(op_bench.TorchBenchmarkBase): + def init(self, M, N, device, dtype, op_func): + self.inputs = { + "input_": torch.rand((M, N), device=device).to(dtype=dtype), + "tensor1": torch.rand((M, N), device=device).to(dtype=dtype), + "tensor2": torch.rand((M, N), device=device).to(dtype=dtype), + } + self.op_func = op_func + + def forward(self, input_, tensor1, tensor2): + return self.op_func(input_, tensor1, tensor2) + + +op_bench.generate_pt_tests_from_op_list( + ternary_ops, ternary_configs_short + ternary_configs_long, TernaryOpBenchmark +) + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/topk_test.py b/benchmarks/operator_benchmark/pt/topk_test.py new file mode 100644 index 000000000000..28fc251e8b17 --- /dev/null +++ b/benchmarks/operator_benchmark/pt/topk_test.py @@ -0,0 +1,46 @@ +import operator_benchmark as op_bench + +import torch + + +"""Microbenchmarks for topk operator""" + + +topk_configs_short = op_bench.config_list( + attr_names=["shape", "k", "dim"], + attrs=[ + [(16, 4), 4, 1], + [(1024 * 1024,), 16, 0], + ], + cross_product_configs={"device": ["cpu"], "dtype": [torch.float]}, + tags=["short"], +) + +topk_configs_long = op_bench.cross_product_configs( + shape=[(64, 2), (1024 * 1024,), (128,)], + k=[1, 2, 4, 16, 32], + dim=[0], + device=["cpu", "cuda"], + dtype=[torch.float, torch.bfloat16], + tags=["long"], +) + + +class TopkBenchmark(op_bench.TorchBenchmarkBase): + def init(self, shape, k, dim, dtype, device): + self.inputs = { + "input": torch.randn(shape, device=device, dtype=dtype), + "k": k, + "dim": dim, + } + + self.set_module_name("topk") + + def forward(self, input, k, dim): + return torch.topk(input, k=k, dim=dim) + + +op_bench.generate_pt_test(topk_configs_short + topk_configs_long, TopkBenchmark) + +if __name__ == "__main__": + op_bench.benchmark_runner.main() diff --git a/benchmarks/operator_benchmark/pt/unary_test.py b/benchmarks/operator_benchmark/pt/unary_test.py index e605c7313965..f2b7c40d974b 100644 --- a/benchmarks/operator_benchmark/pt/unary_test.py +++ b/benchmarks/operator_benchmark/pt/unary_test.py @@ -72,6 +72,10 @@ def long_(input): return input.long() +def clamp(input): + return torch.clamp(input, min=0.25, max=0.75) + + unary_ops_list = op_bench.op_list( attr_names=["op_name", "op_func"], attrs=[ @@ -86,6 +90,7 @@ unary_ops_list = op_bench.op_list( ["atan_", torch.atan_], ["ceil", torch.ceil], ["ceil_", torch.ceil_], + ["clamp", clamp], ["clone", torch.clone], ["cos", torch.cos], ["cos_", torch.cos_], @@ -104,6 +109,7 @@ unary_ops_list = op_bench.op_list( ["floor_", torch.floor_], ["frac", torch.frac], ["frac_", torch.frac_], + ["gelu", torch.nn.functional.gelu], ["hardshrink", torch.hardshrink], ["lgamma", torch.lgamma], ["log", torch.log], diff --git a/benchmarks/operator_benchmark/pt/where_test.py b/benchmarks/operator_benchmark/pt/where_test.py new file mode 100644 index 000000000000..e94fbc4ccfa6 --- /dev/null +++ b/benchmarks/operator_benchmark/pt/where_test.py @@ -0,0 +1,51 @@ +import operator_benchmark as op_bench + +import torch + + +"""Microbenchmarks for where operator.""" + + +configs_short = op_bench.config_list( + attr_names=["cond_shape", "input_shape", "other_shape"], + attrs=[ + [(8, 16, 1), (1,), (1,)], + [(8, 16, 1), (16, 1), (8, 16, 1)], + [(8, 16, 1), (8, 1, 1), (1,)], + ], + cross_product_configs={"device": ["cpu"], "dtype": [torch.float]}, + tags=["short"], +) + + +configs_long = op_bench.cross_product_configs( + cond_shape=[(64, 16, 1), (64, 16, 8), (1024, 64, 16, 128)], + input_shape=[(1,), (16, 1), (64, 16, 1)], + other_shape=[(1,), (16, 1), (64, 16, 1)], + device=["cpu", "cuda"], + dtype=[torch.float], + tags=["long"], +) + + +class WhereBenchmark(op_bench.TorchBenchmarkBase): + def init(self, cond_shape, input_shape, other_shape, dtype, device): + def _create_tensor(shape): + return torch.randn(*shape, dtype=dtype, device=device) + + self.inputs = { + "condition": _create_tensor(cond_shape) > 0, + "input": _create_tensor(input_shape), + "other": _create_tensor(other_shape), + } + self.set_module_name("where") + + def forward(self, condition, input, other): + return torch.where(condition, input, other) + + +op_bench.generate_pt_test(configs_short + configs_long, WhereBenchmark) + + +if __name__ == "__main__": + op_bench.benchmark_runner.main()