diff --git a/benchmarks/operator_benchmark/benchmark_core.py b/benchmarks/operator_benchmark/benchmark_core.py index cb836bb5eaa4..0b7fcf4e555f 100644 --- a/benchmarks/operator_benchmark/benchmark_core.py +++ b/benchmarks/operator_benchmark/benchmark_core.py @@ -4,6 +4,7 @@ import csv import functools import json import os +import platform import timeit from collections import namedtuple from dataclasses import asdict, dataclass @@ -191,6 +192,11 @@ class BenchmarkRunner: self.predefined_minimum_secs = 1 self.max_iters = 1e6 self.use_jit = args.use_jit + self.use_compile = args.use_compile + if self.use_jit and self.use_compile: + raise ValueError( + "use_jit and use_compile are mutually exclusive, please specify one." + ) self.num_runs = args.num_runs self.print_per_iter = False self.output_csv = args.output_csv @@ -222,7 +228,7 @@ class BenchmarkRunner: if self.args.operators: print(f"# {self.args.operators}") - def _print_perf_result(self, reported_run_time_us, test_case): + def _print_perf_result(self, results, test_case): if self.args.report_aibench: # Output for AIBench # Print out per iteration execution time instead of avg time @@ -236,12 +242,14 @@ class BenchmarkRunner: "type": test_name, "metric": "latency", "unit": "us", - "value": str(reported_run_time_us[run]), + "value": str(results["reported_run_time_us"[run]]), } ) ) else: - print(f"# Mode: {'JIT' if self.use_jit else 'Eager'}") + print( + f"# Mode: {'JIT' if self.use_jit else 'Compile' if self.use_compile else 'Eager'}" + ) print( f"# Name: {test_case.test_config.test_name}\n# Input: {test_case.test_config.input_config}" ) @@ -250,25 +258,33 @@ class BenchmarkRunner: if self.num_runs > 1: for run in range(self.num_runs): print( - f"Run: {run}, {mode} Execution Time (us) : {reported_run_time_us[run]:.3f}" + f"Run: {run}, {mode} Execution Time (us) : {results['reported_run_time_us'][run]:.3f}" ) print() else: - print(f"{mode} Execution Time (us) : {reported_run_time_us[0]:.3f}\n") + print( + f"{mode} Execution Time (us) : {results['reported_run_time_us'][0]:.3f}" + ) + print(f"Peak Memory (KB) : {results['peak_memory']}\n") - def _perf_result_to_dict(self, reported_run_time_us, test_case): + def _perf_result_to_dict(self, results, test_case): """This function is the parallel of _print_perf_result, which instead of writing information to terminal, returns a dictionary. """ if self.args.report_aibench: return {} + out = { "test_name": test_case.test_config.test_name, "input_config": test_case.test_config.input_config, - "mode": "JIT" if self.use_jit else "Eager", + "runtime": ( + "JIT" if self.use_jit else "Compile" if self.use_compile else "Eager" + ), "run": "Backward" if test_case.test_config.run_backward else "Forward", - "latency": round(reported_run_time_us[0], 3), + "latency": round(results["reported_run_time_us"][0], 3), "latency unit": "us", + "peak memory": results["peak_memory"], + "memory unit": "KB", } # parsing test_case.test_config.input_config, adding it as entries to the 'out' dictionary @@ -330,6 +346,8 @@ class BenchmarkRunner: func = test_case.run_forward if self.use_jit: func = test_case.run_jit_forward + if self.use_compile: + func = test_case.run_compile_forward forward_time = timeit.timeit( functools.partial(func, iters, print_per_iter, cuda_sync), number=1 ) @@ -346,7 +364,7 @@ class BenchmarkRunner: ) return backward_time - def _measure_time(self, launch_test, test_case, iters, print_per_iter): + def _measure_metrics(self, launch_test, test_case, iters, print_per_iter): """ This function execute the operator for iterations then look at the time. If it's not significant, the number of iterations will be increased before rerun. @@ -354,8 +372,20 @@ class BenchmarkRunner: """ curr_test_total_time = 0 time_trace = [] + peak_memory = 0 + sample_input = next(iter(test_case.op_bench.inputs.values())) + device = sample_input.device + device_module = torch.get_device_module(device.type) + # TODO: add support for cpu memory measurement while True: + if hasattr(device_module, "reset_peak_memory_stats"): + device_module.reset_peak_memory_stats(device) run_time_sec = launch_test(test_case, iters, print_per_iter) + if hasattr(device_module, "synchronize"): + device_module.synchronize(device) + # Memory measurement process + if hasattr(device_module, "max_memory_allocated"): + peak_memory = device_module.max_memory_allocated(device) curr_test_total_time += run_time_sec # Analyze time after each run to decide if the result is stable results_are_significant = self._iteration_result_is_significant( @@ -369,7 +399,13 @@ class BenchmarkRunner: time_trace.append(report_run_time) # Print out the time spent in each epoch in ms if self.args.report_aibench: - mode = "JIT" if self.use_jit else "Eager" + mode = ( + "JIT" + if self.use_jit + else "Compile" + if self.use_compile + else "Eager" + ) test_name = "_".join( [test_case.framework, test_case.test_config.test_name, mode] ) @@ -381,7 +417,7 @@ class BenchmarkRunner: "metric": "latency", "unit": "ms", "value": str(report_run_time / 1e3), - } + }, ) ) if results_are_significant: @@ -391,7 +427,7 @@ class BenchmarkRunner: # iteration count, and run the benchmark again... iters = self._predict_num_iter_needed(iters) reported_run_time_us = np.percentile(np.array(time_trace), 50) - return reported_run_time_us + return reported_run_time_us, peak_memory / 1024 def _check_keep(self, test_flag, cmd_flag): return cmd_flag is None or test_flag == cmd_flag @@ -478,6 +514,7 @@ class BenchmarkRunner: self, perf_list, output_file, + benchmark_name="PyTorch operator benchmark", ): """ Write the result into JSON format, so that it can be uploaded to the benchmark database @@ -495,8 +532,10 @@ class BenchmarkRunner: input_config = perf_item.get("input_config", "") run_type = perf_item.get("run") latency = perf_item.get("latency", 0) - - dtype = "float32" # default + peak_memory = perf_item.get("peak memory", 0) + device = perf_item.get("device", "unknown") + dtype = perf_item.get("dtype", "torch.float").split(".")[1] + runtime = perf_item.get("runtime", None) # Extract mode based on run_type mode = None @@ -505,6 +544,22 @@ class BenchmarkRunner: elif run_type == "Backward": mode = "training" + # Extract use_compile from it + if runtime == "Compile": + use_compile = True + elif runtime == "Eager": + use_compile = False + else: + use_compile = None + + device_arch = ( + torch.cuda.get_device_name(0) + if device == "cuda" + else platform.processor() + if device == "cpu" + else "unknown" + ) + # Create the record @dataclass class BenchmarkInfo: @@ -532,12 +587,18 @@ class BenchmarkRunner: model: ModelInfo metric: MetricInfo - record = BenchmarkRecord( + # Add record for latency + record_latency = BenchmarkRecord( benchmark=BenchmarkInfo( - name="PyTorch operator benchmark", + name=benchmark_name, mode=mode, dtype=dtype, - extra_info={"input_config": input_config}, + extra_info={ + "input_config": input_config, + "device": device, + "arch": device_arch, + "use_compile": use_compile, + }, ), model=ModelInfo( name=test_name, type="micro-benchmark", origins=["pytorch"] @@ -549,8 +610,17 @@ class BenchmarkRunner: target_value=None, ), ) + records.append(asdict(record_latency)) - records.append(asdict(record)) + # Add record for peak memory + record_memory = copy.deepcopy(record_latency) + record_memory.metric = MetricInfo( + name="peak memory", + unit="KB", + benchmark_values=[peak_memory], + target_value=None, + ) + records.append(asdict(record_memory)) # Write all records to the output file with open(output_file, "w", encoding="utf-8") as f: @@ -566,6 +636,7 @@ class BenchmarkRunner: "tag", "run_backward", "Execution Time", + "Peak Memory (KB)", ] if self.args.output_json or self.args.output_json_for_dashboard: @@ -603,13 +674,16 @@ class BenchmarkRunner: test_case, self.args.warmup_iterations, print_per_iter=False ) # Actual Execution - reported_time = [ - self._measure_time( + results = [ + self._measure_metrics( launch_func, test_case, self.iters, self.print_per_iter ) for _ in range(self.num_runs) ] - self._print_perf_result(reported_time, test_case) + result_dict = dict() + result_dict["reported_run_time_us"] = [r[0] for r in results] + result_dict["peak_memory"] = results[0][1] + self._print_perf_result(results=result_dict, test_case=test_case) # output results to csv self._output_csv( @@ -625,16 +699,17 @@ class BenchmarkRunner: ), test_case.test_config.tag, test_case.test_config.run_backward, - reported_time[0], + result_dict["reported_run_time_us"][0], + result_dict["peak_memory"], ], ) if self.args.output_json or self.args.output_json_for_dashboard: - perf_list.append( - self._perf_result_to_dict(reported_time, test_case) - ) + perf_list.append(self._perf_result_to_dict(result_dict, test_case)) if self.args.output_json_for_dashboard: - self._output_json(perf_list, self.args.output_json_for_dashboard) + self._output_json( + perf_list, self.args.output_json_for_dashboard, self.args.benchmark_name + ) if self.args.output_json: with open(self.args.output_json, "w") as f: diff --git a/benchmarks/operator_benchmark/benchmark_pytorch.py b/benchmarks/operator_benchmark/benchmark_pytorch.py index 52ae47047daa..a7ff40ebb340 100644 --- a/benchmarks/operator_benchmark/benchmark_pytorch.py +++ b/benchmarks/operator_benchmark/benchmark_pytorch.py @@ -4,6 +4,15 @@ import time import torch +# Import the C++ extension to register the _consume operator +try: + import benchmark_cpp_extension # noqa: F401 +except ImportError as err: + # If the extension isn't built, the script must raise an error + raise ImportError( + "Failed to import C++ extension, please build it using \ncd pt_extension \npython -m pip install ." + ) from err + """PyTorch performance microbenchmarks. This module contains PyTorch-specific functionalities for performance @@ -71,6 +80,16 @@ class TorchBenchmarkBase(torch.nn.Module): for _ in range(iters): torch.ops.operator_benchmark._consume(self.forward_impl()) + def forward_impl_eager(self): + # This is to supply the inputs to the forward function which + # will be called in both the eager and compile mode of local runs + return self.forward(*self.get_inputs()) + + def forward_consume_eager(self, iters: int): + # Eager version of forward_consume without decorators (compilation handled by torch.compile) + for _ in range(iters): + torch.ops.operator_benchmark._consume(self.forward_impl_eager()) + def module_name(self): """this is used to label the operator being benchmarked""" if self.user_given_name: @@ -117,18 +136,32 @@ class PyTorchOperatorTestCase: self.framework = "PyTorch" self.time_series = [] self._jit_forward_graph = None + self._compile_forward_graph = None def _generate_jit_forward_graph(self): """generate a graph for the forward function via scripting""" scripted_op_bench = torch.jit.script(self.op_bench) return scripted_op_bench.forward_consume + def _generate_compile_forward_graph(self): + """generate a compiled graph for the forward function via torch.compile""" + compiled_forward_consume = torch.compile( + self.op_bench.forward_consume_eager, backend="inductor" + ) + return compiled_forward_consume + def run_jit_forward(self, num_runs, print_per_iter=False, cuda_sync=False): """Run the forward path of an op with JIT mode""" if self._jit_forward_graph is None: self._jit_forward_graph = self._generate_jit_forward_graph() self._jit_forward_graph(num_runs) + def run_compile_forward(self, num_runs, print_per_iter=False, cuda_sync=False): + """Run the forward path of an op with compile mode""" + if self._compile_forward_graph is None: + self._compile_forward_graph = self._generate_compile_forward_graph() + self._compile_forward_graph(num_runs) + def _print_per_iter(self): # print last 50 values length = min(len(self.time_series), 50) @@ -150,14 +183,14 @@ class PyTorchOperatorTestCase: if print_per_iter: for _ in range(num_runs): start_time = time.time() - self.output = self.op_bench.forward_impl() + self.output = self.op_bench.forward_impl_eager() if cuda_sync: torch.cuda.synchronize(torch.cuda.current_device()) end_time = time.time() self.time_series.append((end_time - start_time) * 1e3) else: for _ in range(num_runs): - self.output = self.op_bench.forward_impl() + self.output = self.op_bench.forward_impl_eager() if cuda_sync: torch.cuda.synchronize(torch.cuda.current_device()) diff --git a/benchmarks/operator_benchmark/benchmark_runner.py b/benchmarks/operator_benchmark/benchmark_runner.py index 9dfab781498e..6568cf9bf3ee 100644 --- a/benchmarks/operator_benchmark/benchmark_runner.py +++ b/benchmarks/operator_benchmark/benchmark_runner.py @@ -62,6 +62,13 @@ def parse_args(): default=None, ) + parser.add_argument( + "--benchmark-name", + "--benchmark_name", + help="Name of the benchmark to store results to", + default="PyTorch operator benchmark", + ) + parser.add_argument( "--list-tests", "--list_tests", @@ -135,6 +142,16 @@ def parse_args(): help="Run operators with PyTorch JIT mode", ) + parser.add_argument( + "--use-compile", + "--use_compile", + type=benchmark_utils.str2bool, + nargs="?", + const=True, + default=False, + help="Run operators with PyTorch Compile mode", + ) + parser.add_argument( "--forward-only", "--forward_only", @@ -162,7 +179,7 @@ def parse_args(): "--output-json-for-dashboard", "--output_json_for_dashboard", help="Save results in JSON format for display on the OSS dashboard", - default="False", + default="benchmark-results.json", ) args, _ = parser.parse_known_args()