Fix clang-tidy bugprone* warnings (#148529)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148529
Approved by: https://github.com/ezyang
This commit is contained in:
Yuanyuan Chen
2025-06-23 23:09:56 +00:00
committed by PyTorch MergeBot
parent 3f920f3d8f
commit 07bb097698
28 changed files with 74 additions and 79 deletions

View File

@ -222,8 +222,7 @@ void set_num_threads(int nthreads) {
int stored_nthreads = num_intraop_threads.load();
if (stored_nthreads <= 0) {
// plus one because of master thread
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
stored_nthreads = _get_intraop_pool().size() + 1;
stored_nthreads = static_cast<int>(_get_intraop_pool().size() + 1);
}
if (stored_nthreads != nthreads) {
TORCH_WARN(
@ -251,8 +250,7 @@ int get_num_threads() {
return intraop_default_num_threads();
} else {
TORCH_INTERNAL_ASSERT(nthreads == CONSUMED);
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
return _get_intraop_pool().size() + 1;
return static_cast<int>(_get_intraop_pool().size() + 1);
}
#else
caffe2::PThreadPool* const pool = caffe2::pthreadpool();

View File

@ -111,12 +111,15 @@ static cublasOperation_t _cublasOpFromChar(char op) {
// NOLINTNEXTLINE(bugprone-switch-missing-default-case)
switch (op) {
case 'n':
[[fallthrough]];
case 'N':
return CUBLAS_OP_N;
case 't':
[[fallthrough]];
case 'T':
return CUBLAS_OP_T;
case 'c':
[[fallthrough]];
case 'C':
return CUBLAS_OP_C;
}

View File

@ -156,8 +156,7 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo
default:
TORCH_INTERNAL_ASSERT(false, "unsupported memory_format for cuDNN filters");
}
// NOLINTNEXTLINE(*narrowing-conversions)
set(getDataType(t), static_cast<int64_t>(dim), size, filter_format);
set(getDataType(t), static_cast<int>(dim), size, filter_format);
}
std::string cudnnMemoryFormatToString(cudnnTensorFormat_t tformat) {

View File

@ -98,13 +98,13 @@ void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool
const auto arguments_begin = stack->size() - num_arguments;
std::vector<at::Tensor> tensor_args;
std::vector<int> tensor_args_indices;
std::vector<size_t> tensor_args_indices;
std::vector<c10::List<at::Tensor>> tensorlist_args;
std::vector<int> tensorlist_args_indices;
std::vector<size_t> tensorlist_args_indices;
std::vector<c10::List<std::optional<at::Tensor>>> optional_tensorlist_args;
std::vector<int> optional_tensorlist_args_indices;
std::vector<size_t> optional_tensorlist_args_indices;
std::optional<c10::Device> tgt_device = std::nullopt;
// save converted cpu tensor for TensorList and optional TensorList

View File

@ -162,8 +162,7 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm(
ideep::tensor saved_mean;
ideep::tensor saved_var;
ideep::batch_normalization_forward_training::compute(
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
x, w, b, y, saved_mean, saved_var, momentum, eps);
x, w, b, y, saved_mean, saved_var, static_cast<float>(momentum), static_cast<float>(eps));
if (use_running_stat) {
auto len = x.get_nelems() / w.get_nelems(); // n*h*w
ideep::tensor m = itensor_from_tensor(running_mean);
@ -171,8 +170,7 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm(
const std::vector<float> scales_mean{static_cast<float>(1 - momentum),
static_cast<float>(momentum)};
const std::vector<float> scales_var{static_cast<float>(1 - momentum),
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
static_cast<float>(momentum * len / (len - 1))};
static_cast<float>(momentum * static_cast<double>(len) / (static_cast<double>(len) - 1))};
ideep::sum::compute(scales_mean, {m, saved_mean}, m);
ideep::sum::compute(scales_var, {v, saved_var}, v);
}

View File

@ -85,8 +85,9 @@ at::Tensor quantized_convolution(
std::optional<std::string_view> unary_attr,
torch::List<std::optional<at::Scalar>> unary_scalars,
std::optional<std::string_view> unary_algorithm) {
Attr attr =
Attr(/*q_scale=*/1.0 / inv_output_scale, /*zp=*/output_zero_point);
Attr attr = Attr(
/*q_scale=*/static_cast<float>(1.0 / inv_output_scale),
/*zp=*/output_zero_point);
auto ndim = act.ndimension();
construct_attr_by_post_op(

View File

@ -112,7 +112,7 @@ void quantized_matmul(
// config we support:
// activation: s8&u8; per tensor calibrated; symmetric&asymmetric
// weight: s8; per_tensor/per_channel calibrated; symmetric
auto attr = Attr(1.0 / output_scale, output_zero_point);
auto attr = Attr(static_cast<float>(1.0 / output_scale), output_zero_point);
construct_attr_by_post_op(
binary_post_op,
binary_alpha,

View File

@ -81,8 +81,8 @@ std::vector<Tensor> quantize_per_tensor_list_cpu(
for (const auto i : c10::irange(tensors.size())) {
quantized_tensors.push_back(at::quantize_per_tensor(
tensors[i],
scales[i].item<double>(),
zero_points[i].item<int64_t>(),
scales[static_cast<int64_t>(i)].item<double>(),
zero_points[static_cast<int64_t>(i)].item<int64_t>(),
dtype));
}
return quantized_tensors;
@ -293,18 +293,16 @@ std::tuple<double, int64_t> _choose_qparams_per_tensor(
static float calculate_quant_loss(
const float* input,
int numel,
int64_t numel,
float xmin,
float xmax,
float* q_input,
int bit_width) {
int64_t bit_width) {
xmin = static_cast<at::Half>(xmin);
float data_range = xmax - xmin;
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
float qmax = (1 << bit_width) - 1;
float qmax = static_cast<float>((1 << bit_width) - 1);
float scale = data_range == 0
? 1.0
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
? 1.0f
: static_cast<float>(static_cast<at::Half>(data_range / qmax));
float inverse_scale = scale == 0 ? 1.0f : 1.0f / scale;
@ -347,10 +345,10 @@ std::tuple<Tensor, Tensor> choose_qparams_optimized(
const float* input_row = input_tensor.const_data_ptr<float>();
float xmin = *std::min_element(input_row, input_row + numel);
float xmax = *std::max_element(input_row, input_row + numel);
float n_bins_float = static_cast<float>(n_bins);
float stepsize = (xmax - xmin) / n_bins;
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
int min_bins = n_bins * (1.0 - (float) ratio);
float stepsize = (xmax - xmin) / n_bins_float;
float min_bins = static_cast<float>(n_bins_float* (1.0 - ratio));
Tensor input_tensor_contig = input_tensor.contiguous();
const float* input = input_tensor_contig.const_data_ptr<float>();
std::vector<float> q_input(numel);
@ -363,7 +361,6 @@ std::tuple<Tensor, Tensor> choose_qparams_optimized(
float cur_max = xmax;
float cur_loss = loss;
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
float thr = min_bins * stepsize;
while (cur_min + thr < cur_max) {
// move left

View File

@ -84,7 +84,7 @@ std::tuple<at::Tensor, std::optional<at::Tensor>> PackedLinearWeightsQnnp::
at::device(c10::kCPU).dtype(c10::kFloat));
at::Tensor zero_points = at::empty(
w_zero_points.size() - kPaddingChannels, at::device(c10::kCPU).dtype(c10::kLong));
static_cast<int64_t>(w_zero_points.size() - kPaddingChannels), at::device(c10::kCPU).dtype(c10::kLong));
for (const auto i : c10::irange(zero_points.numel())) {
zero_points[i] = ((int64_t)w_zero_points[i] - 128);
}

View File

@ -108,8 +108,7 @@ Tensor qcat_nhwc_kernel(
const int64_t N = qx0.size(0);
const int64_t H = qx0.size(2);
const int64_t W = qx0.size(3);
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
float inv_scale = 1.0 / scale;
float inv_scale = static_cast<float>(1.0 / scale);
auto output = at::_empty_affine_quantized(
{N, C_out, H, W},
@ -1282,12 +1281,10 @@ void qelu_kernel(
template <bool ReLUFused = false>
void qadd_scalar_kernel(Tensor& out, const Tensor& self, const Scalar& other) {
int64_t zero_point = out.q_zero_point();
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
float scale = out.q_scale();
float inv_scale = 1.0f / scale;
float scale = static_cast<float>(out.q_scale());
float inv_scale = static_cast<float>(1.0f / scale);
int64_t self_zero_point = self.q_zero_point();
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
float self_scale = self.q_scale();
float self_scale = static_cast<float>(self.q_scale());
float multiplier = self_scale * inv_scale;

View File

@ -133,7 +133,7 @@ void NnapiCompilation::run(
t.nbytes());
}
for (const auto i : c10::irange(outputs.size())) {
for (const auto i : c10::irange(static_cast<int32_t>(outputs.size()))) {
auto& t = outputs[i];
// TODO: Check contiguous and dtype.
check_nnapi->Execution_setOutput(
@ -147,7 +147,7 @@ void NnapiCompilation::run(
check_nnapi->Execution_compute(execution);
// TODO: Maybe skip this for fixed-size outputs?
for (const auto i : c10::irange(outputs.size())) {
for (const auto i : c10::irange(static_cast<int32_t>(outputs.size()))) {
auto& t = outputs[i];
uint32_t rank = 0;
check_nnapi->Execution_getOutputOperandRank(execution, i, &rank);
@ -177,9 +177,8 @@ void NnapiCompilation::get_operand_type(const at::Tensor& t, ANeuralNetworksOper
if (t.scalar_type() == c10::kQUInt8) {
TORCH_CHECK(t.is_quantized());
operand->type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM;
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
operand->scale = t.q_scale();
operand->zeroPoint = t.q_zero_point();
operand->scale = static_cast<float>(t.q_scale());
operand->zeroPoint = static_cast<int32_t>(t.q_zero_point());
return;
}
if (t.scalar_type() == c10::kInt) {
@ -194,7 +193,6 @@ void NnapiCompilation::get_operand_type(const at::Tensor& t, ANeuralNetworksOper
"testing with fixed scale, zero_point. Please change your ",
"inputs if you see this in production");
operand->type = ANEURALNETWORKS_TENSOR_QUANT16_ASYMM;
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
operand->scale = 0.125;
operand->zeroPoint = 0;
return;

View File

@ -257,22 +257,22 @@ static struct PyGetSetDef THPEvent_properties[] = {
// NOLINTNEXTLINE(*c-arrays*, *global-variables)
static PyMethodDef THPEvent_methods[] = {
{(char*)"from_ipc_handle",
{"from_ipc_handle",
castPyCFunctionWithKeywords(THPEvent_from_ipc_handle),
METH_CLASS | METH_VARARGS | METH_KEYWORDS,
nullptr},
{(char*)"record",
{"record",
castPyCFunctionWithKeywords(THPEvent_record),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{(char*)"wait",
{"wait",
castPyCFunctionWithKeywords(THPEvent_wait),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{(char*)"query", THPEvent_query, METH_NOARGS, nullptr},
{(char*)"elapsed_time", THPEvent_elapsed_time, METH_O, nullptr},
{(char*)"synchronize", THPEvent_synchronize, METH_NOARGS, nullptr},
{(char*)"ipc_handle", THPEvent_ipc_handle, METH_NOARGS, nullptr},
{"query", THPEvent_query, METH_NOARGS, nullptr},
{"elapsed_time", THPEvent_elapsed_time, METH_O, nullptr},
{"synchronize", THPEvent_synchronize, METH_NOARGS, nullptr},
{"ipc_handle", THPEvent_ipc_handle, METH_NOARGS, nullptr},
{nullptr}};
PyTypeObject THPEventType = {

View File

@ -280,6 +280,7 @@ static PyObject* THPModule_crashIfvptrUBSAN(PyObject* module, PyObject* noarg) {
virtual ~Baz() = default;
};
Baz x{};
// NOLINTNEXTLINE(bugprone-casting*)
auto y = static_cast<Foo*>(static_cast<void*>(&x));
auto rc = y->bar();
return THPUtils_packInt32(rc);
@ -2371,7 +2372,7 @@ Call this whenever a new thread is created in order to propagate values from
auto acc = at::getAccelerator(check.value_or(false));
if (acc.has_value()) {
bool is_available = at::globalContext()
.getAcceleratorHooksInterface(acc.value())
.getAcceleratorHooksInterface(acc)
.isAvailable();
if (!is_available) {

View File

@ -451,12 +451,12 @@ static PyObject* THPEngine_new(
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
static struct PyMethodDef THPEngine_methods[] = {
{(char*)"run_backward",
{"run_backward",
castPyCFunctionWithKeywords(THPEngine_run_backward),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{(char*)"queue_callback", THPEngine_queue_callback, METH_O, nullptr},
{(char*)"is_checkpoint_valid",
{"queue_callback", THPEngine_queue_callback, METH_O, nullptr},
{"is_checkpoint_valid",
THPEngine_is_checkpoint_valid,
METH_NOARGS,
nullptr},

View File

@ -260,7 +260,6 @@ auto PyNode::apply_with_saved_impl(
Py_CLEAR(py_fn->compiled_autograd_backward_state);
}
THPObjectPtr r(PyObject_CallMethod(
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
saved.get_py_compiler(),
"proxy_call_backward",
"OOOiOO",

View File

@ -817,6 +817,7 @@ static PyObject* THPVariable_get_python_dispatch(
// - static Tensor fn(const Tensor&);
// - This function calls the relevant ATen on the tensor
template <typename T>
// NOLINTNEXTLINE(bugprone-crtp-constructor-accessibility)
struct GetterBase {
static PyObject* getter(THPVariable* self, void* /*unused*/) {
HANDLE_TH_ERRORS

View File

@ -420,7 +420,7 @@ static PyObject* reduceopmeta___instancecheck__(
// NOLINTNEXTLINE(*c-arrays)
static PyMethodDef reduceopmeta_methods[] = {
{"__instancecheck__",
(PyCFunction)reduceopmeta___instancecheck__,
reduceopmeta___instancecheck__,
METH_O,
"Custom `__instancecheck__` for ReduceOp"},
{nullptr, nullptr}};

View File

@ -131,6 +131,7 @@ std::vector<c10::Device> getDevicesOfTensors(
devices.reserve(deviceCount);
for (const auto idx : c10::irange(indexBitset.size())) {
if (indexBitset[idx]) {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
devices.emplace_back(impl->type(), static_cast<c10::DeviceIndex>(idx));
}
}

View File

@ -1273,11 +1273,11 @@ inline at::TensorOptions unpack_TensorOptions(
at::TensorOptions result;
auto maybe_requires_grad = std::get<0>(tuple);
if (maybe_requires_grad.has_value()) {
result = result.requires_grad(maybe_requires_grad.value());
result = result.requires_grad(maybe_requires_grad);
}
auto maybe_memory_format = std::get<1>(tuple);
if (maybe_memory_format.has_value()) {
result = result.memory_format(maybe_memory_format.value());
result = result.memory_format(maybe_memory_format);
}
auto maybe_device = std::get<2>(tuple);
if (maybe_device.has_value()) {
@ -1290,11 +1290,11 @@ inline at::TensorOptions unpack_TensorOptions(
}
auto maybe_layout = std::get<4>(tuple);
if (maybe_layout.has_value()) {
result = result.layout(maybe_layout.value());
result = result.layout(maybe_layout);
}
auto maybe_pinned_memory = std::get<5>(tuple);
if (maybe_pinned_memory.has_value()) {
result = result.pinned_memory(maybe_pinned_memory.value());
result = result.pinned_memory(maybe_pinned_memory);
}
return result;
}

View File

@ -1412,7 +1412,7 @@ class StorageOverlapChecker {
*/
std::vector<Tensor> _tensors_from(
const std::vector<PyObject*>& objects,
int64_t size) {
size_t size) {
std::vector<Tensor> tensors;
tensors.reserve(size);
std::transform(

View File

@ -335,7 +335,7 @@ PyTypeObject NodeBaseType = {
"torch._C._NodeBase", /* tp_name */
sizeof(NodeBase), /* tp_basicsize */
0, /* tp_itemsize */
(destructor)NodeBase_dealloc, /* tp_dealloc */
NodeBase_dealloc, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */

View File

@ -421,6 +421,7 @@ std::shared_ptr<AOTIModelContainerRunner> AOTIPythonKernelHolder::
"AOTI for eager does not support ",
c10::DeviceTypeName(device_.type()),
" now.");
// NOLINTNEXTLINE(bugprone-branch-clone)
if (device_.type() == c10::DeviceType::CUDA) {
#ifdef USE_CUDA
return std::make_shared<AOTIModelContainerRunnerCuda>(so_path);

View File

@ -445,10 +445,9 @@ void OSSProxyExecutor::get_input_info_from_serialized(
// If an argument is not filled and has a default value, we should
// also prefill the default value.
for (size_t index = 0; index < schema_args.size(); index++) {
if (!filled[index] && schema_args[index].default_value()) {
// @lint-ignore CLANGTIDY bugprone-unchecked-optional-access
auto default_value = *schema_args[index].default_value();
op_kernel.stack_.at(index) = default_value;
auto default_value = schema_args[index].default_value();
if (!filled[index] && default_value.has_value()) {
op_kernel.stack_.at(index) = std::move(default_value.value());
}
}
}

View File

@ -59,8 +59,10 @@ AOTITorchError aoti_torch_get_current_xpu_stream(
}
AOTITorchError aoti_torch_get_current_xpu_device(int32_t* device_index) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *device_index = static_cast<int32_t>(c10::xpu::current_device()); });
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
*device_index =
static_cast<int32_t>(static_cast<uint16_t>(c10::xpu::current_device()));
});
}
AOTITorchError aoti_torch_set_current_xpu_device(const int32_t& device_index) {
@ -70,7 +72,8 @@ AOTITorchError aoti_torch_set_current_xpu_device(const int32_t& device_index) {
AOTITorchError aoti_torch_get_current_sycl_queue(void** ret) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
int32_t device_index = static_cast<int32_t>(c10::xpu::current_device());
int32_t device_index =
static_cast<int32_t>(static_cast<uint16_t>(c10::xpu::current_device()));
*ret = &(at::xpu::getCurrentXPUStream(device_index).queue());
});
}

View File

@ -438,12 +438,12 @@ PyObject* launch_kernel(PyObject* self, PyObject* args) {
std::array<PyMethodDef, 2> StaticCudaLauncherMethods = {
PyMethodDef{
"_launch_kernel",
(PyCFunction)launch_kernel,
launch_kernel,
METH_VARARGS,
"Statically launch triton compiled CUDA kernels"},
PyMethodDef{
"_load_kernel",
(PyCFunction)load_kernel,
load_kernel,
METH_VARARGS,
"Load CUDA kernel from cubin file"}};

View File

@ -399,7 +399,7 @@ static void InferShapeTypeForUninitializedOutput(
} else {
const_node->t_(attr::value, at::zeros({}, elem_type));
const_node->output()->setType(
TensorType::create(*(output_type->scalarType()), at::kCPU, {}, {}));
TensorType::create(output_type->scalarType(), at::kCPU, {}, {}));
}
} else if (auto output_type = other_output->type()->cast<ListType>()) {
TypePtr elem = output_type->getElementType();

View File

@ -32,8 +32,8 @@ class EventHandlers {
}
static EventHandlers& get() noexcept {
static auto ehsPtr = new EventHandlers();
return *ehsPtr;
static auto ehs = EventHandlers();
return ehs;
}
private:

View File

@ -172,13 +172,13 @@ ScalarType infer_scalar_type(PyObject* obj) {
Py_TYPE(obj)->tp_name,
"'");
if (PySequence_Check(obj)) {
std::optional<ScalarType> scalarType;
auto length = PySequence_Length(obj);
if (length < 0)
throw python_error();
// match NumPy semantics, except use default tensor type instead of double.
if (length == 0)
return torch::tensors::get_default_scalar_type();
ScalarType scalarType{};
for (const auto i : c10::irange(length)) {
THPObjectPtr handle(PySequence_GetItem(obj, i));
if (!handle)
@ -187,16 +187,15 @@ ScalarType infer_scalar_type(PyObject* obj) {
TORCH_CHECK_TYPE(
cur_item != obj, "new(): self-referential lists are incompatible");
ScalarType item_scalarType = infer_scalar_type(cur_item);
scalarType = (scalarType) ? at::promoteTypes(*scalarType, item_scalarType)
: item_scalarType;
scalarType = (i > 0) ? at::promoteTypes(scalarType, item_scalarType)
: item_scalarType;
if (scalarType == ScalarType::ComplexDouble) {
// this won't change (unless we hit undefined, but that will fail
// later).
return *scalarType;
return scalarType;
}
}
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
return *scalarType;
return scalarType;
}
TORCH_CHECK(false, "Could not infer dtype of ", Py_TYPE(obj)->tp_name);
}