mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Initialize Caffe2 only when running Caffe2 benchmarks (#19980)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19980 ghimport-source-id: ca31ca25b88a1c6219e4a32483f70738a8fdbf88 Differential Revision: D15229797 Pulled By: ilia-cher fbshipit-source-id: 0b23dbdba0c0f60932a75d8b1900c54285f5a8e4
This commit is contained in:
committed by
Facebook Github Bot
parent
0c7e98b765
commit
8c97f0b19e
@ -142,8 +142,10 @@ class BenchmarkRunner(object):
|
||||
# Currently, this is a sub-string matching.
|
||||
if self.args.operator and (self.args.operator not in full_test_id):
|
||||
continue
|
||||
if self.args.framework and (self.args.framework not in full_test_id):
|
||||
continue
|
||||
if self.args.framework:
|
||||
frameworks = benchmark_utils.get_requested_frameworks(self.args.framework)
|
||||
if all(fr not in full_test_id for fr in frameworks):
|
||||
continue
|
||||
|
||||
# To reduce variance, fix a numpy randseed to the test case,
|
||||
# so that the randomly generated input tensors remain the
|
||||
|
@ -8,7 +8,7 @@ import argparse
|
||||
|
||||
from caffe2.python import workspace
|
||||
|
||||
from operator_benchmark import benchmark_core
|
||||
from operator_benchmark import benchmark_core, benchmark_utils
|
||||
|
||||
"""Performance microbenchmarks's main binary.
|
||||
|
||||
@ -73,13 +73,14 @@ def main():
|
||||
|
||||
parser.add_argument(
|
||||
'--framework',
|
||||
help='Run PyTorch or Caffe2 operators',
|
||||
default=None)
|
||||
help='Comma-delimited list of frameworks to test (Caffe2, PyTorch)',
|
||||
default="Caffe2,PyTorch")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
|
||||
workspace.ClearGlobalNetObserver()
|
||||
if benchmark_utils.is_caffe2_enabled(args.framework):
|
||||
workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
|
||||
workspace.ClearGlobalNetObserver()
|
||||
|
||||
benchmark_core.BenchmarkRunner(args).run()
|
||||
|
||||
|
@ -61,3 +61,12 @@ def generate_configs(**configs):
|
||||
|
||||
results = configs['sample_func'](*result)
|
||||
return results
|
||||
|
||||
def is_caffe2_enabled(framework_arg):
|
||||
return 'Caffe2' in framework_arg
|
||||
|
||||
def is_pytorch_enabled(framework_arg):
|
||||
return 'PyTorch' in framework_arg
|
||||
|
||||
def get_requested_frameworks(framework_arg):
|
||||
return [fr.strip() for fr in framework_arg.split(',') if len(fr.strip()) > 0]
|
||||
|
Reference in New Issue
Block a user