Use caffe2::int8::Int8TensorCPU when input type is uint8_t (#12250)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12250

We use caffe2::int8::Int8TensorCPU for quantized tensor with uint8_t element type.

Reviewed By: llyfacebook

Differential Revision: D10121216

fbshipit-source-id: b63cd3a75f87e043cc3c83de4f3520b6ffbf1d07
This commit is contained in:
Jongsoo Park
2018-10-02 14:44:47 -07:00
committed by Facebook Github Bot
parent 7c678746ef
commit 04b0774964
2 changed files with 19 additions and 9 deletions

View File

@ -28,6 +28,7 @@
#include "caffe2/core/logging.h"
#include "caffe2/core/net.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor_int8.h"
#include "caffe2/utils/bench_utils.h"
#include "caffe2/utils/string_utils.h"
#include "observers/net_observer_reporter_print.h"
@ -163,12 +164,16 @@ void loadInput(
CAFFE_THROW("Not support GPU on mobile.");
#endif
} else {
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
CHECK_NOTNULL(tensor);
tensor->Resize(input_dims);
if (input_type_list[i] == "uint8_t") {
tensor->mutable_data<uint8_t>();
caffe2::int8::Int8TensorCPU* tensor =
blob->GetMutable<caffe2::int8::Int8TensorCPU>();
CHECK_NOTNULL(tensor);
tensor->t.Resize(input_dims);
tensor->t.mutable_data<uint8_t>();
} else if (input_type_list[i] == "float") {
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
CHECK_NOTNULL(tensor);
tensor->Resize(input_dims);
tensor->mutable_data<float>();
} else {
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);

View File

@ -20,6 +20,7 @@
#include "caffe2/core/init.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor_int8.h"
#ifdef CAFFE2_OPTIMIZER
#include "caffe2/opt/optimizer.h"
#endif
@ -137,14 +138,18 @@ int main(int argc, char** argv) {
if (blob == nullptr) {
blob = workspace->CreateBlob(input_names[i]);
}
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
CHECK_NOTNULL(tensor);
tensor->Resize(input_dims);
if (input_type_list[i] == "uint8_t") {
tensor->mutable_data<uint8_t>();
caffe2::int8::Int8TensorCPU* tensor =
blob->GetMutable<caffe2::int8::Int8TensorCPU>();
CHECK_NOTNULL(tensor);
tensor->t.Resize(input_dims);
tensor->t.mutable_data<uint8_t>();
} else if (input_type_list[i] == "float") {
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
CHECK_NOTNULL(tensor);
tensor->Resize(input_dims);
tensor->mutable_data<float>();
} else {
} else {
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
}
}