[3/N] apply clang-tidy in torch/csrc/autograd (#109368)

This PR applies clang-tidy fixes in torch/csrc/autograd/FunctionsManual.cpp. There are also other fixes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109368
Approved by: https://github.com/Skylion007
This commit is contained in:
cyy
2023-09-17 07:26:59 +00:00
committed by PyTorch MergeBot
parent d8da2a7c85
commit 51d2d825ab
9 changed files with 110 additions and 152 deletions

View File

@ -88,7 +88,8 @@ void copy_range(variable_list& out, IndexRange range, at::ArrayRef<Tensor> t) {
TORCH_CHECK(
range.second - range.first == t.size(),
"inconsistent range for TensorList output");
std::copy(t.begin(), t.end(), out.begin() + range.first);
std::copy(
t.begin(), t.end(), out.begin() + static_cast<int64_t>(range.first));
}
Tensor copysign_tensor_self_backward(
@ -139,7 +140,7 @@ int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) {
return 1;
}
for (auto d : dim) {
d = at::maybe_wrap_dim(d, sizes.size());
d = at::maybe_wrap_dim(d, static_cast<int64_t>(sizes.size()));
size *= sizes[d];
}
return size;
@ -151,7 +152,7 @@ static c10::SymInt _safe_size(c10::SymIntArrayRef sizes, c10::IntArrayRef dim) {
return 1;
}
for (auto d : dim) {
d = at::maybe_wrap_dim(d, sizes.size());
d = at::maybe_wrap_dim(d, static_cast<int64_t>(sizes.size()));
size *= sizes[d];
}
return size;
@ -180,11 +181,11 @@ Tensor restore_reduced_dims(
if (keepdim) {
return output;
}
int64_t total_dims = output.dim() + dims.size();
auto total_dims = output.dim() + dims.size();
std::vector<c10::SymInt> target_shape(total_dims, 0);
for (int64_t i : dims) {
if (i < 0) {
i = total_dims + i;
i = static_cast<int64_t>(total_dims) + i;
}
target_shape[i] = 1;
}
@ -522,7 +523,7 @@ Tensor pow_backward_exponent(
const Tensor& grad,
const Scalar& base,
const Tensor& exponent,
Tensor result) {
const Tensor& result) {
auto grad_lambda = [](const Tensor& a, const Scalar& b) {
return (a * b.log()).conj();
};
@ -540,10 +541,10 @@ Tensor pow_backward_exponent(
auto out = grad *
at::where(cond(exponent),
at::zeros({}, grad.options()),
grad_lambda(std::move(result), base_));
grad_lambda(result, base_));
return handle_r_to_c(exponent, std::move(out));
} else {
auto out = grad * grad_lambda(std::move(result), base_);
auto out = grad * grad_lambda(result, base_);
return handle_r_to_c(exponent, std::move(out));
}
}
@ -561,7 +562,8 @@ Tensor angle_backward(const Tensor& grad, const Tensor& self) {
}
Tensor mvlgamma_backward(const Tensor& grad, const Tensor& self, int64_t p) {
Tensor args = at::arange(-p / 2. + 0.5, 0.5, 0.5, self.options());
Tensor args =
at::arange(-static_cast<double>(p) / 2. + 0.5, 0.5, 0.5, self.options());
args = args.add(self.unsqueeze(-1));
return grad * args.digamma_().sum(-1);
}
@ -585,16 +587,16 @@ Tensor masked_fill_backward(const Tensor& grad, const Tensor& mask) {
}
template <typename T>
Tensor mul_tensor_backward(Tensor grad, T other, ScalarType self_st) {
Tensor mul_tensor_backward(const Tensor& grad, T other, ScalarType self_st) {
auto out = grad * other.conj();
return handle_r_to_c(self_st, std::move(out));
}
template Tensor mul_tensor_backward(Tensor, Tensor, ScalarType);
template Tensor mul_tensor_backward(Tensor, Scalar, ScalarType);
template Tensor mul_tensor_backward(const Tensor&, Tensor, ScalarType);
template Tensor mul_tensor_backward(const Tensor&, Scalar, ScalarType);
template <typename T>
Tensor div_tensor_self_backward(
Tensor grad,
const Tensor& grad,
T other,
ScalarType self_st,
const c10::optional<c10::string_view>& rounding_mode) {
@ -606,40 +608,45 @@ Tensor div_tensor_self_backward(
return handle_r_to_c(self_st, std::move(result));
}
template Tensor div_tensor_self_backward(
Tensor,
const Tensor&,
Tensor,
ScalarType,
const c10::optional<c10::string_view>&);
template Tensor div_tensor_self_backward(
Tensor,
const Tensor&,
Scalar,
ScalarType,
const c10::optional<c10::string_view>&);
template <typename T>
Tensor div_tensor_self_backward(Tensor grad, T other, ScalarType self_st) {
Tensor div_tensor_self_backward(
const Tensor& grad,
T other,
ScalarType self_st) {
return div_tensor_self_backward(
std::move(grad), std::move(other), self_st, c10::nullopt);
grad, std::move(other), self_st, c10::nullopt);
}
template Tensor div_tensor_self_backward(Tensor, Tensor, ScalarType);
template Tensor div_tensor_self_backward(Tensor, Scalar, ScalarType);
template Tensor div_tensor_self_backward(const Tensor&, Tensor, ScalarType);
template Tensor div_tensor_self_backward(const Tensor&, Scalar, ScalarType);
Tensor div_tensor_other_backward(
const Tensor& grad,
const Tensor& self,
Tensor other,
const Tensor& other,
const c10::optional<c10::string_view>& rounding_mode) {
if (rounding_mode.has_value()) {
return at::zeros_like(grad, grad.options().dtype(other.scalar_type()));
}
auto result = -grad * ((self / other) / other).conj();
return handle_r_to_c(std::move(other), std::move(result));
return handle_r_to_c(other, std::move(result));
}
Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other) {
return div_tensor_other_backward(
std::move(grad), std::move(self), std::move(other), c10::nullopt);
Tensor div_tensor_other_backward(
const Tensor& grad,
const Tensor& self,
const Tensor& other) {
return div_tensor_other_backward(grad, self, other, c10::nullopt);
}
Tensor permute_backwards(const Tensor& grad, IntArrayRef fwd_dims) {
@ -647,7 +654,8 @@ Tensor permute_backwards(const Tensor& grad, IntArrayRef fwd_dims) {
auto ndims = fwd_dims.size();
std::vector<int64_t> dims(ndims);
for (const auto i : c10::irange(ndims)) {
dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i;
dims[at::maybe_wrap_dim(fwd_dims[i], static_cast<int64_t>(ndims))] =
static_cast<int64_t>(i);
}
return grad.permute(dims);
}
@ -682,7 +690,7 @@ Tensor unsqueeze_multiple(
Tensor res = t;
for (const auto i : c10::irange(n_dims)) {
if (dims_to_unsqueeze[i]) {
res = res.unsqueeze(i);
res = res.unsqueeze(static_cast<int64_t>(i));
}
}
return res;
@ -833,7 +841,7 @@ Tensor prod_backward(
if (input.dim() == 0) {
return grad;
}
dim = at::maybe_wrap_dim(dim, input.sizes().size());
dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(input.sizes().size()));
if (!keepdim) {
// `prod` reduces the dimension at `dim`,
// so, unsqueeze `grad` and `result` at dim.
@ -989,10 +997,10 @@ Tensor unbind_backward(const variable_list& grads, int64_t dim) {
Tensor unsqueeze_to(const Tensor& self, c10::SymIntArrayRef sym_sizes) {
auto result = self;
int64_t nDims = sym_sizes.size();
auto nDims = sym_sizes.size();
for (const auto dim : c10::irange(nDims)) {
if (sym_sizes[dim] == 1) {
result = result.unsqueeze(dim);
result = result.unsqueeze(static_cast<int64_t>(dim));
}
}
return result;
@ -1008,7 +1016,7 @@ Tensor unsqueeze_to(
Tensor result = self;
for (const auto d : c10::irange(ndim)) {
if (mask.test(d) && sym_sizes[d] == 1) {
result = result.unsqueeze(d);
result = result.unsqueeze(static_cast<int64_t>(d));
}
}
return result;
@ -1069,7 +1077,7 @@ std::vector<Tensor> stack_tensors_backward(
}
bool grad_is_complex = grad.is_complex();
for (const auto i : c10::irange(dtypes.size())) {
auto gr = grad.select(dim, i);
auto gr = grad.select(dim, static_cast<int64_t>(i));
if (grad_is_complex && !at::isComplexType(dtypes[i])) {
gr = at::real(gr);
}
@ -1108,8 +1116,8 @@ std::vector<Tensor> block_diag_backward(
continue;
}
// 0d case
auto dim0 = 1;
auto dim1 = 1;
int64_t dim0 = 1;
int64_t dim1 = 1;
// 2d case
if (shape.size() == 2) {
dim0 = shape[0];
@ -1593,7 +1601,7 @@ Tensor renorm_jvp(
int64_t dim,
const Scalar& maxnorm) {
auto self_sizes = self_p.sizes();
dim = c10::maybe_wrap_dim(dim, self_sizes.size());
dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(self_sizes.size()));
at::DimVector reduce_dims(self_sizes.size());
std::iota(reduce_dims.begin(), reduce_dims.end(), 0);
@ -1643,7 +1651,7 @@ Tensor repeat_backward(
return at::zeros_symint(input_shape, grad.options());
}
const auto input_dims = input_shape.size();
int64_t num_unsqueezed = grad.dim() - input_dims;
auto num_unsqueezed = grad.dim() - input_dims;
for (const auto i : c10::irange(num_unsqueezed)) {
(void)i; // Suppress unused variable warning
grad = grad.sum(0, false);
@ -1688,7 +1696,7 @@ Tensor repeat_backward(
// [g2_4, g2_5]]] [g2_4, g2_5]]
if (repeat != 1) {
grad_size.push_back(repeat);
sum_dims.push_back(grad_size.size() - 1);
sum_dims.push_back(static_cast<int64_t>(grad_size.size() - 1));
}
// Don't need to reshape gradient into (repeat, input_shape[dim]) (repeat ==
// 1)
@ -1805,7 +1813,6 @@ Tensor var_backward(
grad = unsqueeze_multiple(grad, dim, self.sym_sizes().size());
}
const c10::SymFloat rnumel(_safe_size(self.sym_sizes(), dim));
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions)
return (c10::SymFloat(2.0) / (rnumel - correction)) * grad *
(self - self.mean(dim, /*keepdim=*/true));
}
@ -2065,7 +2072,7 @@ Tensor split_with_sizes_backward(
int64_t dim,
c10::SymIntArrayRef sizes,
const at::TensorOptions& options) {
dim = at::maybe_wrap_dim(dim, sizes.size());
dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(sizes.size()));
// it's possible some of the grads are not defined (represents tensors of all
// 0s). Since at::cat can't handle those, let's define them
@ -2099,9 +2106,9 @@ Tensor split_backward(
int64_t dim,
c10::SymIntArrayRef sym_sizes,
const at::TensorOptions& options) {
dim = at::maybe_wrap_dim(dim, sym_sizes.size());
dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(sym_sizes.size()));
const auto& dim_size = sym_sizes[dim];
int64_t num_splits = grads.size();
auto num_splits = grads.size();
std::vector<c10::SymInt> split_sizes(num_splits, split_size);
split_sizes[num_splits - 1] =
split_size - (split_size * num_splits - dim_size);
@ -3016,7 +3023,7 @@ static inline c10::SymInt _min_storage_size(
c10::SymIntArrayRef strides,
c10::SymInt storage_offset) {
c10::SymInt storage_size = storage_offset + 1;
int64_t dim = sizes.size();
auto dim = sizes.size();
for (const auto i : c10::irange(dim)) {
const auto& size_i = sizes[i];
if (size_i == 0) {
@ -3833,9 +3840,7 @@ std::tuple<Tensor, Tensor> linalg_qr_jvp(
// trilImInv(trilIm(Q^H A_1 R_1^{-1}))
at::NoTF32Guard disable_tf32;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool compute_q, reduced;
std::tie(compute_q, reduced) = at::native::_parse_qr_mode(mode);
auto [compute_q, reduced] = at::native::_parse_qr_mode(mode);
TORCH_CHECK(
compute_q,
@ -3929,9 +3934,7 @@ Tensor linalg_qr_backward(
// gA = QgR + pi*(Q trilImInv*(Q^H gQ - gR R^H)R_1^{-H})
at::NoTF32Guard disable_tf32;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool compute_q, reduced;
std::tie(compute_q, reduced) = at::native::_parse_qr_mode(mode);
auto [compute_q, reduced] = at::native::_parse_qr_mode(mode);
TORCH_CHECK(
compute_q,
@ -4150,8 +4153,7 @@ Tensor linalg_det_backward(
auto singular = [](const Tensor& A,
const Tensor& /*d*/,
const Tensor& grad) {
Tensor U, S, Vh;
std::tie(U, S, Vh) = at::linalg_svd(A);
auto [U, S, Vh] = at::linalg_svd(A);
auto alpha = (at::linalg_det(U) * at::linalg_det(Vh)).conj() * grad;
auto D = prod_safe_zeros_backward(alpha.unsqueeze(-1), S, S.dim() - 1);
return (U * D.unsqueeze(-2)).matmul(Vh);
@ -4500,7 +4502,8 @@ Tensor fft_r2c_backward(
// 3. discard the complex dim
auto half_sizes = grad.sym_sizes();
std::vector<c10::SymInt> new_grad_shape(half_sizes.begin(), half_sizes.end());
const auto last_dim = at::maybe_wrap_dim(dim.back(), half_sizes.size());
const auto last_dim =
at::maybe_wrap_dim(dim.back(), static_cast<int64_t>(half_sizes.size()));
new_grad_shape[last_dim] = last_dim_size;
const auto zero_length = last_dim_size - grad.sym_size(dim.back());
@ -4607,10 +4610,10 @@ std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
if (ggI.defined() && training) {
auto ggI_sum = sum_exclude_dim1(ggI);
auto ggIinmu_sum = sum_exclude_dim1(ggI * input_sub_mu);
auto all_sub =
((ggI_sum * gO_sum).div_(M))
auto all_sub = ((ggI_sum * gO_sum).div_(M))
.sub_(sum_exclude_dim1(gO * ggI))
.add_((sigma2_eps_neg_1 * gOinmu_sum * ggIinmu_sum).mul_(3. / M));
.add_((sigma2_eps_neg_1 * gOinmu_sum * ggIinmu_sum)
.mul_(3. / static_cast<double>(M)));
auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(M);
auto gI_1t =
(ggIinmu_sum * sigma2_eps_neg_3_2).div_(M) * (gO_sum.div(M) - gO);
@ -4698,11 +4701,10 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
const Tensor& save_invstd_t,
c10::SymIntArrayRef normalized_shape,
std::array<bool, 3> output_mask) {
const int normalized_ndim = normalized_shape.size();
const auto normalized_ndim = normalized_shape.size();
const auto input_shape = input_t.sizes();
const auto input_ndim = input_t.dim();
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
const int axis = input_ndim - normalized_ndim;
const auto axis = input_ndim - normalized_ndim;
const int64_t M =
c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis);
const int64_t N =
@ -4754,10 +4756,10 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
auto ggI_sum = ggI_expanded.sum(1, true);
auto ggI_mu_sum = (ggI_expanded * input_sub_mu).sum(1, true);
auto all_sub =
((ggI_sum * gxhat_sum).div_(N))
auto all_sub = ((ggI_sum * gxhat_sum).div_(N))
.sub_((ggI_expanded * gxhat).sum(1, true))
.add_((sigma2_eps_neg_1 * gxhat_mu_sum * ggI_mu_sum).mul_(3. / N));
.add_((sigma2_eps_neg_1 * gxhat_mu_sum * ggI_mu_sum)
.mul_(3. / static_cast<double>(N)));
auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(N);
auto gI_1t =
(ggI_mu_sum * sigma2_eps_neg_3_2).div_(N) * (gxhat_sum.div(N) - gxhat);
@ -4982,11 +4984,11 @@ Tensor sinc_backward(const Tensor& grad, const Tensor& self) {
// in pads])
Tensor constant_pad_nd_backward(const Tensor& grad, c10::SymIntArrayRef pad) {
auto negated_pad = pad.vec();
// NOLINTNEXTLINE(modernize-use-transparent-functors)
std::transform(
negated_pad.cbegin(),
negated_pad.cend(),
negated_pad.begin(),
// NOLINTNEXTLINE(modernize-use-transparent-functors)
std::negate<c10::SymInt>());
return at::constant_pad_nd_symint(grad, negated_pad, 0);
}
@ -5265,10 +5267,7 @@ std::tuple<Tensor, Tensor> householder_product_backward(
auto v_i = input.narrow(-1, i, 1);
auto t_i = tau.narrow(-1, i, 1);
Tensor v_i_grad, tau_i_grad;
std::tie(v_i_grad, tau_i_grad) = update_grad(i, v_i, t_i, K);
input_grads[i] = v_i_grad;
tau_grads[i] = tau_i_grad;
std::tie(input_grads[i], tau_grads[i]) = update_grad(i, v_i, t_i, K);
// K <- H_{i + 1}^{-1} @ K @ H_i
if (i != flip_i(k - 1)) {
@ -5302,8 +5301,7 @@ std::tuple<Tensor, Tensor> householder_product_backward(
auto v_i = input.narrow(-1, i, 1);
auto t_i = tau.narrow(-1, i, 1);
Tensor v_i_grad, tau_i_grad;
std::tie(v_i_grad, tau_i_grad) = update_grad(i, v_i, t_i, K);
auto [v_i_grad, tau_i_grad] = update_grad(i, v_i, t_i, K);
input_grad.select(-1, i).copy_(v_i_grad.squeeze(-1));
tau_grad.select(-1, i).copy_(tau_i_grad.squeeze(-1));
@ -5626,8 +5624,7 @@ Tensor linalg_lu_solve_LU(
// gLU = gL + gU
at::NoTF32Guard disable_tf32;
Tensor P, L, U;
std::tie(P, L, U) = at::lu_unpack(
auto [P, L, U] = at::lu_unpack(
LU, pivots, /*unpack_data=*/true, /*unpack_pivots=*/left == adjoint);
// TODO Optimise the order of the operations to avoid operating on large
// tensors unnecessarily
@ -5712,8 +5709,7 @@ Tensor linalg_lu_solve_jvp(
// op_3(B)^H A^{-1} = op_3(X)^H
// We can then rewrite the formula above in terms of X as
// dX = op_2(op_1(-op_3(X)^H P(LdUU^{-1} + dL)L^{-1} P^T)) + S
Tensor P, L, U;
std::tie(P, L, U) = at::lu_unpack(LU, pivots);
auto [P, L, U] = at::lu_unpack(LU, pivots);
// Compute V = op_3(X)^H
auto V = left ? X.mH() : X;
// Compute the inner parens LdUU^{-1} + dL
@ -6103,9 +6099,9 @@ Tensor batch_norm_jvp(
int64_t numel = 1;
for (const auto dim : c10::irange(view_size.size())) {
if (dim != 1) {
numel *= input_t.size(dim);
numel *= input_t.size(static_cast<int64_t>(dim));
view_size[dim] = 1;
dims.push_back(dim);
dims.push_back(static_cast<int64_t>(dim));
}
}
Tensor mean_p;
@ -6159,9 +6155,9 @@ Tensor layer_norm_jvp(
if (i < view_size.size() - normalized_shape.size()) {
view_size_affine[i] = 1;
} else {
numel *= input_t.size(i);
numel *= input_t.size(static_cast<int64_t>(i));
view_size[i] = 1;
dims.push_back(i);
dims.push_back(static_cast<int64_t>(i));
}
}
auto mean_p = saved_mean.view(view_size);
@ -6478,8 +6474,7 @@ Tensor lu_factor_ex_backward(
const Tensor& LU,
const Tensor& pivs,
const bool pivot) {
Tensor P, L, U;
std::tie(P, L, U) =
auto [P, L, U] =
at::lu_unpack(LU, pivs, /*unpack_data=*/true, /*unpack_pivots*/ pivot);
// L.shape == (..., m, k)
@ -6565,13 +6560,9 @@ Tensor lu_factor_ex_jvp(
const Tensor& LU,
const Tensor& pivs,
const bool pivot) {
Tensor dL, dU;
{
Tensor P, L, U;
std::tie(P, L, U) =
auto [P, L, U] =
at::lu_unpack(LU, pivs, /*unpack_data=*/true, /*unpack_pivots=*/pivot);
std::tie(dL, dU) = linalg_lu_jvp(dA, P, L, U, pivot);
}
auto [dL, dU] = linalg_lu_jvp(dA, P, L, U, pivot);
auto m = dA.size(-2);
auto n = dA.size(-1);
@ -6975,10 +6966,10 @@ mkldnn_rnn_layer_differentiable_backward(
seq_length);
std::vector<std::tuple<Tensor, Tensor>> layer_states(seq_length + 1);
layer_states[0] = std::make_tuple(hx_, cx_tmp);
for (int seq = 1; seq < seq_length + 1; seq++) {
for (int64_t seq = 1; seq < seq_length + 1; seq++) {
auto hx = hx_prev;
auto cx = cx_prev;
int x_index = reverse ? seq_length - seq : seq - 1;
auto x_index = reverse ? seq_length - seq : seq - 1;
auto gate = at::linear(input_[x_index], weight0, bias_ih)
.add_(at::linear(hx, weight1, bias_hh));
auto chunked_gates = gate.unsafe_chunk(4, 1);
@ -6997,8 +6988,8 @@ mkldnn_rnn_layer_differentiable_backward(
Tensor dx, dWx, dWh, db, db_, dprev_h, dprev_c, dWh_, dWx_;
Tensor new_grad_hy, d1, dgp, dip, dfp, dop, do_, dg, df, di, da;
std::vector<at::Tensor> layer_dx(seq_length);
for (int seq = seq_length - 1; seq >= 0; seq--) {
int x_index = reverse ? seq_length - seq - 1 : seq;
for (int64_t seq = seq_length - 1; seq >= 0; seq--) {
int64_t x_index = reverse ? seq_length - seq - 1 : seq;
auto i = std::get<0>(layer_gates[x_index]);
auto f = std::get<1>(layer_gates[x_index]);
auto g = std::get<2>(layer_gates[x_index]);

View File

@ -139,23 +139,29 @@ at::Tensor pow_backward_exponent(
const at::Tensor& grad,
const at::Scalar& base,
const at::Tensor& exponent,
at::Tensor result);
const at::Tensor& result);
at::Tensor angle_backward(const at::Tensor& grad, const at::Tensor& self);
template <typename T>
at::Tensor mul_tensor_backward(Tensor grad, T other, ScalarType self_st);
template <typename T>
at::Tensor div_tensor_self_backward(Tensor grad, T other, ScalarType self_st);
at::Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other);
at::Tensor mul_tensor_backward(const Tensor& grad, T other, ScalarType self_st);
template <typename T>
at::Tensor div_tensor_self_backward(
Tensor grad,
const Tensor& grad,
T other,
ScalarType self_st);
at::Tensor div_tensor_other_backward(
const Tensor& grad,
const Tensor& self,
const Tensor& other);
template <typename T>
at::Tensor div_tensor_self_backward(
const Tensor& grad,
T other,
ScalarType self_st,
const c10::optional<c10::string_view>& rounding_mode);
at::Tensor div_tensor_other_backward(
const Tensor& grad,
const Tensor& self,
Tensor other,
const Tensor& other,
const c10::optional<c10::string_view>& rounding_mode);
at::Tensor mvlgamma_backward(
const at::Tensor& grad,

View File

@ -290,17 +290,14 @@ const Variable& AutogradMeta::fw_grad(
static_cast<const torch::autograd::DifferentiableViewMeta*>(this);
// This is ok to do as we ONLY modify fw_grad_ and this field is properly
// locked in all methods
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
auto this_view_meta =
const_cast<torch::autograd::DifferentiableViewMeta*>(const_view_meta);
if (this_view_meta->has_fw_view()) {
const auto& view_info = this_view_meta->get_forward_view();
if (const_view_meta->has_fw_view()) {
const auto& view_info = const_view_meta->get_forward_view();
const auto& base = view_info.base_;
const auto& base_val = base._fw_grad(level);
if (base_val.defined()) {
// Lazy initialization of fw_grad_
this_view_meta->fw_grad_ = std::make_shared<ForwardGrad>();
const_view_meta->fw_grad_ = std::make_shared<ForwardGrad>();
Variable new_val;
if (view_info.has_view_fn()) {
@ -310,8 +307,8 @@ const Variable& AutogradMeta::fw_grad(
self.sizes(), self.strides(), self.storage_offset());
}
this_view_meta->fw_grad_->set_value(new_val, level);
return this_view_meta->fw_grad_->value(level);
const_view_meta->fw_grad_->set_value(new_val, level);
return const_view_meta->fw_grad_->value(level);
}
}
}

View File

@ -24,11 +24,6 @@
#include <utility>
#include <vector>
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32")
#endif
namespace torch {
namespace autograd {
@ -177,9 +172,9 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
name(),
c10::ArrayRef<const c10::IValue>(
inputs_vec.data(), inputs_vec.size()),
sequence_nr());
static_cast<int64_t>(sequence_nr()));
} else {
guard.before(name(), sequence_nr());
guard.before(name(), static_cast<int64_t>(sequence_nr()));
}
return apply(std::move(inputs));
} else {
@ -203,7 +198,6 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
c10::SymIntArrayRef shape,
bool is_tensor_subclass,
bool is_nested) noexcept {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
uint32_t input_nr = input_metadata_.size();
auto meta_shape = MetadataShape{c10::in_place_type<SymIntSmallVec>, shape};
input_metadata_.emplace_back(
@ -212,7 +206,6 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
}
uint32_t add_input_metadata(const at::Tensor& t) noexcept {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
uint32_t input_nr = input_metadata_.size();
input_metadata_.emplace_back(t);
return input_nr;
@ -220,7 +213,6 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
/// Adds a placeholder for an input that will not be used.
uint32_t add_input_metadata(undefined_input u) noexcept {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
uint32_t input_nr = input_metadata_.size();
input_metadata_.emplace_back();
return input_nr;
@ -597,21 +589,18 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
// Sequence number used to correlate backward nodes with forward ops in the
// profiler and provide determinism in the engine.
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const uint64_t sequence_nr_;
// See NOTE [ Topological Number ]
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
uint64_t topological_nr_ = 0;
// Tracks whether this node has been added as the next_edge of another node
// via set_next_edge(s), which always calls topological_nr() of all its
// children See NOTE [ Topological Number ] for why we need this.
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
mutable bool has_parent_ = false;
// Id of the thread that created the instance
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
uint64_t thread_id_ = 0;
// Note [Thread Safety on Autograd Node]
@ -656,14 +645,10 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
// hooks are automatically thread safe), we rely on the user to write thread
// safe C++ hooks if they want the hook to be correctly applied in
// multithreading environment.
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::mutex mutex_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
edge_list next_edges_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
PyObject* pyobj_ = nullptr; // weak reference
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::unique_ptr<AnomalyMetadata> anomaly_metadata_ = nullptr;
// NOTE [Hooks ordering]
@ -676,15 +661,11 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
// even if that node won't be executed.
// - retains_grad_hook are like tensor_pre_hooks except they are always
// ordered after all other tensor pre hooks
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::vector<std::unique_ptr<FunctionPreHook>> tensor_pre_hooks_;
std::unordered_map<size_t, std::unique_ptr<FunctionPreHook>>
retains_grad_hooks_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
at::SmallVector<InputMetadata, 2> input_metadata_;
};
@ -706,7 +687,6 @@ struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> {
edge_list next_edges;
using IterArgs<MakeNextFunctionList>::operator();
void operator()(const Variable& variable) {
// NOLINTNEXTLINE(bugprone-branch-clone)
if (variable.defined()) {
next_edges.emplace_back(impl::gradient_edge(variable));
} else {
@ -714,17 +694,11 @@ struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> {
}
}
void operator()(const Variable* variable) {
// NOLINTNEXTLINE(bugprone-branch-clone)
if (variable->defined()) {
next_edges.emplace_back(impl::gradient_edge(*variable));
} else {
next_edges.emplace_back();
}
operator()(*variable);
}
void operator()(const c10::optional<Variable>& variable) {
// NOLINTNEXTLINE(bugprone-branch-clone)
if (variable.has_value() && variable->defined()) {
next_edges.emplace_back(impl::gradient_edge(*variable));
if (variable.has_value()) {
operator()(*variable);
} else {
next_edges.emplace_back();
}
@ -783,5 +757,3 @@ struct TypeAndSize {
} // namespace autograd
} // namespace torch
C10_CLANG_DIAGNOSTIC_POP()

View File

@ -107,6 +107,7 @@ static void accumulate(
// 5) The other Tensor is not a Tensor subclass (except sparse), since
// it's hard to predict the semantics of arbitrary subclass behavior.
// NOLINTNEXTLINE(bugprone-branch-clone)
if (at::GradMode::is_enabled()) {
buffer[pos] = old_var + var;
} else if (

View File

@ -64,8 +64,7 @@ void PyDefaultSavedVariableHooks::push_hooks(
}
void PyDefaultSavedVariableHooks::pop_hooks() {
PyObject *pack_hook(nullptr), *unpack_hook(nullptr);
std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks();
auto [pack_hook, unpack_hook] = at::SavedTensorDefaultHooks::get_hooks();
TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr);
if (Py_IsInitialized()) {
py::gil_scoped_acquire gil;
@ -76,8 +75,7 @@ void PyDefaultSavedVariableHooks::pop_hooks() {
}
std::unique_ptr<SavedVariableHooks> PyDefaultSavedVariableHooks::get_hooks() {
PyObject *pack_hook(nullptr), *unpack_hook(nullptr);
std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks();
auto [pack_hook, unpack_hook] = at::SavedTensorDefaultHooks::get_hooks();
if (!pack_hook || !unpack_hook) {
return nullptr;
}

View File

@ -222,7 +222,7 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
// shared by multiple Tensors. See Note [ Using ForwardGrad ]
// Any transition from not_initialized to initialized
// must be protected by mutex_
std::shared_ptr<ForwardGrad> fw_grad_;
mutable std::shared_ptr<ForwardGrad> fw_grad_;
// The hooks_ field is actually reused by both python and cpp logic
// For both cases, we have a data structure, cpp_hooks_list_ (cpp)
@ -308,7 +308,6 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
// set_requires_grad also checks error conditions.
if (requires_grad) {
TORCH_INTERNAL_ASSERT(self_impl);
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
set_requires_grad(requires_grad, self_impl);
}
TORCH_CHECK(
@ -765,7 +764,6 @@ inline Variable make_variable(
data.getIntrusivePtr()->unique_version()) {
auto data_impl = data.unsafeReleaseIntrusivePtr();
data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
// NOLINTNEXTLINE(bugprone-branch-clone)
if (requires_grad) {
data_impl->set_autograd_meta(
std::make_unique<AutogradMeta>(data_impl.get(), requires_grad));
@ -777,7 +775,6 @@ inline Variable make_variable(
auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
/*version_counter=*/0,
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
// NOLINTNEXTLINE(bugprone-branch-clone)
if (requires_grad) {
data_impl_copy->set_autograd_meta(std::make_unique<AutogradMeta>(
data_impl_copy.get(), requires_grad));

View File

@ -83,8 +83,7 @@ void initializeDtypes() {
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE)};
for (at::ScalarType scalarType : all_scalar_types) {
std::string primary_name, legacy_name;
std::tie(primary_name, legacy_name) = getDtypeNames(scalarType);
auto [primary_name, legacy_name] = getDtypeNames(scalarType);
PyObject* dtype = THPDtype_New(scalarType, primary_name);
torch::registerDtypeObject((THPDtype*)dtype, scalarType);
Py_INCREF(dtype);

View File

@ -385,10 +385,7 @@ Tensor internal_new_from_data(
at::tracer::impl::NoTracerDispatchMode tracer_guard;
if (isStorage(data)) {
bool is_typed_storage = false;
ScalarType storage_scalar_type{ScalarType::Undefined};
Storage storage;
std::tie(storage, storage_scalar_type, is_typed_storage) =
auto [storage, storage_scalar_type, is_typed_storage] =
createStorageGetType(data);
TORCH_CHECK(