[AI Bench] Resumme speed_benchmark_torch.cc to origin

Summary: we removed all assistant specific code

Test Plan:
```
buck run aibench:run_bench -- -b aibench/specifications/models/pytorch/fbnet/fbnet_mobile_inference.json --platform android/full_jit --framework pytorch --remote --devices  SM-G950U-7.0-24
```

https://our.intern.facebook.com/intern/aibench/details/940147322057842

Reviewed By: kimishpatel

Differential Revision: D20686220

fbshipit-source-id: b7336d5ea15fa11be01abf4ad12747feaaf22ea8
This commit is contained in:
Peng Xia
2020-04-02 08:33:12 -07:00
committed by Facebook GitHub Bot
parent 1bd68eafb5
commit c3abcf83aa

View File

@ -14,7 +14,6 @@
* limitations under the License.
*/
#include <fstream>
#include <string>
#include <vector>
@ -38,7 +37,6 @@ C10_DEFINE_string(
"semicolon to separate the dimension of different "
"tensors.");
C10_DEFINE_string(input_type, "", "Input type (uint8_t/float)");
C10_DEFINE_string(input_file, "", "Input file");
C10_DEFINE_bool(
print_output,
false,
@ -65,22 +63,6 @@ split(char separator, const std::string& string, bool ignore_empty = true) {
return pieces;
}
std::vector<std::vector<c10::IValue>> nlu_process(std::string file_path) {
std::vector<std::vector<c10::IValue>> nlu_inputs;
std::ifstream input_file(FLAGS_input_file);
for (std::string line; getline(input_file, line);) {
std::vector<c10::IValue> nlu_input;
c10::List<std::string> tokens(split(' ', line));
nlu_input.push_back(tokens);
auto len = torch::jit::IValue(static_cast<int64_t>(tokens.size()));
nlu_input.push_back({});
nlu_input.push_back(len);
nlu_inputs.emplace_back(std::move(nlu_input));
std::cout << line << std::endl;
}
return nlu_inputs;
}
int main(int argc, char** argv) {
c10::SetUsageMessage(
"Run speed benchmark for pytorch model.\n"
@ -106,32 +88,27 @@ int main(int argc, char** argv) {
input_type_list.size(),
"Input dims and type should have the same number of items.");
std::vector<std::vector<c10::IValue>> inputs;
if (input_type_list[0] == "NLUType"){
inputs = nlu_process(FLAGS_input_file);
} else {
inputs.push_back(std::vector<c10::IValue>());
for (size_t i = 0; i < input_dims_list.size(); ++i) {
auto input_dims_str = split(',', input_dims_list[i]);
std::vector<int64_t> input_dims;
for (const auto& s : input_dims_str) {
input_dims.push_back(c10::stoi(s));
}
if (input_type_list[i] == "float") {
inputs[0].push_back(torch::ones(input_dims, at::ScalarType::Float));
} else if (input_type_list[i] == "uint8_t") {
inputs[0].push_back(torch::ones(input_dims, at::ScalarType::Byte));
} else if (input_type_list[i] == "int64") {
inputs[0].push_back(torch::ones(input_dims, torch::kI64));
} else {
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
}
std::vector<c10::IValue> inputs;
for (size_t i = 0; i < input_dims_list.size(); ++i) {
auto input_dims_str = split(',', input_dims_list[i]);
std::vector<int64_t> input_dims;
for (const auto& s : input_dims_str) {
input_dims.push_back(c10::stoi(s));
}
if (input_type_list[i] == "float") {
inputs.push_back(torch::ones(input_dims, at::ScalarType::Float));
} else if (input_type_list[i] == "uint8_t") {
inputs.push_back(torch::ones(input_dims, at::ScalarType::Byte));
} else if (input_type_list[i] == "int64") {
inputs.push_back(torch::ones(input_dims, torch::kI64));
} else {
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
}
}
if (FLAGS_pytext_len > 0) {
auto stensor = FLAGS_pytext_len * at::ones({1}, torch::kI64);
inputs[0].push_back(stensor);
inputs.push_back(stensor);
}
torch::autograd::AutoGradMode guard(false);
@ -140,7 +117,7 @@ int main(int argc, char** argv) {
module.eval();
if (FLAGS_print_output) {
std::cout << module.forward(inputs[0]) << std::endl;
std::cout << module.forward(inputs) << std::endl;
}
std::cout << "Starting benchmark." << std::endl;
@ -150,10 +127,8 @@ int main(int argc, char** argv) {
"Number of warm up runs should be non negative, provided ",
FLAGS_warmup,
".");
for (unsigned int i = 0; i < FLAGS_warmup; ++i) {
for (const auto& input : inputs) {
module.forward(input);
}
for (int i = 0; i < FLAGS_warmup; ++i) {
module.forward(inputs);
}
std::cout << "Main runs." << std::endl;
@ -166,13 +141,11 @@ int main(int argc, char** argv) {
std::vector<float> times;
auto millis = timer.MilliSeconds();
for (int i = 0; i < FLAGS_iter; ++i) {
for (const std::vector<c10::IValue>& input: inputs) {
auto start = high_resolution_clock::now();
module.forward(input);
auto stop = high_resolution_clock::now();
auto duration = duration_cast<milliseconds>(stop - start);
times.push_back(duration.count());
}
auto start = high_resolution_clock::now();
module.forward(inputs);
auto stop = high_resolution_clock::now();
auto duration = duration_cast<milliseconds>(stop - start);
times.push_back(duration.count());
}
millis = timer.MilliSeconds();
if (FLAGS_report_pep) {