[14/N] Fix clang-tidy warnings in aten/src/ATen (#132733)

Follows #133807

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132733
Approved by: https://github.com/ezyang
This commit is contained in:
cyy
2024-08-22 10:09:15 +00:00
committed by PyTorch MergeBot
parent 90c821814e
commit 4c8193b8f0

View File

@ -36,8 +36,9 @@
#include <arm_neon.h>
#endif
namespace at {
namespace native {
// NOLINTBEGIN(*-c-arrays)
namespace at::native {
namespace {
void check_tensor_memory_format(const Tensor& ref, const Tensor& other) {
@ -70,7 +71,6 @@ Tensor qcat_nhwc_kernel(
std::vector<void*> data_ptrs;
std::vector<bool> is_fast_path;
// NOLINTNEXTLINE(performance-implicit-conversion-in-loop)
for (const at::Tensor& qx : qxs) {
TORCH_CHECK(
qx.dim() == qx0.dim(),
@ -143,7 +143,7 @@ Tensor qcat_nhwc_kernel(
continue;
}
constexpr int64_t VLEN = Vec::size();
constexpr auto VLEN = Vec::size();
int64_t c = 0;
// Vectorized loop
@ -157,7 +157,7 @@ Tensor qcat_nhwc_kernel(
curr_scale_vec, curr_zero_pt_vec, scale_neg_zp_premul);
Vec::float_vec_return_type retvals;
for (int i = 0; i < Vec::float_num_vecs(); ++i) {
if (ReLUFused) {
if constexpr (ReLUFused) {
retvals[i] =
vec::maximum(float_values[i], Vectorized<float>(0.0f));
} else {
@ -171,21 +171,21 @@ Tensor qcat_nhwc_kernel(
}
// Vectorized loop for channel between 8 and 32 (avx2)
constexpr int kVLEN = Vectorized<float>::size();
constexpr auto kVLEN = Vectorized<float>::size();
int64_t elem_size = curr_C - c;
if ((VLEN == 4 * kVLEN) && elem_size >= kVLEN) {
auto curr_scale_vec = Vectorized<float>(curr_scale);
auto curr_zero_pt_vec = Vectorized<float>((float)curr_zero_pt);
auto scale_neg_zp_premul = curr_scale_vec * curr_zero_pt_vec.neg();
int64_t vec_num = elem_size / kVLEN;
std::array<typename scalar_t::underlying, VLEN> buf_in;
std::array<typename scalar_t::underlying, VLEN> buf_in{};
memcpy(buf_in.data(), iptr + c, vec_num * kVLEN);
auto inp_vec = Vec::loadu(buf_in.data());
auto float_values = inp_vec.dequantize(
curr_scale_vec, curr_zero_pt_vec, scale_neg_zp_premul);
Vec::float_vec_return_type retvals;
for (int i = 0; i < vec_num; ++i) {
if (ReLUFused) {
if constexpr (ReLUFused) {
retvals[i] =
vec::maximum(float_values[i], Vectorized<float>(0.0f));
} else {
@ -204,7 +204,7 @@ Tensor qcat_nhwc_kernel(
curr_scale,
curr_zero_pt,
reinterpret_cast<scalar_t*>(iptr)[c]);
if (ReLUFused) {
if constexpr (ReLUFused) {
float_val = std::max(0.0f, float_val);
}
optr[c] = at::native::quantize_val<scalar_t>(
@ -593,7 +593,6 @@ void qrelu_kernel(const Tensor& qx, Tensor& qy) {
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qrelu", [&]() {
qy = at::_empty_affine_quantized(
qx.sizes(),
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
qx.q_scale(),
qx.q_zero_point(),
@ -758,7 +757,6 @@ void qgelu_kernel(const Tensor& qx, Tensor& qy, GeluType approximate) {
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qgelu", [&]() {
qy = at::_empty_affine_quantized(
qx.sizes(),
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
output_scale,
output_zero_point,
@ -797,7 +795,6 @@ void qgelu_kernel(const Tensor& qx, Tensor& qy, GeluType approximate) {
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qgelu", [&]() {
qy = at::_empty_affine_quantized(
qx.sizes(),
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
output_scale,
output_zero_point,
@ -842,7 +839,6 @@ void qsigmoid_kernel(
qy = at::_empty_affine_quantized(
qx.sizes(),
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
output_scale,
output_zero_point,
@ -885,7 +881,6 @@ void qhardsigmoid_kernel(const Tensor& qx, Tensor& qy) {
// - Output scale is set to 1.0 / 2^(BIT_NUM)
float output_scale = 0.00390625; // 1.0 / 2^8
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
if (SCALAR_TYPE == at::kQInt32) {
output_scale = 2.3283064365386963e-10; // 1.0 / 2^32
}
@ -946,7 +941,6 @@ void qclamp_kernel(
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qclamp", [&]() {
qy = at::_empty_affine_quantized(
qx.sizes(),
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
qx.q_scale(),
qx.q_zero_point(),
@ -980,7 +974,6 @@ void qclamp_min_kernel(const Tensor& qx, const Scalar& min_scalar, Tensor& qy) {
qy = at::_empty_affine_quantized(
qx.sizes(),
at::device(kCPU)
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
.dtype(SCALAR_TYPE)
.memory_format(qx.suggest_memory_format()),
qx.q_scale(),
@ -1006,7 +999,6 @@ void qclamp_max_kernel(const Tensor& qx, const Scalar& max_scalar, Tensor& qy) {
qy = at::_empty_affine_quantized(
qx.sizes(),
at::device(kCPU)
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
.dtype(SCALAR_TYPE)
.memory_format(qx.suggest_memory_format()),
qx.q_scale(),
@ -1049,7 +1041,6 @@ void qthreshold_kernel(
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qthreshold", [&]() {
qy = at::_empty_affine_quantized(
qx.sizes(),
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
at::device(kCPU).dtype(SCALAR_TYPE).memory_format(qx.suggest_memory_format()),
qx.q_scale(),
qx.q_zero_point(),
@ -1158,7 +1149,6 @@ void qtanh_kernel(const Tensor& qx, Tensor& qy) {
// - For unsigned types output zero point is set to (qmax + qmin) / 2.0
float output_scale = 0.0078125; // 2.0 / 512
int64_t output_zero_point = 0;
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
if (SCALAR_TYPE == at::kQInt32) {
output_scale = 4.656612873077393e-10; // 2.0 / 2^32
} else if (SCALAR_TYPE == at::kQUInt8) {
@ -1314,7 +1304,7 @@ void qadd_scalar_kernel(Tensor& out, const Tensor& self, const Scalar& other) {
int32_t c = a_sub_z + other_val;
scalar_t res = at::native::requantize_from_int<scalar_t>(
multiplier, zero_point, c);
if (ReLUFused) {
if constexpr (ReLUFused) {
res.val_ = std::max<scalar_t::underlying>(res.val_, zero_point);
}
return res;
@ -1327,7 +1317,7 @@ void qadd_scalar_kernel(Tensor& out, const Tensor& self, const Scalar& other) {
c[i] = a_sub_z[i] + other_vec;
}
Vec rv = Vec::requantize_from_int(c, multiplier, zero_point);
if (ReLUFused) {
if constexpr (ReLUFused) {
rv = rv.maximum(Vec(static_cast<scalar_t>(zero_point)));
}
return rv;
@ -1386,7 +1376,7 @@ void qadd_kernel(Tensor& out, const Tensor& self, const Tensor& other) {
Vec::float_vec_return_type retvals;
for (const auto i : c10::irange(Vec::float_num_vecs())) {
auto c = da[i] + db[i];
if (ReLUFused) {
if constexpr (ReLUFused) {
c = vec::maximum(c, Vectorized<float>(0.0f));
}
retvals[i] = c;
@ -1435,7 +1425,7 @@ void qmul_kernel(Tensor& out, const Tensor& self, const Tensor& other) {
int32_t c = a_sub_z * b_sub_z;
scalar_t res = at::native::requantize_from_int<scalar_t>(
multiplier, zero_point, c);
if (ReLUFused) {
if constexpr (ReLUFused) {
res.val_ = std::max<scalar_t::underlying>(res.val_, zero_point);
}
return res;
@ -1450,7 +1440,7 @@ void qmul_kernel(Tensor& out, const Tensor& self, const Tensor& other) {
c[i] = a_sub_zp[i] * b_sub_zp[i];
}
Vec rv = Vec::requantize_from_int(c, multiplier, zero_point);
if (ReLUFused) {
if constexpr (ReLUFused) {
rv = rv.maximum(Vec(static_cast<scalar_t>(zero_point)));
}
return rv;
@ -2410,7 +2400,7 @@ void qtopk_kernel(Tensor& values,
});
}
template <typename T>
template <typename T, bool ReluFused>
inline void do_bn_compute(
typename T::underlying* X_ptr,
typename T::underlying* Y_ptr,
@ -2422,7 +2412,6 @@ inline void do_bn_compute(
float* alpha,
float* beta,
int64_t vec_num,
bool ReluFused,
int64_t kVLen
) {
using Vec = Vectorized<T>;
@ -2437,7 +2426,7 @@ inline void do_bn_compute(
// NOLINTNEXTLINE(bugprone-argument-comment)
auto outputs_q = Vec::quantize(vals_dq, /*output_scale=*/1.0f, out_zero_point, /*inv_output_scale=*/1.0f);
// Fake scale again
if (ReluFused) {
if constexpr (ReluFused) {
outputs_q = outputs_q.maximum(out_zero_point_v);
}
outputs_q.store(Y_ptr, vec_num * kVLen);
@ -2480,7 +2469,7 @@ void q_batch_norm_kernel(
int64_t ch = 0;
for(; ch + lanes <= C; ch += lanes) {
do_bn_compute<scalar_t>(
do_bn_compute<scalar_t, ReluFused>(
X_ptr + ch,
Y_ptr + ch,
fake_scale,
@ -2491,7 +2480,6 @@ void q_batch_norm_kernel(
alpha + ch,
beta + ch,
Vec::float_num_vecs(),
ReluFused,
kVLen
);
}
@ -2503,7 +2491,7 @@ void q_batch_norm_kernel(
int64_t vec_num = elem_size / kVLen;
std::vector<typename scalar_t::underlying> buf_in(lanes);
memcpy(buf_in.data(), X_ptr + ch, vec_num * kVLen); // 3 cycles
do_bn_compute<scalar_t>(
do_bn_compute<scalar_t, ReluFused>(
buf_in.data(),
Y_ptr + ch,
fake_scale,
@ -2514,7 +2502,6 @@ void q_batch_norm_kernel(
alpha + ch,
beta + ch,
vec_num,
ReluFused,
kVLen
);
ch += vec_num * kVLen;
@ -2524,7 +2511,7 @@ void q_batch_norm_kernel(
long quantized_down = out_zero_point +
lrintf(alpha[ch] * (X_ptr[ch] - in_zero_point) +
beta[ch]);
if (ReluFused) { // static if
if constexpr (ReluFused) { // static if
quantized_down = std::max<long>(quantized_down, out_zero_point);
}
Y_ptr[ch] = std::min<long>(
@ -4048,7 +4035,6 @@ void dequantize_per_channel_affine_kernel(
c * elements_per_channel + e;
// We need to convert the qint8 value to float to ensure the
// subtraction subexpression returns a float
// NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
auto qvalue = qd[i / elem_per_byte].val_;
if (bit_width < 8) {
qvalue >>= (i % elem_per_byte) * bit_width;
@ -4109,7 +4095,6 @@ void quantize_tensor_per_channel_float_qparams_cpu(
auto i = b * channel * elements_per_channel + e * channel + c;
qvalue = quantize_val_float_qparams(
scales_data[c], zero_points_data[c], rdata[i], quant_min, quant_max);
// NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
if (i % elem_per_byte == 0) {
qdata[i / elem_per_byte] = static_cast<underlying_t>(qvalue);
} else {
@ -4126,7 +4111,6 @@ void quantize_tensor_per_channel_float_qparams_cpu(
c * elements_per_channel + e;
qvalue = quantize_val_float_qparams(
scales_data[c], zero_points_data[c], rdata[i], quant_min, quant_max);
// NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
if (i % elem_per_byte == 0) {
qdata[i / elem_per_byte] = static_cast<underlying_t>(qvalue);
} else {
@ -4145,7 +4129,6 @@ void dequantize_tensor_per_channel_float_qparams_cpu(
const Tensor& scales,
const Tensor& zero_points,
int64_t axis) {
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
qtensor.scalar_type(), "dequantize_tensor_per_channel_float_qparams_cpu", [&]() {
dequantize_per_channel_affine_kernel<float, float, scalar_t>(qtensor, rtensor, scales, zero_points, axis, bit_width);
@ -4173,7 +4156,6 @@ void quantize_tensor_per_tensor_affine_sub_byte_cpu(
// We pack sub_byte values and align them to a byte.
// Eg. for 4-bits Index 0 is packed in the lower 4-bits
// and index 1 is packed in the upper 4-bits.
// NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
if (i % elem_per_byte == 0) {
qdata[i / elem_per_byte] = static_cast<underlying_t>(qvalue);
} else {
@ -4189,7 +4171,6 @@ void dequantize_tensor_per_tensor_affine_sub_byte_cpu(
float scale,
float zero_point) {
// TODO Use fbgemm kernel to pack values
// NOLINTNEXTLINE(clang-diagnostic-unused-variable)
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_sub_byte_cpu", [&]() {
check_tensor_memory_format(rtensor, qtensor);
@ -4199,7 +4180,6 @@ void dequantize_tensor_per_tensor_affine_sub_byte_cpu(
const auto elem_per_byte = CHAR_BIT / bit_width;
for (const auto i : c10::irange(numel)) {
// NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
underlying_t qvalue = qdata[i / elem_per_byte];
qvalue >>= (i % elem_per_byte) * bit_width;
qvalue &= (1 << bit_width) - 1;
@ -4347,5 +4327,5 @@ REGISTER_DISPATCH(
&index_put_kernel_quantized_cpu);
REGISTER_DISPATCH(qmean_inner_dim_stub, &qmean_inner_dim_kernel);
REGISTER_DISPATCH(qstd_inner_dim_stub, &qstd_inner_dim_kernel);
} // namespace native
} // namespace at
} // namespace at::native
// NOLINTEND(*-c-arrays)