add the int support (#15581)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15581

as title

Reviewed By: protonu

Differential Revision: D13556274

fbshipit-source-id: ba21f0970257d526e2fe7574eea4f89465b9c618
This commit is contained in:
Lingyi Liu
2018-12-27 17:13:50 -08:00
committed by Facebook Github Bot
parent 9bf7eb914d
commit c1643ec551

View File

@ -19,7 +19,7 @@
#include <string>
#include <thread>
#include "binaries/benchmark_helper.h"
#include <binaries/benchmark_helper.h>
#include "caffe2/core/blob_serialization.h"
#ifdef __CUDA_ARCH__
#include "caffe2/core/context_gpu.h"
@ -31,9 +31,9 @@
#include "caffe2/core/tensor_int8.h"
#include "caffe2/utils/bench_utils.h"
#include "caffe2/utils/string_utils.h"
#include "observers/net_observer_reporter_print.h"
#include "observers/observer_config.h"
#include "observers/perf_observer.h"
#include <observers/net_observer_reporter_print.h>
#include <observers/observer_config.h>
#include <observers/perf_observer.h>
using std::map;
using std::shared_ptr;
@ -186,6 +186,11 @@ int loadInput(
CHECK_NOTNULL(tensor);
tensor->Resize(input_dims);
tensor->mutable_data<float>();
} else if (input_type_list[i] == "int") {
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
CHECK_NOTNULL(tensor);
tensor->Resize(input_dims);
tensor->mutable_data<int>();
} else {
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
}