[Reland] [1/N] fix clang-tidy warnings in torch/csrc (#108114)

Reland of PR #107648 with auto replaced with Py_ssize_t in eval_frame.c. This PR applies fixes to some found issues by clang-tidy in torch/csrc.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108114
Approved by: https://github.com/Skylion007
This commit is contained in:
cyy
2023-08-30 17:11:11 +00:00
committed by PyTorch MergeBot
parent 7be233f3a5
commit 01fc6466d1
13 changed files with 72 additions and 90 deletions

View File

@ -90,7 +90,7 @@ struct C10_API Storage {
return storage_impl_->mutable_data();
}
at::DataPtr& mutable_data_ptr() {
at::DataPtr& mutable_data_ptr() const {
return storage_impl_->mutable_data_ptr();
}

View File

@ -124,24 +124,18 @@ static PyObject* THPModule_errorIfAnyWorkerFails(
PyObject* module,
PyObject* noargs) {
HANDLE_TH_ERRORS
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int error;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::set<pid_t>* pid_set;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
pid_t worker_pid;
siginfo_t infop;
// Only check the pids we care about
for (auto& w : worker_pids) {
pid_set = &(w.second);
for (auto pid_it = pid_set->begin(); pid_it != pid_set->end(); ++pid_it) {
worker_pid = *pid_it;
auto& pid_set = w.second;
for (auto worker_pid : pid_set) {
// Use waitid rather than waitpid so that we can set NOWAIT, and that
// Python and other handlers can get whatever info they want about the
// child.
siginfo_t infop{};
infop.si_pid = 0;
error = waitid(P_PID, worker_pid, &infop, WEXITED | WNOHANG | WNOWAIT);
auto error =
waitid(P_PID, worker_pid, &infop, WEXITED | WNOHANG | WNOWAIT);
// ignore errors and case with no waitable child
if (error < 0 || infop.si_pid == 0)
continue;
@ -154,7 +148,7 @@ static PyObject* THPModule_errorIfAnyWorkerFails(
<< "num_workers=0 may give better error trace.";
// This is necessary. Otherwise, the runtime error will kill the other
// workers, and trigger this again.
pid_set->clear();
pid_set.clear();
throw std::runtime_error(oss.str());
} else if (
infop.si_code == CLD_KILLED ||
@ -168,7 +162,7 @@ static PyObject* THPModule_errorIfAnyWorkerFails(
}
// This is necessary. Otherwise, the runtime error will kill the other
// workers, and trigger this again.
pid_set->clear();
pid_set.clear();
throw std::runtime_error(oss.str());
}
}

View File

@ -66,21 +66,22 @@ PyObject* THPDevice_pynew(
return THPDevice_New(device);
} else if (r.idx == 1) {
auto as_device = r.device(0); // this works, because device can take strings
auto device_type = r.string(0);
if (as_device.has_index()) {
auto device_type = r.string(0);
throw std::runtime_error(
"type (string) must not include an index because index "
"was passed explicitly: " +
device_type);
}
int32_t device_index = -1;
int64_t device_index = -1;
if (!r.isNone(1)) {
device_index = r.toInt64(1);
// -1 is allowed in ATen/C++, to mean the default device, but not in
// Python.
TORCH_CHECK(device_index >= 0, "Device index must not be negative");
}
at::Device device(as_device.type(), device_index);
at::Device device(
as_device.type(), static_cast<c10::DeviceIndex>(device_index));
return THPDevice_New(device);
}
Py_RETURN_NONE;
@ -163,8 +164,8 @@ PyObject* THPDevice_reduce(PyObject* _self, PyObject* noargs) {
std::ostringstream oss;
oss << self->device.type();
if (self->device.has_index()) {
args = THPObjectPtr{
Py_BuildValue("(si)", oss.str().c_str(), self->device.index())};
args = THPObjectPtr{Py_BuildValue(
"(si)", oss.str().c_str(), static_cast<int>(self->device.index()))};
} else {
args = THPObjectPtr{Py_BuildValue("(s)", oss.str().c_str())};
}

View File

@ -5,6 +5,7 @@
#include <ATen/Device.h>
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct TORCH_API THPDevice {
PyObject_HEAD at::Device device;
};

View File

@ -39,7 +39,8 @@ PyObject* THPDtype_is_floating_point(THPDtype* self, PyObject* noargs) {
PyObject* THPDtype_itemsize(THPDtype* self, PyObject* noargs) {
HANDLE_TH_ERRORS
return THPUtils_packInt64(scalarTypeToTypeMeta(self->scalar_type).itemsize());
return THPUtils_packUInt64(
scalarTypeToTypeMeta(self->scalar_type).itemsize());
END_HANDLE_TH_ERRORS
}
@ -115,8 +116,7 @@ static PyMethodDef THPDtype_methods[] = {
};
PyObject* THPDtype_repr(THPDtype* self) {
std::string name = self->name;
return THPUtils_packString("torch." + name);
return THPUtils_packString(std::string("torch.") + self->name);
}
PyTypeObject THPDtypeType = {

View File

@ -119,8 +119,7 @@ static PyObject* THPGenerator_setState(PyObject* _self, PyObject* _new_state) {
}
uint64_t unpack_uint64(PyObject* pyobj) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
uint64_t unsigned_obj;
uint64_t unsigned_obj = 0;
try {
// First try to interpret as unsigned long
unsigned_obj = THPUtils_unpackUInt64(pyobj);
@ -223,11 +222,7 @@ static PyMethodDef THPGenerator_methods[] = {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
static struct PyMemberDef THPGenerator_members[] = {
{(char*)"_cdata",
T_ULONGLONG,
offsetof(THPGenerator, cdata),
READONLY,
nullptr},
{"_cdata", T_ULONGLONG, offsetof(THPGenerator, cdata), READONLY, nullptr},
{nullptr}};
PyTypeObject THPGeneratorType = {

View File

@ -38,7 +38,7 @@ PyObject* THPSize_New(const torch::autograd::Variable& var) {
return self.release();
}
PyObject* THPSize_NewFromSizes(int dim, const int64_t* sizes) {
PyObject* THPSize_NewFromSizes(int64_t dim, const int64_t* sizes) {
auto self = THPObjectPtr(THPSizeType.tp_alloc(&THPSizeType, dim));
if (!self)
throw python_error();
@ -49,7 +49,8 @@ PyObject* THPSize_NewFromSizes(int dim, const int64_t* sizes) {
PyObject* THPSize_NewFromSymSizes(const at::Tensor& self_) {
auto sym_sizes = self_.sym_sizes();
auto ret = THPObjectPtr(THPSizeType.tp_alloc(&THPSizeType, sym_sizes.size()));
auto ret = THPObjectPtr(THPSizeType.tp_alloc(
&THPSizeType, static_cast<Py_ssize_t>(sym_sizes.size())));
if (!ret)
throw python_error();
@ -70,8 +71,8 @@ PyObject* THPSize_NewFromSymSizes(const at::Tensor& self_) {
// Otherwise, we know that it is an actual integer value.
auto m = si.maybe_as_int();
if (torch::jit::tracer::isTracing()) {
PyObject* py_size_tensor =
THPVariable_Wrap(torch::jit::tracer::getSizeOf(self_, i));
PyObject* py_size_tensor = THPVariable_Wrap(
torch::jit::tracer::getSizeOf(self_, static_cast<int64_t>(i)));
if (!py_size_tensor)
throw python_error();
PyTuple_SET_ITEM(ret.get(), i, py_size_tensor);

View File

@ -9,7 +9,7 @@ extern PyTypeObject THPSizeType;
#define THPSize_Check(obj) (Py_TYPE(obj) == &THPSizeType)
PyObject* THPSize_New(const torch::autograd::Variable& t);
PyObject* THPSize_NewFromSizes(int dim, const int64_t* sizes);
PyObject* THPSize_NewFromSizes(int64_t dim, const int64_t* sizes);
PyObject* THPSize_NewFromSymSizes(const at::Tensor& t);
void THPSize_init(PyObject* module);

View File

@ -57,7 +57,7 @@ static PyObject* THPStorage_dataPtr(PyObject* self, PyObject* noargs) {
TORCH_CHECK(
!invalid,
"Attempted to access the data pointer on an invalid python storage.")
return PyLong_FromVoidPtr(const_cast<void*>(self_.data()));
return PyLong_FromVoidPtr(self_.mutable_data());
END_HANDLE_TH_ERRORS
}
@ -109,7 +109,6 @@ static PyObject* THPStorage_new(PyObject* self, PyObject* noargs) {
allocator,
/*resizable=*/true);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
return THPStorage_New(std::move(new_storage));
END_HANDLE_TH_ERRORS
}
@ -168,7 +167,7 @@ static PyObject* THPStorage_resize_(PyObject* self, PyObject* number_arg) {
auto new_tensor = at::empty(src_tensor.sizes(), src_tensor.options());
new_tensor.copy_(src_tensor);
storage.set_data_ptr_noswap(
std::move(const_cast<at::DataPtr&>(new_tensor.storage().data_ptr())));
std::move(new_tensor.storage().mutable_data_ptr()));
storage.unsafeGetStorageImpl()->set_allocator(
new_tensor.storage().unsafeGetStorageImpl()->allocator());
storage.set_nbytes(new_tensor.storage().nbytes());
@ -224,6 +223,7 @@ static PyObject* THPStorage_fromBuffer(
args,
keywds,
argtypes,
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<char**>(kwlist),
&obj,
&byte_order_str,
@ -248,8 +248,7 @@ static PyObject* THPStorage_fromBuffer(
"function missing required argument 'byte_order' (pos 2)");
size_t element_size = c10::elementSize(scalar_type);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool do_byte_swap;
bool do_byte_swap = false;
if (!is_endian_independent) {
if (strcmp(byte_order_str, "native") == 0) {
do_byte_swap = false;
@ -283,8 +282,7 @@ static PyObject* THPStorage_fromBuffer(
return nullptr;
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
size_t size_bytes;
size_t size_bytes = 0;
if (count < 0) {
if ((buffer.len - offset) % element_size != 0) {
PyErr_SetString(
@ -297,7 +295,7 @@ static PyObject* THPStorage_fromBuffer(
return nullptr;
}
size_bytes = buffer.len - offset;
count = size_bytes / element_size;
count = static_cast<Py_ssize_t>(size_bytes / element_size);
} else {
size_bytes = count * element_size;
}
@ -400,8 +398,7 @@ static PyObject* THPStorage_fromFile(
PyObject* args,
PyObject* keywds) {
HANDLE_TH_ERRORS
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
const char* filename;
const char* filename = nullptr;
Py_ssize_t nbytes = 0;
int shared = 0;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
@ -410,6 +407,7 @@ static PyObject* THPStorage_fromFile(
args,
keywds,
"s|in",
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<char**>(kwlist),
&filename,
&shared,

View File

@ -196,10 +196,9 @@ static PyObject* THPStorage_shareFd(PyObject* self, PyObject* noargs) {
const auto& storage = THPStorage_Unpack(self);
TORCH_CHECK(
storage.device_type() == at::kCPU, "_share_fd_: only available on CPU");
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
at::MapAllocator* ctx;
at::MapAllocator* ctx = at::MapAllocator::fromDataPtr(storage.data_ptr());
// Storage is already in shared memory, just return a handle
if ((ctx = at::MapAllocator::fromDataPtr(storage.data_ptr()))) {
if (ctx) {
// done
} else {
at::Storage new_storage(at::new_shm_fd_storage(storage.nbytes()));
@ -248,11 +247,10 @@ static PyObject* THPStorage_newSharedFd(PyObject* _unused, PyObject* args) {
"a file descriptor (int) and storage size (int)");
return nullptr;
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int fd;
int tmp_fd = (int)THPUtils_unpackLong(_tmp_fd);
int64_t size = THPUtils_unpackLong(_size);
if ((fd = dup(tmp_fd)) == -1) {
int fd = dup(tmp_fd);
if (fd == -1) {
THPUtils_setError("could not duplicate a shared memory file descriptor");
return nullptr;
}
@ -405,16 +403,14 @@ static PyObject* THPStorage_releaseIPCCounter(
#ifdef USE_CUDA
static std::string THPStorage_bytesAsHandleString(PyObject* handle) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
char* buffer;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Py_ssize_t handle_size;
char* buffer = nullptr;
Py_ssize_t handle_size = 0;
if (PyBytes_AsStringAndSize(handle, &buffer, &handle_size) == -1) {
// NOLINTNEXTLINE(bugprone-string-constructor)
return nullptr;
THPUtils_assertRet(
"", handle_size == CUDA_IPC_HANDLE_SIZE, "incorrect handle");
}
// NOLINTNEXTLINE(bugprone-string-constructor)
THPUtils_assert(handle_size == CUDA_IPC_HANDLE_SIZE, "incorrect handle size");
THPUtils_assertRet(
"", handle_size == CUDA_IPC_HANDLE_SIZE, "incorrect handle size");
return std::string(buffer, handle_size);
}
#endif
@ -457,6 +453,9 @@ static PyObject* THPStorage_newSharedCuda(PyObject* _unused, PyObject* args) {
// Ensure that producer prepared all tensor's data
std::string s_ipc_event_handle =
THPStorage_bytesAsHandleString(_event_handle);
if (s_ipc_event_handle.empty()) {
return nullptr;
}
auto ipc_event_handle = reinterpret_cast<const cudaIpcEventHandle_t*>(
s_ipc_event_handle.c_str());
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
@ -467,12 +466,14 @@ static PyObject* THPStorage_newSharedCuda(PyObject* _unused, PyObject* args) {
}
std::string s_handle = THPStorage_bytesAsHandleString(_handle);
if (s_handle.empty()) {
return nullptr;
}
std::shared_ptr<void> basePtr =
c10::cuda::CUDACachingAllocator::getIpcDevPtr(s_handle);
// Offset the basePtr to reconstruct the real storage
// devPtr = basePtr + storage_offset
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
void* devPtr = basePtr.get();
devPtr = (char*)devPtr + storage_offset_bytes;

View File

@ -23,6 +23,7 @@ static PyObject* THPStream_pynew(
args,
kwargs,
"|LLL",
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<char**>(kwlist),
&stream_id,
&device_index,
@ -53,7 +54,8 @@ PyObject* THPStream_Wrap(const c10::Stream& stream) {
THPStream* self = (THPStream*)ptr.get();
self->stream_id = stream.id();
self->device_index = stream.device_index();
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
self->device_index = static_cast<int64_t>(stream.device_index());
self->device_type = static_cast<int64_t>(stream.device_type());
return ptr.release();
END_HANDLE_TH_ERRORS
@ -65,11 +67,9 @@ static void THPStream_dealloc(THPStream* self) {
static PyObject* THPStream_get_device(THPStream* self, void* unused) {
HANDLE_TH_ERRORS
return THPDevice_New(c10::Stream::unpack3(
self->stream_id,
self->device_index,
static_cast<c10::DeviceType>(self->device_type))
.device());
return THPDevice_New(c10::Device(
static_cast<c10::DeviceType>(self->device_type),
static_cast<c10::DeviceIndex>(self->device_index)));
END_HANDLE_TH_ERRORS
}
@ -84,17 +84,17 @@ static PyObject* THPStream_eq(THPStream* self, THPStream* other) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
static struct PyMemberDef THPStream_members[] = {
{(char*)"stream_id",
{"stream_id",
T_LONGLONG,
offsetof(THPStream, stream_id),
READONLY,
nullptr},
{(char*)"device_index",
{"device_index",
T_LONGLONG,
offsetof(THPStream, device_index),
READONLY,
nullptr},
{(char*)"device_type",
{"device_type",
T_LONGLONG,
offsetof(THPStream, device_type),
READONLY,
@ -108,7 +108,7 @@ static struct PyGetSetDef THPStream_properties[] = {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
static PyMethodDef THPStream_methods[] = {
{(char*)"__eq__", (PyCFunction)THPStream_eq, METH_O, nullptr},
{"__eq__", (PyCFunction)THPStream_eq, METH_O, nullptr},
{nullptr}};
PyTypeObject THPStreamType = {

View File

@ -131,7 +131,7 @@ PyObject* THCPModule_getDevice_wrap(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
torch::utils::cuda_lazy_init();
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
auto device = static_cast<int>(c10::cuda::current_device());
auto device = static_cast<int32_t>(c10::cuda::current_device());
return THPUtils_packInt32(device);
END_HANDLE_TH_ERRORS
}
@ -269,8 +269,7 @@ PyObject* THCPModule_setStream_wrap(
auto stream = at::cuda::CUDAStream::unpack3(
stream_id, device_index, static_cast<c10::DeviceType>(device_type));
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
auto device = static_cast<int>(c10::cuda::current_device());
auto device = c10::cuda::current_device();
if (device != stream.device_index()) {
THCPModule_setDevice(stream.device_index());
}
@ -310,9 +309,7 @@ PyObject* THCPModule_cudaCachingAllocator_raw_alloc(
return nullptr;
}
auto size = PyLong_AsSsize_t(size_o);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
cudaStream_t stream = static_cast<cudaStream_t>(PyLong_AsVoidPtr(stream_o));
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
void* mem =
c10::cuda::CUDACachingAllocator::raw_alloc_with_stream(size, stream);
return PyLong_FromVoidPtr(mem);

View File

@ -184,7 +184,7 @@ static PyObject* profiler_end_hook = NULL;
static PyObject* guard_profiler_name_str = NULL; /* cached py str */
// Points to the extra scratch space on the code object
static size_t extra_index = -1;
static Py_ssize_t extra_index = -1;
static Py_tss_t eval_frame_callback_key = Py_tss_NEEDS_INIT;
@ -341,8 +341,7 @@ static struct PyGetSetDef CacheEntry_properties[] = {
static PyObject* cache_entry_new(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
CacheEntry *self;
self = (CacheEntry*) type->tp_alloc(type, 0);
CacheEntry *self = (CacheEntry*) type->tp_alloc(type, 0);
if (self != NULL) {
// The corresponding decrefs for Py_None are in cache_entry_init.
Py_INCREF(Py_None);
@ -586,7 +585,7 @@ Debugger helper functions.
PyObject* _debug_get_cache_entry_list(PyObject* self, PyObject* args) {
// TODO(anijain2305) - CacheEntry being the first class Python object might
// obviate the need of this function. Revisit.
PyObject* object;
PyObject* object = NULL;
if (!PyArg_ParseTuple(args, "O", &object)) {
return NULL;
}
@ -688,7 +687,7 @@ static PyObject* lookup(CacheEntry* e, THP_EVAL_API_FRAME_OBJECT *frame, CacheEn
PyObject* valid = PyObject_CallOneArg(e->check_fn, f_locals);
if (unlikely(valid == NULL)) {
if (guard_error_hook != NULL) {
PyObject *type, *value, *traceback;
PyObject *type = NULL, *value = NULL, *traceback = NULL;
PyErr_Fetch(&type, &value, &traceback);
PyObject* r = call_guard_fail_hook(guard_error_hook, e, index, f_locals);
if (r == NULL) {
@ -729,24 +728,13 @@ inline static PyObject* eval_custom_code(
THP_EVAL_API_FRAME_OBJECT* frame,
PyCodeObject* code,
int throw_flag) {
Py_ssize_t ncells = 0;
Py_ssize_t nfrees = 0;
Py_ssize_t nlocals_new = code->co_nlocals;
Py_ssize_t nlocals_old = frame->f_code->co_nlocals;
ncells = PyCode_GetNCellvars(code);
nfrees = PyCode_GetNFreevars(code);
DEBUG_NULL_CHECK(tstate);
DEBUG_NULL_CHECK(frame);
DEBUG_NULL_CHECK(code);
DEBUG_CHECK(nlocals_new >= nlocals_old);
#if IS_PYTHON_3_11_PLUS
DEBUG_CHECK(ncells == frame->f_code->co_ncellvars);
DEBUG_CHECK(nfrees == frame->f_code->co_nfreevars);
// Generate Python function object and _PyInterpreterFrame in a way similar to
// https://github.com/python/cpython/blob/e715da6db1d1d70cd779dc48e1ba8110c51cc1bf/Python/ceval.c#L1130
#if IS_PYTHON_3_12_PLUS
@ -829,6 +817,12 @@ inline static PyObject* eval_custom_code(
Py_DECREF(name_to_idx);
#else
Py_ssize_t nlocals_new = code->co_nlocals;
Py_ssize_t nlocals_old = frame->f_code->co_nlocals;
DEBUG_CHECK(nlocals_new >= nlocals_old);
Py_ssize_t ncells = PyCode_GetNCellvars(code);
Py_ssize_t nfrees = PyCode_GetNFreevars(code);
DEBUG_CHECK(ncells == PyTuple_GET_SIZE(frame->f_code->co_cellvars));
DEBUG_CHECK(nfrees == PyTuple_GET_SIZE(frame->f_code->co_freevars));