mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add arguments to benchmark to run pytext models. Output results in ms. (#30273)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30273 Pytext models expect input of the form `1xlength` and another input specifying the length. Add the `pytext_len` argument to specify this. ghstack-source-id: 94383501 Test Plan: ./speed_benchmark_torch --model model.pt --input_dims "1,4" --input_type int64 --warmup 10 --iter 10 --report_pep=true --pytext_len=4 Reviewed By: iseeyuan Differential Revision: D18646028 fbshipit-source-id: 7d5fe0c36da6e5f7b0261619ce4784a46b70f3d8
This commit is contained in:
committed by
Facebook Github Bot
parent
b2b1601b30
commit
f2f285c240
@ -48,6 +48,8 @@ C10_DEFINE_bool(
|
|||||||
false,
|
false,
|
||||||
"Whether to print performance stats for AI-PEP.");
|
"Whether to print performance stats for AI-PEP.");
|
||||||
|
|
||||||
|
C10_DEFINE_int(pytext_len, 0, "Length of input sequence.");
|
||||||
|
|
||||||
std::vector<std::string>
|
std::vector<std::string>
|
||||||
split(char separator, const std::string& string, bool ignore_empty = true) {
|
split(char separator, const std::string& string, bool ignore_empty = true) {
|
||||||
std::vector<std::string> pieces;
|
std::vector<std::string> pieces;
|
||||||
@ -97,11 +99,18 @@ int main(int argc, char** argv) {
|
|||||||
inputs.push_back(torch::ones(input_dims, at::ScalarType::Float));
|
inputs.push_back(torch::ones(input_dims, at::ScalarType::Float));
|
||||||
} else if (input_type_list[i] == "uint8_t") {
|
} else if (input_type_list[i] == "uint8_t") {
|
||||||
inputs.push_back(torch::ones(input_dims, at::ScalarType::Byte));
|
inputs.push_back(torch::ones(input_dims, at::ScalarType::Byte));
|
||||||
|
} else if (input_type_list[i] == "int64") {
|
||||||
|
inputs.push_back(torch::ones(input_dims, torch::kI64));
|
||||||
} else {
|
} else {
|
||||||
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
|
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (FLAGS_pytext_len > 0) {
|
||||||
|
auto stensor = FLAGS_pytext_len * at::ones({1}, torch::kI64);
|
||||||
|
inputs.push_back(stensor);
|
||||||
|
}
|
||||||
|
|
||||||
auto qengines = at::globalContext().supportedQEngines();
|
auto qengines = at::globalContext().supportedQEngines();
|
||||||
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) != qengines.end()) {
|
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) != qengines.end()) {
|
||||||
at::globalContext().setQEngine(at::QEngine::QNNPACK);
|
at::globalContext().setQEngine(at::QEngine::QNNPACK);
|
||||||
@ -138,7 +147,7 @@ int main(int argc, char** argv) {
|
|||||||
auto start = high_resolution_clock::now();
|
auto start = high_resolution_clock::now();
|
||||||
module.forward(inputs);
|
module.forward(inputs);
|
||||||
auto stop = high_resolution_clock::now();
|
auto stop = high_resolution_clock::now();
|
||||||
auto duration = duration_cast<microseconds>(stop - start);
|
auto duration = duration_cast<milliseconds>(stop - start);
|
||||||
times.push_back(duration.count());
|
times.push_back(duration.count());
|
||||||
}
|
}
|
||||||
millis = timer.MilliSeconds();
|
millis = timer.MilliSeconds();
|
||||||
|
Reference in New Issue
Block a user