diff --git a/binaries/speed_benchmark_torch.cc b/binaries/speed_benchmark_torch.cc index 41688b0fdb89..6470d5acfcdd 100644 --- a/binaries/speed_benchmark_torch.cc +++ b/binaries/speed_benchmark_torch.cc @@ -48,6 +48,8 @@ C10_DEFINE_bool( false, "Whether to print performance stats for AI-PEP."); +C10_DEFINE_int(pytext_len, 0, "Length of input sequence."); + std::vector split(char separator, const std::string& string, bool ignore_empty = true) { std::vector pieces; @@ -97,11 +99,18 @@ int main(int argc, char** argv) { inputs.push_back(torch::ones(input_dims, at::ScalarType::Float)); } else if (input_type_list[i] == "uint8_t") { inputs.push_back(torch::ones(input_dims, at::ScalarType::Byte)); + } else if (input_type_list[i] == "int64") { + inputs.push_back(torch::ones(input_dims, torch::kI64)); } else { CAFFE_THROW("Unsupported input type: ", input_type_list[i]); } } + if (FLAGS_pytext_len > 0) { + auto stensor = FLAGS_pytext_len * at::ones({1}, torch::kI64); + inputs.push_back(stensor); + } + auto qengines = at::globalContext().supportedQEngines(); if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) != qengines.end()) { at::globalContext().setQEngine(at::QEngine::QNNPACK); @@ -138,7 +147,7 @@ int main(int argc, char** argv) { auto start = high_resolution_clock::now(); module.forward(inputs); auto stop = high_resolution_clock::now(); - auto duration = duration_cast(stop - start); + auto duration = duration_cast(stop - start); times.push_back(duration.count()); } millis = timer.MilliSeconds();