mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
make bench_gen.py work for 3d conv (#12433)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12433 To test 3d conv, we need to pass lists in spec argument. We also don't want to set use_cudnn=True which is the default in brew. Reviewed By: llyfacebook, csummersea Differential Revision: D10234315 fbshipit-source-id: 96a39992a97e020d6e9dac103e6d64df0cc1020b
This commit is contained in:
committed by
Facebook Github Bot
parent
00aedfc0e2
commit
f1f521f71b
@ -6,6 +6,7 @@ from __future__ import print_function
|
|||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import ast
|
||||||
|
|
||||||
from caffe2.python.model_helper import ModelHelper
|
from caffe2.python.model_helper import ModelHelper
|
||||||
from caffe2.python.predictor import mobile_exporter
|
from caffe2.python.predictor import mobile_exporter
|
||||||
@ -15,18 +16,15 @@ from caffe2.python import workspace, brew
|
|||||||
def parse_kwarg(kwarg_str):
|
def parse_kwarg(kwarg_str):
|
||||||
key, value = kwarg_str.split('=')
|
key, value = kwarg_str.split('=')
|
||||||
try:
|
try:
|
||||||
value = int(value)
|
value = ast.literal_eval(value)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
try:
|
pass
|
||||||
value = float(value)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
return key, value
|
return key, value
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
# User defined keyword arguments
|
# User defined keyword arguments
|
||||||
kwargs = {"order": "NCHW"}
|
kwargs = {"order": "NCHW", "use_cudnn": False}
|
||||||
kwargs.update(dict(args.kwargs))
|
kwargs.update(dict(args.kwargs))
|
||||||
|
|
||||||
model = ModelHelper(name=args.benchmark_name)
|
model = ModelHelper(name=args.benchmark_name)
|
||||||
|
Reference in New Issue
Block a user