Enable modernize-use-default-member-init (#149046)

``modernize-use-default-member-init`` prefers initialisation in class members, that make more ``= default`` constructors possible. Some violations or modernize rules have been fixed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149046
Approved by: https://github.com/zou3519
This commit is contained in:
cyy
2025-04-09 11:57:24 +00:00
committed by PyTorch MergeBot
parent 81f60f3880
commit 142f0f86ce
20 changed files with 45 additions and 60 deletions

View File

@ -52,7 +52,6 @@ modernize-*,
-modernize-macro-to-enum, -modernize-macro-to-enum,
-modernize-return-braced-init-list, -modernize-return-braced-init-list,
-modernize-use-auto, -modernize-use-auto,
-modernize-use-default-member-init,
-modernize-use-using, -modernize-use-using,
-modernize-use-trailing-return-type, -modernize-use-trailing-return-type,
-modernize-use-nodiscard, -modernize-use-nodiscard,

View File

@ -116,10 +116,7 @@ public:
DictIterator(const DictIterator& rhs): entryRef_(rhs.entryRef_) {} DictIterator(const DictIterator& rhs): entryRef_(rhs.entryRef_) {}
DictIterator(DictIterator&& rhs) noexcept: entryRef_(std::move(rhs.entryRef_)) {} DictIterator(DictIterator&& rhs) noexcept: entryRef_(std::move(rhs.entryRef_)) {}
DictIterator& operator=(const DictIterator& rhs) { DictIterator& operator=(const DictIterator& rhs) = default;
entryRef_ = rhs.entryRef_;
return *this;
}
DictIterator& operator=(DictIterator&& rhs) noexcept { DictIterator& operator=(DictIterator&& rhs) noexcept {
entryRef_ = std::move(rhs.entryRef_); entryRef_ = std::move(rhs.entryRef_);
return *this; return *this;

View File

@ -225,8 +225,7 @@ struct TORCH_API DispatchKeyExtractor final {
explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse) explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse)
: dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse), : dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse),
nonFallthroughKeys_(DispatchKeySet::FULL), nonFallthroughKeys_(DispatchKeySet::FULL) {
requiresBitsetPerBackend_(false) {
for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) { for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) {
nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL; nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL;
} }
@ -252,7 +251,7 @@ struct TORCH_API DispatchKeyExtractor final {
// Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast // Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast
// path), or if we need to fall back to the slower path and check // path), or if we need to fall back to the slower path and check
// nonFallthroughKeysPerBackend_ // nonFallthroughKeysPerBackend_
bool requiresBitsetPerBackend_; bool requiresBitsetPerBackend_{false};
}; };
} // namespace c10 } // namespace c10

View File

