add Wmissing-prototypes to clang-tidy (#96805)

This PR introduces **-Wmissing-prototypes** of clang-tidy to prevent further coding errors such as the one fixed by PR #96714.

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at fd2cf2a</samp>

This pull request makes several internal functions static to improve performance and avoid name clashes. It also fixes some typos, formatting, and missing includes in various files. It adds a new .clang-tidy check to warn about missing prototypes for non-static functions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96805
Approved by: https://github.com/malfet, https://github.com/albanD
This commit is contained in:
cyy
2023-04-25 18:20:32 +00:00
committed by PyTorch MergeBot
parent 39ff87c6a4
commit dbc7e919b8
26 changed files with 67 additions and 56 deletions

View File

@ -9,6 +9,7 @@ bugprone-*,
-bugprone-lambda-function-name,
-bugprone-reserved-identifier,
-bugprone-swapped-arguments,
clang-diagnostic-missing-prototypes,
cppcoreguidelines-*,
-cppcoreguidelines-avoid-do-while,
-cppcoreguidelines-avoid-magic-numbers,

View File

@ -221,7 +221,7 @@ struct ATenDLMTensor {
DLManagedTensor tensor;
};
void deleter(DLManagedTensor* arg) {
static void deleter(DLManagedTensor* arg) {
delete static_cast<ATenDLMTensor*>(arg->manager_ctx);
}

View File

@ -1108,7 +1108,7 @@ void TensorImpl::ShareExternalPointer(
}
}
void clone_symvec(SymIntArrayRef src, SymDimVector& dst) {
static void clone_symvec(SymIntArrayRef src, SymDimVector& dst) {
dst.clear();
dst.reserve(src.size());
for (const auto& i : src) {

View File

@ -5,7 +5,7 @@
using namespace c10;
#ifndef C10_MOBILE
void check(int64_t value) {
static void check(int64_t value) {
const auto i = SymInt(value);
EXPECT_EQ(i.maybe_as_int(), c10::make_optional(value));
}

View File

@ -78,6 +78,7 @@
#include <torch/csrc/utils/tensor_new.h>
#include <torch/csrc/utils/tensor_numpy.h>
#include <torch/csrc/utils/tensor_qschemes.h>
#include <torch/csrc/utils/verbose.h>
#ifdef USE_DISTRIBUTED
#ifdef USE_C10D
@ -1229,10 +1230,6 @@ void initIttBindings(PyObject* module);
} // namespace torch
#endif
namespace torch {
void initVerboseBindings(PyObject* module);
} // namespace torch
static std::vector<PyMethodDef> methods;
// In Python we can't use the trick of C10_LOG_API_USAGE_ONCE

View File

@ -59,7 +59,7 @@ Tensor apply_loss_reduction(const Tensor& unreduced, int64_t reduction) {
return unreduced;
}
bool isDefined(const c10::optional<Tensor>& t) {
static bool isDefined(const c10::optional<Tensor>& t) {
return t.has_value() && t->defined();
}
@ -145,7 +145,7 @@ int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) {
return size;
}
c10::SymInt _safe_size(c10::SymIntArrayRef sizes, c10::IntArrayRef dim) {
static c10::SymInt _safe_size(c10::SymIntArrayRef sizes, c10::IntArrayRef dim) {
c10::SymInt size = 1;
if (sizes.empty()) {
return 1;
@ -165,7 +165,7 @@ Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result) {
return gradient_result;
}
Tensor handle_r_to_c(Tensor self, Tensor gradient_result) {
static Tensor handle_r_to_c(Tensor self, Tensor gradient_result) {
if (!self.is_complex() && gradient_result.is_complex()) {
// R -> C
return at::real(gradient_result);
@ -4365,7 +4365,7 @@ Tensor fft_r2c_backward(
}
// Helper for batchnorm_double_backward
Tensor sum_exclude_dim1(const Tensor& to_sum, bool keepdim = true) {
static Tensor sum_exclude_dim1(const Tensor& to_sum, bool keepdim = true) {
auto r = to_sum.sum(0, keepdim);
int64_t start_point_exclusive = keepdim ? 1 : 0;
for (int64_t dim = r.dim() - 1; dim > start_point_exclusive; dim--) {
@ -4377,7 +4377,7 @@ Tensor sum_exclude_dim1(const Tensor& to_sum, bool keepdim = true) {
// Helper for batchnorm_double_backward
// similar to expand_as below, but doesn't do the expand_as; operates as if
// reductions were done with keepdim=True
Tensor unsqueeze_dim1(const Tensor& src, const Tensor& target) {
static Tensor unsqueeze_dim1(const Tensor& src, const Tensor& target) {
auto src_expanded = src;
while (src_expanded.sizes().size() < target.sizes().size() - 1) {
src_expanded = src_expanded.unsqueeze(1);
@ -4391,7 +4391,7 @@ Tensor unsqueeze_dim1(const Tensor& src, const Tensor& target) {
// Helper for batchnorm_double_backward
// because gamma/ggG/ggB are 1-dimensional and represent dim==1, we can't
// do a straight expansion because it won't follow the broadcasting rules.
Tensor expand_as_dim1(const Tensor& src, const Tensor& target) {
static Tensor expand_as_dim1(const Tensor& src, const Tensor& target) {
auto src_expanded = src;
while (src_expanded.sizes().size() < target.sizes().size() - 1) {
src_expanded = src_expanded.unsqueeze(1);
@ -4947,7 +4947,7 @@ bool any_variable_defined(const variable_list& variables) {
// from the right.
// Additionally, when the computation is done in-place, we exploit that the
// first `k` coordinates of `u_full/v_full` are zeros.
Tensor apply_simple_transformation(
static Tensor apply_simple_transformation(
int64_t m,
int64_t k,
const Tensor& u_full,

View File

@ -158,7 +158,9 @@ TORCH_LIBRARY_IMPL(aten, Tracer, m) {
namespace torch {
namespace jit {
void general_trace_function(const c10::OperatorHandle& op, Stack* stack) {
static void general_trace_function(
const c10::OperatorHandle& op,
Stack* stack) {
const auto input_size = op.schema().arguments().size();
const auto output_size = op.schema().returns().size();

View File

@ -24,7 +24,7 @@ namespace autograd {
// maintain the logic equality of this file and the python file together if one
// changes.
// TODO: Make the Python API above to just call this C++ API.
variable_list _make_grads(
static variable_list _make_grads(
const variable_list& outputs,
const variable_list& grad_outputs) {
size_t num_tensors = outputs.size();
@ -88,7 +88,7 @@ variable_list _make_grads(
}
return new_grads;
}
variable_list run_backward(
static variable_list run_backward(
const variable_list& outputs,
const variable_list& grad_outputs,
bool keep_graph,

View File

@ -47,7 +47,7 @@ void _foreach_tensor(
} // namespace
void autogradNotImplementedFallbackImpl(
static void autogradNotImplementedFallbackImpl(
const c10::OperatorHandle& op,
c10::DispatchKeySet dispatch_keys,
torch::jit::Stack* stack) {
@ -252,7 +252,7 @@ torch::CppFunction autogradNotImplementedFallback() {
&autogradNotImplementedFallbackImpl>();
}
void autogradNotImplementedInplaceOrViewFallbackImpl(
static void autogradNotImplementedInplaceOrViewFallbackImpl(
const c10::OperatorHandle& op,
c10::DispatchKeySet dispatch_keys,
torch::jit::Stack* stack) {

View File

@ -46,7 +46,7 @@ Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const {
// sure that its
// forward grad was also modified inplace and already present on the
// corresponding output.
void _process_forward_mode_AD(
static void _process_forward_mode_AD(
const variable_list& inputs,
std::unordered_map<at::TensorImpl*, size_t> inputs_mapping,
const at::ArrayRef<c10::optional<Variable>> raw_outputs,
@ -247,7 +247,7 @@ void _process_forward_mode_AD(
}
}
at::Tensor _view_as_self_with_no_grad(at::Tensor self) {
static at::Tensor _view_as_self_with_no_grad(at::Tensor self) {
// This is called below in _process_backward_mode_ad in two places:
//
// (1) An input has been returned, but it wasn't modified. Return it as a view
@ -268,7 +268,7 @@ at::Tensor _view_as_self_with_no_grad(at::Tensor self) {
return self.view_as(self);
}
optional_variable_list _process_backward_mode_ad(
static optional_variable_list _process_backward_mode_ad(
const std::unordered_map<at::TensorImpl*, size_t>& inputs_mapping,
const std::unordered_set<at::TensorImpl*>& non_differentiable,
const std::unordered_set<at::TensorImpl*>& dirty_inputs,

View File

@ -18,7 +18,7 @@ namespace profiler {
// Creates a new profiling scope using RecordFunction and invokes its starting
// callbacks.
void record_function_enter(
static void record_function_enter(
const std::string& name,
const c10::optional<std::string>& args,
at::RecordFunction& rec) {
@ -33,7 +33,7 @@ void record_function_enter(
}
// Legacy signature using cpp_custom_type_hack
at::Tensor record_function_enter_legacy(
static at::Tensor record_function_enter_legacy(
const std::string& name,
const c10::optional<std::string>& args) {
auto rec = std::make_unique<at::RecordFunction>(at::RecordScope::USER_SCOPE);
@ -51,18 +51,19 @@ c10::intrusive_ptr<PythonRecordFunction> record_function_enter_new(
return rec;
}
at::RecordFunction& getRecordFunctionFromTensor(const at::Tensor& handle) {
static at::RecordFunction& getRecordFunctionFromTensor(
const at::Tensor& handle) {
auto& rec = at::cpp_custom_type_hack::cast<at::RecordFunction>(handle);
return rec;
}
// Ends the profiling scope created with record_function_enter.
void record_function_exit(at::RecordFunction& rec) {
static void record_function_exit(at::RecordFunction& rec) {
rec.end();
}
// Legacy signature using cpp_custom_type_hack
void record_function_exit_legacy(const at::Tensor& handle) {
static void record_function_exit_legacy(const at::Tensor& handle) {
// We don't actually need to do anything with handle just need to persist the
// lifetime until now.
auto& rec = getRecordFunctionFromTensor(handle);
@ -70,7 +71,7 @@ void record_function_exit_legacy(const at::Tensor& handle) {
}
// New signature using custom_class
void record_function_exit_new(
static void record_function_exit_new(
const c10::intrusive_ptr<PythonRecordFunction>& record) {
record_function_exit(record->record);
}
@ -100,7 +101,7 @@ c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut(
}
// Legacy signature using cpp_custom_type_hack
c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_legacy(
static c10::intrusive_ptr<c10::ivalue::Future> _call_end_callbacks_on_fut_legacy(
const at::Tensor& handle,
const c10::intrusive_ptr<c10::ivalue::Future>& fut) {
return _call_end_callbacks_on_fut(

View File

@ -157,7 +157,7 @@ AutogradMeta* materialize_autograd_meta(const at::TensorBase& self) {
return get_autograd_meta(self);
}
void update_tensor_hooks_on_new_gradfn(
static void update_tensor_hooks_on_new_gradfn(
const at::TensorBase& self,
const std::shared_ptr<torch::autograd::Node>& old_fn,
const std::shared_ptr<torch::autograd::Node>& new_fn) {

View File

@ -76,7 +76,7 @@ ContextPtr addRecvRpcBackward(
return autogradContext;
}
c10::intrusive_ptr<Message> getMessageWithProfiling(
static c10::intrusive_ptr<Message> getMessageWithProfiling(
c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMessage,
MessageType msgType,
torch::autograd::profiler::ProfilerConfig&& profilerConfig) {

View File

@ -46,7 +46,7 @@ ErrorReport::CallStack::CallStack(
ErrorReport::CallStack::~CallStack() {}
#endif // C10_MOBILE
std::string get_stacked_errors(const std::vector<Call>& error_stack) {
static std::string get_stacked_errors(const std::vector<Call>& error_stack) {
std::stringstream msg;
if (!error_stack.empty()) {
for (auto it = error_stack.rbegin(); it != error_stack.rend() - 1; ++it) {

View File

@ -15,13 +15,13 @@ namespace jit {
static const auto countsAttribute = Symbol::attr("none_counts");
bool hasGradSumToSizeUses(Value* v) {
static bool hasGradSumToSizeUses(Value* v) {
return std::any_of(v->uses().begin(), v->uses().end(), [](const Use& use) {
return use.user->kind() == aten::_grad_sum_to_size;
});
}
void insertProfileNodesForSpecializeAutogradZero(
static void insertProfileNodesForSpecializeAutogradZero(
Block* block,
ProfilingRecord* pr) {
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {

View File

@ -206,7 +206,7 @@ void ProfilingRecord::insertShapeProfile(
n->replaceInput(offset, pn->output());
}
bool needsProfiledInputs(Node* n) {
static bool needsProfiledInputs(Node* n) {
if (tensorexpr::isSupported(n) ||
#ifndef C10_MOBILE
(fuser::cuda::isEnabled() && fuser::cuda::profileNode(n))
@ -243,7 +243,7 @@ bool needsProfiledInputs(Node* n) {
}
}
bool needsProfiledOutput(Node* n) {
static bool needsProfiledOutput(Node* n) {
if (tensorexpr::isSupported(n) ||
#ifndef C10_MOBILE
(fuser::cuda::isEnabled() && fuser::cuda::profileNode(n))

View File

@ -113,15 +113,15 @@ static PyObject* Tensor_instancecheck(PyObject* _self, PyObject* arg) {
END_HANDLE_TH_ERRORS
}
PyObject* Tensor_dtype(PyTensorType* self, void* unused) {
static PyObject* Tensor_dtype(PyTensorType* self, void* unused) {
return torch::autograd::utils::wrap(self->dtype);
}
PyObject* Tensor_layout(PyTensorType* self, void* unused) {
static PyObject* Tensor_layout(PyTensorType* self, void* unused) {
return torch::autograd::utils::wrap(self->layout);
}
PyObject* Tensor_is_cuda(PyTensorType* self, void* unused) {
static PyObject* Tensor_is_cuda(PyTensorType* self, void* unused) {
if (self->is_cuda) {
Py_RETURN_TRUE;
} else {
@ -129,7 +129,7 @@ PyObject* Tensor_is_cuda(PyTensorType* self, void* unused) {
}
}
PyObject* Tensor_is_sparse(PyTensorType* self, void* unused) {
static PyObject* Tensor_is_sparse(PyTensorType* self, void* unused) {
if (self->layout->layout == at::Layout::Strided) {
Py_RETURN_FALSE;
} else {
@ -137,7 +137,7 @@ PyObject* Tensor_is_sparse(PyTensorType* self, void* unused) {
}
}
PyObject* Tensor_is_sparse_csr(PyTensorType* self, void* unused) {
static PyObject* Tensor_is_sparse_csr(PyTensorType* self, void* unused) {
if (self->layout->layout == at::Layout::SparseCsr) {
Py_RETURN_TRUE;
} else {
@ -302,7 +302,7 @@ static THPObjectPtr get_tensor_dict() {
// importing torch.
static std::vector<PyTensorType*> tensor_types;
void set_default_storage_type(Backend backend, ScalarType dtype) {
static void set_default_storage_type(Backend backend, ScalarType dtype) {
THPObjectPtr storage = get_storage_obj(backend, dtype);
auto torch_module = THPObjectPtr(PyImport_ImportModule("torch"));
@ -314,7 +314,7 @@ void set_default_storage_type(Backend backend, ScalarType dtype) {
}
}
void set_default_tensor_type(
static void set_default_tensor_type(
c10::optional<Backend> backend,
c10::optional<ScalarType> dtype) {
if (backend.has_value()) {

View File

@ -261,14 +261,15 @@ void THP_decodeComplexDoubleBuffer(
for (const auto i : c10::irange(len)) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
union {
uint32_t x;
uint64_t x;
double re;
};
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
union {
uint32_t y;
uint64_t y;
double im;
};
static_assert(sizeof(uint64_t) == sizeof(double));
x = (do_byte_swap ? decodeUInt64BE(src) : decodeUInt64LE(src));
src += sizeof(double);
@ -462,7 +463,7 @@ void THP_encodeComplexFloatBuffer(
}
}
void THP_encodeCompelxDoubleBuffer(
void THP_encodeComplexDoubleBuffer(
uint8_t* dst,
const c10::complex<double>* src,
THPByteOrder order,

View File

@ -190,7 +190,7 @@ TORCH_API void THP_encodeDoubleBuffer(
const double* src,
THPByteOrder order,
size_t len);
TORCH_API void THP_encodeComplexloatBuffer(
TORCH_API void THP_encodeComplexFloatBuffer(
uint8_t* dst,
const c10::complex<float>* src,
THPByteOrder order,

View File

@ -13,7 +13,7 @@ namespace torch {
namespace utils {
// NB: device_idx here is NOT a DeviceIndex, but index into PythonArgs
c10::TensorOptions typeIdWithDefault(
static c10::TensorOptions typeIdWithDefault(
PythonArgs& r,
int device_idx,
c10::DispatchKey dispatch_key) {

View File

@ -195,7 +195,7 @@ auto handle_torch_function_setter(
}
// Combines self and args into one tuple.
auto combine_self_args(PyObject* self, PyObject* args) -> py::tuple {
static auto combine_self_args(PyObject* self, PyObject* args) -> py::tuple {
if (args == nullptr) {
return py::make_tuple(py::handle(self));
} else if (self == nullptr) {
@ -601,7 +601,7 @@ bool is_tensor_and_append_overloaded(
return false;
}
bool is_scalar_list(PyObject* obj) {
static bool is_scalar_list(PyObject* obj) {
auto tuple = six::isTuple(obj);
if (!(tuple || PyList_Check(obj))) {
return false;
@ -646,7 +646,7 @@ bool is_tensor_list_and_append_overloaded(
return true;
}
bool is_float_or_complex_list(PyObject* obj) {
static bool is_float_or_complex_list(PyObject* obj) {
auto tuple = six::isTuple(obj);
if (!(tuple || PyList_Check(obj))) {
return false;

View File

@ -36,7 +36,7 @@ static ska::flat_hash_map<
ska::flat_hash_map<c10::DispatchKey, std::shared_ptr<c10::SafePyObject>>>
python_registrations_;
torch::Library::Kind parseKind(const std::string& k) {
static torch::Library::Kind parseKind(const std::string& k) {
static std::unordered_map<std::string, torch::Library::Kind> kind_map = {
{"DEF", torch::Library::DEF},
{"IMPL", torch::Library::IMPL},
@ -46,7 +46,7 @@ torch::Library::Kind parseKind(const std::string& k) {
TORCH_CHECK(it != kind_map.end(), "could not parse ", k);
return it->second;
}
c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) {
static c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) {
static std::unordered_map<std::string, c10::AliasAnalysisKind> key_map = {
{"CONSERVATIVE", c10::AliasAnalysisKind::CONSERVATIVE},
{"FROM_SCHEMA", c10::AliasAnalysisKind::FROM_SCHEMA},
@ -180,7 +180,7 @@ class PythonKernelHolder : public c10::OperatorKernel {
}
};
torch::_RegisterOrVerify register_or_verify() {
static torch::_RegisterOrVerify register_or_verify() {
if (isMainPyInterpreter()) {
return torch::_RegisterOrVerify::REGISTER;
} else {

View File

@ -820,7 +820,7 @@ class CheckSparseTensorInvariantsContext {
bool state;
};
Tensor sparse_compressed_tensor_ctor_worker(
static Tensor sparse_compressed_tensor_ctor_worker(
std::string name,
c10::DispatchKey dispatch_key,
at::ScalarType scalar_type,

View File

@ -19,7 +19,7 @@ using namespace at;
namespace torch {
namespace utils {
const char* parse_privateuseone_backend() {
static const char* parse_privateuseone_backend() {
static std::string backend_name = "torch." + get_privateuse1_backend();
return backend_name.c_str();
}

View File

@ -1,5 +1,6 @@
#include <ATen/native/verbose_wrapper.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/verbose.h>
namespace torch {

View File

@ -0,0 +1,8 @@
#pragma once
#include <torch/csrc/python_headers.h>
namespace torch {
void initVerboseBindings(PyObject* module);
} // namespace torch