[13/N] Fix extra warnings brought by clang-tidy-17 (#140897)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140897
Approved by: https://github.com/ezyang
This commit is contained in:
cyy
2024-11-27 00:35:16 +00:00
committed by PyTorch MergeBot
parent 1df440dc4e
commit 2f082e1e56
14 changed files with 85 additions and 90 deletions

View File

@ -92,6 +92,7 @@ class MatrixRef {
/// The declaration here is extra complicated so that "arrayRef = {}"
/// continues to select the move assignment operator.
template <typename U>
// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
std::enable_if_t<std::is_same_v<U, T>, MatrixRef<T>>& operator=(
U&& Temporary) = delete;

View File

@ -574,12 +574,7 @@ static std::ostream& printMaybeAnnotatedDict(
static std::ostream& printComplex(std::ostream & out, const IValue & v) {
c10::complex<double> d = v.toComplexDouble();
IValue real(d.real()), imag(std::abs(d.imag()));
auto sign = "";
if (d.imag() >= 0) {
sign = "+";
} else {
sign = "-";
}
auto sign = d.imag() >= 0 ? '+' : '-';
return out << real << sign << imag << "j";
}

View File

@ -68,9 +68,11 @@ TensorBase empty_strided_cuda(
std::optional<Device> device_opt,
std::optional<bool> pin_memory_opt) {
TORCH_CHECK(!pin_memory_opt.has_value() || !*pin_memory_opt, "Only dense CPU tensors can be pinned");
#ifndef NDEBUG
// TODO: remove check for jagged, see https://github.com/pytorch/pytorch/issues/130073
const auto layout = layout_or_default(layout_opt);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(layout == Layout::Strided || layout == Layout::Jagged);
#endif
const auto dtype = dtype_or_default(dtype_opt);
return at::detail::empty_strided_cuda(size, stride, dtype, device_opt);

View File

@ -12,6 +12,7 @@
#include <string>
#include <ATen/cuda/tunable/TunableOp.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/util/StringUtil.h>
@ -22,6 +23,7 @@
#include <ATen/ops/allclose.h>
#include <ATen/ops/from_blob.h>
#endif
#include <ATen/OpMathType.h>
#include <fmt/printf.h>
namespace at::cuda::tunable {
@ -150,19 +152,19 @@ struct GemmParams : OpParams {
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
}
char transa;
char transb;
int64_t m;
int64_t n;
int64_t k;
char transa{};
char transb{};
int64_t m{};
int64_t n{};
int64_t k{};
at::opmath_type<T> alpha;
const T* a;
int64_t lda;
const T* b;
int64_t ldb;
const T* a{};
int64_t lda{};
const T* b{};
int64_t ldb{};
at::opmath_type<T> beta;
T* c;
int64_t ldc;
T* c{};
int64_t ldc{};
private:
bool duplicate_inputs_{false};
};
@ -223,7 +225,9 @@ struct GemmAndBiasParams : OpParams {
void Delete() {
c10::cuda::CUDACachingAllocator::raw_delete(c);
if (duplicate_inputs_) {
// NOLINTNEXTLINE(*const-cast)
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
// NOLINTNEXTLINE(*const-cast)
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
}
}
@ -233,30 +237,26 @@ struct GemmAndBiasParams : OpParams {
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
}
char transa;
char transb;
int64_t m;
int64_t n;
int64_t k;
at::opmath_type<T> alpha;
const T* a;
int64_t lda;
const T* b;
int64_t ldb;
T* c;
int64_t ldc;
const T* bias;
at::cuda::blas::GEMMAndBiasActivationEpilogue activation;
char transa{};
char transb{};
int64_t m{};
int64_t n{};
int64_t k{};
at::opmath_type<T> alpha{};
const T* a{};
int64_t lda{};
const T* b{};
int64_t ldb{};
T* c{};
int64_t ldc{};
const T* bias{};
at::cuda::blas::GEMMAndBiasActivationEpilogue activation{};
private:
bool duplicate_inputs_{false};
};
template <typename T>
struct GemmStridedBatchedParams : OpParams {
GemmStridedBatchedParams() = default;
GemmStridedBatchedParams(const GemmStridedBatchedParams&) = default;
GemmStridedBatchedParams& operator=(const GemmStridedBatchedParams&) = default;
std::string Signature() const override {
return fmt::sprintf("%c%c_%ld_%ld_%ld_B_%ld", transa, transb, m, n, k, batch);
}
@ -325,23 +325,23 @@ struct GemmStridedBatchedParams : OpParams {
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
}
char transa;
char transb;
int64_t m;
int64_t n;
int64_t k;
at::opmath_type<T> alpha;
const T* a;
int64_t lda;
int64_t stride_a;
const T* b;
int64_t ldb;
int64_t stride_b;
char transa{};
char transb{};
int64_t m{};
int64_t n{};
int64_t k{};
at::opmath_type<T> alpha{};
const T* a{};
int64_t lda{};
int64_t stride_a{};
const T* b{};
int64_t ldb{};
int64_t stride_b{};
at::opmath_type<T> beta;
T* c;
int64_t ldc;
int64_t stride_c;
int64_t batch;
T* c{};
int64_t ldc{};
int64_t stride_c{};
int64_t batch{};
private:
bool duplicate_inputs_{false};
};
@ -415,27 +415,27 @@ struct ScaledGemmParams : OpParams {
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T)) ? OK : FAIL;
}
char transa;
char transb;
int64_t m;
int64_t n;
int64_t k;
const void* a;
const void* a_scale_ptr;
int64_t lda;
ScalarType a_dtype;
const void* b;
const void* b_scale_ptr;
int64_t ldb;
ScalarType b_dtype;
const void* bias_ptr;
ScalarType bias_dtype;
void* c;
const void* c_scale_ptr;
int64_t ldc;
ScalarType c_dtype;
void* amax_ptr;
bool use_fast_accum;
char transa{};
char transb{};
int64_t m{};
int64_t n{};
int64_t k{};
const void* a{};
const void* a_scale_ptr{};
int64_t lda{};
ScalarType a_dtype{};
const void* b{};
const void* b_scale_ptr{};
int64_t ldb{};
ScalarType b_dtype{};
const void* bias_ptr{};
ScalarType bias_dtype{};
void* c{};
const void* c_scale_ptr{};
int64_t ldc{};
ScalarType c_dtype{};
void* amax_ptr{};
bool use_fast_accum{};
private:
bool duplicate_inputs_{false};
};

View File

@ -9,9 +9,10 @@
//
#include <cuda_runtime.h>
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/tunable/StreamTimer.h>
#include <c10/cuda/CUDAStream.h>
#include <cmath>
namespace at::cuda::tunable {
@ -20,8 +21,7 @@ StreamTimer::StreamTimer() {
AT_CUDA_CHECK(cudaEventCreate(&end_));
}
StreamTimer::~StreamTimer() {
}
StreamTimer::~StreamTimer() = default;
void StreamTimer::Start() {
AT_CUDA_CHECK(cudaDeviceSynchronize());
@ -34,7 +34,7 @@ void StreamTimer::End() {
}
float StreamTimer::Duration() {
float time;
auto time = std::numeric_limits<float>::quiet_NaN();
// time is in ms with a resolution of 1 us
AT_CUDA_CHECK(cudaEventElapsedTime(&time, start_, end_));
return time;

View File

@ -27,8 +27,8 @@ class StreamTimer : public ITimer {
float Duration() override;
private:
cudaEvent_t start_;
cudaEvent_t end_;
cudaEvent_t start_{};
cudaEvent_t end_{};
};
} // namespace at::cuda::tunable

View File

@ -26,8 +26,6 @@ namespace at::cuda::tunable {
template <typename ParamsT>
class Callable {
public:
Callable() = default;
Callable(Callable&&) = default;
virtual ~Callable() = default;
virtual TuningStatus Call(const ParamsT*) {
return FAIL;
@ -40,8 +38,6 @@ class Callable {
template <typename ParamsT, typename TimerT>
class TunableOp {
public:
TunableOp() = default;
TunableOp(TunableOp&&) = default;
virtual ~TunableOp() = default;
TuningStatus operator()(const ParamsT* params) {

View File

@ -126,7 +126,7 @@ c10::intrusive_ptr<TensorImpl> TensorWrapper::shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const {
auto dest_impl = makeTensorWrapperPtr(value(), level_, is_alive_);
dest_impl->set_version_counter(version_counter);
dest_impl->set_version_counter(std::move(version_counter));
// TODO: is this even right?
dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);

View File

@ -264,8 +264,8 @@ class _map;
template <class F, class A, class... Args>
class _map<F, A, c10::guts::typelist::typelist<Args...>> {
public:
static A function_one(F&& fn, const Args&... nested_node) {
return std::forward<F>(fn)(nested_node...);
static A function_one(const F& fn, const Args&... nested_node) {
return fn(nested_node...);
}
static NestedNode<A> function(
const F& fn,

View File

@ -66,6 +66,7 @@ struct C10_API AutogradState {
bool inference_mode_ : 1;
bool fw_grad_mode_ : 1;
bool multithreading_enabled_ : 1;
// NOLINTNEXTLINE(cppcoreguidelines-use-default-member-init)
bool view_replay_enabled_ : 1;
};

View File

@ -30,7 +30,7 @@ using namespace torch;
PyObject* THPGeneratorClass = nullptr;
PyObject* THPGenerator_initDefaultGenerator(at::Generator cdata) {
PyObject* THPGenerator_initDefaultGenerator(const at::Generator& cdata) {
auto type = (PyTypeObject*)THPGeneratorClass;
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
if (!self)
@ -391,7 +391,7 @@ PyObject* pyobj(const Generator& self) {
return self.pyobj();
}
PyObject* THPGenerator_Wrap(Generator gen) {
PyObject* THPGenerator_Wrap(const Generator& gen) {
if (!gen.defined()) {
Py_RETURN_NONE;
}

View File

@ -14,7 +14,7 @@ struct THPGenerator {
// is borrowed. The caller should ensure that the at::Generator object lifetime
// last at least as long as the Python wrapper.
TORCH_PYTHON_API PyObject* THPGenerator_initDefaultGenerator(
at::Generator cdata);
const at::Generator& cdata);
#define THPGenerator_Check(obj) PyObject_IsInstance(obj, THPGeneratorClass)
@ -22,7 +22,7 @@ TORCH_PYTHON_API extern PyObject* THPGeneratorClass;
bool THPGenerator_init(PyObject* module);
TORCH_PYTHON_API PyObject* THPGenerator_Wrap(at::Generator gen);
TORCH_PYTHON_API PyObject* THPGenerator_Wrap(const at::Generator& gen);
TORCH_PYTHON_API at::Generator THPGenerator_Unwrap(PyObject* state);

View File

@ -7,6 +7,7 @@ void initExportBindings(PyObject* module) {
auto rootModule = py::handle(module).cast<py::module>();
auto m = rootModule.def_submodule("_export");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<ExportedProgram>(m, "CppExportedProgram");
m.def("deserialize_exported_program", [](const std::string& serialized) {

View File

@ -7,7 +7,6 @@
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include <torch/csrc/lazy/core/shape.h>
#include <torch/csrc/lazy/core/tensor.h>
#include <atomic>
namespace torch::lazy {