[8/N] Fix extra warnings brought by clang-tidy-17 (#139151)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139151
Approved by: https://github.com/ezyang

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
This commit is contained in:
cyyever
2024-10-30 14:20:06 +00:00
committed by PyTorch MergeBot
parent 44257c063e
commit 456c87c8a2
31 changed files with 72 additions and 54 deletions

View File

@ -222,9 +222,11 @@ exclude_patterns = [
# caffe2_pb.h, otherwise we'd have to build protos as part of this CI job.
# CUDA files are also excluded.
'**/fb/**',
'**/generated/**',
'**/*pb.h',
'c10/xpu/**/*.h',
'c10/xpu/**/*.cpp',
'c10/benchmark/intrusive_ptr_benchmark.cpp',
'c10/cuda/CUDAAlgorithm.h',
'c10/util/complex_math.h',
'c10/util/complex_utils.h',
@ -250,6 +252,8 @@ exclude_patterns = [
'torch/csrc/inductor/aoti_torch/c/shim.h',
'torch/csrc/jit/**/*',
'torch/csrc/jit/serialization/mobile_bytecode_generated.h',
'torch/csrc/utils/pythoncapi_compat.h',
'torch/csrc/utils/throughput_benchmark-inl.h',
]
init_command = [
'python3',

View File

@ -30,8 +30,8 @@ struct Array {
Array() = default;
Array(const Array&) = default;
Array& operator=(const Array&) = default;
Array(Array&&) = default;
Array& operator=(Array&&) = default;
Array(Array&&) noexcept = default;
Array& operator=(Array&&) noexcept = default;
~Array() = default;
#endif
static constexpr int size() {

View File

@ -24,8 +24,8 @@ static PyObject* THPStream_pynew(
HANDLE_TH_ERRORS
int64_t stream_id = -1;
int64_t device_type = 0;
int64_t device_index = 0;
c10::DeviceType device_type{};
c10::DeviceIndex device_index{};
int64_t priority = 0;
static torch::PythonArgParser parser({
@ -42,27 +42,25 @@ static PyObject* THPStream_pynew(
auto default_accelerator = at::getAccelerator(false);
auto device = r.deviceOptional(0);
if (device.has_value()) {
device_type = static_cast<int64_t>(device->type());
device_index = static_cast<int64_t>(device->index());
device_type = device->type();
device_index = device->index();
// Initialize device guard if device is not None.
device_guard_ptr = std::make_unique<c10::DeviceGuard>(device.value());
} else {
// If device is None, we will use the current accelerator and index.
// If the current accelerator is not set, we will use the CPU as device
// type.
device_type = static_cast<int64_t>(
default_accelerator.value_or(c10::DeviceType::CPU));
c10::impl::VirtualGuardImpl impl{
static_cast<c10::DeviceType>(device_type)};
device_type = default_accelerator.value_or(c10::DeviceType::CPU);
c10::impl::VirtualGuardImpl impl{device_type};
const auto current_device = impl.getDevice();
device_index = current_device.index();
}
priority = r.toInt64WithDefault(1, 0);
} else if (r.idx == 1) {
stream_id = r.toInt64WithDefault(0, -1);
device_index = r.toInt64WithDefault(1, 0);
device_type =
r.toInt64WithDefault(2, static_cast<int64_t>(c10::DeviceType::CPU));
device_index = static_cast<c10::DeviceIndex>(r.toInt64WithDefault(1, 0));
device_type = static_cast<c10::DeviceType>(
r.toInt64WithDefault(2, static_cast<int64_t>(c10::DeviceType::CPU)));
priority = r.toInt64WithDefault(3, 0);
} else {
TORCH_CHECK(
@ -84,19 +82,16 @@ static PyObject* THPStream_pynew(
// manage the lifetime of streams.
std::optional<c10::Stream> stream_opt;
if (r.idx == 0) {
c10::impl::VirtualGuardImpl impl{static_cast<c10::DeviceType>(device_type)};
c10::impl::VirtualGuardImpl impl{device_type};
stream_opt = impl.getNewStream(
c10::Device(static_cast<c10::DeviceType>(device_type), device_index),
static_cast<int>(priority));
c10::Device(device_type, device_index), static_cast<int>(priority));
} else {
stream_opt = c10::Stream::unpack3(
stream_id,
static_cast<c10::DeviceIndex>(device_index),
static_cast<c10::DeviceType>(device_type));
stream_opt = c10::Stream::unpack3(stream_id, device_index, device_type);
}
TORCH_CHECK(stream_opt.has_value(), "Failed to create stream");
self->stream_id = static_cast<int64_t>(stream_opt->id());
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
self->device_index = static_cast<int64_t>(stream_opt->device_index());
self->device_type = static_cast<int64_t>(stream_opt->device_type());
@ -139,7 +134,7 @@ static PyObject* THPStream_query(PyObject* _self, PyObject* noargs) {
return PyBool_FromLong(c10::Stream::unpack3(
self->stream_id,
self->device_index,
static_cast<c10::DeviceIndex>(self->device_index),
static_cast<c10::DeviceType>(self->device_type))
.query());
@ -153,7 +148,7 @@ static PyObject* THPStream_synchronize(PyObject* _self, PyObject* noargs) {
c10::Stream::unpack3(
self->stream_id,
self->device_index,
static_cast<c10::DeviceIndex>(self->device_index),
static_cast<c10::DeviceType>(self->device_type))
.synchronize();
}
@ -167,7 +162,7 @@ static PyObject* THPStream_wait_event(PyObject* _self, PyObject* _event) {
auto event = (THPEvent*)_event;
c10::Stream::unpack3(
self->stream_id,
self->device_index,
static_cast<c10::DeviceIndex>(self->device_index),
static_cast<c10::DeviceType>(self->device_type))
.wait(event->event);
}
@ -184,11 +179,11 @@ static PyObject* THPStream_wait_stream(PyObject* _self, PyObject* _other) {
c10::EventFlag::PYTORCH_DEFAULT);
new_event.record(c10::Stream::unpack3(
other_stream->stream_id,
other_stream->device_index,
static_cast<c10::DeviceIndex>(other_stream->device_index),
static_cast<c10::DeviceType>(other_stream->device_type)));
c10::Stream::unpack3(
self->stream_id,
self->device_index,
static_cast<c10::DeviceIndex>(self->device_index),
static_cast<c10::DeviceType>(self->device_type))
.wait(new_event);
}
@ -229,7 +224,7 @@ static PyObject* THPStream_record_event(
TORCH_CHECK(new_event, "event must not be null");
new_event->event.record(c10::Stream::unpack3(
self->stream_id,
self->device_index,
static_cast<c10::DeviceIndex>(self->device_index),
static_cast<c10::DeviceType>(self->device_type)));
return (PyObject*)new_event;
END_HANDLE_TH_ERRORS

View File

@ -17,9 +17,7 @@
#include <iterator>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace torch::python {
namespace detail {

View File

@ -16,6 +16,7 @@ struct PyAnomalyMetadata : public AnomalyMetadata {
// NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
dict_ = PyDict_New();
}
// NOLINTNEXTLINE(bugprone-exception-escape)
~PyAnomalyMetadata() override {
// If python is already dead, leak the wrapped python objects
if (Py_IsInitialized()) {

View File

@ -1,5 +1,6 @@
#pragma once
#include <torch/csrc/utils/python_compat.h>
namespace torch::autograd {
PyMethodDef* get_nested_functions_manual();

View File

@ -886,9 +886,11 @@ std::unique_ptr<ViewFunc> ChainedViewFunc::clone_and_set(
if (symints.has_value()) {
TORCH_INTERNAL_ASSERT(symints->size() == num_symints());
first_symints = std::vector<c10::SymInt>(
symints->begin(), symints->begin() + first->num_symints());
symints->begin(),
symints->begin() + static_cast<std::ptrdiff_t>(first->num_symints()));
second_symints = std::vector<c10::SymInt>(
symints->begin() + first->num_symints(), symints->end());
symints->begin() + static_cast<std::ptrdiff_t>(first->num_symints()),
symints->end());
}
std::optional<std::vector<at::Tensor>> first_tensors;
@ -896,9 +898,11 @@ std::unique_ptr<ViewFunc> ChainedViewFunc::clone_and_set(
if (tensors.has_value()) {
TORCH_INTERNAL_ASSERT(tensors->size() == num_tensors());
first_tensors = std::vector<at::Tensor>(
tensors->begin(), tensors->begin() + first->num_tensors());
tensors->begin(),
tensors->begin() + static_cast<std::ptrdiff_t>(first->num_tensors()));
second_tensors = std::vector<at::Tensor>(
tensors->begin() + first->num_tensors(), tensors->end());
tensors->begin() + static_cast<std::ptrdiff_t>(first->num_tensors()),
tensors->end());
}
return std::make_unique<ChainedViewFunc>(

View File

@ -31,8 +31,8 @@ void THCPGraph_init(PyObject* module) {
"capture_begin",
[](::at::cuda::CUDAGraph& self,
std::optional<c10::cuda::MempoolId_t> pool_opt,
std::string capture_error_mode) {
cudaStreamCaptureMode capture_mode;
const std::string& capture_error_mode) {
cudaStreamCaptureMode capture_mode{};
c10::cuda::MempoolId_t pool = pool_opt.has_value()
? pool_opt.value()
: c10::cuda::MempoolId_t{0, 0};

View File

@ -150,8 +150,8 @@ PyObject* THCPModule_canDeviceAccessPeer_wrap(PyObject* self, PyObject* args) {
THPUtils_checkLong(arg1), "invalid argument to canDeviceAccessPeer");
TORCH_CHECK(
THPUtils_checkLong(arg2), "invalid argument to canDeviceAccessPeer");
int64_t device = THPUtils_unpackLong(arg1);
int64_t peer_device = THPUtils_unpackLong(arg2);
auto device = THPUtils_unpackDeviceIndex(arg1);
auto peer_device = THPUtils_unpackDeviceIndex(arg2);
torch::utils::device_lazy_init(at::kCUDA);
auto can_access = at::cuda::canDeviceAccessPeer(device, peer_device);
@ -1719,7 +1719,7 @@ PyObject* THCPModule_cuda_tunableop_get_results(
for (const auto& [op_sig, kernelmap] : results) {
result_size += kernelmap.size();
}
THPObjectPtr outer_tuple(PyTuple_New(result_size));
THPObjectPtr outer_tuple(PyTuple_New(static_cast<Py_ssize_t>(result_size)));
if (!outer_tuple)
throw python_error();
size_t result_index = 0;
@ -1759,7 +1759,8 @@ PyObject* THCPModule_cuda_tunableop_get_validators(
auto validators = at::cuda::tunable::getTuningContext()
->GetTuningResultsValidator()
.GetAllValidators();
THPObjectPtr outer_tuple(PyTuple_New(validators.size()));
THPObjectPtr outer_tuple(
PyTuple_New(static_cast<Py_ssize_t>(validators.size())));
if (!outer_tuple)
throw python_error();
size_t validator_index = 0;

View File

@ -71,6 +71,7 @@ static PyObject* THCPStream_pynew(
THCPStream* self = (THCPStream*)ptr.get();
self->stream_id = static_cast<int64_t>(stream.id());
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
self->device_index = static_cast<int64_t>(stream.device_index());
self->device_type = static_cast<int64_t>(stream.device_type());
new (&self->cuda_stream) at::cuda::CUDAStream(stream);

View File

@ -265,7 +265,9 @@ struct NcclCommList {
NcclCommList(const std::vector<int>& devices)
: comms(new ncclComm_t[devices.size()]), ndevices(devices.size()) {
NCCL_CHECK(ncclCommInitAll(
to_nccl_comm(comms.get()), devices.size(), devices.data()));
to_nccl_comm(comms.get()),
static_cast<int>(devices.size()),
devices.data()));
}
NcclCommList(NcclCommList&& foo) = default;
~NcclCommList() {

View File

@ -31,8 +31,8 @@ typedef void* ncclComm_t;
/** redefine nccl unique ID in torch scope. this should be identical to native
* nccl impp. */
#define NCCL_UNIQUE_ID_BYTES 128
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
typedef struct {
// NOLINTNEXTLINE(*array)
char internal[NCCL_UNIQUE_ID_BYTES];
} ncclUniqueId;

View File

@ -28,7 +28,7 @@ void initCommMethods(PyObject* module) {
py::call_guard<py::gil_scoped_release>())
.def(
"_broadcast",
[](at::Tensor& tensor, std::vector<int64_t> devices) {
[](at::Tensor& tensor, const std::vector<int64_t>& devices) {
return broadcast(tensor, devices);
},
py::call_guard<py::gil_scoped_release>(),
@ -46,7 +46,7 @@ void initCommMethods(PyObject* module) {
"_scatter",
[](at::Tensor& tensor,
std::vector<int64_t>& devices,
std::optional<std::vector<int64_t>> chunk_sizes,
const std::optional<std::vector<int64_t>>& chunk_sizes,
int64_t dim,
std::optional<py::object> py_streams) {
std::optional<std::vector<std::optional<at::cuda::CUDAStream>>>

View File

@ -1,5 +1,6 @@
#pragma once
#include <torch/csrc/utils/pythoncapi_compat.h>
namespace torch::cuda::python {
void initCommMethods(PyObject* module);

View File

@ -27,7 +27,8 @@ THPUtils_PySequence_to_CUDAStreamList(PyObject* obj) {
// Spicy hot reinterpret cast!!
streams.emplace_back(at::cuda::CUDAStream::unpack3(
(reinterpret_cast<THCPStream*>(stream))->stream_id,
(reinterpret_cast<THCPStream*>(stream))->device_index,
static_cast<c10::DeviceIndex>(
reinterpret_cast<THCPStream*>(stream)->device_index),
static_cast<c10::DeviceType>(
(reinterpret_cast<THCPStream*>(stream))->device_type)));
} else if (stream == Py_None) {

View File

@ -33,7 +33,9 @@
// https://man7.org/linux/man-pages/man1/objcopy.1.html
// todo: use #embed in C++ 23 once available
// The constants are NOT readonly because they may be mutated.
// NOLINTNEXTLINE(*array*)
extern uint8_t _binary_constants_bin_start[];
// NOLINTNEXTLINE(*array*)
extern uint8_t _binary_constants_bin_end[];
#define AOTI_CONST_GPU_ALIGNMENT 64

View File

@ -225,7 +225,7 @@ class AOTInductorModelContainer {
}
bool _should_skip_update(const size_t idx) const {
auto constant_type = models_[0]->constant_type(idx);
auto constant_type = models_[0]->constant_type(static_cast<int64_t>(idx));
return constant_type == ConstantType::TensorConstant;
}

View File

@ -63,6 +63,7 @@ struct ThreadLocalCachedOutputTensor<ArrayRefTensor<T>> {
private:
void realloc(const ArrayRefTensor<T>& t) {
capacity_ = t.numel();
// NOLINTNEXTLINE(*arrays*)
storage_ = std::make_unique<T[]>(t.numel());
AtenTensorHandle handle = nullptr;
AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob(
@ -78,6 +79,7 @@ struct ThreadLocalCachedOutputTensor<ArrayRefTensor<T>> {
tensor_ = handle;
}
// NOLINTNEXTLINE(*arrays*)
std::unique_ptr<T[]> storage_;
int64_t capacity_ = 0;
RAIIAtenTensorHandle tensor_;
@ -140,6 +142,7 @@ struct ThreadLocalCachedOutputArray<ArrayRefTensor<T>> {
void copy_data_from(const ArrayRefTensor<T>& t) {
if (t.numel() > capacity_) {
capacity_ = t.numel();
// NOLINTNEXTLINE(*arrays*)
storage_ = std::make_unique<T[]>(capacity_);
}
std::copy(t.data(), t.data() + t.numel(), storage_.get());
@ -148,6 +151,7 @@ struct ThreadLocalCachedOutputArray<ArrayRefTensor<T>> {
}
private:
// NOLINTNEXTLINE(*arrays*)
std::unique_ptr<T[]> storage_;
uint32_t capacity_ = 0;
ArrayRefTensor<T> tensor_;

View File

@ -1000,6 +1000,7 @@ AOTITorchError aoti_torch_index_put_out(
AtenTensorHandle self,
const AtenTensorHandle* indices,
const uint32_t num_indices,
// NOLINTNEXTLINE(misc-misplaced-const)
const AtenTensorHandle values,
bool accumulate) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({

View File

@ -169,6 +169,7 @@ static inline hash_t Hash(const at::Generator& value) {
// Use an arbitrary randomly-selected 64-bit integer rather than a
// small constant that we then hash at runtime so we don't have to
// repeatedly hash a constant at runtime.
// NOLINTNEXTLINE(*-narrowing-conversions)
static const int64_t kNullOpt = 0x8655d738f3678dda;
// Hashing for std::optional types contributes to hash

View File

@ -260,7 +260,7 @@ class TORCH_API TimedSection {
~TimedSection() {
int64_t now = NowNs();
metric_->AddSample(now, now - start_);
metric_->AddSample(now, static_cast<double>(now - start_));
}
double Elapsed() const {

View File

@ -30,7 +30,7 @@ class TORCH_API Shape {
}
int64_t dim() const {
return sizes_.size();
return static_cast<int64_t>(sizes_.size());
}
c10::ArrayRef<int64_t> sizes() const {
return sizes_;

View File

@ -18,6 +18,7 @@ enum class Level : uint8_t {
kError,
};
// NOLINTNEXTLINE(*array*)
static constexpr const char* const kPyLevelNames[] = {
"NONE",
"NOTE",

View File

@ -15,6 +15,6 @@ enum class TrainingMode {
TRAINING, // Training mode
};
constexpr char kOnnxNodeNameAttribute[] = "onnx_name";
constexpr auto kOnnxNodeNameAttribute = "onnx_name";
} // namespace torch::onnx

View File

@ -312,13 +312,13 @@ std::string ivalueToStr(const c10::IValue& val, bool isString) {
// json only takes "true" and "false" so we convert the string to lower case
if (val.isBool()) {
for (char& c : mystr) {
c = std::tolower(c);
c = static_cast<char>(std::tolower(c));
}
}
// A double quote can cause issues with the chrome tracing so force
// all inputs to not contain more than the 2 we add in this function
int count = std::count(mystr.begin(), mystr.end(), '\"');
auto count = std::count(mystr.begin(), mystr.end(), '"');
return count > 2 ? "\"None\"" : mystr;
}
}

View File

@ -4,7 +4,6 @@
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/object_ptr.h>
#include <iostream>
namespace torch::utils {
namespace {

View File

@ -322,7 +322,6 @@ auto check_has_torch_function(PyObject* obj, bool ignore_mode) -> bool {
} // namespace torch
inline bool sequence_has_torch_function(PyObject* args) {
// NOLINTNEXTLINE(bugprone-branch-clone)
Py_ssize_t nargs = PySequence_Fast_GET_SIZE(args);
for (Py_ssize_t i = 0; i < nargs; i++) {
PyObject* obj = PySequence_Fast_GET_ITEM(args, i);

View File

@ -35,7 +35,7 @@ void initThroughputBenchmarkBindings(PyObject* module) {
.def(
"run_once",
[](ThroughputBenchmark& self,
py::args args,
const py::args& args,
const py::kwargs& kwargs) {
// Depending on this being ScriptModule of nn.Module we will release
// the GIL or not further down in the stack

View File

@ -63,7 +63,8 @@ inline PyObject* THPUtils_packString(const char* str) {
}
inline PyObject* THPUtils_packString(const std::string& str) {
return PyUnicode_FromStringAndSize(str.c_str(), str.size());
return PyUnicode_FromStringAndSize(
str.c_str(), static_cast<Py_ssize_t>(str.size()));
}
inline PyObject* THPUtils_internString(const std::string& str) {

View File

@ -19,7 +19,7 @@ inline void THPUtils_packInt64Array(
}
inline PyObject* THPUtils_packInt64Array(size_t size, const int64_t* sizes) {
THPObjectPtr tuple(PyTuple_New(size));
THPObjectPtr tuple(PyTuple_New(static_cast<Py_ssize_t>(size)));
if (!tuple)
throw python_error();
THPUtils_packInt64Array(tuple.get(), size, sizes);

View File

@ -106,6 +106,7 @@ struct TORCH_API SchemaInfo {
// Alias map of outputs to inputs
std::vector<std::unordered_set<size_t>> output_alias_map_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const c10::FunctionSchema schema_;
bool alias_maps_current_;