Add simple add op based framework overhead benchmark. (#23076)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23076

Tracing based and non tracing based added

Reviewed By: mingzhe09088

Differential Revision: D16097280

fbshipit-source-id: 3a137092f7ccc3dd2d29d95e10178ec89d3ce892
This commit is contained in:
Kimish Patel
2019-07-22 11:18:18 -07:00
committed by Facebook Github Bot
parent 4223e2f9e9
commit 0621068cdc
4 changed files with 166 additions and 0 deletions

View File

@ -0,0 +1,17 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
from utils import NUM_PT_LOOP_ITERS
def add_tensors_loop(x, y):
z = torch.add(x, y)
for i in range(NUM_PT_LOOP_ITERS):
z = torch.add(z, x)
return z
class SimpleAddModule(torch.nn.Module):
def __init__(self, add_op):
super(SimpleAddModule, self).__init__()
self.add_op = add_op
def forward(self, x, y):
return self.add_op(x, y)

View File

@ -0,0 +1,80 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from utils import ms_to_us, benchmark_module, BenchmarkConfig, ModuleConfig
import argparse
from SimpleAddModule import SimpleAddModule, add_tensors_loop
from pt_wrapper_module import WrapperModule
""" Framework overhead benchmark script.
Benchmark framework overhead.
Currently supported ops: add.
As of now runs only forward pass.
Supports both graph mode and eager mode. In graph mode the module is traced via JIT tracing.
Debug option prints the traced graph is graph_mode is enabled.
Graph can be saved via save option. Saved in the directory where benchmark is run.
Example build/run:
buck run @mode/opt <path-to-framework_overhead_benchmark>:framework_overhead_benchmark --
--add_op --graph_mode --eager_mode (Runs both graph mode and eager mode)
buck run @mode/opt <path-to-framework_overhead_benchmark>:framework_overhead_benchmark --
--add_op --graph_mode (Runs only graph mode)
"""
SUPPORTED_OPS = {"add_op"}
def parse_op_args(op):
op_list = ops.split(",")
def print_results(result):
print("===================================")
for key, value in result.items():
print("{}, latency per iter (us):{}".format(key, ms_to_us(value)))
print("===================================")
def benchmark_simple_fn(args, config, module_config, module_type, result):
""" Benchmarks a PyTorch traceable function specified in the config.
Instantiates a wrapper object that wraps the object of module_type and runs the forward
method using benchmark_module.
Args:
config: contains number of warmup and benchmark iterations.
module_config: module_config which contains op, number of parameters that op takes
and wether graph mode is enabled or not.
module_type: Type of the module to be wrapped. e.g. SimpleAddModule for add op.
result: dictionary instance to be populated with the benchmark result (latency per iter).
"""
print("Benchmarking {}".format(module_type.__name__))
f_name = module_config.pt_fn.__name__ + ":Num Operands=" + str(module_config.num_params)
graph_mode_str = "Graph mode" + ":" + str(module_config.graph_mode)
result_key = ','.join((f_name, graph_mode_str))
module = WrapperModule(module_type, module_config, args.debug, args.save)
latency_per_iter_ms = benchmark_module(config, module)
result[result_key] = latency_per_iter_ms
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--op", default="add_op", dest="op", type=str)
parser.add_argument("--debug", default=False, dest="debug", action="store_true")
parser.add_argument("--save", default=False, dest="save", action="store_true")
parser.add_argument("--eager_mode", default=False, dest="eager_mode", action="store_true")
parser.add_argument("--num_warmup_iters", type=int, default=100)
parser.add_argument("--num_iters", type=int, default=1000)
args = parser.parse_args()
if args.op not in SUPPORTED_OPS:
print("Op {} is not supported: Supported ops are:{}".format(args.op, SUPPORTED_OPS))
return
num_warmup_iters = args.num_warmup_iters
num_iters = args.num_iters
config = BenchmarkConfig(num_warmup_iters, num_iters)
graph_mode = True
if args.eager_mode:
graph_mode = False
result = {}
if args.op == "add_op":
num_params = 2
module_config = ModuleConfig(add_tensors_loop, num_params, graph_mode)
benchmark_simple_fn(args, config, module_config, SimpleAddModule, result)
print_results(result)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,44 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
class WrapperModule(object):
""" Wraps the instance of wrapped_type.
For graph_mode traces the instance of wrapped_type.
Randomaly initializes num_params tensors with single float element.
Args:
wrapped_type:
- Object type to be wrapped.
Expects the wrapped_type to:
- be constructed with pt_fn specified in module_config.
- provide forward method that takes module_config.num_params args.
module_config:
- Specified pt_fn to construct wrapped_type with, whether graph_mode
is enabled, and number of parameters wrapped_type's forward method
takes.
debug:
- Whether debug mode is enabled.
save:
- In graph mode, whether graph is to be saved.
"""
def __init__(self, wrapped_type, module_config, debug, save=False):
pt_fn = module_config.pt_fn
self.module = wrapped_type(pt_fn)
self.tensor_inputs = []
self.module_name = wrapped_type.__name__
for _ in range(module_config.num_params):
self.tensor_inputs.append(torch.randn(1))
if module_config.graph_mode:
self.module = torch.jit.trace(self.module, self.tensor_inputs)
if save:
file_name = self.module_name + "_" + pt_fn.__name__ + ".pt"
torch.jit.save(self.module, file_name)
print("Generated graph is saved in {}".format(file_name))
print("Benchmarking module {} with fn {}: Graph mode:{}".format(self.module_name, pt_fn.__name__, module_config.graph_mode))
if (debug and isinstance(self.module, torch.jit.ScriptModule)):
print(self.module.graph)
print(self.module.code)
def forward(self, niters):
with torch.no_grad():
for _ in range(niters):
self.module.forward(*self.tensor_inputs)

View File

@ -0,0 +1,25 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import time
from collections import namedtuple
NUM_PT_LOOP_ITERS = 1000
BenchmarkConfig = namedtuple('BenchmarkConfig', 'num_warmup_iters num_iters')
ModuleConfig = namedtuple('ModuleConfig', 'pt_fn num_params graph_mode')
def ms_to_us(time_ms):
return (time_ms * 1e3)
def secs_to_us(time_s):
return (time_s * 1e6)
def secs_to_ms(time_s):
return (time_s * 1e3)
def benchmark_module(config, module):
module.forward(config.num_warmup_iters)
print("Running module for {} iterations".format(config.num_iters))
start = time.time()
module.forward(config.num_iters)
end = time.time()
time_elapsed_s = (end - start)
return (secs_to_ms(time_elapsed_s) / config.num_iters / NUM_PT_LOOP_ITERS)