diff --git a/binaries/CMakeLists.txt b/binaries/CMakeLists.txt index a98754eea2c3..b683ee002280 100644 --- a/binaries/CMakeLists.txt +++ b/binaries/CMakeLists.txt @@ -4,6 +4,7 @@ if(INTERN_BUILD_MOBILE) caffe2_binary_target("speed_benchmark.cc") else() caffe2_binary_target("speed_benchmark_torch.cc") + caffe2_binary_target("load_benchmark_torch.cc") if(NOT BUILD_LITE_INTERPRETER) caffe2_binary_target("compare_models_torch.cc") endif() diff --git a/binaries/load_benchmark_torch.cc b/binaries/load_benchmark_torch.cc new file mode 100644 index 000000000000..330955657ece --- /dev/null +++ b/binaries/load_benchmark_torch.cc @@ -0,0 +1,93 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include "caffe2/core/timer.h" +#include "caffe2/utils/string_utils.h" +#include +#include +#include +#include +#include + +#include + +#include +using namespace std::chrono; + +C10_DEFINE_string(model, "", "The given torch script model to benchmark."); +C10_DEFINE_int(iter, 10, "The number of iterations to run."); +C10_DEFINE_bool( + report_pep, + true, + "Whether to print performance stats for AI-PEP."); + +int main(int argc, char** argv) { + c10::SetUsageMessage( + "Run model load time benchmark for pytorch model.\n" + "Example usage:\n" + "./load_benchmark_torch" + " --model=" + " --iter=20"); + if (!c10::ParseCommandLineFlags(&argc, &argv)) { + std::cerr << "Failed to parse command line flags!" << std::endl; + return 1; + } + + std::cout << "Starting benchmark." << std::endl; + CAFFE_ENFORCE( + FLAGS_iter >= 0, + "Number of main runs should be non negative, provided ", + FLAGS_iter, + "."); + + caffe2::Timer timer; + std::vector times; + + for (int i = 0; i < FLAGS_iter; ++i) { + auto start = high_resolution_clock::now(); + +#if BUILD_LITE_INTERPRETER + auto module = torch::jit::_load_for_mobile(FLAGS_model); +#else + auto module = torch::jit::load(FLAGS_model); +#endif + + auto stop = high_resolution_clock::now(); + auto duration = duration_cast(stop - start); + times.push_back(duration.count()); + } + + const double micros = static_cast(timer.MicroSeconds()); + if (FLAGS_report_pep) { + for (auto t : times) { + std::cout << R"(PyTorchObserver {"type": "NET", "unit": "us", )" + << R"("metric": "latency", "value": ")" + << t << R"("})" << std::endl; + } + } + + const double iters = static_cast(FLAGS_iter); + std::cout << "Main run finished. Microseconds per iter: " + << micros / iters + << ". Iters per second: " << 1000.0 * 1000 * iters / micros + << std::endl; + + return 0; +}