Files
pytorch/binaries/bench_gen/bench_gen.py
Zsolt Dollenstein b004307252 [codemod][lint][fbcode/c*] Enable BLACK by default
Test Plan: manual inspection & sandcastle

Reviewed By: zertosh

Differential Revision: D30279364

fbshipit-source-id: c1ed77dfe43a3bde358f92737cd5535ae5d13c9a
2021-08-12 10:58:35 -07:00

104 lines
3.1 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import ast
from caffe2.python import workspace, brew
from caffe2.python.model_helper import ModelHelper
from caffe2.python.predictor import mobile_exporter
def parse_kwarg(kwarg_str):
key, value = kwarg_str.split("=")
try:
value = ast.literal_eval(value)
except ValueError:
pass
return key, value
def main(args):
# User defined keyword arguments
kwargs = {"order": "NCHW", "use_cudnn": False}
kwargs.update(dict(args.kwargs))
model = ModelHelper(name=args.benchmark_name)
op_type = args.operator # assumes a brew type op name
input_name = args.input_name
output_name = args.output_name
iters = int(args.instances)
for i in range(iters):
input_blob_name = input_name + (str(i) if i > 0 and args.chain else "")
output_blob_name = output_name + str(i + 1)
add_op = getattr(brew, op_type)
add_op(model, input_blob_name, output_blob_name, **kwargs)
if args.chain:
input_name, output_name = output_name, input_name
workspace.RunNetOnce(model.param_init_net)
init_net, predict_net = mobile_exporter.Export(workspace, model.net, model.params)
if args.debug:
print("init_net:")
for op in init_net.op:
print(" ", op.type, op.input, "-->", op.output)
print("predict_net:")
for op in predict_net.op:
print(" ", op.type, op.input, "-->", op.output)
with open(args.predict_net, "wb") as f:
f.write(predict_net.SerializeToString())
with open(args.init_net, "wb") as f:
f.write(init_net.SerializeToString())
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Utilitity to generate Caffe2 benchmark models."
)
parser.add_argument("operator", help="Caffe2 operator to benchmark.")
parser.add_argument(
"-b",
"--blob",
help="Instantiate a blob --blob name=dim1,dim2,dim3",
action="append",
)
parser.add_argument("--context", help="Context to run on.", default="CPU")
parser.add_argument(
"--kwargs",
help="kwargs to pass to operator.",
nargs="*",
type=parse_kwarg,
default=[],
)
parser.add_argument(
"--init_net", help="Output initialization net.", default="init_net.pb"
)
parser.add_argument(
"--predict_net", help="Output prediction net.", default="predict_net.pb"
)
parser.add_argument(
"--benchmark_name", help="Name of the benchmark network", default="benchmark"
)
parser.add_argument("--input_name", help="Name of the input blob.", default="data")
parser.add_argument(
"--output_name", help="Name of the output blob.", default="output"
)
parser.add_argument(
"--instances", help="Number of instances to run the operator.", default="1"
)
parser.add_argument(
"-d", "--debug", help="Print debug information.", action="store_true"
)
parser.add_argument(
"-c",
"--chain",
help="Chain ops together (create data dependencies)",
action="store_true",
)
args = parser.parse_args()
main(args)