mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
CMake integration for int8 server operators
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/13558 Reviewed By: Maratyszcza Differential Revision: D12945460 Pulled By: dskhudia fbshipit-source-id: 1a91027b305fd6af77eebd9a4fad092a12f54712
This commit is contained in:
committed by
Facebook Github Bot
parent
76c1b5cd79
commit
18de330e86
@ -84,6 +84,7 @@ option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF)
|
||||
cmake_dependent_option(
|
||||
USE_CUDNN "Use cuDNN" ON
|
||||
"USE_CUDA" OFF)
|
||||
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
|
||||
option(USE_FFMPEG "Use ffmpeg" OFF)
|
||||
option(USE_GFLAGS "Use GFLAGS" ON)
|
||||
option(USE_GLOG "Use GLOG" ON)
|
||||
|
@ -91,6 +91,9 @@ if(NOT BUILD_ATEN_ONLY)
|
||||
if (BUILD_CAFFE2_OPS)
|
||||
add_subdirectory(operators)
|
||||
add_subdirectory(operators/rnn)
|
||||
if (USE_FBGEMM)
|
||||
add_subdirectory(operators/quantized/server)
|
||||
endif()
|
||||
if (USE_QNNPACK)
|
||||
add_subdirectory(operators/quantized)
|
||||
endif()
|
||||
|
70
caffe2/operators/quantized/server/CMakeLists.txt
Normal file
70
caffe2/operators/quantized/server/CMakeLists.txt
Normal file
@ -0,0 +1,70 @@
|
||||
# ---[ AVX2 Ops
|
||||
set(caffe2_dnnlowp_avx2_ops_SRCS
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/conv_dnnlowp_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/elementwise_sum_dnnlowp_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/fully_connected_fake_lowp_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/group_norm_dnnlowp_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/relu_dnnlowp_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/transpose.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/dnnlowp.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/norm_minimization_avx2.cc")
|
||||
|
||||
# ---[ CPU files only
|
||||
list(APPEND Caffe2_CPU_SRCS
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/activation_distribution_observer.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/batch_matmul_dnnlowp_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/caffe2_dnnlowp_utils.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/channel_shuffle_dnnlowp_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/concat_dnnlowp_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/conv_dnnlowp_acc16_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/conv_relu_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/dequantize_dnnlowp_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/dnnlowp_partition.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/elementwise_add_dnnlowp_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/elementwise_linear_dnnlowp_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/elementwise_mul_dnnlowp_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/elementwise_sum_relu_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/fully_connected_dnnlowp_acc16_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/fully_connected_dnnlowp_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/fully_connected_rowwise_dnnlowp_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/lstm_unit_dnnlowp_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/op_wrapper.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/pool_dnnlowp_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/quantize_dnnlowp_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/sigmoid_dnnlowp_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/tanh_dnnlowp_op.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/utility_dnnlowp_ops.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/dynamic_histogram.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/kl_minimization.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/norm_minimization.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/p99.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/sigmoid.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/tanh.cc")
|
||||
|
||||
#Common sources
|
||||
|
||||
# ---[ CPU test files
|
||||
# TODO: fc_fake_lowp_test.cc needs avx flags
|
||||
# sigmoid_test.cc doesn build; error: undefined Sigmoid and Compute
|
||||
list(APPEND Caffe2_CPU_TEST_SRCS
|
||||
#"${CMAKE_CURRENT_SOURCE_DIR}/dynamic_histogram_test.cc"
|
||||
#"${CMAKE_CURRENT_SOURCE_DIR}/l2_minimization_test.cc"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/requantization_test.cc")
|
||||
#"${CMAKE_CURRENT_SOURCE_DIR}/sigmoid_test.cc")
|
||||
#"${CMAKE_CURRENT_SOURCE_DIR}/tanh_test.cc")
|
||||
|
||||
if (NOT MSVC)
|
||||
add_library(caffe2_dnnlowp_avx2_ops OBJECT ${caffe2_dnnlowp_avx2_ops_SRCS})
|
||||
add_dependencies(caffe2_dnnlowp_avx2_ops fbgemm Caffe2_PROTO c10)
|
||||
target_include_directories(caffe2_dnnlowp_avx2_ops BEFORE
|
||||
PRIVATE $<BUILD_INTERFACE:${FBGEMM_SOURCE_DIR}/include>)
|
||||
set_property(SOURCE ${caffe2_dnnlowp_avx2_ops_SRCS}
|
||||
APPEND_STRING PROPERTY COMPILE_FLAGS " -mavx2 -mfma -mf16c -mxsave ")
|
||||
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS}
|
||||
$<TARGET_OBJECTS:caffe2_dnnlowp_avx2_ops>)
|
||||
endif()
|
||||
|
||||
|
||||
# ---[ Send the lists to the parent scope.
|
||||
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE)
|
||||
set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE)
|
@ -83,7 +83,9 @@ void OutputMinMaxObserver::Stop() {
|
||||
continue;
|
||||
}
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp critical
|
||||
#endif
|
||||
{
|
||||
if (min_max_map_.find(out_name) == min_max_map_.end()) {
|
||||
min_max_map_[out_name] = make_pair(
|
||||
@ -164,7 +166,9 @@ void OutputMinMaxNetObserver::DumpAndReset_(
|
||||
OutputMinMaxNetObserver::~OutputMinMaxNetObserver() {
|
||||
DumpAndReset_(out_file_name_, true);
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp critical
|
||||
#endif
|
||||
{
|
||||
ofstream f;
|
||||
time_t rawtime;
|
||||
|
@ -16,7 +16,9 @@
|
||||
|
||||
#include "batch_matmul_dnnlowp_op.h"
|
||||
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
// #define DNNLOWP_MEASURE_TIME_BREAKDOWN
|
||||
#ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
|
||||
@ -432,16 +434,19 @@ bool BatchMatMulDNNLowPOp<T>::RunOnDevice() {
|
||||
if (!dequantize_output_) {
|
||||
auto row_offset_len_per_thread =
|
||||
PackAWithRowOffset<uint8_t>::rowOffsetBufferSize();
|
||||
row_offsets_.resize(row_offset_len_per_thread * omp_get_max_threads());
|
||||
row_offsets_.resize(
|
||||
row_offset_len_per_thread * dnnlowp_get_max_threads());
|
||||
auto A_pack_buf_len_per_thread =
|
||||
PackAWithRowOffset<uint8_t>::packedBufferSize();
|
||||
A_pack_buf_.resize(A_pack_buf_len_per_thread * omp_get_max_threads());
|
||||
A_pack_buf_.resize(A_pack_buf_len_per_thread * dnnlowp_get_max_threads());
|
||||
Y_int32_.resize(Y->numel());
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for collapse(2)
|
||||
#endif
|
||||
for (int p = 0; p < num_outer_batches; ++p) {
|
||||
for (int i = 0; i < num_sub_batches; ++i) {
|
||||
int tid = omp_get_thread_num();
|
||||
int tid = dnnlowp_get_thread_num();
|
||||
|
||||
PackAWithRowOffset<uint8_t> packA(
|
||||
trans_a_ ? matrix_op_t::Transpose : matrix_op_t::NoTranspose,
|
||||
@ -490,15 +495,18 @@ bool BatchMatMulDNNLowPOp<T>::RunOnDevice() {
|
||||
// Both input and output are float
|
||||
int row_offset_len_per_thread =
|
||||
PackAWithQuantRowOffset<uint8_t>::rowOffsetBufferSize();
|
||||
row_offsets_.resize(row_offset_len_per_thread * omp_get_max_threads());
|
||||
row_offsets_.resize(
|
||||
row_offset_len_per_thread * dnnlowp_get_max_threads());
|
||||
int A_pack_len_per_thread =
|
||||
PackAWithQuantRowOffset<uint8_t>::packedBufferSize();
|
||||
A_pack_buf_.resize(A_pack_len_per_thread * omp_get_max_threads());
|
||||
A_pack_buf_.resize(A_pack_len_per_thread * dnnlowp_get_max_threads());
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for collapse(2)
|
||||
#endif
|
||||
for (int p = 0; p < num_outer_batches; ++p) {
|
||||
for (int i = 0; i < num_sub_batches; ++i) {
|
||||
int tid = omp_get_thread_num();
|
||||
int tid = dnnlowp_get_thread_num();
|
||||
|
||||
PackAWithQuantRowOffset<uint8_t> packA(
|
||||
trans_a_ ? matrix_op_t::Transpose : matrix_op_t::NoTranspose,
|
||||
@ -540,15 +548,19 @@ bool BatchMatMulDNNLowPOp<T>::RunOnDevice() {
|
||||
// Input quantized and output float
|
||||
auto row_offset_len_per_thread =
|
||||
PackAWithRowOffset<uint8_t>::rowOffsetBufferSize();
|
||||
row_offsets_.resize(row_offset_len_per_thread * omp_get_max_threads());
|
||||
row_offsets_.resize(
|
||||
row_offset_len_per_thread * dnnlowp_get_max_threads());
|
||||
auto A_pack_buf_len_per_thread =
|
||||
PackAWithRowOffset<uint8_t>::packedBufferSize();
|
||||
A_pack_buf_.resize(A_pack_buf_len_per_thread * omp_get_max_threads());
|
||||
A_pack_buf_.resize(
|
||||
A_pack_buf_len_per_thread * dnnlowp_get_max_threads());
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for collapse(2)
|
||||
#endif
|
||||
for (int p = 0; p < num_outer_batches; ++p) {
|
||||
for (int i = 0; i < num_sub_batches; ++i) {
|
||||
int tid = omp_get_thread_num();
|
||||
int tid = dnnlowp_get_thread_num();
|
||||
|
||||
PackAWithRowOffset<uint8_t> packA(
|
||||
trans_a_ ? matrix_op_t::Transpose : matrix_op_t::NoTranspose,
|
||||
@ -609,7 +621,9 @@ bool BatchMatMulDNNLowPOp<T>::RunOnDevice() {
|
||||
|
||||
T* Y_quantized = GetQuantizedOutputData_();
|
||||
Y_int32_.resize(Y->numel());
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for collapse(2)
|
||||
#endif
|
||||
for (int p = 0; p < num_outer_batches; ++p) {
|
||||
for (int i = 0; i < num_sub_batches; ++i) {
|
||||
// Y_q = (scale_A * scale_B) / scale_Y * Y_int32
|
||||
|
@ -35,7 +35,9 @@ bool BatchPermutationDNNLowPOp<T>::RunOnDevice() {
|
||||
int N = X.dim32(0);
|
||||
int K = X.numel() / N;
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < N; ++i) {
|
||||
int origIdx = i * K;
|
||||
int permuteIdx = indices_data[i] * K;
|
||||
|
@ -4,17 +4,19 @@
|
||||
#include "caffe2/operators/quantized/server/tanh.h"
|
||||
|
||||
#include <map>
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
DECLARE_int32(dnnlowp_activation_quantization_precision);
|
||||
DECLARE_int32(dnnlowp_weight_quantization_precision);
|
||||
DECLARE_int32(dnnlowp_requantization_multiplier_precision);
|
||||
DECLARE_int32(dnnlowp_eltwise_quantization_precision);
|
||||
DECLARE_bool(dnnlowp_force_scale_power_of_two);
|
||||
DECLARE_bool(dnnlowp_preserve_activation_sparsity);
|
||||
DECLARE_bool(dnnlowp_preserve_weight_sparsity);
|
||||
DECLARE_string(dnnlowp_activation_quantization_kind);
|
||||
DECLARE_string(dnnlowp_weight_quantization_kind);
|
||||
C10_DECLARE_int32(dnnlowp_activation_quantization_precision);
|
||||
C10_DECLARE_int32(dnnlowp_weight_quantization_precision);
|
||||
C10_DECLARE_int32(dnnlowp_requantization_multiplier_precision);
|
||||
C10_DECLARE_int32(dnnlowp_eltwise_quantization_precision);
|
||||
C10_DECLARE_bool(dnnlowp_force_scale_power_of_two);
|
||||
C10_DECLARE_bool(dnnlowp_preserve_activation_sparsity);
|
||||
C10_DECLARE_bool(dnnlowp_preserve_weight_sparsity);
|
||||
C10_DECLARE_string(dnnlowp_activation_quantization_kind);
|
||||
C10_DECLARE_string(dnnlowp_weight_quantization_kind);
|
||||
|
||||
namespace dnnlowp {
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
#include "caffe2/operators/quantized/server/channel_shuffle_dnnlowp_op.h"
|
||||
|
||||
#include "caffe2/caffe2/utils/eigen_utils.h"
|
||||
#include "caffe2/operators/quantized/server/caffe2_dnnlowp_utils.h"
|
||||
#include "caffe2/operators/quantized/server/transpose.h"
|
||||
#include "caffe2/utils/eigen_utils.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
@ -45,7 +45,9 @@ bool ChannelShuffleDNNLowPOp<T>::RunOnDeviceWithOrderNCHW() {
|
||||
const int stride = C * HxW;
|
||||
const T* X_data = X.template data<T>();
|
||||
T* Y_data = Y->template mutable_data<T>();
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < N; ++i) {
|
||||
ConstEigenMatrixMap<T> X_mat(X_data, K * HxW, G);
|
||||
for (int j = 0; j < K; ++j) {
|
||||
@ -87,14 +89,18 @@ bool ChannelShuffleDNNLowPOp<T>::RunOnDeviceWithOrderNHWC() {
|
||||
T* Y_data = Y->template mutable_data<T>();
|
||||
|
||||
if (G == 4 && std::is_same<T, std::uint8_t>::value && GetCpuId().avx2()) {
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (auto i = 0; i < X.numel(); i += C) {
|
||||
// Transpose each C = GxK matrix
|
||||
fbgemm::transpose_4rows(
|
||||
K, (const std::uint8_t*)(X_data + i), (std::uint8_t*)(Y_data + i));
|
||||
}
|
||||
} else {
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (auto i = 0; i < X.numel(); i += C) {
|
||||
// Transpose each C = GxK matrix
|
||||
math::Transpose(
|
||||
|
@ -1,6 +1,8 @@
|
||||
#include "concat_dnnlowp_op.h"
|
||||
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
#include "dnnlowp_partition.h"
|
||||
|
||||
@ -111,10 +113,12 @@ bool ConcatDNNLowPOp<T>::RunOnDevice() {
|
||||
auto axis_dim = add_axis_ ? 1 : input.dim32(axis_);
|
||||
|
||||
vector<T> input_temp(input.numel());
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel
|
||||
#endif
|
||||
{
|
||||
int nthreads = omp_get_num_threads();
|
||||
int tid = omp_get_thread_num();
|
||||
int nthreads = dnnlowp_get_num_threads();
|
||||
int tid = dnnlowp_get_thread_num();
|
||||
int before_begin, before_end;
|
||||
int after_begin, after_end;
|
||||
|
||||
|
@ -1,17 +1,20 @@
|
||||
#include "conv_dnnlowp_acc16_op.h"
|
||||
#include "dnnlowp_op.h"
|
||||
|
||||
// #define DNNLOWP_ACC16_IN_SLOW_PATH
|
||||
// #define DNNLOWP_MEASURE_TIME_BREAKDOWN
|
||||
#ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
|
||||
#include <chrono>
|
||||
#endif
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
#include "dnnlowp_partition.h"
|
||||
#include "im2col_dnnlowp.h"
|
||||
|
||||
DECLARE_int32(dnnlowp_nbits_in_non_outlier);
|
||||
DECLARE_int32(dnnlowp_copy_to_32bit_frequency);
|
||||
C10_DECLARE_int32(dnnlowp_nbits_in_non_outlier);
|
||||
C10_DECLARE_int32(dnnlowp_copy_to_32bit_frequency);
|
||||
C10_DECLARE_bool(caffe2_dnnlowp_shared_int32_buffer);
|
||||
|
||||
namespace caffe2 {
|
||||
@ -135,16 +138,23 @@ bool ConvDNNLowPAcc16Op<ReluFused>::GetQuantizationParameters_() {
|
||||
}
|
||||
|
||||
if (!reason.empty()) {
|
||||
LOG_FIRST_N(WARNING, 32) << "Conv with weight "
|
||||
<< OperatorBase::debug_def().input(FILTER)
|
||||
<< " falls back to slow path because "
|
||||
<< reason;
|
||||
static int log_occurences = 0;
|
||||
if (log_occurences < 32) {
|
||||
++log_occurences;
|
||||
LOG(WARNING) << "Conv with weight "
|
||||
<< OperatorBase::debug_def().input(FILTER)
|
||||
<< " falls back to slow path because " << reason;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (nbits_in_non_outlier_ < 8 &&
|
||||
ConvPoolOpBase<CPUContext>::order_ != StorageOrder::NHWC) {
|
||||
LOG_FIRST_N(WARNING, 32) << "Outlier-aware quantization only supports "
|
||||
"NHWC layout";
|
||||
static int log_occurences = 0;
|
||||
if (log_occurences < 32) {
|
||||
++log_occurences;
|
||||
LOG(WARNING) << "Outlier-aware quantization only supports "
|
||||
"NHWC layout";
|
||||
}
|
||||
}
|
||||
first_invocation_ = false;
|
||||
}
|
||||
@ -215,7 +225,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
|
||||
buffer_shape.push_back(kernel_dim);
|
||||
buffer_shape.insert(
|
||||
buffer_shape.end(), output_dims.begin(), output_dims.end());
|
||||
buffer_shape.insert(buffer_shape.begin(), omp_get_max_threads());
|
||||
buffer_shape.insert(buffer_shape.begin(), dnnlowp_get_max_threads());
|
||||
|
||||
if (BaseType::kernel_.size() != 2) {
|
||||
SetDeviceTensor(img_shape, &(BaseType::img_shape_device_));
|
||||
@ -236,7 +246,7 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
|
||||
InType* col_buffer_data = col_buffer_.template mutable_data<InType>();
|
||||
|
||||
auto f = [&](vector<int32_t>* Y_int32) {
|
||||
Y_int32->resize(M * output_image_size * omp_get_max_threads());
|
||||
Y_int32->resize(M * output_image_size * dnnlowp_get_max_threads());
|
||||
vector<int> buffer_shape_per_thread(
|
||||
buffer_shape.begin() + 1, buffer_shape.end());
|
||||
|
||||
@ -251,11 +261,14 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
|
||||
} else {
|
||||
Y_data = Y->template mutable_data<uint8_t>();
|
||||
}
|
||||
BaseType::column_offsets_.resize(output_image_size * omp_get_max_threads());
|
||||
BaseType::column_offsets_.resize(
|
||||
output_image_size * dnnlowp_get_max_threads());
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int image_id = 0; image_id < N; ++image_id) {
|
||||
int tid = omp_get_thread_num();
|
||||
int tid = dnnlowp_get_thread_num();
|
||||
for (int group_id = 0; group_id < group_; ++group_id) {
|
||||
if (BaseType::kernel_.size() == 2) {
|
||||
math::Im2ColNCHW<InType>(
|
||||
@ -315,9 +328,13 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
|
||||
int8_t* W_quantized_group =
|
||||
W_quantized_.data() + (M / group_) * group_id * kernel_dim;
|
||||
|
||||
LOG_FIRST_N(WARNING, 32)
|
||||
<< "Consider using DNNLOWP instead of DNNLOWP_ACC16 engine since "
|
||||
"we're falling back to a slow path because of NCHW layout";
|
||||
static int log_occurences = 0;
|
||||
if (log_occurences < 32) {
|
||||
++log_occurences;
|
||||
LOG(WARNING)
|
||||
<< "Consider using DNNLOWP instead of DNNLOWP_ACC16 engine since "
|
||||
"we're falling back to a slow path because of NCHW layout";
|
||||
}
|
||||
|
||||
for (int i = 0; i < M / group_; ++i) {
|
||||
for (int j = 0; j < output_image_size; ++j) {
|
||||
@ -484,7 +501,9 @@ void ConvDNNLowPAcc16Op<ReluFused>::ConvOutlier_(
|
||||
sizeof((*Y_int32)[0]) * M * N);
|
||||
}
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel
|
||||
#endif
|
||||
{
|
||||
int group_begin, group_end, i_begin, i_end;
|
||||
BaseType::PartitionGroupedNHWCConv_(
|
||||
@ -494,8 +513,8 @@ void ConvDNNLowPAcc16Op<ReluFused>::ConvOutlier_(
|
||||
&i_end,
|
||||
group_,
|
||||
N * output_image_size,
|
||||
omp_get_num_threads(),
|
||||
omp_get_thread_num());
|
||||
dnnlowp_get_num_threads(),
|
||||
dnnlowp_get_thread_num());
|
||||
|
||||
for (int group_id = group_begin; group_id < group_end; ++group_id) {
|
||||
assert(Wq_outlier_[group_id].NumOfRows() == kernel_dim);
|
||||
@ -589,13 +608,15 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
|
||||
} else {
|
||||
col_buffer_quantized.resize(
|
||||
group_ * kernel_dim * output_image_size * N);
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel
|
||||
#endif
|
||||
{
|
||||
size_t begin, end;
|
||||
std::tie(begin, end) = Get1DPartition(
|
||||
col_buffer_quantized.size(),
|
||||
omp_get_num_threads(),
|
||||
omp_get_thread_num());
|
||||
dnnlowp_get_num_threads(),
|
||||
dnnlowp_get_thread_num());
|
||||
Quantize<uint8_t>(
|
||||
(const float*)col_buffer_data + begin,
|
||||
col_buffer_quantized.data() + begin,
|
||||
@ -626,13 +647,13 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
|
||||
x_pack_buf_size_per_thread =
|
||||
PackAWithRowOffset<uint8_t, int16_t>::packedBufferSize();
|
||||
row_offsets_.resize(
|
||||
omp_get_max_threads() * row_offset_size_per_thread);
|
||||
dnnlowp_get_max_threads() * row_offset_size_per_thread);
|
||||
} else {
|
||||
x_pack_buf_size_per_thread =
|
||||
PackAMatrix<uint8_t, int16_t>::packedBufferSize();
|
||||
}
|
||||
X_pack_buf_.resize(
|
||||
omp_get_max_threads() * x_pack_buf_size_per_thread);
|
||||
dnnlowp_get_max_threads() * x_pack_buf_size_per_thread);
|
||||
}
|
||||
|
||||
if (nbits_in_non_outlier_ > 0) {
|
||||
@ -641,9 +662,11 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
|
||||
// fast path
|
||||
uint8_t* Y_uint8_data =
|
||||
OutputTensorCPU_(0)->template mutable_data<uint8_t>();
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel
|
||||
#endif
|
||||
{
|
||||
int tid = omp_get_thread_num();
|
||||
int tid = dnnlowp_get_thread_num();
|
||||
int group_begin, group_end;
|
||||
int i_begin, i_end;
|
||||
|
||||
@ -654,8 +677,8 @@ bool ConvDNNLowPAcc16Op<ReluFused>::RunOnDeviceWithOrderNHWCAndType_() {
|
||||
&i_end,
|
||||
group_,
|
||||
N * output_image_size,
|
||||
omp_get_num_threads(),
|
||||
omp_get_thread_num());
|
||||
dnnlowp_get_num_threads(),
|
||||
dnnlowp_get_thread_num());
|
||||
|
||||
for (int group_id = group_begin; group_id < group_end; ++group_id) {
|
||||
if (fuse_output_pipeline) {
|
||||
|
@ -1,11 +1,14 @@
|
||||
#include "conv_dnnlowp_op.h"
|
||||
#include "dnnlowp_op.h"
|
||||
|
||||
// #define DNNLOWP_MEASURE_TIME_BREAKDOWN
|
||||
#ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
|
||||
#include <chrono>
|
||||
#endif
|
||||
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
#include "caffe2/core/tensor_int8.h"
|
||||
#include "caffe2/utils/cpuid.h"
|
||||
@ -224,7 +227,9 @@ void ConvDNNLowPOp<T, ReluFused>::QuantizeBias_() {
|
||||
b_quantized_data_ = bias.template data<int32_t>();
|
||||
if (dequantize_output_) {
|
||||
b_dequantized_.resize(bias.numel());
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < b_dequantized_.size(); ++i) {
|
||||
b_dequantized_[i] =
|
||||
Dequantize<int32_t>(b_quantized_data_[i], bias_qparams);
|
||||
@ -296,9 +301,13 @@ void ConvDNNLowPOp<T, ReluFused>::QuantizeWeight_() {
|
||||
int signed_min = 1 << (qfactory_->GetWeightPrecision() - 1);
|
||||
if (OperatorBase::InputIsType<int8::Int8TensorCPU>(FILTER)) {
|
||||
if (quantize_groupwise_) {
|
||||
LOG_FIRST_N(WARNING, 32) << "Cannot do group-wise quantization for "
|
||||
"pre-quantized weight "
|
||||
<< OperatorBase::debug_def().input(FILTER);
|
||||
static int log_occurences = 0;
|
||||
if (log_occurences < 32) {
|
||||
++log_occurences;
|
||||
LOG(WARNING) << "Cannot do group-wise quantization for "
|
||||
"pre-quantized weight "
|
||||
<< OperatorBase::debug_def().input(FILTER);
|
||||
}
|
||||
}
|
||||
FilterQuantizationParams(0).scale =
|
||||
OperatorBase::Input<int8::Int8TensorCPU>(FILTER).scale;
|
||||
@ -367,10 +376,13 @@ void ConvDNNLowPOp<T, ReluFused>::QuantizeWeight_() {
|
||||
assert(false);
|
||||
}
|
||||
if (!reason.empty()) {
|
||||
LOG_FIRST_N(WARNING, 32) << "Conv with weight "
|
||||
<< OperatorBase::debug_def().input(FILTER)
|
||||
<< " falls back to slow path because "
|
||||
<< reason;
|
||||
static int log_occurences = 0;
|
||||
if (log_occurences < 32) {
|
||||
++log_occurences;
|
||||
LOG(WARNING) << "Conv with weight "
|
||||
<< OperatorBase::debug_def().input(FILTER)
|
||||
<< " falls back to slow path because " << reason;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -460,7 +472,7 @@ void ConvDNNLowPOp<T, ReluFused>::RunOnDeviceEpilogueNCHW_(
|
||||
|
||||
// See batch_matmul_dnnlowp_op.cc to why we compute column_offsets,
|
||||
// row_offset, and const_offset in this way.
|
||||
int tid = omp_get_thread_num();
|
||||
int tid = dnnlowp_get_thread_num();
|
||||
int32_t *column_offsets = column_offsets_.data() + tid * Y_HxW;
|
||||
|
||||
const dnnlowp::TensorQuantizationParams&
|
||||
@ -559,7 +571,7 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
|
||||
buffer_shape.push_back(kernel_dim);
|
||||
buffer_shape.insert(
|
||||
buffer_shape.end(), output_dims.begin(), output_dims.end());
|
||||
buffer_shape.insert(buffer_shape.begin(), omp_get_max_threads());
|
||||
buffer_shape.insert(buffer_shape.begin(), dnnlowp_get_max_threads());
|
||||
|
||||
if (BaseType::kernel_.size() != 2) {
|
||||
SetDeviceTensor(img_shape, &img_shape_device_);
|
||||
@ -585,7 +597,7 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
|
||||
else {
|
||||
Y_data_T = Y->template mutable_data<T>();
|
||||
}
|
||||
column_offsets_.resize(Y_HxW * omp_get_max_threads());
|
||||
column_offsets_.resize(Y_HxW * dnnlowp_get_max_threads());
|
||||
|
||||
auto f = [&](Tensor* col_buffer) {
|
||||
col_buffer->Resize(buffer_shape);
|
||||
@ -594,12 +606,14 @@ bool ConvDNNLowPOp<T, ReluFused>::RunOnDeviceWithOrderNCHWAndType_() {
|
||||
InType *col_buffer_data = col_buffer->template mutable_data<InType>();
|
||||
|
||||
auto f2 = [&](vector<int32_t>* Y_int32) {
|
||||
Y_int32->resize(M * Y_HxW * omp_get_max_threads());
|
||||
Y_int32->resize(M * Y_HxW * dnnlowp_get_max_threads());
|
||||
|
||||
// Im2Col, followed by gemm.
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int image_id = 0; image_id < N; ++image_id) {
|
||||
int tid = omp_get_thread_num();
|
||||
int tid = dnnlowp_get_thread_num();
|
||||
for (int group_id = 0; group_id < group_; ++group_id) {
|
||||
if (BaseType::kernel_.size() == 2) {
|
||||
math::Im2ColNCHW<InType>(
|
||||
@ -729,7 +743,9 @@ void ConvDNNLowPOp<T, ReluFused>::RunOnDeviceEpilogueNHWC_(
|
||||
if (dequantize_output_) {
|
||||
float* Ydata = Y->template mutable_data<float>();
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < N * Y_HxW; ++i) {
|
||||
for (int group_id = 0; group_id < group_; ++group_id) {
|
||||
int32_t row_offset = 0;
|
||||
@ -758,15 +774,21 @@ void ConvDNNLowPOp<T, ReluFused>::RunOnDeviceEpilogueNHWC_(
|
||||
|
||||
if (!dnnlowp::HasStaticQuantization(this)) {
|
||||
if (quantize_groupwise_) {
|
||||
LOG_FIRST_N(WARNING, 32) << "Cannot do group-wise quantization without "
|
||||
"static quantization of activations for "
|
||||
<< OperatorBase::debug_def().output(0);
|
||||
static int log_occurences = 0;
|
||||
if (log_occurences < 32) {
|
||||
++log_occurences;
|
||||
LOG(WARNING) << "Cannot do group-wise quantization without "
|
||||
"static quantization of activations for "
|
||||
<< OperatorBase::debug_def().output(0);
|
||||
}
|
||||
}
|
||||
|
||||
int32_t Y_int32_min = numeric_limits<int32_t>::max();
|
||||
int32_t Y_int32_max = numeric_limits<int32_t>::min();
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for reduction(min:Y_int32_min), reduction(max:Y_int32_max)
|
||||
#endif
|
||||
for (int i = 0; i < N * Y_HxW; ++i) {
|
||||
for (int group_id = 0; group_id < group_; ++group_id) {
|
||||
int32_t row_offset = 0;
|
||||
@ -812,7 +834,9 @@ void ConvDNNLowPOp<T, ReluFused>::RunOnDeviceEpilogueNHWC_(
|
||||
using namespace fbgemm2;
|
||||
#ifdef __AVX2__
|
||||
if (is_same<T, uint8_t>::value && GetCpuId().avx2()) {
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < N * Y_HxW; ++i) {
|
||||
for (int group_id = 0; group_id < group_; ++group_id) {
|
||||
int32_t row_offset;
|
||||
@ -853,7 +877,9 @@ void ConvDNNLowPOp<T, ReluFused>::RunOnDeviceEpilogueNHWC_(
|
||||
} else
|
||||
#endif // __AVX2__
|
||||
{
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < N * Y_HxW; ++i) {
|
||||
for (int group_id = 0; group_id < group_; ++group_id) {
|
||||
int32_t B_zero_point = FilterQuantizationParams(group_id).zero_point;
|
||||
@ -942,7 +968,9 @@ const InType* ConvDNNLowPOp<T, ReluFused>::Im2ColNHWC_(
|
||||
|
||||
InType *col_buffer_data = col_buffer->template mutable_data<InType>();
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for if (N > 1)
|
||||
#endif
|
||||
for (int image_id = 0; image_id < N; ++image_id) {
|
||||
if (BaseType::kernel_.size() <= 2) {
|
||||
math::Im2ColNHWC<InType>(
|
||||
@ -1069,7 +1097,9 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
|
||||
const InType* Xdata = X.template data<InType>();
|
||||
Y_uint8_data = OutputTensorCPU_(0)->template mutable_data<uint8_t>();
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel
|
||||
#endif
|
||||
fbgemm2::depthwise_3x3x3_pad_1(
|
||||
N,
|
||||
X.dim32(1),
|
||||
@ -1089,8 +1119,8 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
|
||||
column_offsets_.data(),
|
||||
b_quantized_data_,
|
||||
ReluFused,
|
||||
omp_get_thread_num(),
|
||||
omp_get_num_threads());
|
||||
dnnlowp_get_thread_num(),
|
||||
dnnlowp_get_num_threads());
|
||||
|
||||
return;
|
||||
} else if (TakeDepthWise3x3FastPath_()) {
|
||||
@ -1098,7 +1128,9 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
|
||||
const InType* Xdata = X.template data<InType>();
|
||||
Y_uint8_data = OutputTensorCPU_(0)->template mutable_data<uint8_t>();
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel
|
||||
#endif
|
||||
fbgemm2::depthwise_3x3_pad_1(
|
||||
N,
|
||||
H,
|
||||
@ -1115,8 +1147,8 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
|
||||
Y_uint8_data,
|
||||
column_offsets_.data(),
|
||||
b_quantized_data_,
|
||||
omp_get_thread_num(),
|
||||
omp_get_num_threads(),
|
||||
dnnlowp_get_thread_num(),
|
||||
dnnlowp_get_num_threads(),
|
||||
ReluFused);
|
||||
|
||||
return;
|
||||
@ -1137,13 +1169,15 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
|
||||
x_pack_buf_size_per_thread =
|
||||
PackAWithRowOffset<uint8_t>::packedBufferSize();
|
||||
}
|
||||
row_offsets_.resize(omp_get_max_threads() * row_offset_size_per_thread);
|
||||
X_pack_buf_.resize(omp_get_max_threads() * x_pack_buf_size_per_thread);
|
||||
row_offsets_.resize(dnnlowp_get_max_threads() * row_offset_size_per_thread);
|
||||
X_pack_buf_.resize(dnnlowp_get_max_threads() * x_pack_buf_size_per_thread);
|
||||
}
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel
|
||||
#endif
|
||||
{
|
||||
int tid = omp_get_thread_num();
|
||||
int tid = dnnlowp_get_thread_num();
|
||||
int group_begin, group_end;
|
||||
int i_begin, i_end;
|
||||
|
||||
@ -1154,7 +1188,7 @@ void ConvDNNLowPOp<T, ReluFused>::ConvNHWCCore_(
|
||||
&i_end,
|
||||
group_,
|
||||
N * Y_HxW,
|
||||
omp_get_num_threads(),
|
||||
dnnlowp_get_num_threads(),
|
||||
tid);
|
||||
|
||||
for (int group_id = group_begin; group_id < group_end; ++group_id) {
|
||||
|
@ -36,10 +36,10 @@ class ConvPoolDNNLowPOpBase : public ConvPoolOpBase<CPUContext> {
|
||||
virtual ~ConvPoolDNNLowPOpBase() {
|
||||
if (measure_quantization_error_) {
|
||||
dnnlowp::ReportQuantizationError(this, quantization_error_stats_);
|
||||
LOG(CRITICAL) << this->debug_def().output(0) << " with type "
|
||||
<< this->debug_def().type() << " has output qparams : "
|
||||
<< "scale " << out_qparams_.scale << " offset "
|
||||
<< out_qparams_.zero_point << "; ";
|
||||
LOG(WARNING) << this->debug_def().output(0) << " with type "
|
||||
<< this->debug_def().type() << " has output qparams : "
|
||||
<< "scale " << out_qparams_.scale << " offset "
|
||||
<< out_qparams_.zero_point << "; ";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -21,7 +21,9 @@ bool ConvReluOp<T, Context>::RunOnDeviceWithOrderNCHW() {
|
||||
Tensor *output = Operator<Context>::Output(0);
|
||||
output->ResizeLike(*local_output);
|
||||
T *output_data = output->template mutable_data<T>();
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < output->numel(); ++i) {
|
||||
output_data[i] = std::max(static_cast<T>(0), output_local_data[i]);
|
||||
}
|
||||
@ -48,7 +50,9 @@ bool ConvReluOp<T, Context>::RunOnDeviceWithOrderNHWC() {
|
||||
Tensor *output = Operator<Context>::Output(0);
|
||||
output->ResizeLike(*local_output);
|
||||
T *output_data = output->template mutable_data<T>();
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < output->numel(); ++i) {
|
||||
output_data[i] = std::max(static_cast<T>(0), output_local_data[i]);
|
||||
}
|
||||
|
@ -1,49 +1,52 @@
|
||||
#include "dnnlowp.h"
|
||||
#include "l2_minimization.h"
|
||||
#include "kl_minimization.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "dnnlowp_op.h"
|
||||
#include "kl_minimization.h"
|
||||
#include "l2_minimization.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cctype>
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
DEFINE_int32(
|
||||
C10_DEFINE_int32(
|
||||
dnnlowp_activation_quantization_precision, 8,
|
||||
"Precision used for activation tensors");
|
||||
DEFINE_int32(
|
||||
C10_DEFINE_int32(
|
||||
dnnlowp_weight_quantization_precision, 8,
|
||||
"Precision used for weight tensors");
|
||||
DEFINE_int32(
|
||||
C10_DEFINE_int32(
|
||||
dnnlowp_requantization_multiplier_precision, 32,
|
||||
"Precision of integer multipliers used for rescaling quantized numbers");
|
||||
DEFINE_int32(
|
||||
C10_DEFINE_int32(
|
||||
dnnlowp_eltwise_quantization_precision, 16,
|
||||
"Precision used for intermediate numbers during elementwise operations");
|
||||
DEFINE_bool(
|
||||
C10_DEFINE_bool(
|
||||
dnnlowp_force_scale_power_of_two, false,
|
||||
"When true, force quantization scales to a power of two");
|
||||
DEFINE_bool(
|
||||
C10_DEFINE_bool(
|
||||
dnnlowp_preserve_activation_sparsity, false,
|
||||
"When true, 0 is mapped to 0 after quantization: "
|
||||
"i.e., symmetric quantization");
|
||||
DEFINE_bool(
|
||||
C10_DEFINE_bool(
|
||||
dnnlowp_preserve_weight_sparsity, false,
|
||||
"When true, 0 is mapped to 0 after quantization: "
|
||||
"i.e., symmetric quantization");
|
||||
DEFINE_string(
|
||||
C10_DEFINE_string(
|
||||
dnnlowp_activation_quantization_kind, "min_max",
|
||||
"Quantization method for activation tensors. "
|
||||
"Allowed values: min_max, l2, l2_approx, kl, l1, p99");
|
||||
DEFINE_string(
|
||||
C10_DEFINE_string(
|
||||
dnnlowp_weight_quantization_kind, "min_max",
|
||||
"Quantization method for weight tensors. "
|
||||
"Allowed values: min_max, l2, l2_approx, kl, l1, p99");
|
||||
DEFINE_int32(
|
||||
C10_DEFINE_int32(
|
||||
dnnlowp_nbits_in_non_outlier, 8,
|
||||
"When outlier-aware quantization is used, if a quantized number can be "
|
||||
"represented by this number of bits, it is considered not an outlier so "
|
||||
"handled with 16-bit accumulation");
|
||||
DEFINE_int32(
|
||||
C10_DEFINE_int32(
|
||||
dnnlowp_copy_to_32bit_frequency, 32,
|
||||
"When outlier-aware quantization is used, this option specifies how often "
|
||||
"we spill 16-bit accumulated numbers to 32-bit during the first pass");
|
||||
@ -449,7 +452,7 @@ QuantizationFactory *QuantizationFactory::GetDefaultInstance() {
|
||||
LOG(INFO) << "nbits_in_non_outlier " << FLAGS_dnnlowp_nbits_in_non_outlier;
|
||||
LOG(INFO) <<
|
||||
"copy_to_32bit_frequency " << FLAGS_dnnlowp_copy_to_32bit_frequency;
|
||||
LOG(INFO) << "omp_get_max_threads() " << omp_get_max_threads();
|
||||
LOG(INFO) << "omp_get_max_threads() " << caffe2::dnnlowp_get_max_threads();
|
||||
|
||||
log_printed = true;
|
||||
}
|
||||
|
@ -241,4 +241,28 @@ class DNNLowPOp : public Operator<CPUContext> {
|
||||
/* using override */ using BaseType::out_qparams_; \
|
||||
/* using override */ using BaseType::qfactory_;
|
||||
|
||||
inline int dnnlowp_get_num_threads() {
|
||||
#ifdef _OPENMP
|
||||
return omp_get_num_threads();
|
||||
#else
|
||||
return 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline int dnnlowp_get_max_threads() {
|
||||
#ifdef _OPENMP
|
||||
return omp_get_max_threads();
|
||||
#else
|
||||
return 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline int dnnlowp_get_thread_num() {
|
||||
#ifdef _OPENMP
|
||||
return omp_get_thread_num();
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
@ -1,10 +1,13 @@
|
||||
#include "dynamic_histogram.h"
|
||||
#include "dnnlowp_op.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
#include <cmath>
|
||||
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
namespace dnnlowp {
|
||||
|
||||
@ -28,18 +31,20 @@ void Histogram::Add(const float* f, int len) {
|
||||
if (bin_width > 0.0) {
|
||||
assert(per_thread_histogram_.size() % nbins == 0);
|
||||
|
||||
// Check if omp_get_max_threads has been reduced, and if so reduce
|
||||
// Check if dnnlowp_get_max_threads has been reduced, and if so reduce
|
||||
// per-thread histogram and clear them.
|
||||
int old_nthreads = per_thread_histogram_.size() / nbins + 1;
|
||||
if (omp_get_max_threads() < old_nthreads) {
|
||||
if (caffe2::dnnlowp_get_max_threads() < old_nthreads) {
|
||||
Finalize();
|
||||
}
|
||||
|
||||
per_thread_histogram_.resize((omp_get_max_threads() - 1) * nbins);
|
||||
per_thread_histogram_.resize((caffe2::dnnlowp_get_max_threads() - 1) * nbins);
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel
|
||||
#endif
|
||||
{
|
||||
int tid = omp_get_thread_num();
|
||||
int tid = caffe2::dnnlowp_get_thread_num();
|
||||
|
||||
uint64_t* my_histogram = nullptr;
|
||||
if (tid == 0) {
|
||||
@ -48,7 +53,9 @@ void Histogram::Add(const float* f, int len) {
|
||||
my_histogram = per_thread_histogram_.data() + (tid - 1) * nbins;
|
||||
}
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp for
|
||||
#endif
|
||||
for (auto i = 0; i < len; ++i) {
|
||||
int bin =
|
||||
std::min(static_cast<int>((f[i] - min_) / bin_width), nbins - 1);
|
||||
@ -125,7 +132,9 @@ void DynamicHistogram::Add(float f) {
|
||||
|
||||
void DynamicHistogram::Add(const float* f, int len) {
|
||||
float minimum = min_, maximum = max_;
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for reduction(min : minimum) reduction(max : maximum)
|
||||
#endif
|
||||
for (int i = 0; i < len; ++i) {
|
||||
minimum = std::min(f[i], minimum);
|
||||
maximum = std::max(f[i], maximum);
|
||||
|
@ -3,8 +3,8 @@
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
|
||||
#include <glog/logging.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
#include "dynamic_histogram.h"
|
||||
|
||||
|
@ -48,7 +48,9 @@ class AddDNNLowPOp : public BinaryElementwiseDNNLowPOp<T, AddFp32Op> {
|
||||
real_multiplier, intermediate_qparams_);
|
||||
|
||||
const T* input_data = InputTensorCPU_(i).template data<T>();
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int j = 0; j < InputTensorCPU_(i).numel(); ++j) {
|
||||
quantized_in[j] = Requantize<int32_t>(
|
||||
input_data[j] - in_qparams_[i].zero_point,
|
||||
@ -57,7 +59,9 @@ class AddDNNLowPOp : public BinaryElementwiseDNNLowPOp<T, AddFp32Op> {
|
||||
} else {
|
||||
assert(A.template IsType<float>());
|
||||
const float* input_data = InputTensorCPU_(i).template data<float>();
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int j = 0; j < InputTensorCPU_(i).numel(); ++j) {
|
||||
quantized_in[j] = Quantize<uint32_t>(
|
||||
input_data[j],
|
||||
@ -78,13 +82,17 @@ class AddDNNLowPOp : public BinaryElementwiseDNNLowPOp<T, AddFp32Op> {
|
||||
A.sizes(),
|
||||
B.sizes(),
|
||||
"Dimension mismatch - did you forget to set broadcast=1?");
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < C->numel(); ++i) {
|
||||
int32_t raw = A_quantized[i] + B_quantized[i] - intermediate_zero_point;
|
||||
C_quantized[i] = Requantize<T>(raw, requantization_params_);
|
||||
}
|
||||
} else if (B.numel() == 1) {
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < C->numel(); ++i) {
|
||||
int32_t raw = A_quantized[i] + B_quantized[0] - intermediate_zero_point;
|
||||
C_quantized[i] = Requantize<T>(raw, requantization_params_);
|
||||
@ -93,7 +101,9 @@ class AddDNNLowPOp : public BinaryElementwiseDNNLowPOp<T, AddFp32Op> {
|
||||
size_t pre, n, post;
|
||||
std::tie(pre, n, post) =
|
||||
elementwise_ops_utils::ComputeLegacyBroadcastSizes(A, B, axis_);
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < pre; ++i) {
|
||||
for (int j = 0; j < n; ++j) {
|
||||
for (int k = 0; k < post; ++k) {
|
||||
|
@ -56,7 +56,9 @@ bool ElementwiseLinearDNNLowPOp<T>::RunOnDevice() {
|
||||
// Quantize b
|
||||
vector<int32_t> b_quantized(b.numel());
|
||||
const float *b_data = b.template data<float>();
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < b.numel(); ++i) {
|
||||
b_quantized[i] = Quantize<int32_t>(
|
||||
b_data[i], 0, in_qparams_[0].scale * in_qparams_[1].scale,
|
||||
@ -64,7 +66,9 @@ bool ElementwiseLinearDNNLowPOp<T>::RunOnDevice() {
|
||||
}
|
||||
|
||||
T *Y_quantized = GetQuantizedOutputData_();
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int d = 0; d < D; ++d) {
|
||||
int32_t raw =
|
||||
|
@ -1,4 +1,4 @@
|
||||
#include "caffe2/caffe2/operators/elementwise_mul_op.h"
|
||||
#include "caffe2/operators/elementwise_mul_op.h"
|
||||
#include "caffe2/operators/quantized/server/elementwise_dnnlowp_op.h"
|
||||
#include "caffe2/operators/quantized/server/op_wrapper.h"
|
||||
#include "caffe2/operators/quantized/server/sigmoid.h"
|
||||
@ -50,14 +50,18 @@ class MulDNNLowPOp : public BinaryElementwiseDNNLowPOp<T, MulFp32Op> {
|
||||
A.sizes(),
|
||||
B.sizes(),
|
||||
"Dimension mismatch - did you forget to set broadcast=1?");
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < C->size(); ++i) {
|
||||
int32_t raw = (A_quantized[i] - in_qparams_[0].zero_point) *
|
||||
(B_quantized[i] - in_qparams_[1].zero_point);
|
||||
C_quantized[i] = Requantize<T>(raw, requantization_params_);
|
||||
}
|
||||
} else if (B.size() == 1) {
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < C->size(); ++i) {
|
||||
int32_t raw = (A_quantized[i] - in_qparams_[0].zero_point) *
|
||||
(B_quantized[0] - in_qparams_[1].zero_point);
|
||||
@ -67,7 +71,9 @@ class MulDNNLowPOp : public BinaryElementwiseDNNLowPOp<T, MulFp32Op> {
|
||||
size_t pre, n, post;
|
||||
std::tie(pre, n, post) =
|
||||
elementwise_ops_utils::ComputeLegacyBroadcastSizes(A, B, axis_);
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < pre; ++i) {
|
||||
for (int j = 0; j < n; ++j) {
|
||||
for (int k = 0; k < post; ++k) {
|
||||
|
@ -91,15 +91,17 @@ bool SumDNNLowPOp<T, ReluFused>::RunOnDevice() {
|
||||
__m256i permute_mask_v = _mm256_set_epi32(
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel
|
||||
#endif
|
||||
{
|
||||
constexpr int VLEN = 8;
|
||||
|
||||
int j_begin, j_end;
|
||||
tie(j_begin, j_end) = Get1DPartition(
|
||||
len,
|
||||
omp_get_num_threads(),
|
||||
omp_get_thread_num(),
|
||||
dnnlowp_get_num_threads(),
|
||||
dnnlowp_get_thread_num(),
|
||||
VLEN);
|
||||
|
||||
int j = j_begin;
|
||||
@ -173,11 +175,13 @@ bool SumDNNLowPOp<T, ReluFused>::RunOnDevice() {
|
||||
input_data[i] = InputTensorCPU_(i).template data<T>();
|
||||
}
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel
|
||||
#endif
|
||||
{
|
||||
int j_begin, j_end;
|
||||
tie(j_begin, j_end) = Get1DPartition(
|
||||
len, omp_get_num_threads(), omp_get_thread_num());
|
||||
len, dnnlowp_get_num_threads(), dnnlowp_get_thread_num());
|
||||
|
||||
for (int j = j_begin; j < j_end; ++j) {
|
||||
int32_t acc = 0;
|
||||
@ -201,11 +205,13 @@ bool SumDNNLowPOp<T, ReluFused>::RunOnDevice() {
|
||||
input_data[i] = InputTensorCPU_(i).template data<float>();
|
||||
}
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel
|
||||
#endif
|
||||
{
|
||||
int j_begin, j_end;
|
||||
tie(j_begin, j_end) = Get1DPartition(
|
||||
len, omp_get_num_threads(), omp_get_thread_num());
|
||||
len, dnnlowp_get_num_threads(), dnnlowp_get_thread_num());
|
||||
|
||||
for (int j = j_begin; j < j_end; ++j) {
|
||||
int32_t acc = 0;
|
||||
|
@ -4,9 +4,8 @@
|
||||
#include <iomanip>
|
||||
#include <bitset>
|
||||
|
||||
#include <glog/logging.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "fully_connected_fake_lowp_op.h"
|
||||
|
||||
constexpr size_t sz = 10000;
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
#include <fbgemm/src/RefImplementations.h>
|
||||
|
||||
DECLARE_int32(dnnlowp_nbits_in_non_outlier);
|
||||
DECLARE_int32(dnnlowp_copy_to_32bit_frequency);
|
||||
C10_DECLARE_int32(dnnlowp_nbits_in_non_outlier);
|
||||
C10_DECLARE_int32(dnnlowp_copy_to_32bit_frequency);
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
@ -46,8 +46,12 @@ bool FullyConnectedDNNLowPAcc16Op::RunOnDevice() {
|
||||
// Pack W if needed
|
||||
if (!Wq_acc16_packed_ || !is_weight_constant_) {
|
||||
if (!Wq_acc16_packed_ && nbits_in_non_outlier_ < 8) {
|
||||
LOG_FIRST_N(WARNING, 32)
|
||||
<< "FC DNNLOWP_ACC16 using outlier-aware quantization";
|
||||
static int log_occurences = 0;
|
||||
if (log_occurences < 32) {
|
||||
++log_occurences;
|
||||
LOG(WARNING)
|
||||
<< "FC DNNLOWP_ACC16 using outlier-aware quantization";
|
||||
}
|
||||
|
||||
// Separate out outliers
|
||||
CAFFE_ENFORCE(!W_quantized_.empty());
|
||||
|
@ -45,13 +45,21 @@ bool FullyConnectedDNNLowPOp<T>::RunOnDevice() {
|
||||
FLAGS_dnnlowp_enforce_default_caffe2_operators) &&
|
||||
dequantize_output_) {
|
||||
if (!GetCpuId().avx2()) {
|
||||
LOG_FIRST_N(WARNING, 32) <<
|
||||
"Falling back to the default Caffe2 operator because AVX2 "
|
||||
"instruction is not available";
|
||||
static int log_occurences = 0;
|
||||
if (log_occurences < 32) {
|
||||
++log_occurences;
|
||||
LOG(WARNING) <<
|
||||
"Falling back to the default Caffe2 operator because AVX2 "
|
||||
"instruction is not available";
|
||||
}
|
||||
} else {
|
||||
LOG_FIRST_N(WARNING, 32) <<
|
||||
"Falling back to the default Caffe2 operator because "
|
||||
"dnnlowp_enforce_default_caffe2_operators option is on";
|
||||
static int log_occurences = 0;
|
||||
if (log_occurences < 32) {
|
||||
++log_occurences;
|
||||
LOG(WARNING) <<
|
||||
"Falling back to the default Caffe2 operator because "
|
||||
"dnnlowp_enforce_default_caffe2_operators option is on";
|
||||
}
|
||||
}
|
||||
|
||||
Fp32Op_()->DequantizeInput();
|
||||
|
@ -191,7 +191,11 @@ class FullyConnectedFakeLowpFPOp final : public Operator<Context> {
|
||||
CAFFE_ENFORCE(N == b.dim32(0), dimErrorString());
|
||||
CAFFE_ENFORCE(N == b.size(), dimErrorString());
|
||||
|
||||
LOG_EVERY_N(INFO, nlines_log) << "FAKE_FP16 fc running";
|
||||
static int log_occurences = 0;
|
||||
if (log_occurences % nlines_log == 0) {
|
||||
++log_occurences;
|
||||
LOG(INFO) << "FAKE_FP16 fc running";
|
||||
}
|
||||
|
||||
Y_shape_cache_ = X.sizes().vec();
|
||||
// This is an invariant of canonical_axis, so we can DCHECK.
|
||||
@ -393,7 +397,11 @@ class FullyConnectedGradientFakeLowpFPOp : public Operator<Context> {
|
||||
dYh.template mutable_data<T_DY>()
|
||||
);
|
||||
|
||||
LOG_EVERY_N(INFO, nlines_log) << "FAKE_FP16 fc grad running";
|
||||
static int log_occurences = 0;
|
||||
if (log_occurences % nlines_log == 0) {
|
||||
++log_occurences;
|
||||
LOG(INFO) << "FAKE_FP16 fc grad running";
|
||||
}
|
||||
|
||||
// Compute dW
|
||||
math::Gemm<T_DY, Context, Engine>(
|
||||
|
@ -5,7 +5,9 @@
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
#include "caffe2/utils/eigen_utils.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
@ -88,7 +90,9 @@ void ComputeQuantizedFusedParamsAVX2(
|
||||
const int k = K / kVLen * kVLen;
|
||||
const int r = K % kVLen;
|
||||
for (int n = N - 1; n >= 0; --n) {
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int g = 0; g < G; ++g) {
|
||||
const __m256i mu_v = _mm256_set1_epi32(mu[n * G + g] + X_zero_point);
|
||||
const __m256i rsig_v = _mm256_set1_epi32(rsig[n * G + g]);
|
||||
@ -211,7 +215,9 @@ void AffineBatchChannelAndRequantizeNCHWAVX2<uint8_t>(
|
||||
const int outer_size = N * C;
|
||||
const int n = HxW / kVLen * kVLen;
|
||||
const int r = HxW % kVLen;
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < outer_size; ++i) {
|
||||
const uint8_t* X_ptr = X + i * HxW;
|
||||
uint8_t* Y_ptr = Y + i * HxW;
|
||||
@ -256,7 +262,9 @@ void AffineBatchChannelAndRequantizeNHWCAVX2<uint8_t>(
|
||||
INIT_REQUANTIZE_AVX2;
|
||||
constexpr int kVLen = 8;
|
||||
const int outer_size = N * HxW;
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < outer_size; ++i) {
|
||||
const int c = i / HxW * C;
|
||||
const int n = C / kVLen * kVLen;
|
||||
@ -382,7 +390,9 @@ void GroupNormDNNLowPOp<T>::QuantizeGammaImpl() {
|
||||
gamma_quantized_.resize(C);
|
||||
gamma_quantized_data_ = gamma_quantized_.data();
|
||||
gamma_dequantized_data_ = gamma.template data<float>();
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < C; ++i) {
|
||||
gamma_quantized_[i] = dnnlowp::Quantize<int32_t>(
|
||||
gamma_dequantized_data_[i],
|
||||
@ -423,7 +433,9 @@ void GroupNormDNNLowPOp<T>::QuantizeBeta() {
|
||||
beta_quantized_.resize(C);
|
||||
beta_quantized_data_ = beta_quantized_.data();
|
||||
beta_dequantized_data_ = beta.template data<float>();
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < C; ++i) {
|
||||
beta_quantized_[i] = dnnlowp::Quantize<int32_t>(
|
||||
beta_dequantized_data_[i],
|
||||
@ -452,7 +464,9 @@ void GroupNormDNNLowPOp<T>::QuantizedGroupMomentsNCHW(
|
||||
var_qparams.scale = X_qparams.scale * X_qparams.scale;
|
||||
var_qparams.zero_point = 0;
|
||||
rsig_dequantized_.resize(outer_size);
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < outer_size; ++i) {
|
||||
int64_t sum = 0;
|
||||
int64_t sumsq = 0;
|
||||
@ -490,7 +504,9 @@ void GroupNormDNNLowPOp<T>::QuantizedGroupMomentsNHWC(
|
||||
var_qparams.scale = X_qparams.scale * X_qparams.scale;
|
||||
var_qparams.zero_point = 0;
|
||||
rsig_dequantized_.resize(outer_size);
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < outer_size; ++i) {
|
||||
const int n = i / G;
|
||||
const int g = i % G;
|
||||
@ -715,7 +731,9 @@ void GroupNormDNNLowPOp<T>::ComputeQuantizedInvStd(
|
||||
qfactory_->GetWeightPrecision(),
|
||||
qfactory_->GetPreserveWeightSparsity());
|
||||
rsig_qparams_.zero_point = 0;
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < N; ++i) {
|
||||
rsig_quantized[i] = dnnlowp::Quantize<int32_t>(
|
||||
rsig[i], rsig_qparams_.zero_point, rsig_qparams_.scale, 32);
|
||||
@ -761,7 +779,9 @@ void GroupNormDNNLowPOp<T>::ComputeQuantizedFusedParams(
|
||||
} else {
|
||||
ConstEigenArrayMap<int32_t> beta_arr(bias, K, G);
|
||||
// Reverse order for-loop to avoid overriding bias data.
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = N - 1; i >= 0; --i) {
|
||||
EigenArrayMap<int32_t> scale_arr(scale + i * C, K, G);
|
||||
scale_arr = gamma_arr.rowwise() *
|
||||
@ -788,7 +808,9 @@ void GroupNormDNNLowPOp<T>::ComputeDequantizedFusedParams(
|
||||
const int C = G * K;
|
||||
ConstEigenArrayMap<float> gamma_arr(gamma, K, G);
|
||||
ConstEigenArrayMap<float> beta_arr(beta, K, G);
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < N; ++i) {
|
||||
EigenArrayMap<float> scale_arr(scale + i * C, K, G);
|
||||
scale_arr = gamma_arr.rowwise() *
|
||||
@ -848,7 +870,9 @@ void GroupNormDNNLowPOp<T>::AffineBatchChannelQuantizedNHWC(
|
||||
N, C, HxW, out_requantization_params, X, scale, bias, Y);
|
||||
} else {
|
||||
Y_int32_.resize(size);
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < N; ++i) {
|
||||
EigenArrayMap<int32_t>(Y_int32_.data() + i * stride, C, HxW) =
|
||||
(ConstEigenArrayMap<T>(X + i * stride, C, HxW)
|
||||
@ -888,7 +912,9 @@ void GroupNormDNNLowPOp<T>::AffineBatchChannelDequantizedNHWC(
|
||||
const float* bias,
|
||||
float* Y) {
|
||||
const int stride = HxW * C;
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < N; ++i) {
|
||||
EigenArrayMap<float>(Y + i * stride, C, HxW) =
|
||||
(ConstEigenArrayMap<float>(X + i * stride, C, HxW).colwise() *
|
||||
|
@ -1,6 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
@ -206,7 +208,9 @@ static void Im2ColNHWC(
|
||||
int height_col = (height + pad_t + pad_b - dkernel_h) / stride_h + 1;
|
||||
int width_col = (width + pad_l + pad_r - dkernel_w) / stride_w + 1;
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for if (!omp_in_parallel())
|
||||
#endif
|
||||
for (int h = 0; h < height_col; ++h) {
|
||||
int h_pad = -pad_t + h * stride_h;
|
||||
T* data_col_temp =
|
||||
@ -284,7 +288,9 @@ static void Im2Col3DNHWC(
|
||||
int height_col = (height + pad_t + pad_b - dkernel_h) / stride_h + 1;
|
||||
int width_col = (width + pad_l + pad_r - dkernel_w) / stride_w + 1;
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for if (!omp_in_parallel())
|
||||
#endif
|
||||
for (int t = 0; t < frame_col; ++t) {
|
||||
int t_pad = -pad_p + t * stride_t;
|
||||
for (int h = 0; h < height_col; ++h) {
|
||||
|
@ -27,7 +27,9 @@ TensorQuantizationParams KLDivergenceMinimization::ChooseQuantizationParams(
|
||||
// Look at mapping [start_bin, start_bin + nbins_selected) to
|
||||
// [0, 1 << precision) for every (start_bin, nbins_selected) combination and
|
||||
// pick the one with smallest KL divergence
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int nbins_selected = 1; nbins_selected <= nbins; ++nbins_selected) {
|
||||
//if (nbins_selected % dst_nbins != 0) continue;
|
||||
double kl_min = numeric_limits<double>::max();
|
||||
|
@ -1,12 +1,11 @@
|
||||
#include "l2_minimization.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include <glog/logging.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace dnnlowp;
|
||||
|
@ -195,7 +195,9 @@ TensorQuantizationParams NormMinimization::ChooseQuantizationParams(
|
||||
// Look at mapping [start_bin, start_bin + nbins_selected) to
|
||||
// [0, 1 << precision) for every (start_bin, nbins_selected) combination and
|
||||
// pick the one with smallest L2 quantization error
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for schedule(dynamic)
|
||||
#endif
|
||||
for (int nbins_selected = 1; nbins_selected <= nbins; ++nbins_selected) {
|
||||
float norm_min = numeric_limits<float>::max();
|
||||
int best_start_bin = 0;
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include "l2_minimization.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <array>
|
||||
|
||||
#include <x86intrin.h>
|
||||
|
||||
|
@ -120,7 +120,9 @@ class AveragePoolDnnLowPOp final
|
||||
|
||||
switch (BaseType::kernel_.size()) {
|
||||
case 2:
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int n = 0; n < X.dim32(0); ++n) {
|
||||
for (int c = 0; c < channels; ++c) {
|
||||
const T *Xdata_temp =
|
||||
@ -159,7 +161,9 @@ class AveragePoolDnnLowPOp final
|
||||
} // for each image
|
||||
break;
|
||||
case 3:
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int n = 0; n < X.dim32(0); ++n) {
|
||||
for (int c = 0; c < channels; ++c) {
|
||||
const T *Xdata_temp =
|
||||
@ -255,7 +259,9 @@ class AveragePoolDnnLowPOp final
|
||||
|
||||
switch (BaseType::kernel_.size()) {
|
||||
case 2:
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int n = 0; n < X.dim32(0); ++n) {
|
||||
const T* Xdata_temp = Xdata + n * height * width * channels;
|
||||
T* Ydata_temp =
|
||||
@ -292,7 +298,9 @@ class AveragePoolDnnLowPOp final
|
||||
} // for each image
|
||||
break;
|
||||
case 3:
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int n = 0; n < X.dim32(0); ++n) {
|
||||
const T* Xdata_temp = Xdata + n * height * width * depth * channels;
|
||||
T* Ydata_temp = Ydata +
|
||||
@ -427,7 +435,9 @@ class MaxPoolDnnLowPOp final : public ConvPoolDNNLowPOpBase<T, MaxPoolFp32Op> {
|
||||
}
|
||||
break;
|
||||
case 2:
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int n = 0; n < X.dim32(0); ++n) {
|
||||
for (int c = 0; c < channels; ++c) {
|
||||
// Do offset.
|
||||
@ -460,7 +470,9 @@ class MaxPoolDnnLowPOp final : public ConvPoolDNNLowPOpBase<T, MaxPoolFp32Op> {
|
||||
}
|
||||
break;
|
||||
case 3:
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int n = 0; n < X.dim32(0); ++n) {
|
||||
for (int c = 0; c < channels; ++c) {
|
||||
// Do offset.
|
||||
@ -551,7 +563,9 @@ class MaxPoolDnnLowPOp final : public ConvPoolDNNLowPOpBase<T, MaxPoolFp32Op> {
|
||||
|
||||
switch (BaseType::kernel_.size()) {
|
||||
case 1:
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int n = 0; n < X.dim32(0); ++n) {
|
||||
const T* Xdata_temp = Xdata + n * height * channels;
|
||||
T* Ydata_temp = Ydata + n * pooled_height * channels;
|
||||
@ -573,7 +587,9 @@ class MaxPoolDnnLowPOp final : public ConvPoolDNNLowPOpBase<T, MaxPoolFp32Op> {
|
||||
}
|
||||
break;
|
||||
case 2:
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int n = 0; n < X.dim32(0); ++n) {
|
||||
const T* Xdata_temp = Xdata + n * height * width * channels;
|
||||
T* Ydata_temp = Ydata + n * pooled_height * pooled_width * channels;
|
||||
@ -603,7 +619,9 @@ class MaxPoolDnnLowPOp final : public ConvPoolDNNLowPOpBase<T, MaxPoolFp32Op> {
|
||||
}
|
||||
break;
|
||||
case 3:
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int n = 0; n < X.dim32(0); ++n) {
|
||||
const T* Xdata_temp = Xdata + n * height * width * depth * channels;
|
||||
T* Ydata_temp = Ydata +
|
||||
|
@ -1,6 +1,9 @@
|
||||
#include "quantize_dnnlowp_op.h"
|
||||
#include "dnnlowp_op.h"
|
||||
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
#include "caffe2/core/tensor_int8.h"
|
||||
#include "caffe2_dnnlowp_utils.h"
|
||||
@ -39,11 +42,14 @@ bool QuantizeDNNLowPOp<T>::RunOnDevice() {
|
||||
|
||||
const float* in_data = Input(0).template data<float>();
|
||||
T* out_data = output->t.template mutable_data<T>();
|
||||
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel
|
||||
#endif
|
||||
{
|
||||
int i_begin, i_end;
|
||||
tie(i_begin, i_end) = Get1DPartition(
|
||||
Input(0).numel(), omp_get_num_threads(), omp_get_thread_num());
|
||||
Input(0).numel(), dnnlowp_get_num_threads(), dnnlowp_get_thread_num());
|
||||
Quantize<T>(
|
||||
in_data + i_begin,
|
||||
out_data + i_begin,
|
||||
|
@ -112,7 +112,9 @@ bool ReluDNNLowPOp<T>::RunOnDevice() {
|
||||
if (GetCpuId().avx2()) {
|
||||
ReluAVX2<T>(N, in_qparams.zero_point, X_data, Y_data);
|
||||
} else {
|
||||
#ifdef _OPENMP
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for (int i = 0; i < N; ++i) {
|
||||
Y_data[i] = std::max(X_data[i], static_cast<T>(in_qparams.zero_point));
|
||||
}
|
||||
|
@ -4,8 +4,8 @@
|
||||
#include <random>
|
||||
#include <cmath>
|
||||
|
||||
#include <glog/logging.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace dnnlowp;
|
||||
|
@ -2,8 +2,8 @@
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
|
||||
#include <glog/logging.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
using namespace dnnlowp;
|
||||
using namespace std;
|
||||
|
@ -2,8 +2,8 @@
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
|
||||
#include <glog/logging.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
using namespace dnnlowp;
|
||||
using namespace std;
|
||||
|
@ -295,6 +295,43 @@ if(BUILD_TEST)
|
||||
set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS} CACHE BOOL "Build shared libs" FORCE)
|
||||
endif()
|
||||
|
||||
# ---[ FBGEMM
|
||||
if(USE_FBGEMM)
|
||||
set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party")
|
||||
if(NOT DEFINED FBGEMM_SOURCE_DIR)
|
||||
set(FBGEMM_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/fbgemm" CACHE STRING "FBGEMM source directory")
|
||||
endif()
|
||||
if(NOT CAFFE2_COMPILER_SUPPORTS_AVX512F_EXTENSIONS)
|
||||
message(WARNING
|
||||
"A compiler with AVX512 support is required for FBGEMM. "
|
||||
"Not compiling with FBGEMM. "
|
||||
"Turn this warning off by USE_FBGEMM=OFF.")
|
||||
set(USE_FBGEMM OFF)
|
||||
endif()
|
||||
if(MSVC)
|
||||
set(USE_FBGEMM OFF)
|
||||
endif()
|
||||
if(USE_FBGEMM AND NOT TARGET fbgemm)
|
||||
set(FBGEMM_BUILD_TESTS OFF CACHE BOOL "")
|
||||
set(FBGEMM_BUILD_BENCHMARKS OFF CACHE BOOL "")
|
||||
set(FBGEMM_LIBRARY_TYPE "static" CACHE STRING "")
|
||||
add_subdirectory("${FBGEMM_SOURCE_DIR}")
|
||||
set_property(TARGET fbgemm_avx2 PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET fbgemm_avx512 PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET fbgemm PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
endif()
|
||||
|
||||
if(USE_FBGEMM)
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS fbgemm)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(USE_FBGEMM)
|
||||
set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party")
|
||||
include_directories(SYSTEM "${CAFFE2_THIRD_PARTY_ROOT}")
|
||||
endif()
|
||||
|
||||
|
||||
# ---[ LMDB
|
||||
if(USE_LMDB)
|
||||
find_package(LMDB)
|
||||
|
@ -182,6 +182,33 @@ if (CAFFE2_COMPILER_SUPPORTS_AVX2_EXTENSIONS)
|
||||
endif()
|
||||
cmake_pop_check_state()
|
||||
|
||||
# ---[ Check if the compiler has AVX512F support.
|
||||
cmake_push_check_state(RESET)
|
||||
if (MSVC)
|
||||
set(CMAKE_REQUIRED_FLAGS "/D__AVX512F__")
|
||||
else()
|
||||
set(CMAKE_REQUIRED_FLAGS "-mavx512f")
|
||||
endif()
|
||||
CHECK_CXX_SOURCE_COMPILES(
|
||||
"#if defined(_MSC_VER)
|
||||
#include <intrin.h>
|
||||
#else
|
||||
#include <x86intrin.h>
|
||||
#endif
|
||||
__m512 addConstant(__m512 arg) {
|
||||
return _mm512_add_ps(arg, _mm512_set1_ps(1.f));
|
||||
}
|
||||
int main() {
|
||||
__m512i a = _mm512_set1_epi32(1);
|
||||
__m256i ymm = _mm512_extracti64x4_epi64(a, 0);
|
||||
__mmask16 m = _mm512_cmp_epi32_mask(a, a, _MM_CMPINT_EQ);
|
||||
__m512i r = _mm512_andnot_si512(a, a);
|
||||
}" CAFFE2_COMPILER_SUPPORTS_AVX512F_EXTENSIONS)
|
||||
if (CAFFE2_COMPILER_SUPPORTS_AVX512F_EXTENSIONS)
|
||||
message(STATUS "Current compiler supports avx512f extension. Will build fbgemm.")
|
||||
endif()
|
||||
cmake_pop_check_state()
|
||||
|
||||
# ---[ Checks if compiler supports -fvisibility=hidden
|
||||
check_cxx_compiler_flag("-fvisibility=hidden" COMPILER_SUPPORTS_HIDDEN_VISIBILITY)
|
||||
check_cxx_compiler_flag("-fvisibility-inlines-hidden" COMPILER_SUPPORTS_HIDDEN_INLINE_VISIBILITY)
|
||||
|
@ -81,6 +81,7 @@ function (caffe2_print_configuration_summary)
|
||||
endif()
|
||||
message(STATUS " USE_ROCM : ${USE_ROCM}")
|
||||
message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}")
|
||||
message(STATUS " USE_FBGEMM : ${USE_FBGEMM}")
|
||||
message(STATUS " USE_FFMPEG : ${USE_FFMPEG}")
|
||||
message(STATUS " USE_GFLAGS : ${USE_GFLAGS}")
|
||||
message(STATUS " USE_GLOG : ${USE_GLOG}")
|
||||
|
9
setup.py
9
setup.py
@ -27,6 +27,9 @@
|
||||
# NO_CUDNN
|
||||
# disables the cuDNN build
|
||||
#
|
||||
# NO_FBGEMM
|
||||
# disables the FBGEMM build
|
||||
#
|
||||
# NO_TEST
|
||||
# disables the test build
|
||||
#
|
||||
@ -153,7 +156,7 @@ def hotpatch_var(var, prefix='USE_'):
|
||||
|
||||
# Before we run the setup_helpers, let's look for NO_* and WITH_*
|
||||
# variables and hotpatch environment with the USE_* equivalent
|
||||
use_env_vars = ['CUDA', 'CUDNN', 'MIOPEN', 'MKLDNN', 'NNPACK', 'DISTRIBUTED',
|
||||
use_env_vars = ['CUDA', 'CUDNN', 'FBGEMM', 'MIOPEN', 'MKLDNN', 'NNPACK', 'DISTRIBUTED',
|
||||
'OPENCV', 'QNNPACK', 'FFMPEG', 'SYSTEM_NCCL', 'GLOO_IBVERBS']
|
||||
list(map(hotpatch_var, use_env_vars))
|
||||
|
||||
@ -168,6 +171,7 @@ from tools.setup_helpers.build import (BUILD_BINARY, BUILD_TEST,
|
||||
from tools.setup_helpers.rocm import USE_ROCM, ROCM_HOME, ROCM_VERSION
|
||||
from tools.setup_helpers.cudnn import (USE_CUDNN, CUDNN_LIBRARY,
|
||||
CUDNN_LIB_DIR, CUDNN_INCLUDE_DIR)
|
||||
from tools.setup_helpers.fbgemm import USE_FBGEMM
|
||||
from tools.setup_helpers.miopen import (USE_MIOPEN, MIOPEN_LIBRARY,
|
||||
MIOPEN_LIB_DIR, MIOPEN_INCLUDE_DIR)
|
||||
from tools.setup_helpers.nccl import USE_NCCL, USE_SYSTEM_NCCL, NCCL_LIB_DIR, \
|
||||
@ -379,6 +383,8 @@ def build_libs(libs):
|
||||
my_env["NVTOOLEXT_HOME"] = NVTOOLEXT_HOME
|
||||
if USE_CUDA_STATIC_LINK:
|
||||
build_libs_cmd += ['--cuda-static-link']
|
||||
if USE_FBGEMM:
|
||||
build_libs_cmd += ['--use-fbgemm']
|
||||
if USE_ROCM:
|
||||
build_libs_cmd += ['--use-rocm']
|
||||
if USE_NNPACK:
|
||||
@ -454,6 +460,7 @@ class build_deps(PytorchCommand):
|
||||
check_file(os.path.join(third_party_path, 'catch', 'CMakeLists.txt'))
|
||||
check_file(os.path.join(third_party_path, 'onnx', 'CMakeLists.txt'))
|
||||
check_file(os.path.join(third_party_path, 'QNNPACK', 'CMakeLists.txt'))
|
||||
check_file(os.path.join(third_party_path, 'fbgemm', 'CMakeLists.txt'))
|
||||
|
||||
check_pydep('yaml', 'pyyaml')
|
||||
check_pydep('typing', 'typing')
|
||||
|
@ -22,6 +22,7 @@ if not exist torch\lib\tmp_install mkdir torch\lib\tmp_install
|
||||
|
||||
: Variable defaults
|
||||
set /a USE_CUDA=0
|
||||
set /a USE_FBGEMM=0
|
||||
set /a USE_ROCM=0
|
||||
set /a USE_NNPACK=0
|
||||
set /a USE_QNNPACK=0
|
||||
@ -43,6 +44,11 @@ if "%1"=="--use-cuda" (
|
||||
goto :process_args_processed
|
||||
)
|
||||
|
||||
if "%1"=="--use-fbgemm" (
|
||||
set /a USE_FBGEMM=1
|
||||
goto :process_args_processed
|
||||
)
|
||||
|
||||
if "%1"=="--use-rocm" (
|
||||
set /a USE_ROCM=1
|
||||
goto :process_args_processed
|
||||
@ -216,6 +222,7 @@ goto:eof
|
||||
-DONNX_NAMESPACE=%ONNX_NAMESPACE% ^
|
||||
-DUSE_CUDA=%USE_CUDA% ^
|
||||
-DUSE_DISTRIBUTED=%USE_DISTRIBUTED% ^
|
||||
-DUSE_FBGEMM=%USE_FBGEMM% ^
|
||||
-DUSE_NUMPY=%USE_NUMPY% ^
|
||||
-DUSE_NNPACK=%USE_NNPACK% ^
|
||||
-DUSE_LEVELDB=%USE_LEVELDB% ^
|
||||
|
@ -43,6 +43,7 @@ fi
|
||||
|
||||
# Options for building only a subset of the libraries
|
||||
USE_CUDA=0
|
||||
USE_FBGEMM=0
|
||||
USE_ROCM=0
|
||||
USE_NNPACK=0
|
||||
USE_MKLDNN=0
|
||||
@ -58,6 +59,9 @@ while [[ $# -gt 0 ]]; do
|
||||
--use-cuda)
|
||||
USE_CUDA=1
|
||||
;;
|
||||
--use-fbgemm)
|
||||
USE_FBGEMM=1
|
||||
;;
|
||||
--use-rocm)
|
||||
USE_ROCM=1
|
||||
;;
|
||||
@ -212,6 +216,7 @@ function build_caffe2() {
|
||||
-DONNX_NAMESPACE=$ONNX_NAMESPACE \
|
||||
-DUSE_CUDA=$USE_CUDA \
|
||||
-DUSE_DISTRIBUTED=$USE_DISTRIBUTED \
|
||||
-DUSE_FBGEMM=$USE_FBGEMM \
|
||||
-DUSE_NUMPY=$USE_NUMPY \
|
||||
-DCAFFE2_STATIC_LINK_CUDA=$CAFFE2_STATIC_LINK_CUDA \
|
||||
-DUSE_ROCM=$USE_ROCM \
|
||||
|
6
tools/setup_helpers/fbgemm.py
Normal file
6
tools/setup_helpers/fbgemm.py
Normal file
@ -0,0 +1,6 @@
|
||||
from .env import check_env_flag
|
||||
|
||||
if check_env_flag('NO_FBGEMM'):
|
||||
USE_FBGEMM = False
|
||||
else:
|
||||
USE_FBGEMM = True
|
Reference in New Issue
Block a user