mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
add c2 benchmark runs in cpp (#20108)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20108 Add cpp runs for c2, hooked up via pybinds. Print output to terminal. This is not hooked up with the pep output yet because I'd like to verify the numbers first. Note that this isn't quite the same mechanism as the pytorch cpp hookup, which uses cpp_python_extensions. If I can use the same mechanism to pull all the inputs for c2 through cpp and do FeedBlobs in cpp, then I'll switch to that. Reviewed By: zheng-xq Differential Revision: D15155976 fbshipit-source-id: 708079dacd3e19aacfe43d70c5e5bc54da2cf9e3
This commit is contained in:
committed by
Facebook Github Bot
parent
d2da3ee601
commit
a9aaf698a4
@ -190,6 +190,15 @@ int ExecutorHelper::GetNumWorkers() const {
|
||||
CAFFE_THROW("Not implemented");
|
||||
}
|
||||
|
||||
// benchmark an individual run so that we can FeedBlobs with new inputs
|
||||
// no warmup
|
||||
// return time taken in microseconds
|
||||
float NetBase::TEST_Benchmark_One_Run() {
|
||||
Timer timer;
|
||||
CAFFE_ENFORCE(Run(), "Run has failed.");
|
||||
return timer.MicroSeconds();
|
||||
}
|
||||
|
||||
std::vector<float> NetBase::TEST_Benchmark(
|
||||
const int warmup_runs,
|
||||
const int main_runs,
|
||||
|
@ -63,6 +63,13 @@ class CAFFE2_API NetBase : public Observable<NetBase> {
|
||||
|
||||
virtual bool RunAsync();
|
||||
|
||||
/* Benchmarks a network for one individual run so that we can feed new
|
||||
* inputs on additional calls.
|
||||
* This function returns the number of microseconds spent
|
||||
* during the benchmark
|
||||
*/
|
||||
virtual float TEST_Benchmark_One_Run();
|
||||
|
||||
/**
|
||||
* Benchmarks a network.
|
||||
*
|
||||
|
@ -1263,6 +1263,14 @@ void addGlobalMethods(py::module& m) {
|
||||
net->TEST_Benchmark(warmup_runs, main_runs, run_individual);
|
||||
return stat;
|
||||
});
|
||||
m.def("benchmark_net_once", [](const std::string& name) {
|
||||
CAFFE_ENFORCE(gWorkspace);
|
||||
auto* net = gWorkspace->GetNet(name);
|
||||
CAFFE_ENFORCE(net, "Didn't find net: ", name);
|
||||
py::gil_scoped_release g;
|
||||
float stat = net->TEST_Benchmark_One_Run();
|
||||
return stat;
|
||||
});
|
||||
|
||||
m.def("delete_net", [](const std::string& name) {
|
||||
CAFFE_ENFORCE(gWorkspace);
|
||||
|
@ -36,6 +36,7 @@ SwitchWorkspace = C.switch_workspace
|
||||
RootFolder = C.root_folder
|
||||
Workspaces = C.workspaces
|
||||
BenchmarkNet = C.benchmark_net
|
||||
BenchmarkNetOnce = C.benchmark_net_once
|
||||
GetStats = C.get_stats
|
||||
|
||||
operator_tracebacks = defaultdict(dict)
|
||||
|
Reference in New Issue
Block a user