Add arguments to benchmark to run pytext models. Output results in ms. (#30273)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30273

Pytext models expect input of the form `1xlength` and another input specifying the length.
Add the `pytext_len` argument to specify this.
ghstack-source-id: 94383501

Test Plan: ./speed_benchmark_torch --model model.pt --input_dims "1,4" --input_type int64 --warmup 10 --iter 10 --report_pep=true --pytext_len=4

Reviewed By: iseeyuan

Differential Revision: D18646028

fbshipit-source-id: 7d5fe0c36da6e5f7b0261619ce4784a46b70f3d8
This commit is contained in:
Supriya Rao
2019-11-21 16:00:58 -08:00
committed by Facebook Github Bot
parent b2b1601b30
commit f2f285c240

View File

@ -48,6 +48,8 @@ C10_DEFINE_bool(
false, false,
"Whether to print performance stats for AI-PEP."); "Whether to print performance stats for AI-PEP.");
C10_DEFINE_int(pytext_len, 0, "Length of input sequence.");
std::vector<std::string> std::vector<std::string>
split(char separator, const std::string& string, bool ignore_empty = true) { split(char separator, const std::string& string, bool ignore_empty = true) {
std::vector<std::string> pieces; std::vector<std::string> pieces;
@ -97,11 +99,18 @@ int main(int argc, char** argv) {
inputs.push_back(torch::ones(input_dims, at::ScalarType::Float)); inputs.push_back(torch::ones(input_dims, at::ScalarType::Float));
} else if (input_type_list[i] == "uint8_t") { } else if (input_type_list[i] == "uint8_t") {
inputs.push_back(torch::ones(input_dims, at::ScalarType::Byte)); inputs.push_back(torch::ones(input_dims, at::ScalarType::Byte));
} else if (input_type_list[i] == "int64") {
inputs.push_back(torch::ones(input_dims, torch::kI64));
} else { } else {
CAFFE_THROW("Unsupported input type: ", input_type_list[i]); CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
} }
} }
if (FLAGS_pytext_len > 0) {
auto stensor = FLAGS_pytext_len * at::ones({1}, torch::kI64);
inputs.push_back(stensor);
}
auto qengines = at::globalContext().supportedQEngines(); auto qengines = at::globalContext().supportedQEngines();
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) != qengines.end()) { if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) != qengines.end()) {
at::globalContext().setQEngine(at::QEngine::QNNPACK); at::globalContext().setQEngine(at::QEngine::QNNPACK);
@ -138,7 +147,7 @@ int main(int argc, char** argv) {
auto start = high_resolution_clock::now(); auto start = high_resolution_clock::now();
module.forward(inputs); module.forward(inputs);
auto stop = high_resolution_clock::now(); auto stop = high_resolution_clock::now();
auto duration = duration_cast<microseconds>(stop - start); auto duration = duration_cast<milliseconds>(stop - start);
times.push_back(duration.count()); times.push_back(duration.count());
} }
millis = timer.MilliSeconds(); millis = timer.MilliSeconds();