mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
make bench_gen.py work for 3d conv (#12433)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12433 To test 3d conv, we need to pass lists in spec argument. We also don't want to set use_cudnn=True which is the default in brew. Reviewed By: llyfacebook, csummersea Differential Revision: D10234315 fbshipit-source-id: 96a39992a97e020d6e9dac103e6d64df0cc1020b
This commit is contained in:
committed by
Facebook Github Bot
parent
00aedfc0e2
commit
f1f521f71b
@ -6,6 +6,7 @@ from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
|
||||
from caffe2.python.model_helper import ModelHelper
|
||||
from caffe2.python.predictor import mobile_exporter
|
||||
@ -15,18 +16,15 @@ from caffe2.python import workspace, brew
|
||||
def parse_kwarg(kwarg_str):
|
||||
key, value = kwarg_str.split('=')
|
||||
try:
|
||||
value = int(value)
|
||||
value = ast.literal_eval(value)
|
||||
except ValueError:
|
||||
try:
|
||||
value = float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
pass
|
||||
return key, value
|
||||
|
||||
|
||||
def main(args):
|
||||
# User defined keyword arguments
|
||||
kwargs = {"order": "NCHW"}
|
||||
kwargs = {"order": "NCHW", "use_cudnn": False}
|
||||
kwargs.update(dict(args.kwargs))
|
||||
|
||||
model = ModelHelper(name=args.benchmark_name)
|
||||
|
Reference in New Issue
Block a user