diff --git a/binaries/benchmark_helper.cc b/binaries/benchmark_helper.cc index ecbae477282c..460c4008566e 100644 --- a/binaries/benchmark_helper.cc +++ b/binaries/benchmark_helper.cc @@ -28,6 +28,7 @@ #include "caffe2/core/logging.h" #include "caffe2/core/net.h" #include "caffe2/core/operator.h" +#include "caffe2/core/tensor_int8.h" #include "caffe2/utils/bench_utils.h" #include "caffe2/utils/string_utils.h" #include "observers/net_observer_reporter_print.h" @@ -163,12 +164,16 @@ void loadInput( CAFFE_THROW("Not support GPU on mobile."); #endif } else { - caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU); - CHECK_NOTNULL(tensor); - tensor->Resize(input_dims); if (input_type_list[i] == "uint8_t") { - tensor->mutable_data(); + caffe2::int8::Int8TensorCPU* tensor = + blob->GetMutable(); + CHECK_NOTNULL(tensor); + tensor->t.Resize(input_dims); + tensor->t.mutable_data(); } else if (input_type_list[i] == "float") { + caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU); + CHECK_NOTNULL(tensor); + tensor->Resize(input_dims); tensor->mutable_data(); } else { CAFFE_THROW("Unsupported input type: ", input_type_list[i]); diff --git a/binaries/speed_benchmark.cc b/binaries/speed_benchmark.cc index fd502cf3c078..772b44b1bce8 100644 --- a/binaries/speed_benchmark.cc +++ b/binaries/speed_benchmark.cc @@ -20,6 +20,7 @@ #include "caffe2/core/init.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" +#include "caffe2/core/tensor_int8.h" #ifdef CAFFE2_OPTIMIZER #include "caffe2/opt/optimizer.h" #endif @@ -137,14 +138,18 @@ int main(int argc, char** argv) { if (blob == nullptr) { blob = workspace->CreateBlob(input_names[i]); } - caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU); - CHECK_NOTNULL(tensor); - tensor->Resize(input_dims); if (input_type_list[i] == "uint8_t") { - tensor->mutable_data(); + caffe2::int8::Int8TensorCPU* tensor = + blob->GetMutable(); + CHECK_NOTNULL(tensor); + tensor->t.Resize(input_dims); + tensor->t.mutable_data(); } else if (input_type_list[i] == "float") { + caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU); + CHECK_NOTNULL(tensor); + tensor->Resize(input_dims); tensor->mutable_data(); - } else { + } else { CAFFE_THROW("Unsupported input type: ", input_type_list[i]); } }