mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Changes: 1. Bump `ruff` from 0.7.4 to 0.8.4 2. Change `%`-formatted strings to f-string 3. Change arguments with the `__`-prefix to positional-only arguments with the `/` separator in function signature. Pull Request resolved: https://github.com/pytorch/pytorch/pull/143753 Approved by: https://github.com/Skylion007
341 lines
11 KiB
Python
341 lines
11 KiB
Python
import argparse
|
|
import itertools
|
|
import os
|
|
|
|
# from . import conv # noqa: F401
|
|
# from . import normalization # noqa: F401
|
|
# from . import pooling # noqa: F401
|
|
from . import ( # noqa: F401
|
|
attention,
|
|
benchmark,
|
|
broadcast,
|
|
concat,
|
|
elementwise,
|
|
matmul,
|
|
reduction,
|
|
rnn_eltwise,
|
|
softmax,
|
|
swish,
|
|
tensor_engine,
|
|
)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
description="""Benchmark operators in specific shapes.
|
|
Works only with Python3.\n A few examples:
|
|
* benchmark.py: runs all the default configs with all the benchmarks.
|
|
* benchmark.py reduce: runs all the default configs with all benchmark with a prefix 'reduce'
|
|
* benchmark.py layernorm_fwd_cpu_128_32_128_128: run a particular benchmark in that config""",
|
|
)
|
|
parser.add_argument(
|
|
"benchmark_names",
|
|
type=str,
|
|
default=None,
|
|
nargs="*",
|
|
help="name of the benchmark to run",
|
|
)
|
|
parser.add_argument(
|
|
"--device",
|
|
type=str,
|
|
default="cpu,cuda",
|
|
help="a comma separated list of device names",
|
|
)
|
|
parser.add_argument(
|
|
"--mode",
|
|
type=str,
|
|
default="fwd,both",
|
|
help="a comma separated list of running modes",
|
|
)
|
|
parser.add_argument(
|
|
"--dtype",
|
|
type=str,
|
|
default="float32",
|
|
help="a comma separated list of Data Types: {float32[default], float16}",
|
|
)
|
|
parser.add_argument(
|
|
"--input-iter",
|
|
type=str,
|
|
default=None,
|
|
help="a comma separated list of Tensor dimensions that includes a start, \
|
|
stop, and increment that can be constant or a power of 2 \
|
|
{start:stop:inc,start:stop:pow2}",
|
|
)
|
|
parser.add_argument(
|
|
"--engine",
|
|
type=str,
|
|
default="pt",
|
|
help="the underlying tensor engine. only pt for now",
|
|
)
|
|
parser.add_argument(
|
|
"--jit-mode",
|
|
"--jit_mode",
|
|
type=str,
|
|
default="trace",
|
|
help="the jit mode to use: one of {trace, none}",
|
|
)
|
|
parser.add_argument(
|
|
"--cuda-pointwise-loop-levels",
|
|
"--cuda_pointwise_loop_levels",
|
|
type=int,
|
|
default=None,
|
|
help="num of loop levesl for Cuda pointwise operations: 2 or 3",
|
|
)
|
|
parser.add_argument(
|
|
"--cuda-pointwise-block-count",
|
|
"--cuda_pointwise_block_count",
|
|
type=int,
|
|
default=None,
|
|
help="num of block for Cuda pointwise operations",
|
|
)
|
|
parser.add_argument(
|
|
"--cuda-pointwise-block-size",
|
|
"--cuda_pointwise_block_size",
|
|
type=int,
|
|
default=None,
|
|
help="num of blocks for Cuda pointwise operations",
|
|
)
|
|
parser.add_argument(
|
|
"--cuda-fuser",
|
|
"--cuda_fuser",
|
|
type=str,
|
|
default="te",
|
|
help="The Cuda fuser backend to use: one of {te, nvf, old, none}",
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
type=str,
|
|
default="stdout",
|
|
help="The output format of the benchmark run {stdout[default], json}",
|
|
)
|
|
parser.add_argument(
|
|
"--print-ir",
|
|
action="store_true",
|
|
help="Print the IR graph of the Fusion.",
|
|
)
|
|
parser.add_argument(
|
|
"--print-kernel",
|
|
action="store_true",
|
|
help="Print generated kernel(s).",
|
|
)
|
|
parser.add_argument(
|
|
"--no-dynamic-shape",
|
|
action="store_true",
|
|
help="Disable shape randomization in dynamic benchmarks.",
|
|
)
|
|
parser.add_argument(
|
|
"--cpu-fusion",
|
|
"--cpu_fusion",
|
|
default=False,
|
|
action="store_true",
|
|
help="Enable CPU fusion.",
|
|
)
|
|
parser.add_argument(
|
|
"--cat-wo-conditionals",
|
|
"--cat_wo_conditionals",
|
|
default=False,
|
|
action="store_true",
|
|
help="Enable CAT wo conditionals.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.cuda_fuser == "te":
|
|
import torch
|
|
|
|
torch._C._jit_set_profiling_executor(True)
|
|
torch._C._jit_set_texpr_fuser_enabled(True)
|
|
torch._C._jit_override_can_fuse_on_gpu(True)
|
|
torch._C._get_graph_executor_optimize(True)
|
|
elif args.cuda_fuser == "old":
|
|
import torch
|
|
|
|
torch._C._jit_set_profiling_executor(False)
|
|
torch._C._jit_set_texpr_fuser_enabled(False)
|
|
torch._C._jit_override_can_fuse_on_gpu(True)
|
|
elif args.cuda_fuser == "nvf":
|
|
import torch
|
|
|
|
torch._C._jit_set_profiling_executor(True)
|
|
torch._C._jit_set_texpr_fuser_enabled(False)
|
|
torch._C._jit_set_nvfuser_enabled(True)
|
|
torch._C._get_graph_executor_optimize(True)
|
|
else:
|
|
raise ValueError(f"Undefined fuser: {args.cuda_fuser}")
|
|
|
|
if args.cpu_fusion:
|
|
import torch
|
|
|
|
torch._C._jit_override_can_fuse_on_cpu(True)
|
|
else:
|
|
import torch
|
|
|
|
torch._C._jit_override_can_fuse_on_cpu(False)
|
|
|
|
if args.cat_wo_conditionals:
|
|
import torch
|
|
|
|
torch._C._jit_cat_wo_conditionals(True)
|
|
else:
|
|
import torch
|
|
|
|
torch._C._jit_cat_wo_conditionals(False)
|
|
|
|
def set_global_threads(num_threads):
|
|
os.environ["OMP_NUM_THREADS"] = str(num_threads)
|
|
os.environ["MKL_NUM_THREADS"] = str(num_threads)
|
|
os.environ["TVM_NUM_THREADS"] = str(num_threads)
|
|
os.environ["NNC_NUM_THREADS"] = str(num_threads)
|
|
|
|
devices = args.device.split(",")
|
|
# accept 'gpu' as an alternative as the 'cuda' device
|
|
devices = ["cuda" if device == "gpu" else device for device in devices]
|
|
cpu_count = 0
|
|
for index, device in enumerate(devices):
|
|
if device.startswith("cpu"):
|
|
cpu_count += 1
|
|
if cpu_count > 1:
|
|
raise ValueError(
|
|
f"more than one CPU device is not allowed: {cpu_count:d}"
|
|
)
|
|
if device == "cpu":
|
|
continue
|
|
num_threads_str = device[3:]
|
|
try:
|
|
# see if the device is in 'cpu1' or 'cpu4' format
|
|
num_threads = int(num_threads_str)
|
|
set_global_threads(num_threads)
|
|
devices[index] = "cpu"
|
|
except ValueError:
|
|
continue
|
|
|
|
modes = args.mode.split(",")
|
|
|
|
datatypes = args.dtype.split(",")
|
|
for index, dtype in enumerate(datatypes):
|
|
datatypes[index] = getattr(torch, dtype)
|
|
if not datatypes[index]:
|
|
raise AttributeError(f"DataType: {dtype} is not valid!")
|
|
|
|
tensor_engine.set_engine_mode(args.engine)
|
|
|
|
def run_default_configs(bench_cls, allow_skip=True):
|
|
for mode, device, dtype, config in itertools.product(
|
|
modes, devices, datatypes, bench_cls.default_configs()
|
|
):
|
|
bench = bench_cls(mode, device, dtype, *config)
|
|
bench.output_type = args.output
|
|
bench.jit_mode = args.jit_mode
|
|
if not bench.is_supported():
|
|
if allow_skip:
|
|
continue
|
|
else:
|
|
raise ValueError(
|
|
f"attempted to run an unsupported benchmark: {bench.desc()}"
|
|
)
|
|
bench.run(args)
|
|
|
|
def run_with_input_iter(bench_cls, input_iter, allow_skip=True):
|
|
tensor_dim_specs = input_iter.split(",")
|
|
tensor_dim_specs = [dim.split(":") for dim in tensor_dim_specs]
|
|
|
|
configs = []
|
|
for start, stop, inc in tensor_dim_specs:
|
|
dim_list = []
|
|
if inc == "pow2":
|
|
curr = int(start)
|
|
while curr <= int(stop):
|
|
dim_list.append(curr)
|
|
curr <<= 1
|
|
elif inc == "pow2+1":
|
|
curr = int(start)
|
|
while curr <= int(stop):
|
|
dim_list.append(curr)
|
|
curr -= 1
|
|
curr <<= 1
|
|
curr += 1
|
|
else:
|
|
dim_list = list(range(int(start), int(stop) + int(inc), int(inc)))
|
|
configs.append(dim_list)
|
|
configs = itertools.product(*configs)
|
|
|
|
for mode, device, dtype, config in itertools.product(
|
|
modes, devices, datatypes, list(configs)
|
|
):
|
|
bench = bench_cls(mode, device, dtype, *config)
|
|
bench.output_type = args.output
|
|
bench.jit_mode = args.jit_mode
|
|
if not bench.is_supported():
|
|
if allow_skip:
|
|
continue
|
|
else:
|
|
raise ValueError(
|
|
f"attempted to run an unsupported benchmark: {bench.desc()}"
|
|
)
|
|
bench.run(args)
|
|
|
|
benchmark_classes = benchmark.benchmark_classes
|
|
if not args.benchmark_names:
|
|
# by default, run all the benchmarks
|
|
for benchmark_cls in benchmark_classes:
|
|
run_default_configs(benchmark_cls, allow_skip=True)
|
|
else:
|
|
for name in args.benchmark_names:
|
|
# if the name is the prefix of a benchmark class, run all the benchmarks for that class
|
|
match_class_name = False
|
|
for bench_cls in benchmark_classes:
|
|
if name in bench_cls.module():
|
|
match_class_name = True
|
|
if (args.input_iter is not None) and bench_cls.input_iterable():
|
|
run_with_input_iter(bench_cls, args.input_iter, allow_skip=True)
|
|
else:
|
|
if args.input_iter is not None:
|
|
print(
|
|
f"WARNING: Incompatible benchmark class called with input_iter arg: {name}"
|
|
)
|
|
run_default_configs(bench_cls, allow_skip=True)
|
|
|
|
if match_class_name:
|
|
continue
|
|
|
|
# if not a class module, parse the config and call it that way
|
|
match_class_name = False
|
|
for bench_cls in benchmark_classes:
|
|
cls_module = bench_cls.module()
|
|
if name.startswith(cls_module):
|
|
match_class_name = True
|
|
if name[len(cls_module)] != "_":
|
|
raise ValueError(f"invalid name: {name}")
|
|
config_str = name[(len(cls_module) + 1) :]
|
|
config = config_str.split("_")
|
|
if len(config) < 2:
|
|
raise ValueError(f"invalid config: {config}")
|
|
mode, device = config[0:2]
|
|
# TODO: make sure virtual devices such as 'cpu1' and 'cpu4' are supported.
|
|
if mode not in ["fwd", "both"]:
|
|
raise ValueError(f"invalid mode: {mode}")
|
|
for i, entry in enumerate(config):
|
|
try:
|
|
value = int(entry)
|
|
config[i] = value
|
|
except ValueError:
|
|
pass
|
|
# TODO: output dtype in the config and parse it back from the str
|
|
bench = bench_cls(config[0], config[1], torch.float32, *config[2:])
|
|
bench.jit_mode = args.jit_mode
|
|
bench.output_type = args.output
|
|
bench.run(args)
|
|
|
|
if not match_class_name:
|
|
available_classes = ", ".join(
|
|
[bench_cls.module() for bench_cls in benchmark_classes]
|
|
)
|
|
raise ValueError(
|
|
f"invalid name: {name}\nAvailable benchmark classes:\n{available_classes}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|