mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AI Bench] Resumme speed_benchmark_torch.cc to origin
Summary: we removed all assistant specific code Test Plan: ``` buck run aibench:run_bench -- -b aibench/specifications/models/pytorch/fbnet/fbnet_mobile_inference.json --platform android/full_jit --framework pytorch --remote --devices SM-G950U-7.0-24 ``` https://our.intern.facebook.com/intern/aibench/details/940147322057842 Reviewed By: kimishpatel Differential Revision: D20686220 fbshipit-source-id: b7336d5ea15fa11be01abf4ad12747feaaf22ea8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
1bd68eafb5
commit
c3abcf83aa
@ -14,7 +14,6 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -38,7 +37,6 @@ C10_DEFINE_string(
|
||||
"semicolon to separate the dimension of different "
|
||||
"tensors.");
|
||||
C10_DEFINE_string(input_type, "", "Input type (uint8_t/float)");
|
||||
C10_DEFINE_string(input_file, "", "Input file");
|
||||
C10_DEFINE_bool(
|
||||
print_output,
|
||||
false,
|
||||
@ -65,22 +63,6 @@ split(char separator, const std::string& string, bool ignore_empty = true) {
|
||||
return pieces;
|
||||
}
|
||||
|
||||
std::vector<std::vector<c10::IValue>> nlu_process(std::string file_path) {
|
||||
std::vector<std::vector<c10::IValue>> nlu_inputs;
|
||||
std::ifstream input_file(FLAGS_input_file);
|
||||
for (std::string line; getline(input_file, line);) {
|
||||
std::vector<c10::IValue> nlu_input;
|
||||
c10::List<std::string> tokens(split(' ', line));
|
||||
nlu_input.push_back(tokens);
|
||||
auto len = torch::jit::IValue(static_cast<int64_t>(tokens.size()));
|
||||
nlu_input.push_back({});
|
||||
nlu_input.push_back(len);
|
||||
nlu_inputs.emplace_back(std::move(nlu_input));
|
||||
std::cout << line << std::endl;
|
||||
}
|
||||
return nlu_inputs;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
c10::SetUsageMessage(
|
||||
"Run speed benchmark for pytorch model.\n"
|
||||
@ -106,32 +88,27 @@ int main(int argc, char** argv) {
|
||||
input_type_list.size(),
|
||||
"Input dims and type should have the same number of items.");
|
||||
|
||||
std::vector<std::vector<c10::IValue>> inputs;
|
||||
if (input_type_list[0] == "NLUType"){
|
||||
inputs = nlu_process(FLAGS_input_file);
|
||||
} else {
|
||||
inputs.push_back(std::vector<c10::IValue>());
|
||||
for (size_t i = 0; i < input_dims_list.size(); ++i) {
|
||||
auto input_dims_str = split(',', input_dims_list[i]);
|
||||
std::vector<int64_t> input_dims;
|
||||
for (const auto& s : input_dims_str) {
|
||||
input_dims.push_back(c10::stoi(s));
|
||||
}
|
||||
if (input_type_list[i] == "float") {
|
||||
inputs[0].push_back(torch::ones(input_dims, at::ScalarType::Float));
|
||||
} else if (input_type_list[i] == "uint8_t") {
|
||||
inputs[0].push_back(torch::ones(input_dims, at::ScalarType::Byte));
|
||||
} else if (input_type_list[i] == "int64") {
|
||||
inputs[0].push_back(torch::ones(input_dims, torch::kI64));
|
||||
} else {
|
||||
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
|
||||
}
|
||||
std::vector<c10::IValue> inputs;
|
||||
for (size_t i = 0; i < input_dims_list.size(); ++i) {
|
||||
auto input_dims_str = split(',', input_dims_list[i]);
|
||||
std::vector<int64_t> input_dims;
|
||||
for (const auto& s : input_dims_str) {
|
||||
input_dims.push_back(c10::stoi(s));
|
||||
}
|
||||
if (input_type_list[i] == "float") {
|
||||
inputs.push_back(torch::ones(input_dims, at::ScalarType::Float));
|
||||
} else if (input_type_list[i] == "uint8_t") {
|
||||
inputs.push_back(torch::ones(input_dims, at::ScalarType::Byte));
|
||||
} else if (input_type_list[i] == "int64") {
|
||||
inputs.push_back(torch::ones(input_dims, torch::kI64));
|
||||
} else {
|
||||
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
|
||||
}
|
||||
}
|
||||
|
||||
if (FLAGS_pytext_len > 0) {
|
||||
auto stensor = FLAGS_pytext_len * at::ones({1}, torch::kI64);
|
||||
inputs[0].push_back(stensor);
|
||||
inputs.push_back(stensor);
|
||||
}
|
||||
|
||||
torch::autograd::AutoGradMode guard(false);
|
||||
@ -140,7 +117,7 @@ int main(int argc, char** argv) {
|
||||
|
||||
module.eval();
|
||||
if (FLAGS_print_output) {
|
||||
std::cout << module.forward(inputs[0]) << std::endl;
|
||||
std::cout << module.forward(inputs) << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "Starting benchmark." << std::endl;
|
||||
@ -150,10 +127,8 @@ int main(int argc, char** argv) {
|
||||
"Number of warm up runs should be non negative, provided ",
|
||||
FLAGS_warmup,
|
||||
".");
|
||||
for (unsigned int i = 0; i < FLAGS_warmup; ++i) {
|
||||
for (const auto& input : inputs) {
|
||||
module.forward(input);
|
||||
}
|
||||
for (int i = 0; i < FLAGS_warmup; ++i) {
|
||||
module.forward(inputs);
|
||||
}
|
||||
|
||||
std::cout << "Main runs." << std::endl;
|
||||
@ -166,13 +141,11 @@ int main(int argc, char** argv) {
|
||||
std::vector<float> times;
|
||||
auto millis = timer.MilliSeconds();
|
||||
for (int i = 0; i < FLAGS_iter; ++i) {
|
||||
for (const std::vector<c10::IValue>& input: inputs) {
|
||||
auto start = high_resolution_clock::now();
|
||||
module.forward(input);
|
||||
auto stop = high_resolution_clock::now();
|
||||
auto duration = duration_cast<milliseconds>(stop - start);
|
||||
times.push_back(duration.count());
|
||||
}
|
||||
auto start = high_resolution_clock::now();
|
||||
module.forward(inputs);
|
||||
auto stop = high_resolution_clock::now();
|
||||
auto duration = duration_cast<milliseconds>(stop - start);
|
||||
times.push_back(duration.count());
|
||||
}
|
||||
millis = timer.MilliSeconds();
|
||||
if (FLAGS_report_pep) {
|
||||
|
Reference in New Issue
Block a user