diff --git a/binaries/bench_gen/bench_gen.py b/binaries/bench_gen/bench_gen.py index 048e151c2c02..579aae3e1b92 100644 --- a/binaries/bench_gen/bench_gen.py +++ b/binaries/bench_gen/bench_gen.py @@ -6,6 +6,7 @@ from __future__ import print_function from __future__ import unicode_literals import argparse +import ast from caffe2.python.model_helper import ModelHelper from caffe2.python.predictor import mobile_exporter @@ -15,18 +16,15 @@ from caffe2.python import workspace, brew def parse_kwarg(kwarg_str): key, value = kwarg_str.split('=') try: - value = int(value) + value = ast.literal_eval(value) except ValueError: - try: - value = float(value) - except ValueError: - pass + pass return key, value def main(args): # User defined keyword arguments - kwargs = {"order": "NCHW"} + kwargs = {"order": "NCHW", "use_cudnn": False} kwargs.update(dict(args.kwargs)) model = ModelHelper(name=args.benchmark_name)