@ -40,7 +40,7 @@ enum TORCH_CUDA_CPP_API TuningStatus {
class TORCH_CUDA_CPP_API ResultEntry { class TORCH_CUDA_CPP_API ResultEntry {
public: public:
explicit ResultEntry(std::string key, double time) : key_(std::move(key)), time_(time) {} explicit ResultEntry(std::string key, double time) : key_(std::move(key)), time_(time) {}
explicit ResultEntry(std::string key, double time, const std::string& blas_sig ) : key_(std::move(key)), time_(time), blas_sig_(blas_sig) {} explicit ResultEntry(std::string key, double time, std::string blas_sig ) : key_(std::move(key)), time_(time), blas_sig_(std::move(blas_sig)) {}
bool operator==(const ResultEntry& other) const { return key_ == other.key_; } bool operator==(const ResultEntry& other) const { return key_ == other.key_; }
bool operator!=(const ResultEntry& other) const { return key_ != other.key_; } bool operator!=(const ResultEntry& other) const { return key_ != other.key_; }
operator std::string () { return key_; } operator std::string () { return key_; }

View File

@ -2,9 +2,9 @@
#include <c10/core/Scalar.h> #include <c10/core/Scalar.h>
#include <limits> #include <limits>
namespace at {
namespace native {
namespace at::native {
template <typename scalar_t> template <typename scalar_t>
int64_t compute_arange_size(const Scalar& start, const Scalar& end, const Scalar& step) { int64_t compute_arange_size(const Scalar& start, const Scalar& end, const Scalar& step) {
@ -42,4 +42,4 @@ int64_t compute_arange_size(const Scalar& start, const Scalar& end, const Scalar
return static_cast<int64_t>(size_d); return static_cast<int64_t>(size_d);
} }
}} // namespace at::native } // namespace at::native

View File

@ -756,7 +756,7 @@ static DimVector default_alldims(const Tensor& self, at::OptionalIntArrayRef dim
IntArrayRef dim_unwrapped = *dim_opt; IntArrayRef dim_unwrapped = *dim_opt;
dim.resize(dim_unwrapped.size()); dim.resize(dim_unwrapped.size());
for (const auto i : c10::irange(dim.size())) { for (const auto i : c10::irange(dim.size())) {
dim[i] = maybe_wrap_dim(dim_unwrapped[i], self.dim(), /*wrap_scalars=*/false); dim[i] = maybe_wrap_dim(dim_unwrapped[i], self.dim(), /*wrap_scalar=*/false);
} }
} else { } else {
dim.resize(self.dim()); dim.resize(self.dim());

View File

@ -887,7 +887,7 @@ static inline void mvlgamma_check(const Tensor& self, int64_t p) {
Tensor mvlgamma(const Tensor& self, int64_t p) { Tensor mvlgamma(const Tensor& self, int64_t p) {
mvlgamma_check(self, p); mvlgamma_check(self, p);
auto dtype = c10::scalarTypeToTypeMeta(self.scalar_type()); auto dtype = c10::scalarTypeToTypeMeta(self.scalar_type());
if (at::isIntegralType(self.scalar_type(), /*include_bool=*/true)) { if (at::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
// int -> float promotion // int -> float promotion
dtype = c10::get_default_dtype(); dtype = c10::get_default_dtype();
} }

View File

@ -16,7 +16,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
namespace at { namespace native { namespace detail { namespace at::native::detail {
// Enum representing the FFT type // Enum representing the FFT type
enum class CuFFTTransformType : int8_t { enum class CuFFTTransformType : int8_t {
@ -58,7 +58,7 @@ struct CuFFTParams
} }
}; };
static_assert(std::is_trivial_v<CuFFTParams>, ""); static_assert(std::is_trivial_v<CuFFTParams> );
// Returns true if the transform type has complex input // Returns true if the transform type has complex input
inline bool cufft_complex_input(CuFFTTransformType type) { inline bool cufft_complex_input(CuFFTTransformType type) {
@ -491,4 +491,4 @@ void cufft_set_plan_cache_max_size_impl(DeviceIndex device_index, int64_t max_si
int64_t cufft_get_plan_cache_size_impl(DeviceIndex device_index); int64_t cufft_get_plan_cache_size_impl(DeviceIndex device_index);
void cufft_clear_plan_cache_impl(DeviceIndex device_index); void cufft_clear_plan_cache_impl(DeviceIndex device_index);
}}} // namespace at::native::detail } // namespace at::native::detail

View File

@ -4,8 +4,8 @@
#include <ATen/cuda/CUDAConfig.h> #include <ATen/cuda/CUDAConfig.h>
#include <ATen/cuda/PinnedMemoryAllocator.h> #include <ATen/cuda/PinnedMemoryAllocator.h>
namespace at {
namespace native { namespace at::native {
static inline int cuda_int_cast(int64_t value, const char* varname) { static inline int cuda_int_cast(int64_t value, const char* varname) {
auto result = static_cast<int>(value); auto result = static_cast<int>(value);
@ -28,5 +28,4 @@ static inline Storage pin_memory(int64_t size) {
/*resizable=*/false); /*resizable=*/false);
} }
} // namespace native } // namespace at::native
} // namespace at

View File

@ -5,7 +5,7 @@
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
namespace at { namespace native { namespace at::native {
TORCH_CUDA_CPP_API void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes); TORCH_CUDA_CPP_API void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes);
@ -50,4 +50,4 @@ inline TensorImpl* resize_impl_cuda_(
return self; return self;
} }
}} }

View File

@ -36,8 +36,8 @@
// The current pytorch implementation sets gesvdj tolerance to epsilon of a C++ data type to target the best possible precision. // The current pytorch implementation sets gesvdj tolerance to epsilon of a C++ data type to target the best possible precision.
constexpr int cusolver_gesvdj_max_sweeps = 400; constexpr int cusolver_gesvdj_max_sweeps = 400;
namespace at {
namespace native { namespace at::native {
void geqrf_batched_cublas(const Tensor& input, const Tensor& tau); void geqrf_batched_cublas(const Tensor& input, const Tensor& tau);
void triangular_solve_cublas(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular); void triangular_solve_cublas(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular);
@ -90,4 +90,4 @@ C10_EXPORT void registerLinalgDispatch(const LinalgDispatch&);
}} // namespace cuda::detail }} // namespace cuda::detail
#endif #endif
}} // namespace at::native } // namespace at::native

View File

@ -6,9 +6,8 @@
#include <ATen/cudnn/cudnn-wrapper.h> #include <ATen/cudnn/cudnn-wrapper.h>
// Declares utilities used by RNN.cpp and also needed by external consumers // Declares utilities used by RNN.cpp and also needed by external consumers
namespace at {
namespace native { namespace at::native::cudnn_rnn {
namespace cudnn_rnn {
TORCH_CUDA_CPP_API std::tuple<Tensor, std::vector<Tensor>> TORCH_CUDA_CPP_API std::tuple<Tensor, std::vector<Tensor>>
copy_weights_to_flat_buf_views( copy_weights_to_flat_buf_views(
@ -27,6 +26,4 @@ copy_weights_to_flat_buf_views(
bool allow_type_change = false, bool allow_type_change = false,
bool include_bias = true); bool include_bias = true);
} // namespace cudnn_rnn } // namespace at::native::cudnn_rnn
} // namespace native
} // namespace at

View File

@ -20,7 +20,7 @@
#endif #endif
#endif #endif
namespace at { namespace native { namespace at::native {
// Mapping ScalarType to ideep tensor data_type // Mapping ScalarType to ideep tensor data_type
TORCH_API ideep::tensor::data_type get_mkldnn_dtype(ScalarType type); TORCH_API ideep::tensor::data_type get_mkldnn_dtype(ScalarType type);
@ -62,6 +62,6 @@ TORCH_API ideep::tensor itensor_from_tensor(const Tensor& tensor, bool from_cons
// Set MKLDNN verbose level // Set MKLDNN verbose level
TORCH_API int set_verbose(int level); TORCH_API int set_verbose(int level);
}} }
#endif // AT_MKLDNN_ENABLED #endif // AT_MKLDNN_ENABLED

View File

@ -131,7 +131,7 @@ struct PostOpParam {
class Attr { class Attr {
public: public:
Attr() : q_scale_(1.f), q_zero_point_(0) {} Attr() : q_scale_(1.f) {}
Attr(float q_scale, int64_t zp = 0) : q_scale_(q_scale), q_zero_point_(zp) {} Attr(float q_scale, int64_t zp = 0) : q_scale_(q_scale), q_zero_point_(zp) {}
/***** eltwise *****/ /***** eltwise *****/

View File

@ -51,8 +51,8 @@ struct TORCH_API QTensorImpl : public c10::TensorImpl {
auto impl = c10::make_intrusive<QTensorImpl>( auto impl = c10::make_intrusive<QTensorImpl>(
Storage(storage()), key_set(), data_type_, quantizer_); Storage(storage()), key_set(), data_type_, quantizer_);
copy_tensor_metadata( copy_tensor_metadata(
/*src_impl=*/this, /*src_q_impl=*/this,
/*dest_impl=*/impl.get(), /*dest_q_impl=*/impl.get(),
/*version_counter=*/version_counter, /*version_counter=*/version_counter,
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change); /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
impl->refresh_numel(); impl->refresh_numel();
@ -72,8 +72,8 @@ struct TORCH_API QTensorImpl : public c10::TensorImpl {
auto impl = c10::make_intrusive<QTensorImpl>( auto impl = c10::make_intrusive<QTensorImpl>(
Storage(storage()), key_set(), data_type_, quantizer_); Storage(storage()), key_set(), data_type_, quantizer_);
copy_tensor_metadata( copy_tensor_metadata(
/*src_impl=*/this, /*src_q_impl=*/this,
/*dest_impl=*/impl.get(), /*dest_q_impl=*/impl.get(),
/*version_counter=*/std::move(version_counter), /*version_counter=*/std::move(version_counter),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change); /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
impl->refresh_numel(); impl->refresh_numel();
@ -91,8 +91,8 @@ struct TORCH_API QTensorImpl : public c10::TensorImpl {
AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set())); AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set()));
auto q_impl = static_cast<const QTensorImpl*>(impl.get()); auto q_impl = static_cast<const QTensorImpl*>(impl.get());
copy_tensor_metadata( copy_tensor_metadata(
/*src_impl=*/q_impl, /*src_q_impl=*/q_impl,
/*dest_impl=*/this, /*dest_q_impl=*/this,
/*version_counter=*/version_counter(), /*version_counter=*/version_counter(),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
refresh_numel(); refresh_numel();

View File

@ -86,7 +86,7 @@ struct WelfordHelper {
std::vector<Welford<T>> welford_stk; std::vector<Welford<T>> welford_stk;
uint64_t depth; // depth of welford_stk. uint64_t depth; // depth of welford_stk.
uint64_t num_chunks; // number of chunks stored in welford_stk. uint64_t num_chunks; // number of chunks stored in welford_stk.
WelfordHelper() {} WelfordHelper() = default;
WelfordHelper(uint64_t N) { WelfordHelper(uint64_t N) {
uint64_t m = (N + kChunkSize - 1) / kChunkSize; //div up uint64_t m = (N + kChunkSize - 1) / kChunkSize; //div up
depth = m > 0 ? ceil(log2(m)) : 0; depth = m > 0 ? ceil(log2(m)) : 0;

View File

@ -1152,16 +1152,13 @@ std::vector<std::shared_ptr<Result>> PythonTracer::getEvents(
// Assuming python_tracer::PythonMemoryTracerBase is defined elsewhere // Assuming python_tracer::PythonMemoryTracerBase is defined elsewhere
class PythonMemoryTracer final : public python_tracer::PythonMemoryTracerBase { class PythonMemoryTracer final : public python_tracer::PythonMemoryTracerBase {
public: public:
explicit PythonMemoryTracer(); explicit PythonMemoryTracer() = default;
~PythonMemoryTracer() override; ~PythonMemoryTracer() override = default;
void start() override; void start() override;
void stop() override; void stop() override;
void export_memory_history(const std::string path) override; void export_memory_history(const std::string path) override;
}; };
PythonMemoryTracer::PythonMemoryTracer() {}
PythonMemoryTracer::~PythonMemoryTracer() {}
static void toggle_memory_tracing(bool enable) { static void toggle_memory_tracing(bool enable) {
PyGILState_STATE gil_state = PyGILState_Ensure(); PyGILState_STATE gil_state = PyGILState_Ensure();
THPObjectPtr torch_cuda_memory_module( THPObjectPtr torch_cuda_memory_module(
@ -1182,9 +1179,9 @@ static void toggle_memory_tracing(bool enable) {
PyTuple_SetItem(args, 3, THPUtils_packInt64(100000)); // max_entries PyTuple_SetItem(args, 3, THPUtils_packInt64(100000)); // max_entries
PyTuple_SetItem(args, 4, Py_None); // device (None) PyTuple_SetItem(args, 4, Py_None); // device (None)
PyTuple_SetItem(args, 5, PyBool_FromLong(0)); // clear_history (False) PyTuple_SetItem(args, 5, PyBool_FromLong(0)); // clear_history (False)
PyObject* result = PyObject_Call(snapshot_func.get(), args, NULL); PyObject* result = PyObject_Call(snapshot_func.get(), args, nullptr);
Py_DECREF(args); Py_DECREF(args);
if (result == NULL) { if (result == nullptr) {
return; return;
} }
PyGILState_Release(gil_state); PyGILState_Release(gil_state);
@ -1209,9 +1206,9 @@ void PythonMemoryTracer::export_memory_history(const std::string path) {
PyObject* py_filename = PyUnicode_FromString(path.c_str()); PyObject* py_filename = PyUnicode_FromString(path.c_str());
// Call the function with arguments (e.g., a file path) // Call the function with arguments (e.g., a file path)
PyObject* args = PyTuple_Pack(1, py_filename); PyObject* args = PyTuple_Pack(1, py_filename);
PyObject* result = PyObject_Call(snapshot_func.get(), args, NULL); PyObject* result = PyObject_Call(snapshot_func.get(), args, nullptr);
Py_DECREF(args); Py_DECREF(args);
if (result == NULL) { if (result == nullptr) {
return; return;
} }
PyGILState_Release(gil_state); PyGILState_Release(gil_state);

View File

@ -31,7 +31,7 @@ dnnl::engine& Engine::getEngine() {
static dnnl::graph::allocator alloc{ static dnnl::graph::allocator alloc{
pytorch_default_allocator, pytorch_default_deallocator}; pytorch_default_allocator, pytorch_default_deallocator};
static dnnl::engine cpu_engine = dnnl::graph::make_engine_with_allocator( static dnnl::engine cpu_engine = dnnl::graph::make_engine_with_allocator(
dnnl::engine::kind::cpu, /* device_id = */ 0, alloc); dnnl::engine::kind::cpu, /* index = */ 0, alloc);
return cpu_engine; return cpu_engine;
} }

View File

@ -18,9 +18,7 @@
TORCH_DECLARE_bool(torch_jit_enable_expanded_stacks); TORCH_DECLARE_bool(torch_jit_enable_expanded_stacks);
TORCH_DECLARE_bool(torch_jit_expanded_stacks_mangled); TORCH_DECLARE_bool(torch_jit_expanded_stacks_mangled);
namespace torch::jit { namespace torch::jit::interpreter {
namespace interpreter {
template <class Ttarget, class Tsource> template <class Ttarget, class Tsource>
Ttarget safe_narrow_cast(Tsource v) { Ttarget safe_narrow_cast(Tsource v) {
@ -64,7 +62,7 @@ struct NodeSourceInfo {
const char* func_name_{nullptr}; const char* func_name_{nullptr};
const char* file_name_{nullptr}; const char* file_name_{nullptr};
size_t line_{0}; size_t line_{0};
NodeSourceInfo() {} NodeSourceInfo() = default;
}; };
struct CodeImpl { struct CodeImpl {
@ -1060,5 +1058,4 @@ struct MobileCodeImpl : CodeImpl {
bool emit_promoted_ops_; bool emit_promoted_ops_;
}; };
} // namespace interpreter } // namespace torch::jit::interpreter
} // namespace torch::jit

View File

@ -17,12 +17,12 @@
class Socket { class Socket {
public: public:
int socket_fd; int socket_fd;
Socket(const Socket& other) = delete;
protected: protected:
Socket() { Socket() {
SYSCHECK_ERR_RETURN_NEG1(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0)); SYSCHECK_ERR_RETURN_NEG1(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0));
} }
Socket(const Socket& other) = delete;
Socket(Socket&& other) noexcept : socket_fd(other.socket_fd) { Socket(Socket&& other) noexcept : socket_fd(other.socket_fd) {
other.socket_fd = -1; other.socket_fd = -1;
}; };
@ -122,7 +122,7 @@ class ManagerServerSocket : public Socket {
SYSCHECK_ERR_RETURN_NEG1(unlink(socket_path.c_str())); SYSCHECK_ERR_RETURN_NEG1(unlink(socket_path.c_str()));
} }
virtual ~ManagerServerSocket() { ~ManagerServerSocket() override {
unlink(socket_path.c_str()); unlink(socket_path.c_str());
} }