Back out "Revert D15435461: [pytorch][PR] PyTorch ThroughputBenchmark" (#22185)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22185

Original commit changeset: 72a0eac1658b

Differential Revision: D15981928

fbshipit-source-id: d2455d79e81c26ee90d41414cde8ac0f9b703bc3
This commit is contained in:
Alexander Sidorov
2019-06-26 16:01:58 -07:00
committed by Facebook Github Bot
parent 3f2a839dda
commit f51de8b61a
13 changed files with 677 additions and 0 deletions

View File

@ -45,6 +45,7 @@
#endif // C10_USING_CUSTOM_GENERATED_MACROS
#ifdef _WIN32
#define C10_HIDDEN
#if defined(C10_BUILD_SHARED_LIBS)
#define C10_EXPORT __declspec(dllexport)
#define C10_IMPORT __declspec(dllimport)
@ -55,8 +56,10 @@
#else // _WIN32
#if defined(__GNUC__)
#define C10_EXPORT __attribute__((__visibility__("default")))
#define C10_HIDDEN __attribute__((__visibility__("hidden")))
#else // defined(__GNUC__)
#define C10_EXPORT
#define C10_HIDDEN
#endif // defined(__GNUC__)
#define C10_IMPORT C10_EXPORT
#endif // _WIN32

View File

@ -0,0 +1,79 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
from torch.utils import ThroughputBenchmark
from torch.testing import assert_allclose
from common_utils import run_tests, TestCase
class TwoLayerNet(torch.jit.ScriptModule):
def __init__(self, D_in, H, D_out):
super(TwoLayerNet, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(2 * H, D_out)
@torch.jit.script_method
def forward(self, x1, x2):
h1_relu = self.linear1(x1).clamp(min=0)
h2_relu = self.linear1(x2).clamp(min=0)
cat = torch.cat((h1_relu, h2_relu), 1)
y_pred = self.linear2(cat)
return y_pred
class TwoLayerNetModule(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super(TwoLayerNetModule, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(2 * H, D_out)
def forward(self, x1, x2):
h1_relu = self.linear1(x1).clamp(min=0)
h2_relu = self.linear1(x2).clamp(min=0)
cat = torch.cat((h1_relu, h2_relu), 1)
y_pred = self.linear2(cat)
return y_pred
class TestThroughputBenchmark(TestCase):
def linear_test(self, Module):
D_in = 10
H = 5
D_out = 15
B = 8
NUM_INPUTS = 2
module = Module(D_in, H, D_out)
inputs = []
for i in range(NUM_INPUTS):
inputs.append([torch.randn(B, D_in), torch.randn(B, D_in)])
bench = ThroughputBenchmark(module)
for input in inputs:
# can do both args and kwargs here
bench.add_input(input[0], x2=input[1])
for i in range(NUM_INPUTS):
# or just unpack the list of inputs
module_result = module(*inputs[i])
bench_result = bench.run_once(*inputs[i])
assert_allclose(bench_result, module_result)
stats = bench.benchmark(
num_calling_threads=4,
num_warmup_iters=100,
num_iters=1000,
)
print("Avg latency (ms): {}".format(stats.latency_avg_ms))
print("Number of iterations: {}".format(stats.num_iters))
def test_script_module(self):
self.linear_test(TwoLayerNet)
def test_module(self):
self.linear_test(TwoLayerNetModule)
if __name__ == '__main__':
run_tests()

View File

@ -248,6 +248,8 @@ def add_torch_libs():
"torch/csrc/onnx/init.cpp",
"torch/csrc/serialization.cpp",
"torch/csrc/tensor/python_tensor.cpp",
"torch/csrc/utils/init.cpp",
"torch/csrc/utils/throughput_benchmark.cpp",
"torch/csrc/utils.cpp",
"torch/csrc/utils/cuda_lazy_init.cpp",
"torch/csrc/utils/invalid_arguments.cpp",

View File

@ -84,6 +84,8 @@ set(TORCH_PYTHON_SRCS
${TORCH_SRC_DIR}/csrc/jit/script/python_tree_views.cpp
${TORCH_SRC_DIR}/csrc/multiprocessing/init.cpp
${TORCH_SRC_DIR}/csrc/onnx/init.cpp
${TORCH_SRC_DIR}/csrc/utils/init.cpp
${TORCH_SRC_DIR}/csrc/utils/throughput_benchmark.cpp
${TORCH_SRC_DIR}/csrc/serialization.cpp
${TORCH_SRC_DIR}/csrc/tensor/python_tensor.cpp
${TORCH_SRC_DIR}/csrc/utils.cpp

View File

@ -44,6 +44,7 @@
#include <torch/csrc/jit/init.h>
#include <torch/csrc/jit/python_ir.h>
#include <torch/csrc/onnx/init.h>
#include <torch/csrc/utils/init.h>
#include <torch/csrc/api/include/torch/python/init.h>
#ifdef USE_CUDNN
@ -644,6 +645,7 @@ PyObject* initModule() {
// init.
torch::onnx::initONNXBindings(module);
torch::jit::initJITBindings(module);
torch::throughput_benchmark::initThroughputBenchmarkBindings(module);
torch::autograd::initNNFunctions(module);
torch::autograd::init_legacy_variable(module);
torch::python::init_bindings(module);

View File

@ -1,5 +1,7 @@
#pragma once
#include <torch/csrc/utils/pybind.h>
namespace torch {
namespace jit {

50
torch/csrc/utils/init.cpp Normal file
View File

@ -0,0 +1,50 @@
#include <ATen/core/ivalue.h>
#include <torch/csrc/utils/init.h>
#include <torch/csrc/utils/throughput_benchmark.h>
#include <pybind11/functional.h>
namespace torch {
namespace throughput_benchmark {
void initThroughputBenchmarkBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
using namespace torch::throughput_benchmark;
py::class_<BenchmarkConfig>(m, "BenchmarkConfig")
.def(py::init<>())
.def_readwrite(
"num_calling_threads", &BenchmarkConfig::num_calling_threads)
.def_readwrite("num_worker_threads", &BenchmarkConfig::num_worker_threads)
.def_readwrite("num_warmup_iters", &BenchmarkConfig::num_warmup_iters)
.def_readwrite("num_iters", &BenchmarkConfig::num_iters);
py::class_<BenchmarkExecutionStats>(m, "BenchmarkExecutionStats")
.def_readonly("latency_avg_ms", &BenchmarkExecutionStats::latency_avg_ms)
.def_readonly("num_iters", &BenchmarkExecutionStats::num_iters);
py::class_<ThroughputBenchmark>(m, "ThroughputBenchmark", py::dynamic_attr())
.def(py::init<std::shared_ptr<jit::script::Module>>())
.def(py::init<py::object>())
.def(
"add_input",
[](ThroughputBenchmark& self, py::args args, py::kwargs kwargs) {
self.addInput(std::move(args), std::move(kwargs));
})
.def(
"run_once",
[](ThroughputBenchmark& self, py::args args, py::kwargs kwargs) {
// Depending on this being ScriptModule of nn.Module we will release
// the GIL or not further down in the stack
return self.runOnce(std::move(args), std::move(kwargs));
})
.def("benchmark", [](ThroughputBenchmark& self, BenchmarkConfig config) {
// The benchmark always runs without the GIL. GIL will be used where
// needed. This will happen only in the nn.Module mode when manipulating
// inputs and running actual inference
AutoNoGIL no_gil_guard;
return self.benchmark(config);
});
}
} // namespace throughput_benchmark
} // namespace torch

11
torch/csrc/utils/init.h Normal file
View File

@ -0,0 +1,11 @@
#pragma once
#include <torch/csrc/utils/pybind.h>
namespace torch {
namespace throughput_benchmark {
void initThroughputBenchmarkBindings(PyObject* module);
} // namespace throughput_benchmark
} // namespace torch

View File

@ -0,0 +1,130 @@
#pragma once
#include <random>
#include <thread>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/jit/pybind_utils.h>
namespace torch {
namespace throughput_benchmark {
namespace detail {
template <class Input, class Output, class Model>
BenchmarkExecutionStats BenchmarkHelper<Input, Output, Model>::benchmark(
const BenchmarkConfig& config) const {
CHECK(initialized_);
TORCH_CHECK(
config.num_worker_threads == 1,
"Only parallelization by callers is supported");
// We pre-generate inputs here for each of the threads. This allows us to
// safely move inputs out for each of the threads independently and thus avoid
// overhead from the benchmark runner itself
std::vector<std::vector<Input>> thread_inputs(config.num_calling_threads);
std::vector<size_t> input_iters(config.num_calling_threads);
{
std::random_device seeder;
std::mt19937 engine(seeder());
TORCH_CHECK(
!inputs_.empty(),
"Please provide benchmark inptus."
"Did you forget to call add_input()? ");
std::uniform_int_distribution<int> dist(0, inputs_.size() - 1);
for (int thread_id = 0; thread_id < config.num_calling_threads;
++thread_id) {
// Just in case we generate num_iters inputs for each of the threads
// This was if one thread does all the work we will be fine
for (int i = 0; i < config.num_iters + config.num_warmup_iters; ++i) {
thread_inputs[thread_id].push_back(cloneInput(inputs_[dist(engine)]));
}
input_iters[thread_id] = 0;
}
}
std::mutex m;
std::condition_variable worker_main_cv;
std::condition_variable main_worker_cv;
// TODO: add GUARDED_BY once it is available
int64_t initialized{0};
int64_t finished{0};
bool start{false};
std::atomic<int64_t> num_forwards{0};
std::vector<std::thread> callers;
for (auto thread_id = 0; thread_id < config.num_calling_threads;
++thread_id) {
callers.emplace_back([&, thread_id]() {
// We use conditional variable as a barrier to make sure each thread
// performs required warmeup iterations before we start measuring
for (auto j = 0; j < config.num_warmup_iters; ++j) {
runOnce(std::move(thread_inputs[thread_id][input_iters[thread_id]]));
++input_iters[thread_id];
}
{
std::unique_lock<std::mutex> lock(m);
++initialized;
worker_main_cv.notify_one();
while (!start) {
main_worker_cv.wait(lock);
}
}
LOG(INFO) << "Starting forward thread " << thread_id;
while (num_forwards.fetch_add(1) < config.num_iters) {
runOnce(std::move(thread_inputs[thread_id][input_iters[thread_id]]));
++input_iters[thread_id];
}
{
std::unique_lock<std::mutex> lock(m);
++finished;
worker_main_cv.notify_one();
LOG(INFO) << "Shutting down forward thread " << thread_id
<< ". Total number of finished threads: " << finished;
}
});
}
using Clock = std::chrono::high_resolution_clock;
using TimePoint = std::chrono::time_point<Clock>;
TimePoint start_time;
{
std::unique_lock<std::mutex> lock(m);
while (initialized != config.num_calling_threads) {
worker_main_cv.wait(lock);
}
LOG(INFO) << "Starting threads";
start = true;
start_time = Clock::now();
}
main_worker_cv.notify_all();
{
std::unique_lock<std::mutex> lock(m);
worker_main_cv.wait(
lock, [&]() { return finished == config.num_calling_threads; });
}
auto end_time = std::chrono::high_resolution_clock::now();
LOG(INFO) << "Finished benchmark";
BenchmarkExecutionStats stats;
float total_time_ms = std::chrono::duration_cast<std::chrono::nanoseconds>(
end_time - start_time)
.count() /
1000.0 / 1000.0;
stats.latency_avg_ms =
total_time_ms * config.num_calling_threads / num_forwards;
stats.num_iters = num_forwards;
for (auto& t : callers) {
t.join();
}
return stats;
}
} // namespace detail
} // namespace throughput_benchmark
} // namespace torch

View File

@ -0,0 +1,129 @@
#include <torch/csrc/utils/throughput_benchmark.h>
#include <torch/csrc/jit/pybind_utils.h>
#include <torch/csrc/utils/auto_gil.h>
namespace torch {
namespace throughput_benchmark {
void ThroughputBenchmark::addInput(py::args args, py::kwargs kwargs) {
CHECK(script_module_.initialized() ^ module_.initialized());
if (script_module_.initialized()) {
script_module_.addInput(std::move(args), std::move(kwargs));
} else {
CHECK(module_.initialized());
module_.addInput(std::move(args), std::move(kwargs));
}
}
py::object ThroughputBenchmark::runOnce(py::args&& args, py::kwargs&& kwargs) {
CHECK(script_module_.initialized() ^ module_.initialized());
if (script_module_.initialized()) {
c10::IValue result;
{
AutoNoGIL no_gil_guard;
result = script_module_.runOnce(std::move(args), std::move(kwargs));
}
return jit::toPyObject(std::move(result));
} else {
CHECK(module_.initialized());
return module_.runOnce(std::move(args), std::move(kwargs));
}
}
ThroughputBenchmark::ThroughputBenchmark(
std::shared_ptr<jit::script::Module> script_module)
: script_module_(std::move(script_module)) {}
ThroughputBenchmark::ThroughputBenchmark(
py::object module)
: module_(std::move(module)) {}
BenchmarkExecutionStats ThroughputBenchmark::benchmark(
const BenchmarkConfig& config) const {
CHECK(script_module_.initialized() ^ module_.initialized());
// Main benchmark thread doesn't hold the GIL after scheduling worker threads
// But for now we don't release it as we will be implicitly manipulating with
// py::object ref. counts in the case of nn.Module benchmarking.
if (script_module_.initialized()) {
return script_module_.benchmark(config);
} else {
CHECK(module_.initialized());
TORCH_WARN("Starting benchmark on an nn.Module. This can be slow due "
"to Python GIL.For proper inference simulation you might want to switch to "
"a ScriptModule instead");
return module_.benchmark(config);
}
}
namespace detail {
template <>
void ScriptModuleBenchmark::runOnce(ScriptModuleInput&& input) const {
CHECK(initialized_);
// TODO: provide guarantees that compiler won't optimize this out
model_->get_method("forward").function()(std::move(input));
}
template <>
ScriptModuleOutput ScriptModuleBenchmark::runOnce(
py::args&& args,
py::kwargs&& kwargs) const {
CHECK(initialized_);
auto& function = model_->get_method("forward").function();
ScriptModuleInput stack = jit::createStackForSchema(
function.getSchema(),
std::move(args),
std::move(kwargs),
model_->module_object());
return function(std::move(stack));
}
template <>
void ModuleBenchmark::runOnce(ModuleInput&& input) const {
CHECK(initialized_);
AutoGIL gil_guard;
model_(*input.args, **input.kwargs);
}
template <>
ModuleOutput ModuleBenchmark::runOnce(py::args&& args, py::kwargs&& kwargs)
const {
CHECK(initialized_);
AutoGIL gil_guard;
return model_(*args, **kwargs);
}
template <>
void ScriptModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs) {
jit::Stack stack = jit::createStackForSchema(
model_->get_method("forward").function().getSchema(),
std::move(args),
std::move(kwargs),
model_->module_object());
inputs_.emplace_back(std::move(stack));
}
template <>
void ModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs) {
inputs_.emplace_back(std::move(args), std::move(kwargs));
}
template <>
ModuleInput cloneInput<ModuleInput>(const ModuleInput& input) {
AutoGIL gil_guard;
py::args args = input.args;
py::kwargs kwargs = input.kwargs;
return {std::move(args), std::move(kwargs)};
}
template <>
ScriptModuleInput cloneInput<ScriptModuleInput>(
const ScriptModuleInput& input) {
return input;
}
} // namespace detail
} // namespace throughput_benchmark
} // namepsace torch

View File

@ -0,0 +1,178 @@
#pragma once
#include <ATen/core/ivalue.h>
#include <torch/csrc/jit/script/module.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/jit/pybind_utils.h>
#include <vector>
#include <memory>
namespace py = pybind11;
namespace torch {
namespace throughput_benchmark {
/**
* The struct is used to provide results of a benchmark to the caller
* In the future all additional statics should be added here.
*/
struct BenchmarkExecutionStats {
float latency_avg_ms{-1};
int64_t num_iters{-1};
};
/**
* Use this struct in order to configure a throughput benchmark run.
* This struct should include parameters related to threading, batching, number
* of iterations, warm-up, etc. More configs can be added as needed.
* General rule here is that only things that c++ must(!) to be aware of should
* be here. If we can keep other parts in python, we should keep them there.
* This is typical for things that are not perf critical and don't affect
* execution statistics benchmark returns.
*/
struct BenchmarkConfig {
public:
// Calling threads are those threads that are calling into a module in
// parallel.
int num_calling_threads{1};
// Worker threads are not supported yet. This is just an example that we plan
// to support some sort of multi-threaded forward calls. We may change this
// setting in the future to support different intra and inter op parallelizm
// which is not available in PyTorch yet
int num_worker_threads{1};
// Warmup iters are used to make sure we run a module a few times before
// actually measuring things. This way we avoid cold caches and any other
// similar problems
int num_warmup_iters{1};
// Number of iterations the benchmark should run with. This number is separate
// from the warmup iterations
int64_t num_iters{100};
};
namespace detail {
/**
* A helper class to abstract out different models we test throughput of
*/
template <class Input, class Output, class Model>
class BenchmarkHelper {
public:
BenchmarkHelper(): initialized_{false} {}
explicit BenchmarkHelper(Model model): model_(model), initialized_(true) {}
// This method to be used in benchmark() method
// Note that there is no result. This way we don't have to call this under GIL
// even when running in the nn.Module mode. Otherwise destructor of the result
// would race with Python
void runOnce(Input&&) const;
// This method is to be used when calling from Python dirrectly
Output runOnce(py::args&&, py::kwargs&&) const;
// Aggregate input in the format Model expects in order to avoid further
// conversions at the benchmark time
void addInput(py::args&&, py::kwargs&&);
BenchmarkExecutionStats benchmark(const BenchmarkConfig& config) const;
bool initialized() const { return initialized_; }
// Destructor doesn't require the GIL because it is going to be executed on
// the PyThon thread
std::vector<Input> inputs_;
Model model_;
bool initialized_{false};
};
struct C10_HIDDEN ModuleInput {
ModuleInput(ModuleInput&& other) = default;
ModuleInput(const ModuleInput&) = delete;
ModuleInput& operator=(ModuleInput& other) = delete;
ModuleInput& operator=(ModuleInput&& other) = delete;
ModuleInput(py::args&& args, py::kwargs&& kwargs)
: args(std::move(args)), kwargs(std::move(kwargs)) {}
py::args args;
py::kwargs kwargs;
};
typedef py::object ModuleOutput;
typedef std::vector<at::IValue> ScriptModuleInput;
typedef at::IValue ScriptModuleOutput;
template<class Input>
Input cloneInput(const Input& input);
typedef BenchmarkHelper<
ScriptModuleInput,
at::IValue,
std::shared_ptr<jit::script::Module>>
ScriptModuleBenchmark;
typedef BenchmarkHelper<ModuleInput, py::object, py::object> ModuleBenchmark;
template <>
void ScriptModuleBenchmark::runOnce(
ScriptModuleInput&& input) const;
template <>
ScriptModuleOutput ScriptModuleBenchmark::runOnce(
py::args&& args,
py::kwargs&& kwargs) const;
template <>
void ModuleBenchmark::runOnce(ModuleInput&& input) const;
template <>
ModuleOutput ModuleBenchmark::runOnce(py::args&& args, py::kwargs&& kwargs)
const;
template <>
void ScriptModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs);
template <>
void ModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs);
} // namespace detail
/**
* This class is a small c++ component responsible for executing a PyTorch
* module under an inference server like load. It can emulate multiple calling
* threads to a single module provided. In the future we plan to enhance this
* component to support inter and intra-op parallelism as well as multiple
* models running in a single process.
*
* For current available configurations refer to the BenchmkarConfig
* documentation
*
* The class supports working with either nn.Module or ScriptModule.
* Under the hood it just dispatches to corresponding specialization of
* class BenchmarkHelper<Input, Output, Model>
*/
class C10_HIDDEN ThroughputBenchmark {
public:
explicit ThroughputBenchmark(std::shared_ptr<jit::script::Module> module);
explicit ThroughputBenchmark(py::object module);
// Add one more input example. This input example should be in the exact
// format the module under test expects. It is responsibility of the module to
// perform any such format checks, the benchmark doesn't perform any
// validation of its own
void addInput(py::args args, py::kwargs kwargs);
// Equivalent to just running the model dirrectly on the given input
py::object runOnce(py::args&& args, py::kwargs&& kwargs);
// The main method of the class allows to perform a multi-threaded benchmark
// It returns BenchmarkExecutionStats object with a lot of useful statistics
// about runtime execution. We can enhance this class in the future to provide
// more information to the user
BenchmarkExecutionStats benchmark(const BenchmarkConfig& config) const;
private:
detail::ScriptModuleBenchmark script_module_;
detail::ModuleBenchmark module_;
};
} // namespace throughput benchmark
} // namepsace torch
#include <torch/csrc/utils/throughput_benchmark-inl.h>

View File

@ -1 +1,3 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from .throughput_benchmark import ThroughputBenchmark # noqa: F401

View File

@ -0,0 +1,87 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import torch._C
class ThroughputBenchmark(object):
'''
This class is a wrapper around a c++ component throughput_benchmark::ThroughputBenchmark
responsible for executing a PyTorch module (nn.Module or ScriptModule)
under an inference server like load. It can emulate multiple calling threads
to a single module provided. In the future we plan to enhance this component
to support inter and intra-op parallelism as well as multiple models
running in a single process.
Please note that even though nn.Module is supported, it might incur an overhead
from the need to hold GIL every time we execute Python code or pass around
inputs as Python objects. As soon as you have a ScriptModule version of your
model for inference deployment it is better to switch to using it in this
benchmark.
Example::
>>> from torch.utils import ThroughputBenchmark
>>> bench = ThroughputBenchmark(my_module)
>>> # Pre-populate benchmark's data set with the inputs
>>> for input in inputs:
# Both args and kwargs work, same as any PyTorch Module / ScriptModule
bench.add_input(input[0], x2=input[1])
>>> Inputs supplied above are randomly used during the execution
>>> stats = bench.benchmark(
num_calling_threads=4,
num_warmup_iters = 100,
num_iters = 1000,
)
>>> print("Avg latency (ms): {}".format(stats.latency_avg_ms))
>>> print("Number of iterations: {}".format(stats.num_iters))
'''
def __init__(self, module):
if isinstance(module, torch.jit.ScriptModule):
self._benchmark = torch._C.ThroughputBenchmark(module._c)
else:
self._benchmark = torch._C.ThroughputBenchmark(module)
def run_once(self, *args, **kwargs):
'''
Given input id (input_idx) run benchmark once and return prediction.
This is useful for testing that benchmark actually runs the module you
want it to run. input_idx here is an index into inputs array populated
by calling add_input() method.
'''
return self._benchmark.run_once(*args, **kwargs)
def add_input(self, *args, **kwargs):
'''
Store a single input to a module into the benchmark memory and keep it
there. During the benchmark execution every thread is going to pick up a
random input from the all the inputs ever supplied to the benchmark via
this function.
'''
self._benchmark.add_input(*args, **kwargs)
def benchmark(self, num_calling_threads=1, num_warmup_iters=10, num_iters=100):
'''
Args:
num_warmup_iters (int): Warmup iters are used to make sure we run a module
a few times before actually measuring things. This way we avoid cold
caches and any other similar problems. This is the number of warmup
iterations for each of the thread in separate
num_iters (int): Number of iterations the benchmark should run with.
This number is separate from the warmup iterations. Also the number is
shared across all the threads. Once the num_iters iterations across all
the threads is reached, we will stop execution. Though total number of
iterations might be slightly larger. Which is reported as
stats.num_iters where stats is the result of this function
This function returns BenchmarkExecutionStats object which is defined via pybind11.
It currently has two fields:
- num_iters - number of actual iterations the benchmark have made
- avg_latency_ms - average time it took to infer on one input example in milliseconds
'''
config = torch._C.BenchmarkConfig()
config.num_calling_threads = num_calling_threads
config.num_warmup_iters = num_warmup_iters
config.num_iters = num_iters
return self._benchmark.benchmark(config)