mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Update the operator benchmarking, to benchmark using torch.compile (#161394)
This pull request enhances the PyTorch operator benchmarking suite by introducing support for benchmarking with `torch.compile` mode, in addition to existing Eager and JIT. It also adds peak memory measurement (fwd/bwd pass); improves the output format in JSON to be used by dashboard for reporting; and introduce some more CLI options. The new CLI flags introduced are: - Added `--use-compile` CLI argument and corresponding logic to run benchmarks using `torch.compile`, including mutual exclusivity with `--use-jit` - Added `--benchmark-name` argument for customizing the benchmark name in output - Updated default value for `--output-json-for-dashboard` to `benchmark-results.json` for more predictable output file name Sample command to run a single operator: `python -m pt.mm_test --use-compile` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161394 Approved by: https://github.com/jbschlosser
This commit is contained in:
committed by
PyTorch MergeBot
parent
82f1eb9b03
commit
af60398c3a
@ -4,6 +4,7 @@ import csv
|
|||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
import timeit
|
import timeit
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
@ -191,6 +192,11 @@ class BenchmarkRunner:
|
|||||||
self.predefined_minimum_secs = 1
|
self.predefined_minimum_secs = 1
|
||||||
self.max_iters = 1e6
|
self.max_iters = 1e6
|
||||||
self.use_jit = args.use_jit
|
self.use_jit = args.use_jit
|
||||||
|
self.use_compile = args.use_compile
|
||||||
|
if self.use_jit and self.use_compile:
|
||||||
|
raise ValueError(
|
||||||
|
"use_jit and use_compile are mutually exclusive, please specify one."
|
||||||
|
)
|
||||||
self.num_runs = args.num_runs
|
self.num_runs = args.num_runs
|
||||||
self.print_per_iter = False
|
self.print_per_iter = False
|
||||||
self.output_csv = args.output_csv
|
self.output_csv = args.output_csv
|
||||||
@ -222,7 +228,7 @@ class BenchmarkRunner:
|
|||||||
if self.args.operators:
|
if self.args.operators:
|
||||||
print(f"# {self.args.operators}")
|
print(f"# {self.args.operators}")
|
||||||
|
|
||||||
def _print_perf_result(self, reported_run_time_us, test_case):
|
def _print_perf_result(self, results, test_case):
|
||||||
if self.args.report_aibench:
|
if self.args.report_aibench:
|
||||||
# Output for AIBench
|
# Output for AIBench
|
||||||
# Print out per iteration execution time instead of avg time
|
# Print out per iteration execution time instead of avg time
|
||||||
@ -236,12 +242,14 @@ class BenchmarkRunner:
|
|||||||
"type": test_name,
|
"type": test_name,
|
||||||
"metric": "latency",
|
"metric": "latency",
|
||||||
"unit": "us",
|
"unit": "us",
|
||||||
"value": str(reported_run_time_us[run]),
|
"value": str(results["reported_run_time_us"[run]]),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(f"# Mode: {'JIT' if self.use_jit else 'Eager'}")
|
print(
|
||||||
|
f"# Mode: {'JIT' if self.use_jit else 'Compile' if self.use_compile else 'Eager'}"
|
||||||
|
)
|
||||||
print(
|
print(
|
||||||
f"# Name: {test_case.test_config.test_name}\n# Input: {test_case.test_config.input_config}"
|
f"# Name: {test_case.test_config.test_name}\n# Input: {test_case.test_config.input_config}"
|
||||||
)
|
)
|
||||||
@ -250,25 +258,33 @@ class BenchmarkRunner:
|
|||||||
if self.num_runs > 1:
|
if self.num_runs > 1:
|
||||||
for run in range(self.num_runs):
|
for run in range(self.num_runs):
|
||||||
print(
|
print(
|
||||||
f"Run: {run}, {mode} Execution Time (us) : {reported_run_time_us[run]:.3f}"
|
f"Run: {run}, {mode} Execution Time (us) : {results['reported_run_time_us'][run]:.3f}"
|
||||||
)
|
)
|
||||||
print()
|
print()
|
||||||
else:
|
else:
|
||||||
print(f"{mode} Execution Time (us) : {reported_run_time_us[0]:.3f}\n")
|
print(
|
||||||
|
f"{mode} Execution Time (us) : {results['reported_run_time_us'][0]:.3f}"
|
||||||
|
)
|
||||||
|
print(f"Peak Memory (KB) : {results['peak_memory']}\n")
|
||||||
|
|
||||||
def _perf_result_to_dict(self, reported_run_time_us, test_case):
|
def _perf_result_to_dict(self, results, test_case):
|
||||||
"""This function is the parallel of _print_perf_result, which instead of
|
"""This function is the parallel of _print_perf_result, which instead of
|
||||||
writing information to terminal, returns a dictionary.
|
writing information to terminal, returns a dictionary.
|
||||||
"""
|
"""
|
||||||
if self.args.report_aibench:
|
if self.args.report_aibench:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
out = {
|
out = {
|
||||||
"test_name": test_case.test_config.test_name,
|
"test_name": test_case.test_config.test_name,
|
||||||
"input_config": test_case.test_config.input_config,
|
"input_config": test_case.test_config.input_config,
|
||||||
"mode": "JIT" if self.use_jit else "Eager",
|
"runtime": (
|
||||||
|
"JIT" if self.use_jit else "Compile" if self.use_compile else "Eager"
|
||||||
|
),
|
||||||
"run": "Backward" if test_case.test_config.run_backward else "Forward",
|
"run": "Backward" if test_case.test_config.run_backward else "Forward",
|
||||||
"latency": round(reported_run_time_us[0], 3),
|
"latency": round(results["reported_run_time_us"][0], 3),
|
||||||
"latency unit": "us",
|
"latency unit": "us",
|
||||||
|
"peak memory": results["peak_memory"],
|
||||||
|
"memory unit": "KB",
|
||||||
}
|
}
|
||||||
|
|
||||||
# parsing test_case.test_config.input_config, adding it as entries to the 'out' dictionary
|
# parsing test_case.test_config.input_config, adding it as entries to the 'out' dictionary
|
||||||
@ -330,6 +346,8 @@ class BenchmarkRunner:
|
|||||||
func = test_case.run_forward
|
func = test_case.run_forward
|
||||||
if self.use_jit:
|
if self.use_jit:
|
||||||
func = test_case.run_jit_forward
|
func = test_case.run_jit_forward
|
||||||
|
if self.use_compile:
|
||||||
|
func = test_case.run_compile_forward
|
||||||
forward_time = timeit.timeit(
|
forward_time = timeit.timeit(
|
||||||
functools.partial(func, iters, print_per_iter, cuda_sync), number=1
|
functools.partial(func, iters, print_per_iter, cuda_sync), number=1
|
||||||
)
|
)
|
||||||
@ -346,7 +364,7 @@ class BenchmarkRunner:
|
|||||||
)
|
)
|
||||||
return backward_time
|
return backward_time
|
||||||
|
|
||||||
def _measure_time(self, launch_test, test_case, iters, print_per_iter):
|
def _measure_metrics(self, launch_test, test_case, iters, print_per_iter):
|
||||||
"""
|
"""
|
||||||
This function execute the operator for <iters> iterations then look at the time.
|
This function execute the operator for <iters> iterations then look at the time.
|
||||||
If it's not significant, the number of iterations will be increased before rerun.
|
If it's not significant, the number of iterations will be increased before rerun.
|
||||||
@ -354,8 +372,20 @@ class BenchmarkRunner:
|
|||||||
"""
|
"""
|
||||||
curr_test_total_time = 0
|
curr_test_total_time = 0
|
||||||
time_trace = []
|
time_trace = []
|
||||||
|
peak_memory = 0
|
||||||
|
sample_input = next(iter(test_case.op_bench.inputs.values()))
|
||||||
|
device = sample_input.device
|
||||||
|
device_module = torch.get_device_module(device.type)
|
||||||
|
# TODO: add support for cpu memory measurement
|
||||||
while True:
|
while True:
|
||||||
|
if hasattr(device_module, "reset_peak_memory_stats"):
|
||||||
|
device_module.reset_peak_memory_stats(device)
|
||||||
run_time_sec = launch_test(test_case, iters, print_per_iter)
|
run_time_sec = launch_test(test_case, iters, print_per_iter)
|
||||||
|
if hasattr(device_module, "synchronize"):
|
||||||
|
device_module.synchronize(device)
|
||||||
|
# Memory measurement process
|
||||||
|
if hasattr(device_module, "max_memory_allocated"):
|
||||||
|
peak_memory = device_module.max_memory_allocated(device)
|
||||||
curr_test_total_time += run_time_sec
|
curr_test_total_time += run_time_sec
|
||||||
# Analyze time after each run to decide if the result is stable
|
# Analyze time after each run to decide if the result is stable
|
||||||
results_are_significant = self._iteration_result_is_significant(
|
results_are_significant = self._iteration_result_is_significant(
|
||||||
@ -369,7 +399,13 @@ class BenchmarkRunner:
|
|||||||
time_trace.append(report_run_time)
|
time_trace.append(report_run_time)
|
||||||
# Print out the time spent in each epoch in ms
|
# Print out the time spent in each epoch in ms
|
||||||
if self.args.report_aibench:
|
if self.args.report_aibench:
|
||||||
mode = "JIT" if self.use_jit else "Eager"
|
mode = (
|
||||||
|
"JIT"
|
||||||
|
if self.use_jit
|
||||||
|
else "Compile"
|
||||||
|
if self.use_compile
|
||||||
|
else "Eager"
|
||||||
|
)
|
||||||
test_name = "_".join(
|
test_name = "_".join(
|
||||||
[test_case.framework, test_case.test_config.test_name, mode]
|
[test_case.framework, test_case.test_config.test_name, mode]
|
||||||
)
|
)
|
||||||
@ -381,7 +417,7 @@ class BenchmarkRunner:
|
|||||||
"metric": "latency",
|
"metric": "latency",
|
||||||
"unit": "ms",
|
"unit": "ms",
|
||||||
"value": str(report_run_time / 1e3),
|
"value": str(report_run_time / 1e3),
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if results_are_significant:
|
if results_are_significant:
|
||||||
@ -391,7 +427,7 @@ class BenchmarkRunner:
|
|||||||
# iteration count, and run the benchmark again...
|
# iteration count, and run the benchmark again...
|
||||||
iters = self._predict_num_iter_needed(iters)
|
iters = self._predict_num_iter_needed(iters)
|
||||||
reported_run_time_us = np.percentile(np.array(time_trace), 50)
|
reported_run_time_us = np.percentile(np.array(time_trace), 50)
|
||||||
return reported_run_time_us
|
return reported_run_time_us, peak_memory / 1024
|
||||||
|
|
||||||
def _check_keep(self, test_flag, cmd_flag):
|
def _check_keep(self, test_flag, cmd_flag):
|
||||||
return cmd_flag is None or test_flag == cmd_flag
|
return cmd_flag is None or test_flag == cmd_flag
|
||||||
@ -478,6 +514,7 @@ class BenchmarkRunner:
|
|||||||
self,
|
self,
|
||||||
perf_list,
|
perf_list,
|
||||||
output_file,
|
output_file,
|
||||||
|
benchmark_name="PyTorch operator benchmark",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Write the result into JSON format, so that it can be uploaded to the benchmark database
|
Write the result into JSON format, so that it can be uploaded to the benchmark database
|
||||||
@ -495,8 +532,10 @@ class BenchmarkRunner:
|
|||||||
input_config = perf_item.get("input_config", "")
|
input_config = perf_item.get("input_config", "")
|
||||||
run_type = perf_item.get("run")
|
run_type = perf_item.get("run")
|
||||||
latency = perf_item.get("latency", 0)
|
latency = perf_item.get("latency", 0)
|
||||||
|
peak_memory = perf_item.get("peak memory", 0)
|
||||||
dtype = "float32" # default
|
device = perf_item.get("device", "unknown")
|
||||||
|
dtype = perf_item.get("dtype", "torch.float").split(".")[1]
|
||||||
|
runtime = perf_item.get("runtime", None)
|
||||||
|
|
||||||
# Extract mode based on run_type
|
# Extract mode based on run_type
|
||||||
mode = None
|
mode = None
|
||||||
@ -505,6 +544,22 @@ class BenchmarkRunner:
|
|||||||
elif run_type == "Backward":
|
elif run_type == "Backward":
|
||||||
mode = "training"
|
mode = "training"
|
||||||
|
|
||||||
|
# Extract use_compile from it
|
||||||
|
if runtime == "Compile":
|
||||||
|
use_compile = True
|
||||||
|
elif runtime == "Eager":
|
||||||
|
use_compile = False
|
||||||
|
else:
|
||||||
|
use_compile = None
|
||||||
|
|
||||||
|
device_arch = (
|
||||||
|
torch.cuda.get_device_name(0)
|
||||||
|
if device == "cuda"
|
||||||
|
else platform.processor()
|
||||||
|
if device == "cpu"
|
||||||
|
else "unknown"
|
||||||
|
)
|
||||||
|
|
||||||
# Create the record
|
# Create the record
|
||||||
@dataclass
|
@dataclass
|
||||||
class BenchmarkInfo:
|
class BenchmarkInfo:
|
||||||
@ -532,12 +587,18 @@ class BenchmarkRunner:
|
|||||||
model: ModelInfo
|
model: ModelInfo
|
||||||
metric: MetricInfo
|
metric: MetricInfo
|
||||||
|
|
||||||
record = BenchmarkRecord(
|
# Add record for latency
|
||||||
|
record_latency = BenchmarkRecord(
|
||||||
benchmark=BenchmarkInfo(
|
benchmark=BenchmarkInfo(
|
||||||
name="PyTorch operator benchmark",
|
name=benchmark_name,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
extra_info={"input_config": input_config},
|
extra_info={
|
||||||
|
"input_config": input_config,
|
||||||
|
"device": device,
|
||||||
|
"arch": device_arch,
|
||||||
|
"use_compile": use_compile,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
model=ModelInfo(
|
model=ModelInfo(
|
||||||
name=test_name, type="micro-benchmark", origins=["pytorch"]
|
name=test_name, type="micro-benchmark", origins=["pytorch"]
|
||||||
@ -549,8 +610,17 @@ class BenchmarkRunner:
|
|||||||
target_value=None,
|
target_value=None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
records.append(asdict(record_latency))
|
||||||
|
|
||||||
records.append(asdict(record))
|
# Add record for peak memory
|
||||||
|
record_memory = copy.deepcopy(record_latency)
|
||||||
|
record_memory.metric = MetricInfo(
|
||||||
|
name="peak memory",
|
||||||
|
unit="KB",
|
||||||
|
benchmark_values=[peak_memory],
|
||||||
|
target_value=None,
|
||||||
|
)
|
||||||
|
records.append(asdict(record_memory))
|
||||||
|
|
||||||
# Write all records to the output file
|
# Write all records to the output file
|
||||||
with open(output_file, "w", encoding="utf-8") as f:
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
@ -566,6 +636,7 @@ class BenchmarkRunner:
|
|||||||
"tag",
|
"tag",
|
||||||
"run_backward",
|
"run_backward",
|
||||||
"Execution Time",
|
"Execution Time",
|
||||||
|
"Peak Memory (KB)",
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.args.output_json or self.args.output_json_for_dashboard:
|
if self.args.output_json or self.args.output_json_for_dashboard:
|
||||||
@ -603,13 +674,16 @@ class BenchmarkRunner:
|
|||||||
test_case, self.args.warmup_iterations, print_per_iter=False
|
test_case, self.args.warmup_iterations, print_per_iter=False
|
||||||
)
|
)
|
||||||
# Actual Execution
|
# Actual Execution
|
||||||
reported_time = [
|
results = [
|
||||||
self._measure_time(
|
self._measure_metrics(
|
||||||
launch_func, test_case, self.iters, self.print_per_iter
|
launch_func, test_case, self.iters, self.print_per_iter
|
||||||
)
|
)
|
||||||
for _ in range(self.num_runs)
|
for _ in range(self.num_runs)
|
||||||
]
|
]
|
||||||
self._print_perf_result(reported_time, test_case)
|
result_dict = dict()
|
||||||
|
result_dict["reported_run_time_us"] = [r[0] for r in results]
|
||||||
|
result_dict["peak_memory"] = results[0][1]
|
||||||
|
self._print_perf_result(results=result_dict, test_case=test_case)
|
||||||
|
|
||||||
# output results to csv
|
# output results to csv
|
||||||
self._output_csv(
|
self._output_csv(
|
||||||
@ -625,16 +699,17 @@ class BenchmarkRunner:
|
|||||||
),
|
),
|
||||||
test_case.test_config.tag,
|
test_case.test_config.tag,
|
||||||
test_case.test_config.run_backward,
|
test_case.test_config.run_backward,
|
||||||
reported_time[0],
|
result_dict["reported_run_time_us"][0],
|
||||||
|
result_dict["peak_memory"],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
if self.args.output_json or self.args.output_json_for_dashboard:
|
if self.args.output_json or self.args.output_json_for_dashboard:
|
||||||
perf_list.append(
|
perf_list.append(self._perf_result_to_dict(result_dict, test_case))
|
||||||
self._perf_result_to_dict(reported_time, test_case)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.output_json_for_dashboard:
|
if self.args.output_json_for_dashboard:
|
||||||
self._output_json(perf_list, self.args.output_json_for_dashboard)
|
self._output_json(
|
||||||
|
perf_list, self.args.output_json_for_dashboard, self.args.benchmark_name
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.output_json:
|
if self.args.output_json:
|
||||||
with open(self.args.output_json, "w") as f:
|
with open(self.args.output_json, "w") as f:
|
||||||
|
@ -4,6 +4,15 @@ import time
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# Import the C++ extension to register the _consume operator
|
||||||
|
try:
|
||||||
|
import benchmark_cpp_extension # noqa: F401
|
||||||
|
except ImportError as err:
|
||||||
|
# If the extension isn't built, the script must raise an error
|
||||||
|
raise ImportError(
|
||||||
|
"Failed to import C++ extension, please build it using \ncd pt_extension \npython -m pip install ."
|
||||||
|
) from err
|
||||||
|
|
||||||
"""PyTorch performance microbenchmarks.
|
"""PyTorch performance microbenchmarks.
|
||||||
|
|
||||||
This module contains PyTorch-specific functionalities for performance
|
This module contains PyTorch-specific functionalities for performance
|
||||||
@ -71,6 +80,16 @@ class TorchBenchmarkBase(torch.nn.Module):
|
|||||||
for _ in range(iters):
|
for _ in range(iters):
|
||||||
torch.ops.operator_benchmark._consume(self.forward_impl())
|
torch.ops.operator_benchmark._consume(self.forward_impl())
|
||||||
|
|
||||||
|
def forward_impl_eager(self):
|
||||||
|
# This is to supply the inputs to the forward function which
|
||||||
|
# will be called in both the eager and compile mode of local runs
|
||||||
|
return self.forward(*self.get_inputs())
|
||||||
|
|
||||||
|
def forward_consume_eager(self, iters: int):
|
||||||
|
# Eager version of forward_consume without decorators (compilation handled by torch.compile)
|
||||||
|
for _ in range(iters):
|
||||||
|
torch.ops.operator_benchmark._consume(self.forward_impl_eager())
|
||||||
|
|
||||||
def module_name(self):
|
def module_name(self):
|
||||||
"""this is used to label the operator being benchmarked"""
|
"""this is used to label the operator being benchmarked"""
|
||||||
if self.user_given_name:
|
if self.user_given_name:
|
||||||
@ -117,18 +136,32 @@ class PyTorchOperatorTestCase:
|
|||||||
self.framework = "PyTorch"
|
self.framework = "PyTorch"
|
||||||
self.time_series = []
|
self.time_series = []
|
||||||
self._jit_forward_graph = None
|
self._jit_forward_graph = None
|
||||||
|
self._compile_forward_graph = None
|
||||||
|
|
||||||
def _generate_jit_forward_graph(self):
|
def _generate_jit_forward_graph(self):
|
||||||
"""generate a graph for the forward function via scripting"""
|
"""generate a graph for the forward function via scripting"""
|
||||||
scripted_op_bench = torch.jit.script(self.op_bench)
|
scripted_op_bench = torch.jit.script(self.op_bench)
|
||||||
return scripted_op_bench.forward_consume
|
return scripted_op_bench.forward_consume
|
||||||
|
|
||||||
|
def _generate_compile_forward_graph(self):
|
||||||
|
"""generate a compiled graph for the forward function via torch.compile"""
|
||||||
|
compiled_forward_consume = torch.compile(
|
||||||
|
self.op_bench.forward_consume_eager, backend="inductor"
|
||||||
|
)
|
||||||
|
return compiled_forward_consume
|
||||||
|
|
||||||
def run_jit_forward(self, num_runs, print_per_iter=False, cuda_sync=False):
|
def run_jit_forward(self, num_runs, print_per_iter=False, cuda_sync=False):
|
||||||
"""Run the forward path of an op with JIT mode"""
|
"""Run the forward path of an op with JIT mode"""
|
||||||
if self._jit_forward_graph is None:
|
if self._jit_forward_graph is None:
|
||||||
self._jit_forward_graph = self._generate_jit_forward_graph()
|
self._jit_forward_graph = self._generate_jit_forward_graph()
|
||||||
self._jit_forward_graph(num_runs)
|
self._jit_forward_graph(num_runs)
|
||||||
|
|
||||||
|
def run_compile_forward(self, num_runs, print_per_iter=False, cuda_sync=False):
|
||||||
|
"""Run the forward path of an op with compile mode"""
|
||||||
|
if self._compile_forward_graph is None:
|
||||||
|
self._compile_forward_graph = self._generate_compile_forward_graph()
|
||||||
|
self._compile_forward_graph(num_runs)
|
||||||
|
|
||||||
def _print_per_iter(self):
|
def _print_per_iter(self):
|
||||||
# print last 50 values
|
# print last 50 values
|
||||||
length = min(len(self.time_series), 50)
|
length = min(len(self.time_series), 50)
|
||||||
@ -150,14 +183,14 @@ class PyTorchOperatorTestCase:
|
|||||||
if print_per_iter:
|
if print_per_iter:
|
||||||
for _ in range(num_runs):
|
for _ in range(num_runs):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
self.output = self.op_bench.forward_impl()
|
self.output = self.op_bench.forward_impl_eager()
|
||||||
if cuda_sync:
|
if cuda_sync:
|
||||||
torch.cuda.synchronize(torch.cuda.current_device())
|
torch.cuda.synchronize(torch.cuda.current_device())
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
self.time_series.append((end_time - start_time) * 1e3)
|
self.time_series.append((end_time - start_time) * 1e3)
|
||||||
else:
|
else:
|
||||||
for _ in range(num_runs):
|
for _ in range(num_runs):
|
||||||
self.output = self.op_bench.forward_impl()
|
self.output = self.op_bench.forward_impl_eager()
|
||||||
if cuda_sync:
|
if cuda_sync:
|
||||||
torch.cuda.synchronize(torch.cuda.current_device())
|
torch.cuda.synchronize(torch.cuda.current_device())
|
||||||
|
|
||||||
|
@ -62,6 +62,13 @@ def parse_args():
|
|||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--benchmark-name",
|
||||||
|
"--benchmark_name",
|
||||||
|
help="Name of the benchmark to store results to",
|
||||||
|
default="PyTorch operator benchmark",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--list-tests",
|
"--list-tests",
|
||||||
"--list_tests",
|
"--list_tests",
|
||||||
@ -135,6 +142,16 @@ def parse_args():
|
|||||||
help="Run operators with PyTorch JIT mode",
|
help="Run operators with PyTorch JIT mode",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-compile",
|
||||||
|
"--use_compile",
|
||||||
|
type=benchmark_utils.str2bool,
|
||||||
|
nargs="?",
|
||||||
|
const=True,
|
||||||
|
default=False,
|
||||||
|
help="Run operators with PyTorch Compile mode",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--forward-only",
|
"--forward-only",
|
||||||
"--forward_only",
|
"--forward_only",
|
||||||
@ -162,7 +179,7 @@ def parse_args():
|
|||||||
"--output-json-for-dashboard",
|
"--output-json-for-dashboard",
|
||||||
"--output_json_for_dashboard",
|
"--output_json_for_dashboard",
|
||||||
help="Save results in JSON format for display on the OSS dashboard",
|
help="Save results in JSON format for display on the OSS dashboard",
|
||||||
default="False",
|
default="benchmark-results.json",
|
||||||
)
|
)
|
||||||
|
|
||||||
args, _ = parser.parse_known_args()
|
args, _ = parser.parse_known_args()
|
||||||
|
Reference in New Issue
Block a user