mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Back out "Revert D15435461: [pytorch][PR] PyTorch ThroughputBenchmark" (#22185)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22185 Original commit changeset: 72a0eac1658b Differential Revision: D15981928 fbshipit-source-id: d2455d79e81c26ee90d41414cde8ac0f9b703bc3
This commit is contained in:
committed by
Facebook Github Bot
parent
3f2a839dda
commit
f51de8b61a
@ -45,6 +45,7 @@
|
||||
#endif // C10_USING_CUSTOM_GENERATED_MACROS
|
||||
|
||||
#ifdef _WIN32
|
||||
#define C10_HIDDEN
|
||||
#if defined(C10_BUILD_SHARED_LIBS)
|
||||
#define C10_EXPORT __declspec(dllexport)
|
||||
#define C10_IMPORT __declspec(dllimport)
|
||||
@ -55,8 +56,10 @@
|
||||
#else // _WIN32
|
||||
#if defined(__GNUC__)
|
||||
#define C10_EXPORT __attribute__((__visibility__("default")))
|
||||
#define C10_HIDDEN __attribute__((__visibility__("hidden")))
|
||||
#else // defined(__GNUC__)
|
||||
#define C10_EXPORT
|
||||
#define C10_HIDDEN
|
||||
#endif // defined(__GNUC__)
|
||||
#define C10_IMPORT C10_EXPORT
|
||||
#endif // _WIN32
|
||||
|
79
test/test_throughput_benchmark.py
Normal file
79
test/test_throughput_benchmark.py
Normal file
@ -0,0 +1,79 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import torch
|
||||
from torch.utils import ThroughputBenchmark
|
||||
from torch.testing import assert_allclose
|
||||
|
||||
from common_utils import run_tests, TestCase
|
||||
|
||||
class TwoLayerNet(torch.jit.ScriptModule):
|
||||
def __init__(self, D_in, H, D_out):
|
||||
super(TwoLayerNet, self).__init__()
|
||||
self.linear1 = torch.nn.Linear(D_in, H)
|
||||
self.linear2 = torch.nn.Linear(2 * H, D_out)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x1, x2):
|
||||
h1_relu = self.linear1(x1).clamp(min=0)
|
||||
h2_relu = self.linear1(x2).clamp(min=0)
|
||||
cat = torch.cat((h1_relu, h2_relu), 1)
|
||||
y_pred = self.linear2(cat)
|
||||
return y_pred
|
||||
|
||||
class TwoLayerNetModule(torch.nn.Module):
|
||||
def __init__(self, D_in, H, D_out):
|
||||
super(TwoLayerNetModule, self).__init__()
|
||||
self.linear1 = torch.nn.Linear(D_in, H)
|
||||
self.linear2 = torch.nn.Linear(2 * H, D_out)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
h1_relu = self.linear1(x1).clamp(min=0)
|
||||
h2_relu = self.linear1(x2).clamp(min=0)
|
||||
cat = torch.cat((h1_relu, h2_relu), 1)
|
||||
y_pred = self.linear2(cat)
|
||||
return y_pred
|
||||
|
||||
class TestThroughputBenchmark(TestCase):
|
||||
def linear_test(self, Module):
|
||||
D_in = 10
|
||||
H = 5
|
||||
D_out = 15
|
||||
B = 8
|
||||
NUM_INPUTS = 2
|
||||
|
||||
module = Module(D_in, H, D_out)
|
||||
|
||||
inputs = []
|
||||
|
||||
for i in range(NUM_INPUTS):
|
||||
inputs.append([torch.randn(B, D_in), torch.randn(B, D_in)])
|
||||
bench = ThroughputBenchmark(module)
|
||||
|
||||
for input in inputs:
|
||||
# can do both args and kwargs here
|
||||
bench.add_input(input[0], x2=input[1])
|
||||
|
||||
for i in range(NUM_INPUTS):
|
||||
# or just unpack the list of inputs
|
||||
module_result = module(*inputs[i])
|
||||
bench_result = bench.run_once(*inputs[i])
|
||||
assert_allclose(bench_result, module_result)
|
||||
|
||||
stats = bench.benchmark(
|
||||
num_calling_threads=4,
|
||||
num_warmup_iters=100,
|
||||
num_iters=1000,
|
||||
)
|
||||
|
||||
print("Avg latency (ms): {}".format(stats.latency_avg_ms))
|
||||
print("Number of iterations: {}".format(stats.num_iters))
|
||||
|
||||
|
||||
def test_script_module(self):
|
||||
self.linear_test(TwoLayerNet)
|
||||
|
||||
def test_module(self):
|
||||
self.linear_test(TwoLayerNetModule)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
@ -248,6 +248,8 @@ def add_torch_libs():
|
||||
"torch/csrc/onnx/init.cpp",
|
||||
"torch/csrc/serialization.cpp",
|
||||
"torch/csrc/tensor/python_tensor.cpp",
|
||||
"torch/csrc/utils/init.cpp",
|
||||
"torch/csrc/utils/throughput_benchmark.cpp",
|
||||
"torch/csrc/utils.cpp",
|
||||
"torch/csrc/utils/cuda_lazy_init.cpp",
|
||||
"torch/csrc/utils/invalid_arguments.cpp",
|
||||
|
@ -84,6 +84,8 @@ set(TORCH_PYTHON_SRCS
|
||||
${TORCH_SRC_DIR}/csrc/jit/script/python_tree_views.cpp
|
||||
${TORCH_SRC_DIR}/csrc/multiprocessing/init.cpp
|
||||
${TORCH_SRC_DIR}/csrc/onnx/init.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/init.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils/throughput_benchmark.cpp
|
||||
${TORCH_SRC_DIR}/csrc/serialization.cpp
|
||||
${TORCH_SRC_DIR}/csrc/tensor/python_tensor.cpp
|
||||
${TORCH_SRC_DIR}/csrc/utils.cpp
|
||||
|
@ -44,6 +44,7 @@
|
||||
#include <torch/csrc/jit/init.h>
|
||||
#include <torch/csrc/jit/python_ir.h>
|
||||
#include <torch/csrc/onnx/init.h>
|
||||
#include <torch/csrc/utils/init.h>
|
||||
#include <torch/csrc/api/include/torch/python/init.h>
|
||||
|
||||
#ifdef USE_CUDNN
|
||||
@ -644,6 +645,7 @@ PyObject* initModule() {
|
||||
// init.
|
||||
torch::onnx::initONNXBindings(module);
|
||||
torch::jit::initJITBindings(module);
|
||||
torch::throughput_benchmark::initThroughputBenchmarkBindings(module);
|
||||
torch::autograd::initNNFunctions(module);
|
||||
torch::autograd::init_legacy_variable(module);
|
||||
torch::python::init_bindings(module);
|
||||
|
@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
|
50
torch/csrc/utils/init.cpp
Normal file
50
torch/csrc/utils/init.cpp
Normal file
@ -0,0 +1,50 @@
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <torch/csrc/utils/init.h>
|
||||
#include <torch/csrc/utils/throughput_benchmark.h>
|
||||
|
||||
#include <pybind11/functional.h>
|
||||
|
||||
namespace torch {
|
||||
namespace throughput_benchmark {
|
||||
|
||||
void initThroughputBenchmarkBindings(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
using namespace torch::throughput_benchmark;
|
||||
py::class_<BenchmarkConfig>(m, "BenchmarkConfig")
|
||||
.def(py::init<>())
|
||||
.def_readwrite(
|
||||
"num_calling_threads", &BenchmarkConfig::num_calling_threads)
|
||||
.def_readwrite("num_worker_threads", &BenchmarkConfig::num_worker_threads)
|
||||
.def_readwrite("num_warmup_iters", &BenchmarkConfig::num_warmup_iters)
|
||||
.def_readwrite("num_iters", &BenchmarkConfig::num_iters);
|
||||
|
||||
py::class_<BenchmarkExecutionStats>(m, "BenchmarkExecutionStats")
|
||||
.def_readonly("latency_avg_ms", &BenchmarkExecutionStats::latency_avg_ms)
|
||||
.def_readonly("num_iters", &BenchmarkExecutionStats::num_iters);
|
||||
|
||||
py::class_<ThroughputBenchmark>(m, "ThroughputBenchmark", py::dynamic_attr())
|
||||
.def(py::init<std::shared_ptr<jit::script::Module>>())
|
||||
.def(py::init<py::object>())
|
||||
.def(
|
||||
"add_input",
|
||||
[](ThroughputBenchmark& self, py::args args, py::kwargs kwargs) {
|
||||
self.addInput(std::move(args), std::move(kwargs));
|
||||
})
|
||||
.def(
|
||||
"run_once",
|
||||
[](ThroughputBenchmark& self, py::args args, py::kwargs kwargs) {
|
||||
// Depending on this being ScriptModule of nn.Module we will release
|
||||
// the GIL or not further down in the stack
|
||||
return self.runOnce(std::move(args), std::move(kwargs));
|
||||
})
|
||||
.def("benchmark", [](ThroughputBenchmark& self, BenchmarkConfig config) {
|
||||
// The benchmark always runs without the GIL. GIL will be used where
|
||||
// needed. This will happen only in the nn.Module mode when manipulating
|
||||
// inputs and running actual inference
|
||||
AutoNoGIL no_gil_guard;
|
||||
return self.benchmark(config);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace throughput_benchmark
|
||||
} // namespace torch
|
11
torch/csrc/utils/init.h
Normal file
11
torch/csrc/utils/init.h
Normal file
@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch {
|
||||
namespace throughput_benchmark {
|
||||
|
||||
void initThroughputBenchmarkBindings(PyObject* module);
|
||||
|
||||
} // namespace throughput_benchmark
|
||||
} // namespace torch
|
130
torch/csrc/utils/throughput_benchmark-inl.h
Normal file
130
torch/csrc/utils/throughput_benchmark-inl.h
Normal file
@ -0,0 +1,130 @@
|
||||
#pragma once
|
||||
|
||||
#include <random>
|
||||
#include <thread>
|
||||
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/jit/pybind_utils.h>
|
||||
|
||||
namespace torch {
|
||||
namespace throughput_benchmark {
|
||||
namespace detail {
|
||||
|
||||
template <class Input, class Output, class Model>
|
||||
BenchmarkExecutionStats BenchmarkHelper<Input, Output, Model>::benchmark(
|
||||
const BenchmarkConfig& config) const {
|
||||
CHECK(initialized_);
|
||||
TORCH_CHECK(
|
||||
config.num_worker_threads == 1,
|
||||
"Only parallelization by callers is supported");
|
||||
|
||||
// We pre-generate inputs here for each of the threads. This allows us to
|
||||
// safely move inputs out for each of the threads independently and thus avoid
|
||||
// overhead from the benchmark runner itself
|
||||
std::vector<std::vector<Input>> thread_inputs(config.num_calling_threads);
|
||||
std::vector<size_t> input_iters(config.num_calling_threads);
|
||||
{
|
||||
std::random_device seeder;
|
||||
std::mt19937 engine(seeder());
|
||||
TORCH_CHECK(
|
||||
!inputs_.empty(),
|
||||
"Please provide benchmark inptus."
|
||||
"Did you forget to call add_input()? ");
|
||||
std::uniform_int_distribution<int> dist(0, inputs_.size() - 1);
|
||||
|
||||
for (int thread_id = 0; thread_id < config.num_calling_threads;
|
||||
++thread_id) {
|
||||
// Just in case we generate num_iters inputs for each of the threads
|
||||
// This was if one thread does all the work we will be fine
|
||||
for (int i = 0; i < config.num_iters + config.num_warmup_iters; ++i) {
|
||||
thread_inputs[thread_id].push_back(cloneInput(inputs_[dist(engine)]));
|
||||
}
|
||||
input_iters[thread_id] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
std::mutex m;
|
||||
std::condition_variable worker_main_cv;
|
||||
std::condition_variable main_worker_cv;
|
||||
// TODO: add GUARDED_BY once it is available
|
||||
int64_t initialized{0};
|
||||
int64_t finished{0};
|
||||
bool start{false};
|
||||
std::atomic<int64_t> num_forwards{0};
|
||||
std::vector<std::thread> callers;
|
||||
|
||||
for (auto thread_id = 0; thread_id < config.num_calling_threads;
|
||||
++thread_id) {
|
||||
callers.emplace_back([&, thread_id]() {
|
||||
// We use conditional variable as a barrier to make sure each thread
|
||||
// performs required warmeup iterations before we start measuring
|
||||
for (auto j = 0; j < config.num_warmup_iters; ++j) {
|
||||
runOnce(std::move(thread_inputs[thread_id][input_iters[thread_id]]));
|
||||
++input_iters[thread_id];
|
||||
}
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(m);
|
||||
++initialized;
|
||||
worker_main_cv.notify_one();
|
||||
while (!start) {
|
||||
main_worker_cv.wait(lock);
|
||||
}
|
||||
}
|
||||
LOG(INFO) << "Starting forward thread " << thread_id;
|
||||
while (num_forwards.fetch_add(1) < config.num_iters) {
|
||||
runOnce(std::move(thread_inputs[thread_id][input_iters[thread_id]]));
|
||||
++input_iters[thread_id];
|
||||
}
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(m);
|
||||
++finished;
|
||||
worker_main_cv.notify_one();
|
||||
LOG(INFO) << "Shutting down forward thread " << thread_id
|
||||
<< ". Total number of finished threads: " << finished;
|
||||
}
|
||||
|
||||
});
|
||||
}
|
||||
|
||||
using Clock = std::chrono::high_resolution_clock;
|
||||
using TimePoint = std::chrono::time_point<Clock>;
|
||||
TimePoint start_time;
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(m);
|
||||
while (initialized != config.num_calling_threads) {
|
||||
worker_main_cv.wait(lock);
|
||||
}
|
||||
LOG(INFO) << "Starting threads";
|
||||
start = true;
|
||||
start_time = Clock::now();
|
||||
}
|
||||
|
||||
main_worker_cv.notify_all();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(m);
|
||||
worker_main_cv.wait(
|
||||
lock, [&]() { return finished == config.num_calling_threads; });
|
||||
}
|
||||
auto end_time = std::chrono::high_resolution_clock::now();
|
||||
LOG(INFO) << "Finished benchmark";
|
||||
|
||||
BenchmarkExecutionStats stats;
|
||||
float total_time_ms = std::chrono::duration_cast<std::chrono::nanoseconds>(
|
||||
end_time - start_time)
|
||||
.count() /
|
||||
1000.0 / 1000.0;
|
||||
stats.latency_avg_ms =
|
||||
total_time_ms * config.num_calling_threads / num_forwards;
|
||||
stats.num_iters = num_forwards;
|
||||
|
||||
for (auto& t : callers) {
|
||||
t.join();
|
||||
}
|
||||
return stats;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace throughput_benchmark
|
||||
} // namespace torch
|
129
torch/csrc/utils/throughput_benchmark.cpp
Normal file
129
torch/csrc/utils/throughput_benchmark.cpp
Normal file
@ -0,0 +1,129 @@
|
||||
#include <torch/csrc/utils/throughput_benchmark.h>
|
||||
|
||||
#include <torch/csrc/jit/pybind_utils.h>
|
||||
#include <torch/csrc/utils/auto_gil.h>
|
||||
|
||||
namespace torch {
|
||||
namespace throughput_benchmark {
|
||||
|
||||
void ThroughputBenchmark::addInput(py::args args, py::kwargs kwargs) {
|
||||
CHECK(script_module_.initialized() ^ module_.initialized());
|
||||
if (script_module_.initialized()) {
|
||||
script_module_.addInput(std::move(args), std::move(kwargs));
|
||||
} else {
|
||||
CHECK(module_.initialized());
|
||||
module_.addInput(std::move(args), std::move(kwargs));
|
||||
}
|
||||
}
|
||||
|
||||
py::object ThroughputBenchmark::runOnce(py::args&& args, py::kwargs&& kwargs) {
|
||||
CHECK(script_module_.initialized() ^ module_.initialized());
|
||||
if (script_module_.initialized()) {
|
||||
c10::IValue result;
|
||||
{
|
||||
AutoNoGIL no_gil_guard;
|
||||
result = script_module_.runOnce(std::move(args), std::move(kwargs));
|
||||
}
|
||||
return jit::toPyObject(std::move(result));
|
||||
} else {
|
||||
CHECK(module_.initialized());
|
||||
return module_.runOnce(std::move(args), std::move(kwargs));
|
||||
}
|
||||
}
|
||||
|
||||
ThroughputBenchmark::ThroughputBenchmark(
|
||||
std::shared_ptr<jit::script::Module> script_module)
|
||||
: script_module_(std::move(script_module)) {}
|
||||
|
||||
ThroughputBenchmark::ThroughputBenchmark(
|
||||
py::object module)
|
||||
: module_(std::move(module)) {}
|
||||
|
||||
BenchmarkExecutionStats ThroughputBenchmark::benchmark(
|
||||
const BenchmarkConfig& config) const {
|
||||
CHECK(script_module_.initialized() ^ module_.initialized());
|
||||
// Main benchmark thread doesn't hold the GIL after scheduling worker threads
|
||||
// But for now we don't release it as we will be implicitly manipulating with
|
||||
// py::object ref. counts in the case of nn.Module benchmarking.
|
||||
if (script_module_.initialized()) {
|
||||
return script_module_.benchmark(config);
|
||||
} else {
|
||||
CHECK(module_.initialized());
|
||||
TORCH_WARN("Starting benchmark on an nn.Module. This can be slow due "
|
||||
"to Python GIL.For proper inference simulation you might want to switch to "
|
||||
"a ScriptModule instead");
|
||||
return module_.benchmark(config);
|
||||
}
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <>
|
||||
void ScriptModuleBenchmark::runOnce(ScriptModuleInput&& input) const {
|
||||
CHECK(initialized_);
|
||||
// TODO: provide guarantees that compiler won't optimize this out
|
||||
model_->get_method("forward").function()(std::move(input));
|
||||
}
|
||||
|
||||
template <>
|
||||
ScriptModuleOutput ScriptModuleBenchmark::runOnce(
|
||||
py::args&& args,
|
||||
py::kwargs&& kwargs) const {
|
||||
CHECK(initialized_);
|
||||
auto& function = model_->get_method("forward").function();
|
||||
ScriptModuleInput stack = jit::createStackForSchema(
|
||||
function.getSchema(),
|
||||
std::move(args),
|
||||
std::move(kwargs),
|
||||
model_->module_object());
|
||||
return function(std::move(stack));
|
||||
}
|
||||
|
||||
template <>
|
||||
void ModuleBenchmark::runOnce(ModuleInput&& input) const {
|
||||
CHECK(initialized_);
|
||||
AutoGIL gil_guard;
|
||||
model_(*input.args, **input.kwargs);
|
||||
}
|
||||
|
||||
template <>
|
||||
ModuleOutput ModuleBenchmark::runOnce(py::args&& args, py::kwargs&& kwargs)
|
||||
const {
|
||||
CHECK(initialized_);
|
||||
AutoGIL gil_guard;
|
||||
return model_(*args, **kwargs);
|
||||
}
|
||||
|
||||
template <>
|
||||
void ScriptModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs) {
|
||||
jit::Stack stack = jit::createStackForSchema(
|
||||
model_->get_method("forward").function().getSchema(),
|
||||
std::move(args),
|
||||
std::move(kwargs),
|
||||
model_->module_object());
|
||||
inputs_.emplace_back(std::move(stack));
|
||||
}
|
||||
|
||||
template <>
|
||||
void ModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs) {
|
||||
inputs_.emplace_back(std::move(args), std::move(kwargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
ModuleInput cloneInput<ModuleInput>(const ModuleInput& input) {
|
||||
AutoGIL gil_guard;
|
||||
py::args args = input.args;
|
||||
py::kwargs kwargs = input.kwargs;
|
||||
return {std::move(args), std::move(kwargs)};
|
||||
}
|
||||
|
||||
template <>
|
||||
ScriptModuleInput cloneInput<ScriptModuleInput>(
|
||||
const ScriptModuleInput& input) {
|
||||
return input;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace throughput_benchmark
|
||||
} // namepsace torch
|
178
torch/csrc/utils/throughput_benchmark.h
Normal file
178
torch/csrc/utils/throughput_benchmark.h
Normal file
@ -0,0 +1,178 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <torch/csrc/jit/script/module.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include <torch/csrc/jit/pybind_utils.h>
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace torch {
|
||||
namespace throughput_benchmark {
|
||||
|
||||
/**
|
||||
* The struct is used to provide results of a benchmark to the caller
|
||||
* In the future all additional statics should be added here.
|
||||
*/
|
||||
struct BenchmarkExecutionStats {
|
||||
float latency_avg_ms{-1};
|
||||
int64_t num_iters{-1};
|
||||
};
|
||||
|
||||
/**
|
||||
* Use this struct in order to configure a throughput benchmark run.
|
||||
* This struct should include parameters related to threading, batching, number
|
||||
* of iterations, warm-up, etc. More configs can be added as needed.
|
||||
* General rule here is that only things that c++ must(!) to be aware of should
|
||||
* be here. If we can keep other parts in python, we should keep them there.
|
||||
* This is typical for things that are not perf critical and don't affect
|
||||
* execution statistics benchmark returns.
|
||||
*/
|
||||
struct BenchmarkConfig {
|
||||
public:
|
||||
// Calling threads are those threads that are calling into a module in
|
||||
// parallel.
|
||||
int num_calling_threads{1};
|
||||
// Worker threads are not supported yet. This is just an example that we plan
|
||||
// to support some sort of multi-threaded forward calls. We may change this
|
||||
// setting in the future to support different intra and inter op parallelizm
|
||||
// which is not available in PyTorch yet
|
||||
int num_worker_threads{1};
|
||||
// Warmup iters are used to make sure we run a module a few times before
|
||||
// actually measuring things. This way we avoid cold caches and any other
|
||||
// similar problems
|
||||
int num_warmup_iters{1};
|
||||
// Number of iterations the benchmark should run with. This number is separate
|
||||
// from the warmup iterations
|
||||
int64_t num_iters{100};
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
/**
|
||||
* A helper class to abstract out different models we test throughput of
|
||||
*/
|
||||
template <class Input, class Output, class Model>
|
||||
class BenchmarkHelper {
|
||||
public:
|
||||
BenchmarkHelper(): initialized_{false} {}
|
||||
explicit BenchmarkHelper(Model model): model_(model), initialized_(true) {}
|
||||
|
||||
// This method to be used in benchmark() method
|
||||
// Note that there is no result. This way we don't have to call this under GIL
|
||||
// even when running in the nn.Module mode. Otherwise destructor of the result
|
||||
// would race with Python
|
||||
void runOnce(Input&&) const;
|
||||
// This method is to be used when calling from Python dirrectly
|
||||
Output runOnce(py::args&&, py::kwargs&&) const;
|
||||
// Aggregate input in the format Model expects in order to avoid further
|
||||
// conversions at the benchmark time
|
||||
void addInput(py::args&&, py::kwargs&&);
|
||||
BenchmarkExecutionStats benchmark(const BenchmarkConfig& config) const;
|
||||
|
||||
bool initialized() const { return initialized_; }
|
||||
|
||||
// Destructor doesn't require the GIL because it is going to be executed on
|
||||
// the PyThon thread
|
||||
std::vector<Input> inputs_;
|
||||
Model model_;
|
||||
bool initialized_{false};
|
||||
};
|
||||
|
||||
struct C10_HIDDEN ModuleInput {
|
||||
ModuleInput(ModuleInput&& other) = default;
|
||||
|
||||
ModuleInput(const ModuleInput&) = delete;
|
||||
ModuleInput& operator=(ModuleInput& other) = delete;
|
||||
ModuleInput& operator=(ModuleInput&& other) = delete;
|
||||
|
||||
ModuleInput(py::args&& args, py::kwargs&& kwargs)
|
||||
: args(std::move(args)), kwargs(std::move(kwargs)) {}
|
||||
|
||||
py::args args;
|
||||
py::kwargs kwargs;
|
||||
};
|
||||
typedef py::object ModuleOutput;
|
||||
typedef std::vector<at::IValue> ScriptModuleInput;
|
||||
typedef at::IValue ScriptModuleOutput;
|
||||
|
||||
template<class Input>
|
||||
Input cloneInput(const Input& input);
|
||||
|
||||
typedef BenchmarkHelper<
|
||||
ScriptModuleInput,
|
||||
at::IValue,
|
||||
std::shared_ptr<jit::script::Module>>
|
||||
ScriptModuleBenchmark;
|
||||
typedef BenchmarkHelper<ModuleInput, py::object, py::object> ModuleBenchmark;
|
||||
|
||||
template <>
|
||||
void ScriptModuleBenchmark::runOnce(
|
||||
ScriptModuleInput&& input) const;
|
||||
|
||||
template <>
|
||||
ScriptModuleOutput ScriptModuleBenchmark::runOnce(
|
||||
py::args&& args,
|
||||
py::kwargs&& kwargs) const;
|
||||
|
||||
template <>
|
||||
void ModuleBenchmark::runOnce(ModuleInput&& input) const;
|
||||
|
||||
template <>
|
||||
ModuleOutput ModuleBenchmark::runOnce(py::args&& args, py::kwargs&& kwargs)
|
||||
const;
|
||||
|
||||
template <>
|
||||
void ScriptModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs);
|
||||
|
||||
template <>
|
||||
void ModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs);
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
* This class is a small c++ component responsible for executing a PyTorch
|
||||
* module under an inference server like load. It can emulate multiple calling
|
||||
* threads to a single module provided. In the future we plan to enhance this
|
||||
* component to support inter and intra-op parallelism as well as multiple
|
||||
* models running in a single process.
|
||||
*
|
||||
* For current available configurations refer to the BenchmkarConfig
|
||||
* documentation
|
||||
*
|
||||
* The class supports working with either nn.Module or ScriptModule.
|
||||
* Under the hood it just dispatches to corresponding specialization of
|
||||
* class BenchmarkHelper<Input, Output, Model>
|
||||
*/
|
||||
class C10_HIDDEN ThroughputBenchmark {
|
||||
public:
|
||||
explicit ThroughputBenchmark(std::shared_ptr<jit::script::Module> module);
|
||||
explicit ThroughputBenchmark(py::object module);
|
||||
|
||||
// Add one more input example. This input example should be in the exact
|
||||
// format the module under test expects. It is responsibility of the module to
|
||||
// perform any such format checks, the benchmark doesn't perform any
|
||||
// validation of its own
|
||||
void addInput(py::args args, py::kwargs kwargs);
|
||||
|
||||
// Equivalent to just running the model dirrectly on the given input
|
||||
py::object runOnce(py::args&& args, py::kwargs&& kwargs);
|
||||
|
||||
// The main method of the class allows to perform a multi-threaded benchmark
|
||||
// It returns BenchmarkExecutionStats object with a lot of useful statistics
|
||||
// about runtime execution. We can enhance this class in the future to provide
|
||||
// more information to the user
|
||||
BenchmarkExecutionStats benchmark(const BenchmarkConfig& config) const;
|
||||
|
||||
private:
|
||||
detail::ScriptModuleBenchmark script_module_;
|
||||
detail::ModuleBenchmark module_;
|
||||
};
|
||||
} // namespace throughput benchmark
|
||||
} // namepsace torch
|
||||
|
||||
#include <torch/csrc/utils/throughput_benchmark-inl.h>
|
@ -1 +1,3 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
from .throughput_benchmark import ThroughputBenchmark # noqa: F401
|
||||
|
87
torch/utils/throughput_benchmark.py
Normal file
87
torch/utils/throughput_benchmark.py
Normal file
@ -0,0 +1,87 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import torch._C
|
||||
|
||||
class ThroughputBenchmark(object):
|
||||
'''
|
||||
This class is a wrapper around a c++ component throughput_benchmark::ThroughputBenchmark
|
||||
responsible for executing a PyTorch module (nn.Module or ScriptModule)
|
||||
under an inference server like load. It can emulate multiple calling threads
|
||||
to a single module provided. In the future we plan to enhance this component
|
||||
to support inter and intra-op parallelism as well as multiple models
|
||||
running in a single process.
|
||||
|
||||
Please note that even though nn.Module is supported, it might incur an overhead
|
||||
from the need to hold GIL every time we execute Python code or pass around
|
||||
inputs as Python objects. As soon as you have a ScriptModule version of your
|
||||
model for inference deployment it is better to switch to using it in this
|
||||
benchmark.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from torch.utils import ThroughputBenchmark
|
||||
>>> bench = ThroughputBenchmark(my_module)
|
||||
>>> # Pre-populate benchmark's data set with the inputs
|
||||
>>> for input in inputs:
|
||||
# Both args and kwargs work, same as any PyTorch Module / ScriptModule
|
||||
bench.add_input(input[0], x2=input[1])
|
||||
>>> Inputs supplied above are randomly used during the execution
|
||||
>>> stats = bench.benchmark(
|
||||
num_calling_threads=4,
|
||||
num_warmup_iters = 100,
|
||||
num_iters = 1000,
|
||||
)
|
||||
>>> print("Avg latency (ms): {}".format(stats.latency_avg_ms))
|
||||
>>> print("Number of iterations: {}".format(stats.num_iters))
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self, module):
|
||||
if isinstance(module, torch.jit.ScriptModule):
|
||||
self._benchmark = torch._C.ThroughputBenchmark(module._c)
|
||||
else:
|
||||
self._benchmark = torch._C.ThroughputBenchmark(module)
|
||||
|
||||
def run_once(self, *args, **kwargs):
|
||||
'''
|
||||
Given input id (input_idx) run benchmark once and return prediction.
|
||||
This is useful for testing that benchmark actually runs the module you
|
||||
want it to run. input_idx here is an index into inputs array populated
|
||||
by calling add_input() method.
|
||||
'''
|
||||
return self._benchmark.run_once(*args, **kwargs)
|
||||
|
||||
def add_input(self, *args, **kwargs):
|
||||
'''
|
||||
Store a single input to a module into the benchmark memory and keep it
|
||||
there. During the benchmark execution every thread is going to pick up a
|
||||
random input from the all the inputs ever supplied to the benchmark via
|
||||
this function.
|
||||
'''
|
||||
self._benchmark.add_input(*args, **kwargs)
|
||||
|
||||
def benchmark(self, num_calling_threads=1, num_warmup_iters=10, num_iters=100):
|
||||
'''
|
||||
Args:
|
||||
num_warmup_iters (int): Warmup iters are used to make sure we run a module
|
||||
a few times before actually measuring things. This way we avoid cold
|
||||
caches and any other similar problems. This is the number of warmup
|
||||
iterations for each of the thread in separate
|
||||
|
||||
num_iters (int): Number of iterations the benchmark should run with.
|
||||
This number is separate from the warmup iterations. Also the number is
|
||||
shared across all the threads. Once the num_iters iterations across all
|
||||
the threads is reached, we will stop execution. Though total number of
|
||||
iterations might be slightly larger. Which is reported as
|
||||
stats.num_iters where stats is the result of this function
|
||||
|
||||
This function returns BenchmarkExecutionStats object which is defined via pybind11.
|
||||
It currently has two fields:
|
||||
- num_iters - number of actual iterations the benchmark have made
|
||||
- avg_latency_ms - average time it took to infer on one input example in milliseconds
|
||||
'''
|
||||
config = torch._C.BenchmarkConfig()
|
||||
config.num_calling_threads = num_calling_threads
|
||||
config.num_warmup_iters = num_warmup_iters
|
||||
config.num_iters = num_iters
|
||||
return self._benchmark.benchmark(config)
|
Reference in New Issue
Block a user