mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[codemod][lint][fbcode/c*] Enable BLACK by default
Test Plan: manual inspection & sandcastle Reviewed By: zertosh Differential Revision: D30279364 fbshipit-source-id: c1ed77dfe43a3bde358f92737cd5535ae5d13c9a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
aac3c7bd06
commit
b004307252
@ -3,13 +3,13 @@
|
||||
import argparse
|
||||
import ast
|
||||
|
||||
from caffe2.python import workspace, brew
|
||||
from caffe2.python.model_helper import ModelHelper
|
||||
from caffe2.python.predictor import mobile_exporter
|
||||
from caffe2.python import workspace, brew
|
||||
|
||||
|
||||
def parse_kwarg(kwarg_str):
|
||||
key, value = kwarg_str.split('=')
|
||||
key, value = kwarg_str.split("=")
|
||||
try:
|
||||
value = ast.literal_eval(value)
|
||||
except ValueError:
|
||||
@ -30,7 +30,7 @@ def main(args):
|
||||
|
||||
iters = int(args.instances)
|
||||
for i in range(iters):
|
||||
input_blob_name = input_name + (str(i) if i > 0 and args.chain else '')
|
||||
input_blob_name = input_name + (str(i) if i > 0 and args.chain else "")
|
||||
output_blob_name = output_name + str(i + 1)
|
||||
add_op = getattr(brew, op_type)
|
||||
add_op(model, input_blob_name, output_blob_name, **kwargs)
|
||||
@ -39,9 +39,7 @@ def main(args):
|
||||
|
||||
workspace.RunNetOnce(model.param_init_net)
|
||||
|
||||
init_net, predict_net = mobile_exporter.Export(
|
||||
workspace, model.net, model.params
|
||||
)
|
||||
init_net, predict_net = mobile_exporter.Export(workspace, model.net, model.params)
|
||||
|
||||
if args.debug:
|
||||
print("init_net:")
|
||||
@ -51,40 +49,55 @@ def main(args):
|
||||
for op in predict_net.op:
|
||||
print(" ", op.type, op.input, "-->", op.output)
|
||||
|
||||
with open(args.predict_net, 'wb') as f:
|
||||
with open(args.predict_net, "wb") as f:
|
||||
f.write(predict_net.SerializeToString())
|
||||
with open(args.init_net, 'wb') as f:
|
||||
with open(args.init_net, "wb") as f:
|
||||
f.write(init_net.SerializeToString())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Utilitity to generate Caffe2 benchmark models.")
|
||||
description="Utilitity to generate Caffe2 benchmark models."
|
||||
)
|
||||
parser.add_argument("operator", help="Caffe2 operator to benchmark.")
|
||||
parser.add_argument("-b", "--blob",
|
||||
help="Instantiate a blob --blob name=dim1,dim2,dim3",
|
||||
action='append')
|
||||
parser.add_argument(
|
||||
"-b",
|
||||
"--blob",
|
||||
help="Instantiate a blob --blob name=dim1,dim2,dim3",
|
||||
action="append",
|
||||
)
|
||||
parser.add_argument("--context", help="Context to run on.", default="CPU")
|
||||
parser.add_argument("--kwargs", help="kwargs to pass to operator.",
|
||||
nargs="*", type=parse_kwarg, default=[])
|
||||
parser.add_argument("--init_net", help="Output initialization net.",
|
||||
default="init_net.pb")
|
||||
parser.add_argument("--predict_net", help="Output prediction net.",
|
||||
default="predict_net.pb")
|
||||
parser.add_argument("--benchmark_name",
|
||||
help="Name of the benchmark network",
|
||||
default="benchmark")
|
||||
parser.add_argument("--input_name", help="Name of the input blob.",
|
||||
default="data")
|
||||
parser.add_argument("--output_name", help="Name of the output blob.",
|
||||
default="output")
|
||||
parser.add_argument("--instances",
|
||||
help="Number of instances to run the operator.",
|
||||
default="1")
|
||||
parser.add_argument("-d", "--debug", help="Print debug information.",
|
||||
action='store_true')
|
||||
parser.add_argument("-c", "--chain",
|
||||
help="Chain ops together (create data dependencies)",
|
||||
action='store_true')
|
||||
parser.add_argument(
|
||||
"--kwargs",
|
||||
help="kwargs to pass to operator.",
|
||||
nargs="*",
|
||||
type=parse_kwarg,
|
||||
default=[],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init_net", help="Output initialization net.", default="init_net.pb"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--predict_net", help="Output prediction net.", default="predict_net.pb"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--benchmark_name", help="Name of the benchmark network", default="benchmark"
|
||||
)
|
||||
parser.add_argument("--input_name", help="Name of the input blob.", default="data")
|
||||
parser.add_argument(
|
||||
"--output_name", help="Name of the output blob.", default="output"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instances", help="Number of instances to run the operator.", default="1"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d", "--debug", help="Print debug information.", action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--chain",
|
||||
help="Chain ops together (create data dependencies)",
|
||||
action="store_true",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
Reference in New Issue
Block a user