mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add the int support (#15581)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15581 as title Reviewed By: protonu Differential Revision: D13556274 fbshipit-source-id: ba21f0970257d526e2fe7574eea4f89465b9c618
This commit is contained in:
committed by
Facebook Github Bot
parent
9bf7eb914d
commit
c1643ec551
@ -19,7 +19,7 @@
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
#include "binaries/benchmark_helper.h"
|
||||
#include <binaries/benchmark_helper.h>
|
||||
#include "caffe2/core/blob_serialization.h"
|
||||
#ifdef __CUDA_ARCH__
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
@ -31,9 +31,9 @@
|
||||
#include "caffe2/core/tensor_int8.h"
|
||||
#include "caffe2/utils/bench_utils.h"
|
||||
#include "caffe2/utils/string_utils.h"
|
||||
#include "observers/net_observer_reporter_print.h"
|
||||
#include "observers/observer_config.h"
|
||||
#include "observers/perf_observer.h"
|
||||
#include <observers/net_observer_reporter_print.h>
|
||||
#include <observers/observer_config.h>
|
||||
#include <observers/perf_observer.h>
|
||||
|
||||
using std::map;
|
||||
using std::shared_ptr;
|
||||
@ -186,6 +186,11 @@ int loadInput(
|
||||
CHECK_NOTNULL(tensor);
|
||||
tensor->Resize(input_dims);
|
||||
tensor->mutable_data<float>();
|
||||
} else if (input_type_list[i] == "int") {
|
||||
caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
|
||||
CHECK_NOTNULL(tensor);
|
||||
tensor->Resize(input_dims);
|
||||
tensor->mutable_data<int>();
|
||||
} else {
|
||||
CAFFE_THROW("Unsupported input type: ", input_type_list[i]);
|
||||
}
|
||||
|
Reference in New Issue
Block a user