mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[14/N] Fix clang-tidy warnings in aten/src/ATen (#132733)
Follows #133807 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132733 Approved by: https://github.com/ezyang
This commit is contained in:
@ -36,8 +36,9 @@
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
// NOLINTBEGIN(*-c-arrays)
|
||||
namespace at::native {
|
||||
namespace {
|
||||
|
||||
void check_tensor_memory_format(const Tensor& ref, const Tensor& other) {
|
||||
@ -70,7 +71,6 @@ Tensor qcat_nhwc_kernel(
|
||||
std::vector<void*> data_ptrs;
|
||||
std::vector<bool> is_fast_path;
|
||||
|
||||
// NOLINTNEXTLINE(performance-implicit-conversion-in-loop)
|
||||
for (const at::Tensor& qx : qxs) {
|
||||
TORCH_CHECK(
|
||||
qx.dim() == qx0.dim(),
|
||||
@ -143,7 +143,7 @@ Tensor qcat_nhwc_kernel(
|
||||
continue;
|
||||
}
|
||||
|
||||
constexpr int64_t VLEN = Vec::size();
|
||||
constexpr auto VLEN = Vec::size();
|
||||
int64_t c = 0;
|
||||
|
||||
// Vectorized loop
|
||||
@ -157,7 +157,7 @@ Tensor qcat_nhwc_kernel(
|
||||
curr_scale_vec, curr_zero_pt_vec, scale_neg_zp_premul);
|
||||
Vec::float_vec_return_type retvals;
|
||||
for (int i = 0; i < Vec::float_num_vecs(); ++i) {
|
||||
if (ReLUFused) {
|
||||
if constexpr (ReLUFused) {
|
||||
retvals[i] =
|
||||
vec::maximum(float_values[i], Vectorized<float>(0.0f));
|
||||
} else {
|
||||
@ -171,21 +171,21 @@ Tensor qcat_nhwc_kernel(
|
||||
}
|
||||
|
||||
// Vectorized loop for channel between 8 and 32 (avx2)
|
||||
constexpr int kVLEN = Vectorized<float>::size();
|
||||
constexpr auto kVLEN = Vectorized<float>::size();
|
||||
int64_t elem_size = curr_C - c;
|
||||
if ((VLEN == 4 * kVLEN) && elem_size >= kVLEN) {
|
||||
auto curr_scale_vec = Vectorized<float>(curr_scale);
|
||||
auto curr_zero_pt_vec = Vectorized<float>((float)curr_zero_pt);
|
||||
auto scale_neg_zp_premul = curr_scale_vec * curr_zero_pt_vec.neg();
|
||||
int64_t vec_num = elem_size / kVLEN;
|
||||
std::array<typename scalar_t::underlying, VLEN> buf_in;
|
||||
std::array<typename scalar_t::underlying, VLEN> buf_in{};
|
||||
memcpy(buf_in.data(), iptr + c, vec_num * kVLEN);
|
||||
auto inp_vec = Vec::loadu(buf_in.data());
|
||||
auto float_values = inp_vec.dequantize(
|
||||
curr_scale_vec, curr_zero_pt_vec, scale_neg_zp_premul);
|
||||
Vec::float_vec_return_type retvals;
|
||||
for (int i = 0; i < vec_num; ++i) {
|
||||
if (ReLUFused) {
|
||||
if constexpr (ReLUFused) {
|
||||
retvals[i] =
|
||||
vec::maximum(float_values[i], Vectorized<float>(0.0f));
|
||||
} else {
|
||||
@ -204,7 +204,7 @@ Tensor qcat_nhwc_kernel(
|
||||
curr_scale,
|
||||
curr_zero_pt,
|
||||
reinterpret_cast<scalar_t*>(iptr)[c]);
|
||||
if (ReLUFused) {
|
||||
if constexpr (ReLUFused) {
|
||||
float_val = std::max(0.0f, float_val);
|
||||
}
|
||||
optr[c] = at::native::quantize_val<scalar_t>(
|
||||
@ -593,7 +593,6 @@ void qrelu_kernel(const Tensor& qx, Tensor& qy) {
|
||||
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qrelu", [&]() {
|
||||
qy = at::_empty_affine_quantized(
|
||||
qx.sizes(),
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
|
||||
at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
|
||||
qx.q_scale(),
|
||||
qx.q_zero_point(),
|
||||
@ -758,7 +757,6 @@ void qgelu_kernel(const Tensor& qx, Tensor& qy, GeluType approximate) {
|
||||
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qgelu", [&]() {
|
||||
qy = at::_empty_affine_quantized(
|
||||
qx.sizes(),
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
|
||||
at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
|
||||
output_scale,
|
||||
output_zero_point,
|
||||
@ -797,7 +795,6 @@ void qgelu_kernel(const Tensor& qx, Tensor& qy, GeluType approximate) {
|
||||
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qgelu", [&]() {
|
||||
qy = at::_empty_affine_quantized(
|
||||
qx.sizes(),
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
|
||||
at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
|
||||
output_scale,
|
||||
output_zero_point,
|
||||
@ -842,7 +839,6 @@ void qsigmoid_kernel(
|
||||
|
||||
qy = at::_empty_affine_quantized(
|
||||
qx.sizes(),
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
|
||||
at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
|
||||
output_scale,
|
||||
output_zero_point,
|
||||
@ -885,7 +881,6 @@ void qhardsigmoid_kernel(const Tensor& qx, Tensor& qy) {
|
||||
|
||||
// - Output scale is set to 1.0 / 2^(BIT_NUM)
|
||||
float output_scale = 0.00390625; // 1.0 / 2^8
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
|
||||
if (SCALAR_TYPE == at::kQInt32) {
|
||||
output_scale = 2.3283064365386963e-10; // 1.0 / 2^32
|
||||
}
|
||||
@ -946,7 +941,6 @@ void qclamp_kernel(
|
||||
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qclamp", [&]() {
|
||||
qy = at::_empty_affine_quantized(
|
||||
qx.sizes(),
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
|
||||
at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
|
||||
qx.q_scale(),
|
||||
qx.q_zero_point(),
|
||||
@ -980,7 +974,6 @@ void qclamp_min_kernel(const Tensor& qx, const Scalar& min_scalar, Tensor& qy) {
|
||||
qy = at::_empty_affine_quantized(
|
||||
qx.sizes(),
|
||||
at::device(kCPU)
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
|
||||
.dtype(SCALAR_TYPE)
|
||||
.memory_format(qx.suggest_memory_format()),
|
||||
qx.q_scale(),
|
||||
@ -1006,7 +999,6 @@ void qclamp_max_kernel(const Tensor& qx, const Scalar& max_scalar, Tensor& qy) {
|
||||
qy = at::_empty_affine_quantized(
|
||||
qx.sizes(),
|
||||
at::device(kCPU)
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
|
||||
.dtype(SCALAR_TYPE)
|
||||
.memory_format(qx.suggest_memory_format()),
|
||||
qx.q_scale(),
|
||||
@ -1049,7 +1041,6 @@ void qthreshold_kernel(
|
||||
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qthreshold", [&]() {
|
||||
qy = at::_empty_affine_quantized(
|
||||
qx.sizes(),
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
|
||||
at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
|
||||
qx.q_scale(),
|
||||
qx.q_zero_point(),
|
||||
@ -1158,7 +1149,6 @@ void qtanh_kernel(const Tensor& qx, Tensor& qy) {
|
||||
// - For unsigned types output zero point is set to (qmax + qmin) / 2.0
|
||||
float output_scale = 0.0078125; // 2.0 / 512
|
||||
int64_t output_zero_point = 0;
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
|
||||
if (SCALAR_TYPE == at::kQInt32) {
|
||||
output_scale = 4.656612873077393e-10; // 2.0 / 2^32
|
||||
} else if (SCALAR_TYPE == at::kQUInt8) {
|
||||
@ -1314,7 +1304,7 @@ void qadd_scalar_kernel(Tensor& out, const Tensor& self, const Scalar& other) {
|
||||
int32_t c = a_sub_z + other_val;
|
||||
scalar_t res = at::native::requantize_from_int<scalar_t>(
|
||||
multiplier, zero_point, c);
|
||||
if (ReLUFused) {
|
||||
if constexpr (ReLUFused) {
|
||||
res.val_ = std::max<scalar_t::underlying>(res.val_, zero_point);
|
||||
}
|
||||
return res;
|
||||
@ -1327,7 +1317,7 @@ void qadd_scalar_kernel(Tensor& out, const Tensor& self, const Scalar& other) {
|
||||
c[i] = a_sub_z[i] + other_vec;
|
||||
}
|
||||
Vec rv = Vec::requantize_from_int(c, multiplier, zero_point);
|
||||
if (ReLUFused) {
|
||||
if constexpr (ReLUFused) {
|
||||
rv = rv.maximum(Vec(static_cast<scalar_t>(zero_point)));
|
||||
}
|
||||
return rv;
|
||||
@ -1386,7 +1376,7 @@ void qadd_kernel(Tensor& out, const Tensor& self, const Tensor& other) {
|
||||
Vec::float_vec_return_type retvals;
|
||||
for (const auto i : c10::irange(Vec::float_num_vecs())) {
|
||||
auto c = da[i] + db[i];
|
||||
if (ReLUFused) {
|
||||
if constexpr (ReLUFused) {
|
||||
c = vec::maximum(c, Vectorized<float>(0.0f));
|
||||
}
|
||||
retvals[i] = c;
|
||||
@ -1435,7 +1425,7 @@ void qmul_kernel(Tensor& out, const Tensor& self, const Tensor& other) {
|
||||
int32_t c = a_sub_z * b_sub_z;
|
||||
scalar_t res = at::native::requantize_from_int<scalar_t>(
|
||||
multiplier, zero_point, c);
|
||||
if (ReLUFused) {
|
||||
if constexpr (ReLUFused) {
|
||||
res.val_ = std::max<scalar_t::underlying>(res.val_, zero_point);
|
||||
}
|
||||
return res;
|
||||
@ -1450,7 +1440,7 @@ void qmul_kernel(Tensor& out, const Tensor& self, const Tensor& other) {
|
||||
c[i] = a_sub_zp[i] * b_sub_zp[i];
|
||||
}
|
||||
Vec rv = Vec::requantize_from_int(c, multiplier, zero_point);
|
||||
if (ReLUFused) {
|
||||
if constexpr (ReLUFused) {
|
||||
rv = rv.maximum(Vec(static_cast<scalar_t>(zero_point)));
|
||||
}
|
||||
return rv;
|
||||
@ -2410,7 +2400,7 @@ void qtopk_kernel(Tensor& values,
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, bool ReluFused>
|
||||
inline void do_bn_compute(
|
||||
typename T::underlying* X_ptr,
|
||||
typename T::underlying* Y_ptr,
|
||||
@ -2422,7 +2412,6 @@ inline void do_bn_compute(
|
||||
float* alpha,
|
||||
float* beta,
|
||||
int64_t vec_num,
|
||||
bool ReluFused,
|
||||
int64_t kVLen
|
||||
) {
|
||||
using Vec = Vectorized<T>;
|
||||
@ -2437,7 +2426,7 @@ inline void do_bn_compute(
|
||||
// NOLINTNEXTLINE(bugprone-argument-comment)
|
||||
auto outputs_q = Vec::quantize(vals_dq, /*output_scale=*/1.0f, out_zero_point, /*inv_output_scale=*/1.0f);
|
||||
// Fake scale again
|
||||
if (ReluFused) {
|
||||
if constexpr (ReluFused) {
|
||||
outputs_q = outputs_q.maximum(out_zero_point_v);
|
||||
}
|
||||
outputs_q.store(Y_ptr, vec_num * kVLen);
|
||||
@ -2480,7 +2469,7 @@ void q_batch_norm_kernel(
|
||||
int64_t ch = 0;
|
||||
|
||||
for(; ch + lanes <= C; ch += lanes) {
|
||||
do_bn_compute<scalar_t>(
|
||||
do_bn_compute<scalar_t, ReluFused>(
|
||||
X_ptr + ch,
|
||||
Y_ptr + ch,
|
||||
fake_scale,
|
||||
@ -2491,7 +2480,6 @@ void q_batch_norm_kernel(
|
||||
alpha + ch,
|
||||
beta + ch,
|
||||
Vec::float_num_vecs(),
|
||||
ReluFused,
|
||||
kVLen
|
||||
);
|
||||
}
|
||||
@ -2503,7 +2491,7 @@ void q_batch_norm_kernel(
|
||||
int64_t vec_num = elem_size / kVLen;
|
||||
std::vector<typename scalar_t::underlying> buf_in(lanes);
|
||||
memcpy(buf_in.data(), X_ptr + ch, vec_num * kVLen); // 3 cycles
|
||||
do_bn_compute<scalar_t>(
|
||||
do_bn_compute<scalar_t, ReluFused>(
|
||||
buf_in.data(),
|
||||
Y_ptr + ch,
|
||||
fake_scale,
|
||||
@ -2514,7 +2502,6 @@ void q_batch_norm_kernel(
|
||||
alpha + ch,
|
||||
beta + ch,
|
||||
vec_num,
|
||||
ReluFused,
|
||||
kVLen
|
||||
);
|
||||
ch += vec_num * kVLen;
|
||||
@ -2524,7 +2511,7 @@ void q_batch_norm_kernel(
|
||||
long quantized_down = out_zero_point +
|
||||
lrintf(alpha[ch] * (X_ptr[ch] - in_zero_point) +
|
||||
beta[ch]);
|
||||
if (ReluFused) { // static if
|
||||
if constexpr (ReluFused) { // static if
|
||||
quantized_down = std::max<long>(quantized_down, out_zero_point);
|
||||
}
|
||||
Y_ptr[ch] = std::min<long>(
|
||||
@ -4048,7 +4035,6 @@ void dequantize_per_channel_affine_kernel(
|
||||
c * elements_per_channel + e;
|
||||
// We need to convert the qint8 value to float to ensure the
|
||||
// subtraction subexpression returns a float
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
|
||||
auto qvalue = qd[i / elem_per_byte].val_;
|
||||
if (bit_width < 8) {
|
||||
qvalue >>= (i % elem_per_byte) * bit_width;
|
||||
@ -4109,7 +4095,6 @@ void quantize_tensor_per_channel_float_qparams_cpu(
|
||||
auto i = b * channel * elements_per_channel + e * channel + c;
|
||||
qvalue = quantize_val_float_qparams(
|
||||
scales_data[c], zero_points_data[c], rdata[i], quant_min, quant_max);
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
|
||||
if (i % elem_per_byte == 0) {
|
||||
qdata[i / elem_per_byte] = static_cast<underlying_t>(qvalue);
|
||||
} else {
|
||||
@ -4126,7 +4111,6 @@ void quantize_tensor_per_channel_float_qparams_cpu(
|
||||
c * elements_per_channel + e;
|
||||
qvalue = quantize_val_float_qparams(
|
||||
scales_data[c], zero_points_data[c], rdata[i], quant_min, quant_max);
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
|
||||
if (i % elem_per_byte == 0) {
|
||||
qdata[i / elem_per_byte] = static_cast<underlying_t>(qvalue);
|
||||
} else {
|
||||
@ -4145,7 +4129,6 @@ void dequantize_tensor_per_channel_float_qparams_cpu(
|
||||
const Tensor& scales,
|
||||
const Tensor& zero_points,
|
||||
int64_t axis) {
|
||||
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
|
||||
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
|
||||
qtensor.scalar_type(), "dequantize_tensor_per_channel_float_qparams_cpu", [&]() {
|
||||
dequantize_per_channel_affine_kernel<float, float, scalar_t>(qtensor, rtensor, scales, zero_points, axis, bit_width);
|
||||
@ -4173,7 +4156,6 @@ void quantize_tensor_per_tensor_affine_sub_byte_cpu(
|
||||
// We pack sub_byte values and align them to a byte.
|
||||
// Eg. for 4-bits Index 0 is packed in the lower 4-bits
|
||||
// and index 1 is packed in the upper 4-bits.
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
|
||||
if (i % elem_per_byte == 0) {
|
||||
qdata[i / elem_per_byte] = static_cast<underlying_t>(qvalue);
|
||||
} else {
|
||||
@ -4189,7 +4171,6 @@ void dequantize_tensor_per_tensor_affine_sub_byte_cpu(
|
||||
float scale,
|
||||
float zero_point) {
|
||||
// TODO Use fbgemm kernel to pack values
|
||||
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
|
||||
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
|
||||
qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_sub_byte_cpu", [&]() {
|
||||
check_tensor_memory_format(rtensor, qtensor);
|
||||
@ -4199,7 +4180,6 @@ void dequantize_tensor_per_tensor_affine_sub_byte_cpu(
|
||||
const auto elem_per_byte = CHAR_BIT / bit_width;
|
||||
|
||||
for (const auto i : c10::irange(numel)) {
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
|
||||
underlying_t qvalue = qdata[i / elem_per_byte];
|
||||
qvalue >>= (i % elem_per_byte) * bit_width;
|
||||
qvalue &= (1 << bit_width) - 1;
|
||||
@ -4347,5 +4327,5 @@ REGISTER_DISPATCH(
|
||||
&index_put_kernel_quantized_cpu);
|
||||
REGISTER_DISPATCH(qmean_inner_dim_stub, &qmean_inner_dim_kernel);
|
||||
REGISTER_DISPATCH(qstd_inner_dim_stub, &qstd_inner_dim_kernel);
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
} // namespace at::native
|
||||
// NOLINTEND(*-c-arrays)
|
||||
|
Reference in New Issue
Block a user