mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use caffe2::int8::Int8TensorCPU when input type is uint8_t (#12250)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12250 We use caffe2::int8::Int8TensorCPU for quantized tensor with uint8_t element type. Reviewed By: llyfacebook Differential Revision: D10121216 fbshipit-source-id: b63cd3a75f87e043cc3c83de4f3520b6ffbf1d07
This commit is contained in:
committed by
Facebook Github Bot
parent
7c678746ef
commit
04b0774964
@ -28,6 +28,7 @@
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/net.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/core/tensor_int8.h"
|
||||
#include "caffe2/utils/bench_utils.h"
|
||||
#include "caffe2/utils/string_utils.h"
|
||||
#include "observers/net_observer_reporter_print.h"
|
||||
@ -163,12 +164,16 @@ void loadInput(
|
||||
CAFFE_THROW("Not support GPU on mobile.");
|
||||
#endif
|
||||
} else {
|
||||
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
|
||||
CHECK_NOTNULL(tensor);
|
||||
tensor->Resize(input_dims);
|
||||
if (input_type_list[i] == "uint8_t") {
|
||||
tensor->mutable_data<uint8_t>();
|
||||
caffe2::int8::Int8TensorCPU* tensor =
|
||||
blob->GetMutable<caffe2::int8::Int8TensorCPU>();
|
||||
CHECK_NOTNULL(tensor);
|
||||
tensor->t.Resize(input_dims);
|
||||
tensor->t.mutable_data<uint8_t>();
|
||||
} else if (input_type_list[i] == "float") {
|
||||
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
|
||||
CHECK_NOTNULL(tensor);
|
||||
tensor->Resize(input_dims);
|
||||
tensor->mutable_data<float>();
|
||||
} else {
|
||||
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include "caffe2/core/init.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/core/tensor_int8.h"
|
||||
#ifdef CAFFE2_OPTIMIZER
|
||||
#include "caffe2/opt/optimizer.h"
|
||||
#endif
|
||||
@ -137,14 +138,18 @@ int main(int argc, char** argv) {
|
||||
if (blob == nullptr) {
|
||||
blob = workspace->CreateBlob(input_names[i]);
|
||||
}
|
||||
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
|
||||
CHECK_NOTNULL(tensor);
|
||||
tensor->Resize(input_dims);
|
||||
if (input_type_list[i] == "uint8_t") {
|
||||
tensor->mutable_data<uint8_t>();
|
||||
caffe2::int8::Int8TensorCPU* tensor =
|
||||
blob->GetMutable<caffe2::int8::Int8TensorCPU>();
|
||||
CHECK_NOTNULL(tensor);
|
||||
tensor->t.Resize(input_dims);
|
||||
tensor->t.mutable_data<uint8_t>();
|
||||
} else if (input_type_list[i] == "float") {
|
||||
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
|
||||
CHECK_NOTNULL(tensor);
|
||||
tensor->Resize(input_dims);
|
||||
tensor->mutable_data<float>();
|
||||
} else {
|
||||
} else {
|
||||
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user