From f2f285c240efa2743f54653b83c03cc236a1fb27 Mon Sep 17 00:00:00 2001 From: Supriya Rao Date: Thu, 21 Nov 2019 16:00:58 -0800 Subject: [PATCH] Add arguments to benchmark to run pytext models. Output results in ms. (#30273) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30273 Pytext models expect input of the form `1xlength` and another input specifying the length. Add the `pytext_len` argument to specify this. ghstack-source-id: 94383501 Test Plan: ./speed_benchmark_torch --model model.pt --input_dims "1,4" --input_type int64 --warmup 10 --iter 10 --report_pep=true --pytext_len=4 Reviewed By: iseeyuan Differential Revision: D18646028 fbshipit-source-id: 7d5fe0c36da6e5f7b0261619ce4784a46b70f3d8 --- binaries/speed_benchmark_torch.cc | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/binaries/speed_benchmark_torch.cc b/binaries/speed_benchmark_torch.cc index 41688b0fdb89..6470d5acfcdd 100644 --- a/binaries/speed_benchmark_torch.cc +++ b/binaries/speed_benchmark_torch.cc @@ -48,6 +48,8 @@ C10_DEFINE_bool( false, "Whether to print performance stats for AI-PEP."); +C10_DEFINE_int(pytext_len, 0, "Length of input sequence."); + std::vector split(char separator, const std::string& string, bool ignore_empty = true) { std::vector pieces; @@ -97,11 +99,18 @@ int main(int argc, char** argv) { inputs.push_back(torch::ones(input_dims, at::ScalarType::Float)); } else if (input_type_list[i] == "uint8_t") { inputs.push_back(torch::ones(input_dims, at::ScalarType::Byte)); + } else if (input_type_list[i] == "int64") { + inputs.push_back(torch::ones(input_dims, torch::kI64)); } else { CAFFE_THROW("Unsupported input type: ", input_type_list[i]); } } + if (FLAGS_pytext_len > 0) { + auto stensor = FLAGS_pytext_len * at::ones({1}, torch::kI64); + inputs.push_back(stensor); + } + auto qengines = at::globalContext().supportedQEngines(); if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) != qengines.end()) { at::globalContext().setQEngine(at::QEngine::QNNPACK); @@ -138,7 +147,7 @@ int main(int argc, char** argv) { auto start = high_resolution_clock::now(); module.forward(inputs); auto stop = high_resolution_clock::now(); - auto duration = duration_cast(stop - start); + auto duration = duration_cast(stop - start); times.push_back(duration.count()); } millis = timer.MilliSeconds();