[5/N] Fix clang-tidy warnings in aten/src/ATen (#132565)

Follows #132001

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132565
Approved by: https://github.com/Skylion007
This commit is contained in:
cyy
2024-08-04 14:39:16 +00:00
committed by PyTorch MergeBot
parent 908d2a153b
commit 105ba7b58c
19 changed files with 62 additions and 75 deletions

View File

@ -100,7 +100,7 @@ class TORCH_API Context {
const void* data,
std::optional<c10::DeviceType> device_type = std::nullopt) {
auto opt_device_type =
device_type.has_value() ? device_type.value() : at::getAccelerator();
device_type.has_value() ? device_type : at::getAccelerator();
if (!opt_device_type.has_value() || // there is no accelerator
!at::isAccelerator(
opt_device_type.value())) { // passed device not an accelerator

View File

@ -131,21 +131,21 @@ static Device getATenDevice(const DLDevice& ctx, void* data) {
#ifndef USE_ROCM
// if we are compiled under HIP, we cannot do cuda
case DLDeviceType::kDLCUDA:
return at::Device(DeviceType::CUDA, ctx.device_id);
return at::Device(DeviceType::CUDA, static_cast<c10::DeviceIndex>(ctx.device_id));
#endif
case DLDeviceType::kDLOpenCL:
return at::Device(DeviceType::OPENCL, ctx.device_id);
return at::Device(DeviceType::OPENCL, static_cast<c10::DeviceIndex>(ctx.device_id));
case DLDeviceType::kDLROCM:
#ifdef USE_ROCM
// this looks funny, we need to return CUDA here to masquerade
return at::Device(DeviceType::CUDA, ctx.device_id);
return at::Device(DeviceType::CUDA, static_cast<c10::DeviceIndex>(ctx.device_id));
#else
return at::Device(DeviceType::HIP, ctx.device_id);
return at::Device(DeviceType::HIP, static_cast<c10::DeviceIndex>(ctx.device_id));
#endif
case DLDeviceType::kDLOneAPI:
return at::detail::getXPUHooks().getDeviceFromPtr(data);
case DLDeviceType::kDLMAIA:
return at::Device(DeviceType::MAIA, ctx.device_id);
return at::Device(DeviceType::MAIA, static_cast<c10::DeviceIndex>(ctx.device_id));
default:
TORCH_CHECK(
false, "Unsupported device_type: ", std::to_string(ctx.device_type));
@ -286,7 +286,7 @@ DLManagedTensor* toDLPack(const Tensor& src) {
device_id = src.get_device();
}
atDLMTensor->tensor.dl_tensor.device = getDLDevice(src, device_id);
atDLMTensor->tensor.dl_tensor.ndim = src.dim();
atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim());
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
atDLMTensor->tensor.dl_tensor.shape = view.sizes().data();
atDLMTensor->tensor.dl_tensor.strides = view.strides().data();

View File

@ -20,7 +20,7 @@ namespace at {
* The method should_include_kernel_dtype() returns true/false
* based on whether the switching code for a specific dtype should be
* included based on build time constants generated from tracing model
* execution. This method will be implmeneted via code-generation and
* execution. This method will be implemented via code-generation and
* included in this file when code-gen is ready.
*/
inline constexpr bool should_include_kernel_dtype(

View File

@ -29,7 +29,7 @@ static Tensor permute_inverse(const Tensor& self, IntArrayRef dims, InverseRetur
static Tensor unsqueeze_copy_to(const Tensor & self, c10::SymIntArrayRef sizes, InverseReturnMode inverse_return_mode) {
auto result = self;
bool need_alias = (inverse_return_mode == InverseReturnMode::AlwaysView);
int64_t nDims = sizes.size();
int64_t nDims = static_cast<int64_t>(sizes.size());
for(const auto dim : c10::irange(nDims)) {
if (sizes[dim] == 1) {
need_alias = false;

View File

@ -1,8 +1,5 @@
#include <ATen/TensorGeometry.h>
#include <limits>
#include <cstddef>
namespace at {
// See TensorGeometry.h on why this is useful now that we cache is_contiguous.

View File

@ -498,10 +498,10 @@ static void gemm_batched_mkl_impl(
template <typename scalar_t>
using is_blas_library_type = std::integral_constant<bool,
std::is_same<scalar_t, double>::value ||
std::is_same<scalar_t, float>::value ||
std::is_same<scalar_t, c10::complex<double>>::value ||
std::is_same<scalar_t, c10::complex<float>>::value>;
std::is_same_v<scalar_t, double> ||
std::is_same_v<scalar_t, float> ||
std::is_same_v<scalar_t, c10::complex<double>> ||
std::is_same_v<scalar_t, c10::complex<float>>>;
template <typename scalar_t>
void gemm_batched_generic(

View File

@ -33,7 +33,7 @@ static std::vector<T> to_cpu(const std::vector<T>& tensors) {
// Explicitly handling undefined tensors here instead of letting `at::_to_cpu` handle it.
// Otherwise, we'd need to require all backends with their own implementation of _to_cpu
// to properly handle undefined tensors.
if constexpr(std::is_same<T, std::optional<at::Tensor>>::value) {
if constexpr(std::is_same_v<T, std::optional<at::Tensor>>) {
if (tensors[i].has_value() && tensors[i].value().defined()) {
to_translate[i] = true;
valid_tensors.push_back(tensors[i].value());
@ -58,7 +58,7 @@ static std::vector<T> to_cpu(const std::vector<T>& tensors) {
return cpu_tensors;
}
static std::optional<c10::Device> compute_target_device(std::vector<at::Tensor>& t_args, std::vector<c10::List<at::Tensor>> tlist_args) {
static std::optional<c10::Device> compute_target_device(std::vector<at::Tensor>& t_args, const std::vector<c10::List<at::Tensor>>& tlist_args) {
// Decide what device to move the output tensor(s) to.
// The current convention is that we use the first tensor arg to pick the device
// Barring that, we take the first tensor from a TensorList arg.

View File

@ -340,8 +340,8 @@ Tensor copy(const Tensor& self, const Tensor& src, bool non_blocking) {
// when self has zero storage.
// This kernel should never really be run, except with debugging using compile(backend="aot_eager")
for (const auto i : c10::irange(src.size())) {
auto curr_src = src[i];
auto curr_self = self[i];
const auto& curr_src = src[i];
const auto& curr_self = self[i];
outs.push_back(at::copy(curr_self, curr_src, non_blocking));
}
return outs;

View File

@ -145,7 +145,7 @@ static Tensor& zero_cpu_(Tensor &self, int64_t nelements) {
if (nullptr == ptr) {
return self.fill_(0);
}
int64_t size_bytes = nelements * self.dtype().itemsize();
auto size_bytes = nelements * self.dtype().itemsize();
if (size_bytes > 0) {
std::memset(ptr, 0, size_bytes);
}

View File

@ -14,9 +14,6 @@
#include <ATen/ops/fractional_max_pool2d_native.h>
#endif
#include <tuple>
#include <vector>
namespace at {
namespace meta {
@ -61,7 +58,7 @@ TORCH_META_FUNC(fractional_max_pool2d) (
/* sizes */
int64_t numPlanes = input.size(planeDim);
int64_t inputH = input.size(heightDim);
int inputW = input.size(widthDim);
auto inputW = input.size(widthDim);
TORCH_CHECK(outputH + poolSizeH - 1 <= inputH,
"fractional_max_pool2d(): pool height ", poolSizeH,
@ -88,15 +85,15 @@ TORCH_META_FUNC(fractional_max_pool2d_backward)(
IntArrayRef output_size,
const at::Tensor& indices) {
int numBatch = 1;
int64_t numBatch = 1;
int planeDim = 0;
int heightDim = 1;
int widthDim = 2;
int outputH = output_size[0];
int outputW = output_size[1];
auto outputH = output_size[0];
auto outputW = output_size[1];
int ndims = input.ndimension();
auto ndims = input.ndimension();
if (ndims == 4) {
numBatch = input.size(0);
planeDim = 1;
@ -105,9 +102,9 @@ TORCH_META_FUNC(fractional_max_pool2d_backward)(
}
/* sizes */
int numPlanes = input.size(planeDim);
int inputH = input.size(heightDim);
int inputW = input.size(widthDim);
auto numPlanes = input.size(planeDim);
auto inputH = input.size(heightDim);
auto inputW = input.size(widthDim);
/* get contiguous gradOutput */
auto gradOutput = gradOutput_.contiguous();
@ -236,13 +233,11 @@ static void fractional_max_pool2d_backward_out_single_batch_frame(
const scalar_t* gradOutputForPlane = gradOutput + plane * outputW * outputH;
const int64_t* indicesForPlane = indices + plane * outputW * outputH;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int h, w;
for (h = 0; h < outputH; ++h) {
for (w = 0; w < outputW; ++w) {
for (int h = 0; h < outputH; ++h) {
for (int w = 0; w < outputW; ++w) {
int outputIndex = h * outputW + w;
int64_t index = indicesForPlane[outputIndex];
AT_ASSERT(index >= 0 && index < inputW * inputH);
AT_ASSERT(index >= 0 && index < static_cast<int64_t>(inputW) * inputH);
gradInputForPlane[index] += gradOutputForPlane[outputIndex];
}
@ -353,15 +348,15 @@ TORCH_IMPL_FUNC(fractional_max_pool2d_backward_cpu) (
gradInput.zero_();
int numBatch = 1;
int64_t numBatch = 1;
int planeDim = 0;
int heightDim = 1;
int widthDim = 2;
int outputH = output_size[0];
int outputW = output_size[1];
auto outputH = output_size[0];
auto outputW = output_size[1];
int ndims = input.ndimension();
auto ndims = input.ndimension();
if (ndims == 4) {
numBatch = input.size(0);
planeDim = 1;
@ -370,9 +365,9 @@ TORCH_IMPL_FUNC(fractional_max_pool2d_backward_cpu) (
}
/* sizes */
int numPlanes = input.size(planeDim);
int inputH = input.size(heightDim);
int inputW = input.size(widthDim);
auto numPlanes = input.size(planeDim);
auto inputH = input.size(heightDim);
auto inputW = input.size(widthDim);
/* get contiguous gradOutput */
auto gradOutput = gradOutput_.contiguous();

View File

@ -16,7 +16,6 @@
#include <ATen/ops/fractional_max_pool3d_native.h>
#endif
#include <vector>
namespace at::meta {
TORCH_PRECOMPUTE_META_FUNC(fractional_max_pool3d)(

View File

@ -6,12 +6,12 @@
namespace at::native {
template<typename scalar_t>
inline std::vector<int> generate_intervals(
inline std::vector<int64_t> generate_intervals(
scalar_t sample,
int64_t inputSize,
int64_t outputSize,
int64_t poolSize) {
std::vector<int> sequence(outputSize);
std::vector<int64_t> sequence(outputSize);
if (outputSize > 1) {
scalar_t alpha = static_cast<scalar_t>(inputSize - poolSize) /
static_cast<scalar_t>(outputSize - 1);
@ -45,7 +45,7 @@ inline void fractional_max_pool_check_shape(
int64_t C = randomSamples.size(1);
int64_t D = randomSamples.size(2);
int64_t input_batch, input_channel;
int64_t input_batch = 0, input_channel = 0;
if (ndim == 2) {
// fractional_max_pool2d
if (input.ndimension() == 3) {

View File

@ -10,9 +10,9 @@
#include <ATen/ops/_fused_adagrad.h>
#include <ATen/ops/_fused_adagrad_native.h>
#endif
namespace at {
namespace native {
namespace at::native {
void _fused_adagrad_kernel_cpu_(
at::TensorList params,
@ -56,4 +56,3 @@ void _fused_adagrad_kernel_cpu_(
DEFINE_DISPATCH(fused_adagrad_stub);
}
}

View File

@ -1,9 +1,7 @@
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
namespace at {
namespace native {
namespace at::native {
using fused_adagrad_fn = void (*)(
const at::Tensor& param,
@ -19,5 +17,4 @@ using fused_adagrad_fn = void (*)(
DECLARE_DISPATCH(fused_adagrad_fn, fused_adagrad_stub);
}
}
} // namespace at::native

View File

@ -12,9 +12,9 @@
#include <ATen/ops/_fused_adamw.h>
#include <ATen/ops/_fused_adamw_native.h>
#endif
namespace at {
namespace native {
namespace at::native {
void _fused_adam_kernel_cpu_(
at::TensorList params,
@ -46,7 +46,7 @@ void _fused_adam_kernel_cpu_(
if (amsgrad) {
TORCH_CHECK(max_exp_avg_sqs.size() == n_tensors);
} else {
TORCH_CHECK(max_exp_avg_sqs.size() == 0);
TORCH_CHECK(max_exp_avg_sqs.empty());
}
TORCH_CHECK(state_steps.size() == n_tensors);
at::Tensor max_exp_avg_sq = at::Tensor();
@ -122,7 +122,7 @@ void _fused_adamw_kernel_cpu_(
if (amsgrad) {
TORCH_CHECK(max_exp_avg_sqs.size() == n_tensors);
} else {
TORCH_CHECK(max_exp_avg_sqs.size() == 0);
TORCH_CHECK(max_exp_avg_sqs.empty());
}
TORCH_CHECK(state_steps.size() == n_tensors);
at::Tensor max_exp_avg_sq = at::Tensor();
@ -172,4 +172,3 @@ void _fused_adamw_kernel_cpu_(
DEFINE_DISPATCH(fused_adam_stub);
}
}

View File

@ -1,9 +1,7 @@
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
namespace at {
namespace native {
namespace at::native {
enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 };
@ -26,5 +24,4 @@ using fused_adam_fn = void (*)(
DECLARE_DISPATCH(fused_adam_fn, fused_adam_stub);
}
}
} // namespace at::native

View File

@ -10,9 +10,9 @@
#include <ATen/ops/_fused_sgd.h>
#include <ATen/ops/_fused_sgd_native.h>
#endif
namespace at {
namespace native {
namespace at::native {
void _fused_sgd_kernel_cpu_(
@ -39,7 +39,7 @@ void _fused_sgd_kernel_cpu_(
TORCH_CHECK(grads.size() == n_tensors);
bool no_momentum_buffer = momentum == 0.0;
if (no_momentum_buffer) {
TORCH_CHECK(momentum_buffer_list.size() == 0);
TORCH_CHECK(momentum_buffer_list.empty());
} else {
TORCH_CHECK(momentum_buffer_list.size() == n_tensors);
}
@ -83,4 +83,3 @@ void _fused_sgd_kernel_cpu_(
DEFINE_DISPATCH(fused_sgd_stub);
}
}

View File

@ -1,9 +1,7 @@
#include <ATen/core/Tensor.h>
#include <ATen/native/DispatchStub.h>
namespace at {
namespace native {
namespace at::native {
using fused_sgd_fn = void (*)(
const at::Tensor& param,
@ -20,5 +18,4 @@ using fused_sgd_fn = void (*)(
DECLARE_DISPATCH(fused_sgd_fn, fused_sgd_stub);
}
}
} // namespace at::native

View File

@ -654,7 +654,11 @@ void cpu_flash_attention_backward(
fill_stub(row_ptr + last_col + 1, static_cast<accum_t>(0), kvBlockSize - last_col - 1);
}
}
#ifdef _MSC_VER
if (is_reduced_type) {
#else
if constexpr (is_reduced_type) {
#endif
for (const auto row : c10::irange(qBlockSize)) {
convert<accum_t, scalar_t>(
attn_data + row * kvBlockSize,
@ -706,7 +710,11 @@ void cpu_flash_attention_backward(
grad_attn_data + row * kvBlockSize,
kvBlockSize);
}
#ifdef _MSC_VER
if (is_reduced_type) {
#else
if constexpr (is_reduced_type) {
#endif
for (const auto row : c10::irange(qBlockSize)) {
convert<accum_t, scalar_t>(
grad_attn_data + row * kvBlockSize,