mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 04:04:57 +08:00 
			
		
		
		
	Add simple add op based framework overhead benchmark. (#23076)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23076 Tracing based and non tracing based added Reviewed By: mingzhe09088 Differential Revision: D16097280 fbshipit-source-id: 3a137092f7ccc3dd2d29d95e10178ec89d3ce892
This commit is contained in:
		
				
					committed by
					
						 Facebook Github Bot
						Facebook Github Bot
					
				
			
			
				
	
			
			
			
						parent
						
							4223e2f9e9
						
					
				
				
					commit
					0621068cdc
				
			
							
								
								
									
										17
									
								
								benchmarks/framework_overhead_benchmark/SimpleAddModule.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								benchmarks/framework_overhead_benchmark/SimpleAddModule.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,17 @@ | |||||||
|  | from __future__ import absolute_import, division, print_function, unicode_literals | ||||||
|  | import torch | ||||||
|  | from utils import NUM_PT_LOOP_ITERS | ||||||
|  |  | ||||||
|  | def add_tensors_loop(x, y): | ||||||
|  |     z = torch.add(x, y) | ||||||
|  |     for i in range(NUM_PT_LOOP_ITERS): | ||||||
|  |         z = torch.add(z, x) | ||||||
|  |     return z | ||||||
|  |  | ||||||
|  | class SimpleAddModule(torch.nn.Module): | ||||||
|  |     def __init__(self, add_op): | ||||||
|  |         super(SimpleAddModule, self).__init__() | ||||||
|  |         self.add_op = add_op | ||||||
|  |  | ||||||
|  |     def forward(self, x, y): | ||||||
|  |         return self.add_op(x, y) | ||||||
| @ -0,0 +1,80 @@ | |||||||
|  | from __future__ import absolute_import, division, print_function, unicode_literals | ||||||
|  | from utils import ms_to_us, benchmark_module, BenchmarkConfig, ModuleConfig | ||||||
|  | import argparse | ||||||
|  |  | ||||||
|  | from SimpleAddModule import SimpleAddModule, add_tensors_loop | ||||||
|  | from pt_wrapper_module import WrapperModule | ||||||
|  |  | ||||||
|  | """ Framework overhead benchmark script. | ||||||
|  | Benchmark framework overhead. | ||||||
|  | Currently supported ops: add. | ||||||
|  | As of now runs only forward pass. | ||||||
|  | Supports both graph mode and eager mode. In graph mode the module is traced via JIT tracing. | ||||||
|  | Debug option prints the traced graph is graph_mode is enabled. | ||||||
|  | Graph can be saved via save option. Saved in the directory where benchmark is run. | ||||||
|  | Example build/run: | ||||||
|  | buck run @mode/opt <path-to-framework_overhead_benchmark>:framework_overhead_benchmark -- | ||||||
|  |  --add_op --graph_mode --eager_mode (Runs both graph mode and eager mode) | ||||||
|  | buck run @mode/opt <path-to-framework_overhead_benchmark>:framework_overhead_benchmark -- | ||||||
|  |  --add_op --graph_mode (Runs only graph mode) | ||||||
|  | """ | ||||||
|  |  | ||||||
|  | SUPPORTED_OPS = {"add_op"} | ||||||
|  |  | ||||||
|  | def parse_op_args(op): | ||||||
|  |     op_list = ops.split(",") | ||||||
|  |  | ||||||
|  | def print_results(result): | ||||||
|  |     print("===================================") | ||||||
|  |     for key, value in result.items(): | ||||||
|  |         print("{}, latency per iter (us):{}".format(key, ms_to_us(value))) | ||||||
|  |     print("===================================") | ||||||
|  |  | ||||||
|  | def benchmark_simple_fn(args, config, module_config, module_type, result): | ||||||
|  |     """ Benchmarks a PyTorch traceable function specified in the config. | ||||||
|  |     Instantiates a wrapper object that wraps the object of module_type and runs the forward | ||||||
|  |     method using benchmark_module. | ||||||
|  |     Args: | ||||||
|  |         config:         contains number of warmup and benchmark iterations. | ||||||
|  |         module_config:  module_config which contains op, number of parameters that op takes | ||||||
|  |                     and wether graph mode is enabled or not. | ||||||
|  |         module_type:    Type of the module to be wrapped. e.g. SimpleAddModule for add op. | ||||||
|  |         result:         dictionary instance to be populated with the benchmark result (latency per iter). | ||||||
|  |     """ | ||||||
|  |     print("Benchmarking {}".format(module_type.__name__)) | ||||||
|  |     f_name = module_config.pt_fn.__name__ + ":Num Operands=" + str(module_config.num_params) | ||||||
|  |     graph_mode_str = "Graph mode" + ":" + str(module_config.graph_mode) | ||||||
|  |     result_key = ','.join((f_name, graph_mode_str)) | ||||||
|  |     module = WrapperModule(module_type, module_config, args.debug, args.save) | ||||||
|  |     latency_per_iter_ms = benchmark_module(config, module) | ||||||
|  |     result[result_key] = latency_per_iter_ms | ||||||
|  |  | ||||||
|  | def main(): | ||||||
|  |     parser = argparse.ArgumentParser() | ||||||
|  |     parser.add_argument("--op", default="add_op", dest="op", type=str) | ||||||
|  |     parser.add_argument("--debug", default=False, dest="debug", action="store_true") | ||||||
|  |     parser.add_argument("--save", default=False, dest="save", action="store_true") | ||||||
|  |     parser.add_argument("--eager_mode", default=False, dest="eager_mode", action="store_true") | ||||||
|  |     parser.add_argument("--num_warmup_iters", type=int, default=100) | ||||||
|  |     parser.add_argument("--num_iters", type=int, default=1000) | ||||||
|  |     args = parser.parse_args() | ||||||
|  |  | ||||||
|  |     if args.op not in SUPPORTED_OPS: | ||||||
|  |         print("Op {} is not supported: Supported ops are:{}".format(args.op, SUPPORTED_OPS)) | ||||||
|  |         return | ||||||
|  |  | ||||||
|  |     num_warmup_iters = args.num_warmup_iters | ||||||
|  |     num_iters = args.num_iters | ||||||
|  |     config = BenchmarkConfig(num_warmup_iters, num_iters) | ||||||
|  |     graph_mode = True | ||||||
|  |     if args.eager_mode: | ||||||
|  |         graph_mode = False | ||||||
|  |     result = {} | ||||||
|  |     if args.op == "add_op": | ||||||
|  |         num_params = 2 | ||||||
|  |         module_config = ModuleConfig(add_tensors_loop, num_params, graph_mode) | ||||||
|  |         benchmark_simple_fn(args, config, module_config, SimpleAddModule, result) | ||||||
|  |     print_results(result) | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     main() | ||||||
							
								
								
									
										44
									
								
								benchmarks/framework_overhead_benchmark/pt_wrapper_module.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								benchmarks/framework_overhead_benchmark/pt_wrapper_module.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,44 @@ | |||||||
|  | from __future__ import absolute_import, division, print_function, unicode_literals | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  | class WrapperModule(object): | ||||||
|  |     """ Wraps the instance of wrapped_type. | ||||||
|  |     For graph_mode traces the instance of wrapped_type. | ||||||
|  |     Randomaly initializes num_params tensors with single float element. | ||||||
|  |     Args: | ||||||
|  |         wrapped_type: | ||||||
|  |             - Object type to be wrapped. | ||||||
|  |                 Expects the wrapped_type to: | ||||||
|  |                    - be constructed with pt_fn specified in module_config. | ||||||
|  |                    - provide forward method that takes module_config.num_params args. | ||||||
|  |         module_config: | ||||||
|  |             - Specified pt_fn to construct wrapped_type with, whether graph_mode | ||||||
|  |               is enabled, and number of parameters wrapped_type's forward method | ||||||
|  |               takes. | ||||||
|  |         debug: | ||||||
|  |             - Whether debug mode is enabled. | ||||||
|  |         save: | ||||||
|  |             - In graph mode, whether graph is to be saved. | ||||||
|  |     """ | ||||||
|  |     def __init__(self, wrapped_type, module_config, debug, save=False): | ||||||
|  |         pt_fn = module_config.pt_fn | ||||||
|  |         self.module = wrapped_type(pt_fn) | ||||||
|  |         self.tensor_inputs = [] | ||||||
|  |         self.module_name = wrapped_type.__name__ | ||||||
|  |         for _ in range(module_config.num_params): | ||||||
|  |             self.tensor_inputs.append(torch.randn(1)) | ||||||
|  |         if module_config.graph_mode: | ||||||
|  |             self.module = torch.jit.trace(self.module, self.tensor_inputs) | ||||||
|  |             if save: | ||||||
|  |                 file_name = self.module_name + "_" + pt_fn.__name__ + ".pt" | ||||||
|  |                 torch.jit.save(self.module, file_name) | ||||||
|  |                 print("Generated graph is saved in {}".format(file_name)) | ||||||
|  |         print("Benchmarking module {} with fn {}: Graph mode:{}".format(self.module_name, pt_fn.__name__, module_config.graph_mode)) | ||||||
|  |         if (debug and isinstance(self.module, torch.jit.ScriptModule)): | ||||||
|  |             print(self.module.graph) | ||||||
|  |             print(self.module.code) | ||||||
|  |  | ||||||
|  |     def forward(self, niters): | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             for _ in range(niters): | ||||||
|  |                 self.module.forward(*self.tensor_inputs) | ||||||
							
								
								
									
										25
									
								
								benchmarks/framework_overhead_benchmark/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								benchmarks/framework_overhead_benchmark/utils.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,25 @@ | |||||||
|  | from __future__ import absolute_import, division, print_function, unicode_literals | ||||||
|  | import time | ||||||
|  | from collections import namedtuple | ||||||
|  |  | ||||||
|  | NUM_PT_LOOP_ITERS = 1000 | ||||||
|  | BenchmarkConfig = namedtuple('BenchmarkConfig', 'num_warmup_iters num_iters') | ||||||
|  | ModuleConfig = namedtuple('ModuleConfig', 'pt_fn num_params graph_mode') | ||||||
|  |  | ||||||
|  | def ms_to_us(time_ms): | ||||||
|  |     return (time_ms * 1e3) | ||||||
|  |  | ||||||
|  | def secs_to_us(time_s): | ||||||
|  |     return (time_s * 1e6) | ||||||
|  |  | ||||||
|  | def secs_to_ms(time_s): | ||||||
|  |     return (time_s * 1e3) | ||||||
|  |  | ||||||
|  | def benchmark_module(config, module): | ||||||
|  |     module.forward(config.num_warmup_iters) | ||||||
|  |     print("Running module for {} iterations".format(config.num_iters)) | ||||||
|  |     start = time.time() | ||||||
|  |     module.forward(config.num_iters) | ||||||
|  |     end = time.time() | ||||||
|  |     time_elapsed_s = (end - start) | ||||||
|  |     return (secs_to_ms(time_elapsed_s) / config.num_iters / NUM_PT_LOOP_ITERS) | ||||||
		Reference in New Issue
	
	Block a user