mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[5/N] Fix clang-tidy warnings in aten/src/ATen (#132565)
Follows #132001 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132565 Approved by: https://github.com/Skylion007
This commit is contained in:
@ -100,7 +100,7 @@ class TORCH_API Context {
|
||||
const void* data,
|
||||
std::optional<c10::DeviceType> device_type = std::nullopt) {
|
||||
auto opt_device_type =
|
||||
device_type.has_value() ? device_type.value() : at::getAccelerator();
|
||||
device_type.has_value() ? device_type : at::getAccelerator();
|
||||
if (!opt_device_type.has_value() || // there is no accelerator
|
||||
!at::isAccelerator(
|
||||
opt_device_type.value())) { // passed device not an accelerator
|
||||
|
@ -131,21 +131,21 @@ static Device getATenDevice(const DLDevice& ctx, void* data) {
|
||||
#ifndef USE_ROCM
|
||||
// if we are compiled under HIP, we cannot do cuda
|
||||
case DLDeviceType::kDLCUDA:
|
||||
return at::Device(DeviceType::CUDA, ctx.device_id);
|
||||
return at::Device(DeviceType::CUDA, static_cast<c10::DeviceIndex>(ctx.device_id));
|
||||
#endif
|
||||
case DLDeviceType::kDLOpenCL:
|
||||
return at::Device(DeviceType::OPENCL, ctx.device_id);
|
||||
return at::Device(DeviceType::OPENCL, static_cast<c10::DeviceIndex>(ctx.device_id));
|
||||
case DLDeviceType::kDLROCM:
|
||||
#ifdef USE_ROCM
|
||||
// this looks funny, we need to return CUDA here to masquerade
|
||||
return at::Device(DeviceType::CUDA, ctx.device_id);
|
||||
return at::Device(DeviceType::CUDA, static_cast<c10::DeviceIndex>(ctx.device_id));
|
||||
#else
|
||||
return at::Device(DeviceType::HIP, ctx.device_id);
|
||||
return at::Device(DeviceType::HIP, static_cast<c10::DeviceIndex>(ctx.device_id));
|
||||
#endif
|
||||
case DLDeviceType::kDLOneAPI:
|
||||
return at::detail::getXPUHooks().getDeviceFromPtr(data);
|
||||
case DLDeviceType::kDLMAIA:
|
||||
return at::Device(DeviceType::MAIA, ctx.device_id);
|
||||
return at::Device(DeviceType::MAIA, static_cast<c10::DeviceIndex>(ctx.device_id));
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false, "Unsupported device_type: ", std::to_string(ctx.device_type));
|
||||
@ -286,7 +286,7 @@ DLManagedTensor* toDLPack(const Tensor& src) {
|
||||
device_id = src.get_device();
|
||||
}
|
||||
atDLMTensor->tensor.dl_tensor.device = getDLDevice(src, device_id);
|
||||
atDLMTensor->tensor.dl_tensor.ndim = src.dim();
|
||||
atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim());
|
||||
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
|
||||
atDLMTensor->tensor.dl_tensor.shape = view.sizes().data();
|
||||
atDLMTensor->tensor.dl_tensor.strides = view.strides().data();
|
||||
|
@ -20,7 +20,7 @@ namespace at {
|
||||
* The method should_include_kernel_dtype() returns true/false
|
||||
* based on whether the switching code for a specific dtype should be
|
||||
* included based on build time constants generated from tracing model
|
||||
* execution. This method will be implmeneted via code-generation and
|
||||
* execution. This method will be implemented via code-generation and
|
||||
* included in this file when code-gen is ready.
|
||||
*/
|
||||
inline constexpr bool should_include_kernel_dtype(
|
||||
|
@ -29,7 +29,7 @@ static Tensor permute_inverse(const Tensor& self, IntArrayRef dims, InverseRetur
|
||||
static Tensor unsqueeze_copy_to(const Tensor & self, c10::SymIntArrayRef sizes, InverseReturnMode inverse_return_mode) {
|
||||
auto result = self;
|
||||
bool need_alias = (inverse_return_mode == InverseReturnMode::AlwaysView);
|
||||
int64_t nDims = sizes.size();
|
||||
int64_t nDims = static_cast<int64_t>(sizes.size());
|
||||
for(const auto dim : c10::irange(nDims)) {
|
||||
if (sizes[dim] == 1) {
|
||||
need_alias = false;
|
||||
|
@ -1,8 +1,5 @@
|
||||
#include <ATen/TensorGeometry.h>
|
||||
|
||||
#include <limits>
|
||||
#include <cstddef>
|
||||
|
||||
namespace at {
|
||||
|
||||
// See TensorGeometry.h on why this is useful now that we cache is_contiguous.
|
||||
|
@ -498,10 +498,10 @@ static void gemm_batched_mkl_impl(
|
||||
|
||||
template <typename scalar_t>
|
||||
using is_blas_library_type = std::integral_constant<bool,
|
||||
std::is_same<scalar_t, double>::value ||
|
||||
std::is_same<scalar_t, float>::value ||
|
||||
std::is_same<scalar_t, c10::complex<double>>::value ||
|
||||
std::is_same<scalar_t, c10::complex<float>>::value>;
|
||||
std::is_same_v<scalar_t, double> ||
|
||||
std::is_same_v<scalar_t, float> ||
|
||||
std::is_same_v<scalar_t, c10::complex<double>> ||
|
||||
std::is_same_v<scalar_t, c10::complex<float>>>;
|
||||
|
||||
template <typename scalar_t>
|
||||
void gemm_batched_generic(
|
||||
|
@ -33,7 +33,7 @@ static std::vector<T> to_cpu(const std::vector<T>& tensors) {
|
||||
// Explicitly handling undefined tensors here instead of letting `at::_to_cpu` handle it.
|
||||
// Otherwise, we'd need to require all backends with their own implementation of _to_cpu
|
||||
// to properly handle undefined tensors.
|
||||
if constexpr(std::is_same<T, std::optional<at::Tensor>>::value) {
|
||||
if constexpr(std::is_same_v<T, std::optional<at::Tensor>>) {
|
||||
if (tensors[i].has_value() && tensors[i].value().defined()) {
|
||||
to_translate[i] = true;
|
||||
valid_tensors.push_back(tensors[i].value());
|
||||
@ -58,7 +58,7 @@ static std::vector<T> to_cpu(const std::vector<T>& tensors) {
|
||||
return cpu_tensors;
|
||||
}
|
||||
|
||||
static std::optional<c10::Device> compute_target_device(std::vector<at::Tensor>& t_args, std::vector<c10::List<at::Tensor>> tlist_args) {
|
||||
static std::optional<c10::Device> compute_target_device(std::vector<at::Tensor>& t_args, const std::vector<c10::List<at::Tensor>>& tlist_args) {
|
||||
// Decide what device to move the output tensor(s) to.
|
||||
// The current convention is that we use the first tensor arg to pick the device
|
||||
// Barring that, we take the first tensor from a TensorList arg.
|
||||
|
@ -340,8 +340,8 @@ Tensor copy(const Tensor& self, const Tensor& src, bool non_blocking) {
|
||||
// when self has zero storage.
|
||||
// This kernel should never really be run, except with debugging using compile(backend="aot_eager")
|
||||
for (const auto i : c10::irange(src.size())) {
|
||||
auto curr_src = src[i];
|
||||
auto curr_self = self[i];
|
||||
const auto& curr_src = src[i];
|
||||
const auto& curr_self = self[i];
|
||||
outs.push_back(at::copy(curr_self, curr_src, non_blocking));
|
||||
}
|
||||
return outs;
|
||||
|
@ -145,7 +145,7 @@ static Tensor& zero_cpu_(Tensor &self, int64_t nelements) {
|
||||
if (nullptr == ptr) {
|
||||
return self.fill_(0);
|
||||
}
|
||||
int64_t size_bytes = nelements * self.dtype().itemsize();
|
||||
auto size_bytes = nelements * self.dtype().itemsize();
|
||||
if (size_bytes > 0) {
|
||||
std::memset(ptr, 0, size_bytes);
|
||||
}
|
||||
|
@ -14,9 +14,6 @@
|
||||
#include <ATen/ops/fractional_max_pool2d_native.h>
|
||||
#endif
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
namespace at {
|
||||
|
||||
namespace meta {
|
||||
@ -61,7 +58,7 @@ TORCH_META_FUNC(fractional_max_pool2d) (
|
||||
/* sizes */
|
||||
int64_t numPlanes = input.size(planeDim);
|
||||
int64_t inputH = input.size(heightDim);
|
||||
int inputW = input.size(widthDim);
|
||||
auto inputW = input.size(widthDim);
|
||||
|
||||
TORCH_CHECK(outputH + poolSizeH - 1 <= inputH,
|
||||
"fractional_max_pool2d(): pool height ", poolSizeH,
|
||||
@ -88,15 +85,15 @@ TORCH_META_FUNC(fractional_max_pool2d_backward)(
|
||||
IntArrayRef output_size,
|
||||
const at::Tensor& indices) {
|
||||
|
||||
int numBatch = 1;
|
||||
int64_t numBatch = 1;
|
||||
int planeDim = 0;
|
||||
int heightDim = 1;
|
||||
int widthDim = 2;
|
||||
|
||||
int outputH = output_size[0];
|
||||
int outputW = output_size[1];
|
||||
auto outputH = output_size[0];
|
||||
auto outputW = output_size[1];
|
||||
|
||||
int ndims = input.ndimension();
|
||||
auto ndims = input.ndimension();
|
||||
if (ndims == 4) {
|
||||
numBatch = input.size(0);
|
||||
planeDim = 1;
|
||||
@ -105,9 +102,9 @@ TORCH_META_FUNC(fractional_max_pool2d_backward)(
|
||||
}
|
||||
|
||||
/* sizes */
|
||||
int numPlanes = input.size(planeDim);
|
||||
int inputH = input.size(heightDim);
|
||||
int inputW = input.size(widthDim);
|
||||
auto numPlanes = input.size(planeDim);
|
||||
auto inputH = input.size(heightDim);
|
||||
auto inputW = input.size(widthDim);
|
||||
|
||||
/* get contiguous gradOutput */
|
||||
auto gradOutput = gradOutput_.contiguous();
|
||||
@ -236,13 +233,11 @@ static void fractional_max_pool2d_backward_out_single_batch_frame(
|
||||
const scalar_t* gradOutputForPlane = gradOutput + plane * outputW * outputH;
|
||||
const int64_t* indicesForPlane = indices + plane * outputW * outputH;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
int h, w;
|
||||
for (h = 0; h < outputH; ++h) {
|
||||
for (w = 0; w < outputW; ++w) {
|
||||
for (int h = 0; h < outputH; ++h) {
|
||||
for (int w = 0; w < outputW; ++w) {
|
||||
int outputIndex = h * outputW + w;
|
||||
int64_t index = indicesForPlane[outputIndex];
|
||||
AT_ASSERT(index >= 0 && index < inputW * inputH);
|
||||
AT_ASSERT(index >= 0 && index < static_cast<int64_t>(inputW) * inputH);
|
||||
|
||||
gradInputForPlane[index] += gradOutputForPlane[outputIndex];
|
||||
}
|
||||
@ -353,15 +348,15 @@ TORCH_IMPL_FUNC(fractional_max_pool2d_backward_cpu) (
|
||||
|
||||
gradInput.zero_();
|
||||
|
||||
int numBatch = 1;
|
||||
int64_t numBatch = 1;
|
||||
int planeDim = 0;
|
||||
int heightDim = 1;
|
||||
int widthDim = 2;
|
||||
|
||||
int outputH = output_size[0];
|
||||
int outputW = output_size[1];
|
||||
auto outputH = output_size[0];
|
||||
auto outputW = output_size[1];
|
||||
|
||||
int ndims = input.ndimension();
|
||||
auto ndims = input.ndimension();
|
||||
if (ndims == 4) {
|
||||
numBatch = input.size(0);
|
||||
planeDim = 1;
|
||||
@ -370,9 +365,9 @@ TORCH_IMPL_FUNC(fractional_max_pool2d_backward_cpu) (
|
||||
}
|
||||
|
||||
/* sizes */
|
||||
int numPlanes = input.size(planeDim);
|
||||
int inputH = input.size(heightDim);
|
||||
int inputW = input.size(widthDim);
|
||||
auto numPlanes = input.size(planeDim);
|
||||
auto inputH = input.size(heightDim);
|
||||
auto inputW = input.size(widthDim);
|
||||
|
||||
/* get contiguous gradOutput */
|
||||
auto gradOutput = gradOutput_.contiguous();
|
||||
|
@ -16,7 +16,6 @@
|
||||
#include <ATen/ops/fractional_max_pool3d_native.h>
|
||||
#endif
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace at::meta {
|
||||
TORCH_PRECOMPUTE_META_FUNC(fractional_max_pool3d)(
|
||||
|
@ -6,12 +6,12 @@
|
||||
namespace at::native {
|
||||
|
||||
template<typename scalar_t>
|
||||
inline std::vector<int> generate_intervals(
|
||||
inline std::vector<int64_t> generate_intervals(
|
||||
scalar_t sample,
|
||||
int64_t inputSize,
|
||||
int64_t outputSize,
|
||||
int64_t poolSize) {
|
||||
std::vector<int> sequence(outputSize);
|
||||
std::vector<int64_t> sequence(outputSize);
|
||||
if (outputSize > 1) {
|
||||
scalar_t alpha = static_cast<scalar_t>(inputSize - poolSize) /
|
||||
static_cast<scalar_t>(outputSize - 1);
|
||||
@ -45,7 +45,7 @@ inline void fractional_max_pool_check_shape(
|
||||
int64_t C = randomSamples.size(1);
|
||||
int64_t D = randomSamples.size(2);
|
||||
|
||||
int64_t input_batch, input_channel;
|
||||
int64_t input_batch = 0, input_channel = 0;
|
||||
if (ndim == 2) {
|
||||
// fractional_max_pool2d
|
||||
if (input.ndimension() == 3) {
|
||||
|
@ -10,9 +10,9 @@
|
||||
#include <ATen/ops/_fused_adagrad.h>
|
||||
#include <ATen/ops/_fused_adagrad_native.h>
|
||||
#endif
|
||||
namespace at {
|
||||
|
||||
namespace native {
|
||||
|
||||
namespace at::native {
|
||||
|
||||
void _fused_adagrad_kernel_cpu_(
|
||||
at::TensorList params,
|
||||
@ -56,4 +56,3 @@ void _fused_adagrad_kernel_cpu_(
|
||||
DEFINE_DISPATCH(fused_adagrad_stub);
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,7 @@
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
namespace native {
|
||||
namespace at::native {
|
||||
|
||||
using fused_adagrad_fn = void (*)(
|
||||
const at::Tensor& param,
|
||||
@ -19,5 +17,4 @@ using fused_adagrad_fn = void (*)(
|
||||
|
||||
DECLARE_DISPATCH(fused_adagrad_fn, fused_adagrad_stub);
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace at::native
|
||||
|
@ -12,9 +12,9 @@
|
||||
#include <ATen/ops/_fused_adamw.h>
|
||||
#include <ATen/ops/_fused_adamw_native.h>
|
||||
#endif
|
||||
namespace at {
|
||||
|
||||
namespace native {
|
||||
|
||||
namespace at::native {
|
||||
|
||||
void _fused_adam_kernel_cpu_(
|
||||
at::TensorList params,
|
||||
@ -46,7 +46,7 @@ void _fused_adam_kernel_cpu_(
|
||||
if (amsgrad) {
|
||||
TORCH_CHECK(max_exp_avg_sqs.size() == n_tensors);
|
||||
} else {
|
||||
TORCH_CHECK(max_exp_avg_sqs.size() == 0);
|
||||
TORCH_CHECK(max_exp_avg_sqs.empty());
|
||||
}
|
||||
TORCH_CHECK(state_steps.size() == n_tensors);
|
||||
at::Tensor max_exp_avg_sq = at::Tensor();
|
||||
@ -122,7 +122,7 @@ void _fused_adamw_kernel_cpu_(
|
||||
if (amsgrad) {
|
||||
TORCH_CHECK(max_exp_avg_sqs.size() == n_tensors);
|
||||
} else {
|
||||
TORCH_CHECK(max_exp_avg_sqs.size() == 0);
|
||||
TORCH_CHECK(max_exp_avg_sqs.empty());
|
||||
}
|
||||
TORCH_CHECK(state_steps.size() == n_tensors);
|
||||
at::Tensor max_exp_avg_sq = at::Tensor();
|
||||
@ -172,4 +172,3 @@ void _fused_adamw_kernel_cpu_(
|
||||
DEFINE_DISPATCH(fused_adam_stub);
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,7 @@
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
namespace native {
|
||||
namespace at::native {
|
||||
|
||||
enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 };
|
||||
|
||||
@ -26,5 +24,4 @@ using fused_adam_fn = void (*)(
|
||||
|
||||
DECLARE_DISPATCH(fused_adam_fn, fused_adam_stub);
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace at::native
|
||||
|
@ -10,9 +10,9 @@
|
||||
#include <ATen/ops/_fused_sgd.h>
|
||||
#include <ATen/ops/_fused_sgd_native.h>
|
||||
#endif
|
||||
namespace at {
|
||||
|
||||
namespace native {
|
||||
|
||||
namespace at::native {
|
||||
|
||||
|
||||
void _fused_sgd_kernel_cpu_(
|
||||
@ -39,7 +39,7 @@ void _fused_sgd_kernel_cpu_(
|
||||
TORCH_CHECK(grads.size() == n_tensors);
|
||||
bool no_momentum_buffer = momentum == 0.0;
|
||||
if (no_momentum_buffer) {
|
||||
TORCH_CHECK(momentum_buffer_list.size() == 0);
|
||||
TORCH_CHECK(momentum_buffer_list.empty());
|
||||
} else {
|
||||
TORCH_CHECK(momentum_buffer_list.size() == n_tensors);
|
||||
}
|
||||
@ -83,4 +83,3 @@ void _fused_sgd_kernel_cpu_(
|
||||
DEFINE_DISPATCH(fused_sgd_stub);
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,7 @@
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
namespace native {
|
||||
namespace at::native {
|
||||
|
||||
using fused_sgd_fn = void (*)(
|
||||
const at::Tensor& param,
|
||||
@ -20,5 +18,4 @@ using fused_sgd_fn = void (*)(
|
||||
|
||||
DECLARE_DISPATCH(fused_sgd_fn, fused_sgd_stub);
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace at::native
|
||||
|
@ -654,7 +654,11 @@ void cpu_flash_attention_backward(
|
||||
fill_stub(row_ptr + last_col + 1, static_cast<accum_t>(0), kvBlockSize - last_col - 1);
|
||||
}
|
||||
}
|
||||
#ifdef _MSC_VER
|
||||
if (is_reduced_type) {
|
||||
#else
|
||||
if constexpr (is_reduced_type) {
|
||||
#endif
|
||||
for (const auto row : c10::irange(qBlockSize)) {
|
||||
convert<accum_t, scalar_t>(
|
||||
attn_data + row * kvBlockSize,
|
||||
@ -706,7 +710,11 @@ void cpu_flash_attention_backward(
|
||||
grad_attn_data + row * kvBlockSize,
|
||||
kvBlockSize);
|
||||
}
|
||||
#ifdef _MSC_VER
|
||||
if (is_reduced_type) {
|
||||
#else
|
||||
if constexpr (is_reduced_type) {
|
||||
#endif
|
||||
for (const auto row : c10::irange(qBlockSize)) {
|
||||
convert<accum_t, scalar_t>(
|
||||
grad_attn_data + row * kvBlockSize,
|
||||
|
Reference in New Issue
Block a user