Make TensorIterator, SparseTensorMath and UnaryOps clang-tidy clean (#55087)

Summary:
Disable `cppcoreguidelines-macro-usage` as PyTorch codebase uses a lots
of macros that violate this rule.

Disable `bugprone-reserved-identifier` and
`performance-unnecessary-value-param` as those checks are very slow

Add `NOLINT` to DEFINE_DISPATCH as it introduces non-const global variables
Replace `for(auto i = 0; i < lim; ++i)` with `for(auto i: c10::irange(lim))` throughout the modified files

Pull Request resolved: https://github.com/pytorch/pytorch/pull/55087

Reviewed By: samestep

Differential Revision: D27475822

Pulled By: malfet

fbshipit-source-id: 2651a4b3dc062066a15e69380354414a198fb279
This commit is contained in:
Nikita Shulga
2021-04-01 09:03:13 -07:00
committed by Facebook GitHub Bot
parent f0dafeb0cb
commit 8d5df95551
5 changed files with 116 additions and 110 deletions

View File

@ -6,8 +6,10 @@ bugprone-*,
-bugprone-forward-declaration-namespace,
-bugprone-macro-parentheses,
-bugprone-lambda-function-name,
-bugprone-reserved-identifier,
cppcoreguidelines-*,
-cppcoreguidelines-interfaces-global-init,
-cppcoreguidelines-macro-usage,
-cppcoreguidelines-owning-memory,
-cppcoreguidelines-pro-bounds-array-to-pointer-decay,
-cppcoreguidelines-pro-bounds-constant-array-index,
@ -30,6 +32,7 @@ modernize-*,
-modernize-use-trailing-return-type,
performance-*,
-performance-noexcept-move-constructor,
-performance-unnecessary-value-param,
'
HeaderFilterRegex: 'torch/csrc/.*'
AnalyzeTemporaryDtors: false

View File

@ -300,7 +300,7 @@ void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) {
// Promotes common dtype to the default float scalar type, if needed
if (config.promote_integer_inputs_to_float_ &&
c10::isIntegralType(common_dtype_, /*include_bool=*/true)) {
c10::isIntegralType(common_dtype_, /*includeBool=*/true)) {
common_dtype_ = c10::typeMetaToScalarType(c10::get_default_dtype());
}
@ -570,7 +570,7 @@ bool TensorIteratorBase::is_dim_reduced(int dim) const {
}
void TensorIteratorBase::permute_dimensions(IntArrayRef perm) {
TORCH_INTERNAL_ASSERT(perm.size() == ndim());
TORCH_INTERNAL_ASSERT(perm.size() == static_cast<unsigned>(ndim()));
auto reorder = [perm](IntArrayRef data) {
auto res = DimVector(data.size(), 0);
@ -637,7 +637,7 @@ void TensorIteratorBase::serial_for_each(loop2d_t loop, Range range) const {
return;
}
auto strides = get_strides();
while (strides.size() < 2 * ntensors()) {
while (strides.size() < 2U * ntensors()) {
strides.push_back(0);
}
@ -1100,7 +1100,7 @@ bool TensorIteratorBase::fast_set_up(const TensorIteratorConfig& config) {
case FastSetupType::NON_OVERLAPPING_DENSE:
{
// find the index of a defined tensor in operands_ start from input tensor
int i_defined;
int i_defined; // NOLINT(cppcoreguidelines-init-variables)
for (i_defined = ntensors() - 1; i_defined >= 0; --i_defined) {
if (operands_[i_defined].tensor.defined()) break;
}
@ -1187,7 +1187,7 @@ FastSetupType TensorIteratorBase::compute_fast_setup_type(const TensorIteratorCo
return FastSetupType::NONE;
}
TensorIteratorBase::TensorIteratorBase() {}
TensorIteratorBase::TensorIteratorBase() = default;
void TensorIteratorBase::build(TensorIteratorConfig& config) {
// populate some persistent configuration fields

View File

@ -10,12 +10,12 @@
namespace at { namespace native {
DEFINE_DISPATCH(where_kernel);
DEFINE_DISPATCH(max_stub);
DEFINE_DISPATCH(min_stub);
DEFINE_DISPATCH(_aminmax_stub);
DEFINE_DISPATCH(isposinf_stub);
DEFINE_DISPATCH(isneginf_stub);
DEFINE_DISPATCH(where_kernel); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(max_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(min_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(_aminmax_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(isposinf_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(isneginf_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
bool allclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) {
return at::isclose(self, other, rtol, atol, equal_nan).all().item<uint8_t>();
@ -66,7 +66,7 @@ Tensor isclose(const Tensor& self, const Tensor& other, double rtol, double atol
// Computes allowed and actual error
Tensor cast_other;
if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
cast_other = other.to(at::get_default_dtype());
} else {
cast_other = other;
@ -86,7 +86,7 @@ Tensor isnan(const Tensor& self) {
Tensor isreal(const Tensor& self) {
// Note: Integral and Floating tensor values are always real
if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true) ||
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true) ||
c10::isFloatingType(self.scalar_type())) {
return at::ones_like(self, at::kBool, at::MemoryFormat::Preserve);
}
@ -96,7 +96,7 @@ Tensor isreal(const Tensor& self) {
Tensor isinf(const Tensor &self) {
// Note: Integral tensor values are never infinite
if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
return at::zeros_like(self, at::kBool, at::MemoryFormat::Preserve);
}
@ -122,7 +122,7 @@ Tensor& isposinf_out(const Tensor& self, Tensor& result) {
TORCH_CHECK(result.scalar_type() == at::kBool, "isposinf does not support non-boolean outputs.");
result.resize_(self.sizes());
if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
result.fill_(false);
} else {
auto iter = TensorIteratorConfig()
@ -146,7 +146,7 @@ Tensor& isneginf_out(const Tensor& self, Tensor& result) {
TORCH_CHECK(result.scalar_type() == at::kBool, "isneginf does not support non-boolean outputs.");
result.resize_(self.sizes());
if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
result.fill_(false);
} else {
auto iter = TensorIteratorConfig()
@ -161,7 +161,7 @@ Tensor& isneginf_out(const Tensor& self, Tensor& result) {
Tensor isfinite(const Tensor& self) {
// Note: Integral tensor values are always finite
if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
return at::ones_like(self, at::kBool, at::MemoryFormat::Preserve);
}

View File

@ -157,7 +157,7 @@ Tensor rad2deg(const Tensor& self) {
// Note: int-> float promotion handled differently from other Unary ops,
// as it does not use the usual TensorIterator + Kernel Dispatch pattern.
auto options = self.options();
if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
options = options.dtype(c10::get_default_dtype());
}
auto result = at::empty_like(self, options);
@ -175,7 +175,7 @@ Tensor deg2rad(const Tensor& self) {
// Note: int-> float promotion handled differently from other Unary ops,
// as it does not use the usual TensorIterator + Kernel Dispatch pattern.
auto options = self.options();
if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
options = options.dtype(c10::get_default_dtype());
}
auto result = at::empty_like(self, options);
@ -485,7 +485,7 @@ Tensor& nan_to_num_out(const Tensor& self,
" should be same as input: ",
self.scalar_type());
if (c10::isIntegralType(self.scalar_type(), /*include_bool=*/true)) {
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
result.resize_as_(self);
result.copy_(self);
return result;
@ -678,10 +678,15 @@ Tensor& polygamma_out(int64_t n, const Tensor& self, Tensor& result) {
return result;
}
namespace {
constexpr double HALF = 0.5;
constexpr double QUARTER = 0.25;
}
static inline void mvlgamma_check(const Tensor& self, int64_t p) {
TORCH_CHECK(at::isFloatingType(self.scalar_type()),
"mvlgamma is not implemented for ", self.scalar_type());
TORCH_CHECK((self > 0.5f * (p - 1)).all().item<bool>(),
TORCH_CHECK((self > HALF * (p - 1)).all().item<bool>(),
"All elements must be greater than (p-1)/2");
TORCH_CHECK(p >= 1, "p has to be greater than or equal to 1");
}
@ -689,30 +694,31 @@ static inline void mvlgamma_check(const Tensor& self, int64_t p) {
Tensor mvlgamma(const Tensor& self, int64_t p) {
mvlgamma_check(self, p);
Tensor args = native::arange(
-p / 2. + 0.5,
0.5,
0.5,
-p * HALF + HALF,
HALF,
HALF,
optTypeMetaToScalarType(self.options().dtype_opt()),
self.options().layout_opt(),
self.options().device_opt(),
self.options().pinned_memory_opt());
args = args.add(self.unsqueeze(-1));
return args.lgamma_().sum(-1).add_(p * (p - 1) * std::log(c10::pi<double>) / 4.);
const auto p2_sub_p = static_cast<double>(p * (p - 1));
return args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER);
}
Tensor& mvlgamma_(Tensor& self, int64_t p) {
mvlgamma_check(self, p);
auto dtype_opt = self.options().dtype_opt();
Tensor args = native::arange(
-p / 2. + 0.5,
0.5,
0.5,
-p *HALF + HALF,
HALF,
HALF,
optTypeMetaToScalarType(self.options().dtype_opt()),
self.options().layout_opt(),
self.options().device_opt(),
self.options().pinned_memory_opt());
args = args.add(self.unsqueeze(-1));
return self.copy_(args.lgamma_().sum(-1).add_(p * (p - 1) * std::log(c10::pi<double>) / 4.));
const auto p2_sub_p = static_cast<double>(p * (p - 1));
return self.copy_(args.lgamma_().sum(-1).add_(p2_sub_p * std::log(c10::pi<double>) * QUARTER));
}
Tensor& lgamma_out(const Tensor& self, Tensor& result) { return unary_op_impl_float_out(result, self, lgamma_stub); }
@ -761,60 +767,60 @@ Tensor& special_gammaln_out(const Tensor& self, Tensor& result) { return at::lga
Tensor special_entr(const Tensor& self) { return unary_op_impl_float(self, entr_stub); }
Tensor& special_entr_out(const Tensor& self, Tensor& result) { return unary_op_impl_float_out(result, self, entr_stub);}
DEFINE_DISPATCH(abs_stub);
DEFINE_DISPATCH(angle_stub);
DEFINE_DISPATCH(real_stub);
DEFINE_DISPATCH(imag_stub);
DEFINE_DISPATCH(conj_stub);
DEFINE_DISPATCH(acos_stub);
DEFINE_DISPATCH(acosh_stub);
DEFINE_DISPATCH(asinh_stub);
DEFINE_DISPATCH(atanh_stub);
DEFINE_DISPATCH(asin_stub);
DEFINE_DISPATCH(atan_stub);
DEFINE_DISPATCH(bitwise_not_stub);
DEFINE_DISPATCH(ceil_stub);
DEFINE_DISPATCH(clamp_stub);
DEFINE_DISPATCH(clamp_max_stub);
DEFINE_DISPATCH(clamp_min_stub);
DEFINE_DISPATCH(cos_stub);
DEFINE_DISPATCH(cosh_stub);
DEFINE_DISPATCH(digamma_stub);
DEFINE_DISPATCH(entr_stub);
DEFINE_DISPATCH(erf_stub);
DEFINE_DISPATCH(erfc_stub);
DEFINE_DISPATCH(erfinv_stub);
DEFINE_DISPATCH(exp_stub);
DEFINE_DISPATCH(exp2_stub);
DEFINE_DISPATCH(expm1_stub);
DEFINE_DISPATCH(floor_stub);
DEFINE_DISPATCH(frac_stub);
DEFINE_DISPATCH(frexp_stub);
DEFINE_DISPATCH(i0_stub);
DEFINE_DISPATCH(log_stub);
DEFINE_DISPATCH(log10_stub);
DEFINE_DISPATCH(log1p_stub);
DEFINE_DISPATCH(log2_stub);
DEFINE_DISPATCH(logical_not_stub);
DEFINE_DISPATCH(neg_stub);
DEFINE_DISPATCH(nan_to_num_stub);
DEFINE_DISPATCH(polygamma_stub);
DEFINE_DISPATCH(reciprocal_stub);
DEFINE_DISPATCH(round_stub);
DEFINE_DISPATCH(rsqrt_stub);
DEFINE_DISPATCH(sigmoid_stub);
DEFINE_DISPATCH(logit_stub);
DEFINE_DISPATCH(sign_stub);
DEFINE_DISPATCH(signbit_stub);
DEFINE_DISPATCH(sgn_stub);
DEFINE_DISPATCH(sin_stub);
DEFINE_DISPATCH(sinc_stub);
DEFINE_DISPATCH(sinh_stub);
DEFINE_DISPATCH(sqrt_stub);
DEFINE_DISPATCH(tan_stub);
DEFINE_DISPATCH(tanh_stub);
DEFINE_DISPATCH(trigamma_stub);
DEFINE_DISPATCH(trunc_stub);
DEFINE_DISPATCH(lgamma_stub);
DEFINE_DISPATCH(abs_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(angle_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(real_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(imag_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(conj_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(acos_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(acosh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(asinh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(atanh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(asin_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(atan_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(bitwise_not_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(ceil_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(clamp_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(clamp_max_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(clamp_min_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(cos_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(cosh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(digamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(entr_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(erf_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(erfc_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(erfinv_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(exp_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(exp2_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(expm1_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(floor_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(frac_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(frexp_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(i0_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(log_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(log10_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(log1p_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(log2_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(logical_not_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(neg_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(nan_to_num_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(polygamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(reciprocal_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(round_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(rsqrt_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(sigmoid_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(logit_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(sign_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(signbit_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(sgn_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(sin_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(sinc_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(sinh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(sqrt_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(tan_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(tanh_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(trigamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(trunc_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
DEFINE_DISPATCH(lgamma_stub); // NOLINT(cppcoreguidelines-avoid-non-const-global-variables)
} // namespace native
} // namespace at

View File

@ -1,5 +1,6 @@
#include <ATen/native/sparse/SparseTensorMath.h>
#include <c10/util/irange.h>
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/SparseTensorImpl.h>
@ -283,7 +284,7 @@ SparseTensor& div_out_sparse_zerodim(const SparseTensor& t, const Tensor& value,
Tensor div_sparse(const Tensor& self, const Tensor& value) {
auto commonDtype = at::result_type(self, value);
if (c10::isIntegralType(commonDtype, /*include_bool=*/true)) {
if (c10::isIntegralType(commonDtype, /*includeBool=*/true)) {
commonDtype = typeMetaToScalarType(at::get_default_dtype());
}
Tensor result = at::empty({0}, self.options().dtype(commonDtype));
@ -376,7 +377,7 @@ Tensor norm_sparse(const SparseTensor& self, const optional<Scalar>& p, IntArray
if (dim.size() > 0) {
// Only full reductions are supported, so check if that is the case
int64_t ndim = self.dim();
bool passed_full_reduction_check = ndim == dim.size();
bool passed_full_reduction_check = static_cast<size_t>(ndim) == dim.size();
if (passed_full_reduction_check) {
auto dim_ = dim.vec();
maybe_wrap_dims(dim_, ndim);
@ -395,7 +396,8 @@ Tensor norm_sparse(const SparseTensor& self, const optional<Scalar>& p, IntArray
}
TORCH_CHECK(keepdim == false, "norm_sparse currently does not support keepdim=True");
TORCH_CHECK(!dtype.has_value(), "norm_sparse currently does not support 'dtype' argument");
auto p_ = p.value_or(2.0);
constexpr auto TWO = 2.0;
auto p_ = p.value_or(TWO);
return self.coalesce()._values().norm(p_);
}
@ -468,7 +470,6 @@ SparseTensor& add_out_sparse_contiguous(SparseTensor& r, const SparseTensor& t,
Tensor r_values = new_values_with_size_of(s_values, max_nnz).zero_();
int64_t blockSize = r_values.stride(0);
int64_t cmp, d;
int64_t r_i = 0, t_i = 0, s_i = 0;
auto t_indices = t._indices();
auto src_indices = src._indices();
@ -485,13 +486,14 @@ SparseTensor& add_out_sparse_contiguous(SparseTensor& r, const SparseTensor& t,
scalar_t* r_values_ptr = r_values.data_ptr<scalar_t>();
scalar_t cast_value = value.to<scalar_t>();
while (t_i < t_nnz || s_i < s_nnz) {
int64_t cmp;
if (t_i >= t_nnz) {
cmp = -1;
} else if (s_i >= s_nnz) {
cmp = 1;
} else {
cmp = 0;
for (d = 0; d < sparse_dim; d++) {
for (auto d: c10::irange(sparse_dim)) {
if (t_indices_accessor[d][t_i] < src_indices_accessor[d][s_i]) {
cmp = 1;
break;
@ -503,7 +505,7 @@ SparseTensor& add_out_sparse_contiguous(SparseTensor& r, const SparseTensor& t,
}
}
if (cmp >= 0) {
for (d = 0; d < sparse_dim; d++) {
for (auto d: c10::irange(sparse_dim)) {
r_indices_accessor[d][r_i] = t_indices_accessor[d][t_i];
}
if (t_values.numel() > 0) { // We add all elements from t_values to r_values only if t_values is not an empty tensor
@ -514,7 +516,7 @@ SparseTensor& add_out_sparse_contiguous(SparseTensor& r, const SparseTensor& t,
t_i++;
}
if (cmp <= 0) {
for (d = 0; d < sparse_dim; d++) {
for (auto d: c10::irange(sparse_dim)) {
r_indices_accessor[d][r_i] = src_indices_accessor[d][s_i];
}
if (s_values.numel() > 0) { // We add all elements from s_values to r_values only if s_values is not an empty tensor
@ -619,9 +621,9 @@ void add_dense_sparse_worker_cpu(Tensor& r, const Scalar& value, const SparseTen
scalar_t cast_value = value.to<scalar_t>();
at::parallel_for(0, sparse._nnz(), 0, [&](int64_t start, int64_t end) {
for (auto k = start; k < end; k++) {
for (auto k: c10::irange(start, end)) {
int64_t index = r.storage_offset();
for (int64_t d = 0; d < sparse.sparse_dim(); d++) {
for (auto d: c10::irange(sparse.sparse_dim())) {
index += r.stride(d) * indices_accessor[d][k];
}
r_ptr[index] += cast_value * values_accessor[k];
@ -732,7 +734,6 @@ SparseTensor& mul_out_sparse_cpu(const Tensor& t_, const Tensor& src_, SparseTen
Tensor src_indices = src._indices();
Tensor r_indices = at::empty({sparse_dim, max_nnz}, t_indices.options());
int64_t match, d;
int64_t r_i = 0, t_i = 0, s_i = 0;
auto commonDtype = promoteTypes(t_.scalar_type(), src_.scalar_type());
@ -752,21 +753,17 @@ SparseTensor& mul_out_sparse_cpu(const Tensor& t_, const Tensor& src_, SparseTen
// entry to the result indices vector. Returns true if matching
// indices were found.
auto index_preamble = [&]() {
match = 1;
for (d = 0; d < sparse_dim; d++) {
for (auto d: c10::irange(sparse_dim)) {
if (t_indices_accessor[d][t_i] < src_indices_accessor[d][s_i]) {
t_i++;
match = 0;
break;
return false;
}
if (t_indices_accessor[d][t_i] > src_indices_accessor[d][s_i]) {
s_i++;
match = 0;
break;
return false;
}
}
if (!match) return false;
for (d = 0; d < sparse_dim; d++) {
for (auto d: c10::irange(sparse_dim)) {
r_indices_accessor[d][r_i] = t_indices_accessor[d][t_i];
}
return true;
@ -813,7 +810,6 @@ SparseTensor& mul_out_sparse_cpu(const Tensor& t_, const Tensor& src_, SparseTen
template <typename scalar_t>
void s_addmm_out_sparse_dense_worker(int64_t nnz, int64_t dim_i, int64_t dim_j, int64_t dim_k, Tensor& r, const Scalar& beta, const Tensor& t, const Scalar& alpha, const Tensor& indices, const Tensor& values, const Tensor& dense) {
int64_t i;
// r_ = alpha * sparse * dense
scalar_t cast_alpha = alpha.to<scalar_t>();
@ -838,7 +834,7 @@ void s_addmm_out_sparse_dense_worker(int64_t nnz, int64_t dim_i, int64_t dim_j,
int64_t dense_stride1 = dense.stride(1);
int64_t r_stride0 = r.stride(0);
int64_t r_stride1 = r.stride(1);
for (i = 0; i < nnz; i++) {
for (auto i: c10::irange(nnz)) {
scalar_t val = values_accessor[i];
int64_t row = indices_accessor[0][i];
int64_t col = indices_accessor[1][i];
@ -1281,7 +1277,7 @@ Tensor _sparse_sum(const SparseTensor& input, IntArrayRef dims_to_sum) {
}
else {
new_indices = at::empty({sparse_dim - sparse_dims_to_sum_size, input._nnz()}, indices.options());
for (int64_t i = 0; i < dims_to_keep_v.size(); i++) {
for (auto i: c10::irange(dims_to_keep_v.size())) {
int64_t d = dims_to_keep_v[i];
if (d < sparse_dim) new_indices[i].copy_(indices[d]);
else break;
@ -1292,6 +1288,7 @@ Tensor _sparse_sum(const SparseTensor& input, IntArrayRef dims_to_sum) {
int64_t new_sparse_dim = new_indices.size(0);
int64_t new_dense_dim = new_values.dim() - 1; // exclude nnz dim
std::vector<int64_t> new_sizes;
new_sizes.reserve(dims_to_keep_v.size());
for (auto d : dims_to_keep_v) new_sizes.emplace_back(sizes[d]);
if (sum_all_sparse_dim) new_sizes.emplace(new_sizes.begin(), 1);
@ -1367,7 +1364,7 @@ Tensor _sparse_sum_backward_cpu(const Tensor& grad_, const SparseTensor& input_,
int64_t sparse_dims_to_sum_size = 0;
auto sparse_dims_to_keep_v = std::vector<int64_t>();
auto dense_dims_to_sum_v = std::vector<int64_t>();
for (int64_t d = 0; d < input_dim; d++) {
for (auto d: c10::irange(input_dim)) {
if (dims_to_sum_b[d]) {
if (d < input_sparse_dim) sparse_dims_to_sum_size ++;
else dense_dims_to_sum_v.emplace_back(d + 1 - input_sparse_dim);
@ -1388,7 +1385,7 @@ Tensor _sparse_sum_backward_cpu(const Tensor& grad_, const SparseTensor& input_,
if (sum_dense_dim) {
auto dense_expand_size = std::vector<int64_t>(expand_size);
dense_expand_size.erase(dense_expand_size.begin());
AT_ASSERT(dense_expand_size.size() == (input_values.dim() - 1));
AT_ASSERT(dense_expand_size.size() == static_cast<size_t>(input_values.dim() - 1));
for (auto d : dense_dims_to_sum_v) grad_input_values = grad_input_values.unsqueeze(d - 1); // -1 since grad has no nnz dim
grad_input_values = grad_input_values.expand(dense_expand_size);
}
@ -1428,7 +1425,7 @@ Tensor _sparse_sum_backward_cpu(const Tensor& grad_, const SparseTensor& input_,
// binary search to find matching indices
at::parallel_for(0, input_nnz, 0, [&](int64_t start, int64_t end) {
for (auto i = start; i < end; i++) {
for (auto i: c10::irange(start, end)) {
int64_t input_idx = input_indices_1D_accessor[i];
int64_t l = 0, r = grad_nnz - 1;
while (l <= r) {
@ -1492,7 +1489,7 @@ scalar_t binary_search_strided_rightmost(scalar_t search_val, TensorAccessor<sca
int64_t left_ind = 0;
int64_t right_ind = length - 1;
int64_t mid_ind;
int64_t mid_ind; // NOLINT(cppcoreguidelines-init-variables)
bool done_searching = false;
while (!done_searching) {