Initialize Caffe2 only when running Caffe2 benchmarks (#19980)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19980
ghimport-source-id: ca31ca25b88a1c6219e4a32483f70738a8fdbf88

Differential Revision: D15229797

Pulled By: ilia-cher

fbshipit-source-id: 0b23dbdba0c0f60932a75d8b1900c54285f5a8e4
This commit is contained in:
Ilia Cherniavskii
2019-05-06 19:04:00 -07:00
committed by Facebook Github Bot
parent 0c7e98b765
commit 8c97f0b19e
3 changed files with 19 additions and 7 deletions

View File

@ -142,8 +142,10 @@ class BenchmarkRunner(object):
# Currently, this is a sub-string matching.
if self.args.operator and (self.args.operator not in full_test_id):
continue
if self.args.framework and (self.args.framework not in full_test_id):
continue
if self.args.framework:
frameworks = benchmark_utils.get_requested_frameworks(self.args.framework)
if all(fr not in full_test_id for fr in frameworks):
continue
# To reduce variance, fix a numpy randseed to the test case,
# so that the randomly generated input tensors remain the

View File

@ -8,7 +8,7 @@ import argparse
from caffe2.python import workspace
from operator_benchmark import benchmark_core
from operator_benchmark import benchmark_core, benchmark_utils
"""Performance microbenchmarks's main binary.
@ -73,13 +73,14 @@ def main():
parser.add_argument(
'--framework',
help='Run PyTorch or Caffe2 operators',
default=None)
help='Comma-delimited list of frameworks to test (Caffe2, PyTorch)',
default="Caffe2,PyTorch")
args = parser.parse_args()
workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
workspace.ClearGlobalNetObserver()
if benchmark_utils.is_caffe2_enabled(args.framework):
workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
workspace.ClearGlobalNetObserver()
benchmark_core.BenchmarkRunner(args).run()

View File

@ -61,3 +61,12 @@ def generate_configs(**configs):
results = configs['sample_func'](*result)
return results
def is_caffe2_enabled(framework_arg):
return 'Caffe2' in framework_arg
def is_pytorch_enabled(framework_arg):
return 'PyTorch' in framework_arg
def get_requested_frameworks(framework_arg):
return [fr.strip() for fr in framework_arg.split(',') if len(fr.strip()) > 0]