mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[3/N] apply clang-tidy in torch/csrc/autograd (#109368)
This PR applies clang-tidy fixes in torch/csrc/autograd/FunctionsManual.cpp. There are also other fixes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/109368 Approved by: https://github.com/Skylion007
This commit is contained in:
		| @ -88,7 +88,8 @@ void copy_range(variable_list& out, IndexRange range, at::ArrayRef<Tensor> t) { | |||||||
|   TORCH_CHECK( |   TORCH_CHECK( | ||||||
|       range.second - range.first == t.size(), |       range.second - range.first == t.size(), | ||||||
|       "inconsistent range for TensorList output"); |       "inconsistent range for TensorList output"); | ||||||
|   std::copy(t.begin(), t.end(), out.begin() + range.first); |   std::copy( | ||||||
|  |       t.begin(), t.end(), out.begin() + static_cast<int64_t>(range.first)); | ||||||
| } | } | ||||||
|  |  | ||||||
| Tensor copysign_tensor_self_backward( | Tensor copysign_tensor_self_backward( | ||||||
| @ -139,7 +140,7 @@ int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) { | |||||||
|     return 1; |     return 1; | ||||||
|   } |   } | ||||||
|   for (auto d : dim) { |   for (auto d : dim) { | ||||||
|     d = at::maybe_wrap_dim(d, sizes.size()); |     d = at::maybe_wrap_dim(d, static_cast<int64_t>(sizes.size())); | ||||||
|     size *= sizes[d]; |     size *= sizes[d]; | ||||||
|   } |   } | ||||||
|   return size; |   return size; | ||||||
| @ -151,7 +152,7 @@ static c10::SymInt _safe_size(c10::SymIntArrayRef sizes, c10::IntArrayRef dim) { | |||||||
|     return 1; |     return 1; | ||||||
|   } |   } | ||||||
|   for (auto d : dim) { |   for (auto d : dim) { | ||||||
|     d = at::maybe_wrap_dim(d, sizes.size()); |     d = at::maybe_wrap_dim(d, static_cast<int64_t>(sizes.size())); | ||||||
|     size *= sizes[d]; |     size *= sizes[d]; | ||||||
|   } |   } | ||||||
|   return size; |   return size; | ||||||
| @ -180,11 +181,11 @@ Tensor restore_reduced_dims( | |||||||
|   if (keepdim) { |   if (keepdim) { | ||||||
|     return output; |     return output; | ||||||
|   } |   } | ||||||
|   int64_t total_dims = output.dim() + dims.size(); |   auto total_dims = output.dim() + dims.size(); | ||||||
|   std::vector<c10::SymInt> target_shape(total_dims, 0); |   std::vector<c10::SymInt> target_shape(total_dims, 0); | ||||||
|   for (int64_t i : dims) { |   for (int64_t i : dims) { | ||||||
|     if (i < 0) { |     if (i < 0) { | ||||||
|       i = total_dims + i; |       i = static_cast<int64_t>(total_dims) + i; | ||||||
|     } |     } | ||||||
|     target_shape[i] = 1; |     target_shape[i] = 1; | ||||||
|   } |   } | ||||||
| @ -522,7 +523,7 @@ Tensor pow_backward_exponent( | |||||||
|     const Tensor& grad, |     const Tensor& grad, | ||||||
|     const Scalar& base, |     const Scalar& base, | ||||||
|     const Tensor& exponent, |     const Tensor& exponent, | ||||||
|     Tensor result) { |     const Tensor& result) { | ||||||
|   auto grad_lambda = [](const Tensor& a, const Scalar& b) { |   auto grad_lambda = [](const Tensor& a, const Scalar& b) { | ||||||
|     return (a * b.log()).conj(); |     return (a * b.log()).conj(); | ||||||
|   }; |   }; | ||||||
| @ -540,10 +541,10 @@ Tensor pow_backward_exponent( | |||||||
|     auto out = grad * |     auto out = grad * | ||||||
|         at::where(cond(exponent), |         at::where(cond(exponent), | ||||||
|                   at::zeros({}, grad.options()), |                   at::zeros({}, grad.options()), | ||||||
|                   grad_lambda(std::move(result), base_)); |                   grad_lambda(result, base_)); | ||||||
|     return handle_r_to_c(exponent, std::move(out)); |     return handle_r_to_c(exponent, std::move(out)); | ||||||
|   } else { |   } else { | ||||||
|     auto out = grad * grad_lambda(std::move(result), base_); |     auto out = grad * grad_lambda(result, base_); | ||||||
|     return handle_r_to_c(exponent, std::move(out)); |     return handle_r_to_c(exponent, std::move(out)); | ||||||
|   } |   } | ||||||
| } | } | ||||||
| @ -561,7 +562,8 @@ Tensor angle_backward(const Tensor& grad, const Tensor& self) { | |||||||
| } | } | ||||||
|  |  | ||||||
| Tensor mvlgamma_backward(const Tensor& grad, const Tensor& self, int64_t p) { | Tensor mvlgamma_backward(const Tensor& grad, const Tensor& self, int64_t p) { | ||||||
|   Tensor args = at::arange(-p / 2. + 0.5, 0.5, 0.5, self.options()); |   Tensor args = | ||||||
|  |       at::arange(-static_cast<double>(p) / 2. + 0.5, 0.5, 0.5, self.options()); | ||||||
|   args = args.add(self.unsqueeze(-1)); |   args = args.add(self.unsqueeze(-1)); | ||||||
|   return grad * args.digamma_().sum(-1); |   return grad * args.digamma_().sum(-1); | ||||||
| } | } | ||||||
| @ -585,16 +587,16 @@ Tensor masked_fill_backward(const Tensor& grad, const Tensor& mask) { | |||||||
| } | } | ||||||
|  |  | ||||||
| template <typename T> | template <typename T> | ||||||
| Tensor mul_tensor_backward(Tensor grad, T other, ScalarType self_st) { | Tensor mul_tensor_backward(const Tensor& grad, T other, ScalarType self_st) { | ||||||
|   auto out = grad * other.conj(); |   auto out = grad * other.conj(); | ||||||
|   return handle_r_to_c(self_st, std::move(out)); |   return handle_r_to_c(self_st, std::move(out)); | ||||||
| } | } | ||||||
| template Tensor mul_tensor_backward(Tensor, Tensor, ScalarType); | template Tensor mul_tensor_backward(const Tensor&, Tensor, ScalarType); | ||||||
| template Tensor mul_tensor_backward(Tensor, Scalar, ScalarType); | template Tensor mul_tensor_backward(const Tensor&, Scalar, ScalarType); | ||||||
|  |  | ||||||
| template <typename T> | template <typename T> | ||||||
| Tensor div_tensor_self_backward( | Tensor div_tensor_self_backward( | ||||||
|     Tensor grad, |     const Tensor& grad, | ||||||
|     T other, |     T other, | ||||||
|     ScalarType self_st, |     ScalarType self_st, | ||||||
|     const c10::optional<c10::string_view>& rounding_mode) { |     const c10::optional<c10::string_view>& rounding_mode) { | ||||||
| @ -606,40 +608,45 @@ Tensor div_tensor_self_backward( | |||||||
|   return handle_r_to_c(self_st, std::move(result)); |   return handle_r_to_c(self_st, std::move(result)); | ||||||
| } | } | ||||||
| template Tensor div_tensor_self_backward( | template Tensor div_tensor_self_backward( | ||||||
|     Tensor, |     const Tensor&, | ||||||
|     Tensor, |     Tensor, | ||||||
|     ScalarType, |     ScalarType, | ||||||
|     const c10::optional<c10::string_view>&); |     const c10::optional<c10::string_view>&); | ||||||
| template Tensor div_tensor_self_backward( | template Tensor div_tensor_self_backward( | ||||||
|     Tensor, |     const Tensor&, | ||||||
|     Scalar, |     Scalar, | ||||||
|     ScalarType, |     ScalarType, | ||||||
|     const c10::optional<c10::string_view>&); |     const c10::optional<c10::string_view>&); | ||||||
|  |  | ||||||
| template <typename T> | template <typename T> | ||||||
| Tensor div_tensor_self_backward(Tensor grad, T other, ScalarType self_st) { | Tensor div_tensor_self_backward( | ||||||
|  |     const Tensor& grad, | ||||||
|  |     T other, | ||||||
|  |     ScalarType self_st) { | ||||||
|   return div_tensor_self_backward( |   return div_tensor_self_backward( | ||||||
|       std::move(grad), std::move(other), self_st, c10::nullopt); |       grad, std::move(other), self_st, c10::nullopt); | ||||||
| } | } | ||||||
| template Tensor div_tensor_self_backward(Tensor, Tensor, ScalarType); | template Tensor div_tensor_self_backward(const Tensor&, Tensor, ScalarType); | ||||||
| template Tensor div_tensor_self_backward(Tensor, Scalar, ScalarType); | template Tensor div_tensor_self_backward(const Tensor&, Scalar, ScalarType); | ||||||
|  |  | ||||||
| Tensor div_tensor_other_backward( | Tensor div_tensor_other_backward( | ||||||
|     const Tensor& grad, |     const Tensor& grad, | ||||||
|     const Tensor& self, |     const Tensor& self, | ||||||
|     Tensor other, |     const Tensor& other, | ||||||
|     const c10::optional<c10::string_view>& rounding_mode) { |     const c10::optional<c10::string_view>& rounding_mode) { | ||||||
|   if (rounding_mode.has_value()) { |   if (rounding_mode.has_value()) { | ||||||
|     return at::zeros_like(grad, grad.options().dtype(other.scalar_type())); |     return at::zeros_like(grad, grad.options().dtype(other.scalar_type())); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   auto result = -grad * ((self / other) / other).conj(); |   auto result = -grad * ((self / other) / other).conj(); | ||||||
|   return handle_r_to_c(std::move(other), std::move(result)); |   return handle_r_to_c(other, std::move(result)); | ||||||
| } | } | ||||||
|  |  | ||||||
| Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other) { | Tensor div_tensor_other_backward( | ||||||
|   return div_tensor_other_backward( |     const Tensor& grad, | ||||||
|       std::move(grad), std::move(self), std::move(other), c10::nullopt); |     const Tensor& self, | ||||||
|  |     const Tensor& other) { | ||||||
|  |   return div_tensor_other_backward(grad, self, other, c10::nullopt); | ||||||
| } | } | ||||||
|  |  | ||||||
| Tensor permute_backwards(const Tensor& grad, IntArrayRef fwd_dims) { | Tensor permute_backwards(const Tensor& grad, IntArrayRef fwd_dims) { | ||||||
| @ -647,7 +654,8 @@ Tensor permute_backwards(const Tensor& grad, IntArrayRef fwd_dims) { | |||||||
|   auto ndims = fwd_dims.size(); |   auto ndims = fwd_dims.size(); | ||||||
|   std::vector<int64_t> dims(ndims); |   std::vector<int64_t> dims(ndims); | ||||||
|   for (const auto i : c10::irange(ndims)) { |   for (const auto i : c10::irange(ndims)) { | ||||||
|     dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i; |     dims[at::maybe_wrap_dim(fwd_dims[i], static_cast<int64_t>(ndims))] = | ||||||
|  |         static_cast<int64_t>(i); | ||||||
|   } |   } | ||||||
|   return grad.permute(dims); |   return grad.permute(dims); | ||||||
| } | } | ||||||
| @ -682,7 +690,7 @@ Tensor unsqueeze_multiple( | |||||||
|   Tensor res = t; |   Tensor res = t; | ||||||
|   for (const auto i : c10::irange(n_dims)) { |   for (const auto i : c10::irange(n_dims)) { | ||||||
|     if (dims_to_unsqueeze[i]) { |     if (dims_to_unsqueeze[i]) { | ||||||
|       res = res.unsqueeze(i); |       res = res.unsqueeze(static_cast<int64_t>(i)); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|   return res; |   return res; | ||||||
| @ -833,7 +841,7 @@ Tensor prod_backward( | |||||||
|   if (input.dim() == 0) { |   if (input.dim() == 0) { | ||||||
|     return grad; |     return grad; | ||||||
|   } |   } | ||||||
|   dim = at::maybe_wrap_dim(dim, input.sizes().size()); |   dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(input.sizes().size())); | ||||||
|   if (!keepdim) { |   if (!keepdim) { | ||||||
|     // `prod` reduces the dimension at `dim`, |     // `prod` reduces the dimension at `dim`, | ||||||
|     // so, unsqueeze `grad` and `result` at dim. |     // so, unsqueeze `grad` and `result` at dim. | ||||||
| @ -989,10 +997,10 @@ Tensor unbind_backward(const variable_list& grads, int64_t dim) { | |||||||
| Tensor unsqueeze_to(const Tensor& self, c10::SymIntArrayRef sym_sizes) { | Tensor unsqueeze_to(const Tensor& self, c10::SymIntArrayRef sym_sizes) { | ||||||
|   auto result = self; |   auto result = self; | ||||||
|  |  | ||||||
|   int64_t nDims = sym_sizes.size(); |   auto nDims = sym_sizes.size(); | ||||||
|   for (const auto dim : c10::irange(nDims)) { |   for (const auto dim : c10::irange(nDims)) { | ||||||
|     if (sym_sizes[dim] == 1) { |     if (sym_sizes[dim] == 1) { | ||||||
|       result = result.unsqueeze(dim); |       result = result.unsqueeze(static_cast<int64_t>(dim)); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|   return result; |   return result; | ||||||
| @ -1008,7 +1016,7 @@ Tensor unsqueeze_to( | |||||||
|   Tensor result = self; |   Tensor result = self; | ||||||
|   for (const auto d : c10::irange(ndim)) { |   for (const auto d : c10::irange(ndim)) { | ||||||
|     if (mask.test(d) && sym_sizes[d] == 1) { |     if (mask.test(d) && sym_sizes[d] == 1) { | ||||||
|       result = result.unsqueeze(d); |       result = result.unsqueeze(static_cast<int64_t>(d)); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|   return result; |   return result; | ||||||
| @ -1069,7 +1077,7 @@ std::vector<Tensor> stack_tensors_backward( | |||||||
|   } |   } | ||||||
|   bool grad_is_complex = grad.is_complex(); |   bool grad_is_complex = grad.is_complex(); | ||||||
|   for (const auto i : c10::irange(dtypes.size())) { |   for (const auto i : c10::irange(dtypes.size())) { | ||||||
|     auto gr = grad.select(dim, i); |     auto gr = grad.select(dim, static_cast<int64_t>(i)); | ||||||
|     if (grad_is_complex && !at::isComplexType(dtypes[i])) { |     if (grad_is_complex && !at::isComplexType(dtypes[i])) { | ||||||
|       gr = at::real(gr); |       gr = at::real(gr); | ||||||
|     } |     } | ||||||
| @ -1108,8 +1116,8 @@ std::vector<Tensor> block_diag_backward( | |||||||
|       continue; |       continue; | ||||||
|     } |     } | ||||||
|     // 0d case |     // 0d case | ||||||
|     auto dim0 = 1; |     int64_t dim0 = 1; | ||||||
|     auto dim1 = 1; |     int64_t dim1 = 1; | ||||||
|     // 2d case |     // 2d case | ||||||
|     if (shape.size() == 2) { |     if (shape.size() == 2) { | ||||||
|       dim0 = shape[0]; |       dim0 = shape[0]; | ||||||
| @ -1593,7 +1601,7 @@ Tensor renorm_jvp( | |||||||
|     int64_t dim, |     int64_t dim, | ||||||
|     const Scalar& maxnorm) { |     const Scalar& maxnorm) { | ||||||
|   auto self_sizes = self_p.sizes(); |   auto self_sizes = self_p.sizes(); | ||||||
|   dim = c10::maybe_wrap_dim(dim, self_sizes.size()); |   dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(self_sizes.size())); | ||||||
|  |  | ||||||
|   at::DimVector reduce_dims(self_sizes.size()); |   at::DimVector reduce_dims(self_sizes.size()); | ||||||
|   std::iota(reduce_dims.begin(), reduce_dims.end(), 0); |   std::iota(reduce_dims.begin(), reduce_dims.end(), 0); | ||||||
| @ -1643,7 +1651,7 @@ Tensor repeat_backward( | |||||||
|     return at::zeros_symint(input_shape, grad.options()); |     return at::zeros_symint(input_shape, grad.options()); | ||||||
|   } |   } | ||||||
|   const auto input_dims = input_shape.size(); |   const auto input_dims = input_shape.size(); | ||||||
|   int64_t num_unsqueezed = grad.dim() - input_dims; |   auto num_unsqueezed = grad.dim() - input_dims; | ||||||
|   for (const auto i : c10::irange(num_unsqueezed)) { |   for (const auto i : c10::irange(num_unsqueezed)) { | ||||||
|     (void)i; // Suppress unused variable warning |     (void)i; // Suppress unused variable warning | ||||||
|     grad = grad.sum(0, false); |     grad = grad.sum(0, false); | ||||||
| @ -1688,7 +1696,7 @@ Tensor repeat_backward( | |||||||
|     //                             [g2_4, g2_5]]]           [g2_4, g2_5]] |     //                             [g2_4, g2_5]]]           [g2_4, g2_5]] | ||||||
|     if (repeat != 1) { |     if (repeat != 1) { | ||||||
|       grad_size.push_back(repeat); |       grad_size.push_back(repeat); | ||||||
|       sum_dims.push_back(grad_size.size() - 1); |       sum_dims.push_back(static_cast<int64_t>(grad_size.size() - 1)); | ||||||
|     } |     } | ||||||
|     // Don't need to reshape gradient into (repeat, input_shape[dim]) (repeat == |     // Don't need to reshape gradient into (repeat, input_shape[dim]) (repeat == | ||||||
|     // 1) |     // 1) | ||||||
| @ -1805,7 +1813,6 @@ Tensor var_backward( | |||||||
|     grad = unsqueeze_multiple(grad, dim, self.sym_sizes().size()); |     grad = unsqueeze_multiple(grad, dim, self.sym_sizes().size()); | ||||||
|   } |   } | ||||||
|   const c10::SymFloat rnumel(_safe_size(self.sym_sizes(), dim)); |   const c10::SymFloat rnumel(_safe_size(self.sym_sizes(), dim)); | ||||||
|   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-narrowing-conversions) |  | ||||||
|   return (c10::SymFloat(2.0) / (rnumel - correction)) * grad * |   return (c10::SymFloat(2.0) / (rnumel - correction)) * grad * | ||||||
|       (self - self.mean(dim, /*keepdim=*/true)); |       (self - self.mean(dim, /*keepdim=*/true)); | ||||||
| } | } | ||||||
| @ -2065,7 +2072,7 @@ Tensor split_with_sizes_backward( | |||||||
|     int64_t dim, |     int64_t dim, | ||||||
|     c10::SymIntArrayRef sizes, |     c10::SymIntArrayRef sizes, | ||||||
|     const at::TensorOptions& options) { |     const at::TensorOptions& options) { | ||||||
|   dim = at::maybe_wrap_dim(dim, sizes.size()); |   dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(sizes.size())); | ||||||
|  |  | ||||||
|   // it's possible some of the grads are not defined (represents tensors of all |   // it's possible some of the grads are not defined (represents tensors of all | ||||||
|   // 0s). Since at::cat can't handle those, let's define them |   // 0s). Since at::cat can't handle those, let's define them | ||||||
| @ -2099,9 +2106,9 @@ Tensor split_backward( | |||||||
|     int64_t dim, |     int64_t dim, | ||||||
|     c10::SymIntArrayRef sym_sizes, |     c10::SymIntArrayRef sym_sizes, | ||||||
|     const at::TensorOptions& options) { |     const at::TensorOptions& options) { | ||||||
|   dim = at::maybe_wrap_dim(dim, sym_sizes.size()); |   dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(sym_sizes.size())); | ||||||
|   const auto& dim_size = sym_sizes[dim]; |   const auto& dim_size = sym_sizes[dim]; | ||||||
|   int64_t num_splits = grads.size(); |   auto num_splits = grads.size(); | ||||||
|   std::vector<c10::SymInt> split_sizes(num_splits, split_size); |   std::vector<c10::SymInt> split_sizes(num_splits, split_size); | ||||||
|   split_sizes[num_splits - 1] = |   split_sizes[num_splits - 1] = | ||||||
|       split_size - (split_size * num_splits - dim_size); |       split_size - (split_size * num_splits - dim_size); | ||||||
| @ -3016,7 +3023,7 @@ static inline c10::SymInt _min_storage_size( | |||||||
|     c10::SymIntArrayRef strides, |     c10::SymIntArrayRef strides, | ||||||
|     c10::SymInt storage_offset) { |     c10::SymInt storage_offset) { | ||||||
|   c10::SymInt storage_size = storage_offset + 1; |   c10::SymInt storage_size = storage_offset + 1; | ||||||
|   int64_t dim = sizes.size(); |   auto dim = sizes.size(); | ||||||
|   for (const auto i : c10::irange(dim)) { |   for (const auto i : c10::irange(dim)) { | ||||||
|     const auto& size_i = sizes[i]; |     const auto& size_i = sizes[i]; | ||||||
|     if (size_i == 0) { |     if (size_i == 0) { | ||||||
| @ -3833,9 +3840,7 @@ std::tuple<Tensor, Tensor> linalg_qr_jvp( | |||||||
|   // trilImInv(trilIm(Q^H A_1 R_1^{-1})) |   // trilImInv(trilIm(Q^H A_1 R_1^{-1})) | ||||||
|   at::NoTF32Guard disable_tf32; |   at::NoTF32Guard disable_tf32; | ||||||
|  |  | ||||||
|   // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |   auto [compute_q, reduced] = at::native::_parse_qr_mode(mode); | ||||||
|   bool compute_q, reduced; |  | ||||||
|   std::tie(compute_q, reduced) = at::native::_parse_qr_mode(mode); |  | ||||||
|  |  | ||||||
|   TORCH_CHECK( |   TORCH_CHECK( | ||||||
|       compute_q, |       compute_q, | ||||||
| @ -3929,9 +3934,7 @@ Tensor linalg_qr_backward( | |||||||
|   // gA = QgR + pi*(Q trilImInv*(Q^H gQ - gR R^H)R_1^{-H}) |   // gA = QgR + pi*(Q trilImInv*(Q^H gQ - gR R^H)R_1^{-H}) | ||||||
|   at::NoTF32Guard disable_tf32; |   at::NoTF32Guard disable_tf32; | ||||||
|  |  | ||||||
|   // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |   auto [compute_q, reduced] = at::native::_parse_qr_mode(mode); | ||||||
|   bool compute_q, reduced; |  | ||||||
|   std::tie(compute_q, reduced) = at::native::_parse_qr_mode(mode); |  | ||||||
|  |  | ||||||
|   TORCH_CHECK( |   TORCH_CHECK( | ||||||
|       compute_q, |       compute_q, | ||||||
| @ -4150,8 +4153,7 @@ Tensor linalg_det_backward( | |||||||
|     auto singular = [](const Tensor& A, |     auto singular = [](const Tensor& A, | ||||||
|                        const Tensor& /*d*/, |                        const Tensor& /*d*/, | ||||||
|                        const Tensor& grad) { |                        const Tensor& grad) { | ||||||
|       Tensor U, S, Vh; |       auto [U, S, Vh] = at::linalg_svd(A); | ||||||
|       std::tie(U, S, Vh) = at::linalg_svd(A); |  | ||||||
|       auto alpha = (at::linalg_det(U) * at::linalg_det(Vh)).conj() * grad; |       auto alpha = (at::linalg_det(U) * at::linalg_det(Vh)).conj() * grad; | ||||||
|       auto D = prod_safe_zeros_backward(alpha.unsqueeze(-1), S, S.dim() - 1); |       auto D = prod_safe_zeros_backward(alpha.unsqueeze(-1), S, S.dim() - 1); | ||||||
|       return (U * D.unsqueeze(-2)).matmul(Vh); |       return (U * D.unsqueeze(-2)).matmul(Vh); | ||||||
| @ -4500,7 +4502,8 @@ Tensor fft_r2c_backward( | |||||||
|   //     3. discard the complex dim |   //     3. discard the complex dim | ||||||
|   auto half_sizes = grad.sym_sizes(); |   auto half_sizes = grad.sym_sizes(); | ||||||
|   std::vector<c10::SymInt> new_grad_shape(half_sizes.begin(), half_sizes.end()); |   std::vector<c10::SymInt> new_grad_shape(half_sizes.begin(), half_sizes.end()); | ||||||
|   const auto last_dim = at::maybe_wrap_dim(dim.back(), half_sizes.size()); |   const auto last_dim = | ||||||
|  |       at::maybe_wrap_dim(dim.back(), static_cast<int64_t>(half_sizes.size())); | ||||||
|   new_grad_shape[last_dim] = last_dim_size; |   new_grad_shape[last_dim] = last_dim_size; | ||||||
|  |  | ||||||
|   const auto zero_length = last_dim_size - grad.sym_size(dim.back()); |   const auto zero_length = last_dim_size - grad.sym_size(dim.back()); | ||||||
| @ -4607,10 +4610,10 @@ std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward( | |||||||
|   if (ggI.defined() && training) { |   if (ggI.defined() && training) { | ||||||
|     auto ggI_sum = sum_exclude_dim1(ggI); |     auto ggI_sum = sum_exclude_dim1(ggI); | ||||||
|     auto ggIinmu_sum = sum_exclude_dim1(ggI * input_sub_mu); |     auto ggIinmu_sum = sum_exclude_dim1(ggI * input_sub_mu); | ||||||
|     auto all_sub = |     auto all_sub = ((ggI_sum * gO_sum).div_(M)) | ||||||
|         ((ggI_sum * gO_sum).div_(M)) |                        .sub_(sum_exclude_dim1(gO * ggI)) | ||||||
|             .sub_(sum_exclude_dim1(gO * ggI)) |                        .add_((sigma2_eps_neg_1 * gOinmu_sum * ggIinmu_sum) | ||||||
|             .add_((sigma2_eps_neg_1 * gOinmu_sum * ggIinmu_sum).mul_(3. / M)); |                                  .mul_(3. / static_cast<double>(M))); | ||||||
|     auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(M); |     auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(M); | ||||||
|     auto gI_1t = |     auto gI_1t = | ||||||
|         (ggIinmu_sum * sigma2_eps_neg_3_2).div_(M) * (gO_sum.div(M) - gO); |         (ggIinmu_sum * sigma2_eps_neg_3_2).div_(M) * (gO_sum.div(M) - gO); | ||||||
| @ -4698,11 +4701,10 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward( | |||||||
|     const Tensor& save_invstd_t, |     const Tensor& save_invstd_t, | ||||||
|     c10::SymIntArrayRef normalized_shape, |     c10::SymIntArrayRef normalized_shape, | ||||||
|     std::array<bool, 3> output_mask) { |     std::array<bool, 3> output_mask) { | ||||||
|   const int normalized_ndim = normalized_shape.size(); |   const auto normalized_ndim = normalized_shape.size(); | ||||||
|   const auto input_shape = input_t.sizes(); |   const auto input_shape = input_t.sizes(); | ||||||
|   const auto input_ndim = input_t.dim(); |   const auto input_ndim = input_t.dim(); | ||||||
|   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) |   const auto axis = input_ndim - normalized_ndim; | ||||||
|   const int axis = input_ndim - normalized_ndim; |  | ||||||
|   const int64_t M = |   const int64_t M = | ||||||
|       c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis); |       c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis); | ||||||
|   const int64_t N = |   const int64_t N = | ||||||
| @ -4754,10 +4756,10 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward( | |||||||
|     auto ggI_sum = ggI_expanded.sum(1, true); |     auto ggI_sum = ggI_expanded.sum(1, true); | ||||||
|     auto ggI_mu_sum = (ggI_expanded * input_sub_mu).sum(1, true); |     auto ggI_mu_sum = (ggI_expanded * input_sub_mu).sum(1, true); | ||||||
|  |  | ||||||
|     auto all_sub = |     auto all_sub = ((ggI_sum * gxhat_sum).div_(N)) | ||||||
|         ((ggI_sum * gxhat_sum).div_(N)) |                        .sub_((ggI_expanded * gxhat).sum(1, true)) | ||||||
|             .sub_((ggI_expanded * gxhat).sum(1, true)) |                        .add_((sigma2_eps_neg_1 * gxhat_mu_sum * ggI_mu_sum) | ||||||
|             .add_((sigma2_eps_neg_1 * gxhat_mu_sum * ggI_mu_sum).mul_(3. / N)); |                                  .mul_(3. / static_cast<double>(N))); | ||||||
|     auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(N); |     auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(N); | ||||||
|     auto gI_1t = |     auto gI_1t = | ||||||
|         (ggI_mu_sum * sigma2_eps_neg_3_2).div_(N) * (gxhat_sum.div(N) - gxhat); |         (ggI_mu_sum * sigma2_eps_neg_3_2).div_(N) * (gxhat_sum.div(N) - gxhat); | ||||||
| @ -4982,11 +4984,11 @@ Tensor sinc_backward(const Tensor& grad, const Tensor& self) { | |||||||
| // in pads]) | // in pads]) | ||||||
| Tensor constant_pad_nd_backward(const Tensor& grad, c10::SymIntArrayRef pad) { | Tensor constant_pad_nd_backward(const Tensor& grad, c10::SymIntArrayRef pad) { | ||||||
|   auto negated_pad = pad.vec(); |   auto negated_pad = pad.vec(); | ||||||
|   // NOLINTNEXTLINE(modernize-use-transparent-functors) |  | ||||||
|   std::transform( |   std::transform( | ||||||
|       negated_pad.cbegin(), |       negated_pad.cbegin(), | ||||||
|       negated_pad.cend(), |       negated_pad.cend(), | ||||||
|       negated_pad.begin(), |       negated_pad.begin(), | ||||||
|  |       // NOLINTNEXTLINE(modernize-use-transparent-functors) | ||||||
|       std::negate<c10::SymInt>()); |       std::negate<c10::SymInt>()); | ||||||
|   return at::constant_pad_nd_symint(grad, negated_pad, 0); |   return at::constant_pad_nd_symint(grad, negated_pad, 0); | ||||||
| } | } | ||||||
| @ -5265,10 +5267,7 @@ std::tuple<Tensor, Tensor> householder_product_backward( | |||||||
|       auto v_i = input.narrow(-1, i, 1); |       auto v_i = input.narrow(-1, i, 1); | ||||||
|       auto t_i = tau.narrow(-1, i, 1); |       auto t_i = tau.narrow(-1, i, 1); | ||||||
|  |  | ||||||
|       Tensor v_i_grad, tau_i_grad; |       std::tie(input_grads[i], tau_grads[i]) = update_grad(i, v_i, t_i, K); | ||||||
|       std::tie(v_i_grad, tau_i_grad) = update_grad(i, v_i, t_i, K); |  | ||||||
|       input_grads[i] = v_i_grad; |  | ||||||
|       tau_grads[i] = tau_i_grad; |  | ||||||
|  |  | ||||||
|       // K <- H_{i + 1}^{-1} @ K @ H_i |       // K <- H_{i + 1}^{-1} @ K @ H_i | ||||||
|       if (i != flip_i(k - 1)) { |       if (i != flip_i(k - 1)) { | ||||||
| @ -5302,8 +5301,7 @@ std::tuple<Tensor, Tensor> householder_product_backward( | |||||||
|       auto v_i = input.narrow(-1, i, 1); |       auto v_i = input.narrow(-1, i, 1); | ||||||
|       auto t_i = tau.narrow(-1, i, 1); |       auto t_i = tau.narrow(-1, i, 1); | ||||||
|  |  | ||||||
|       Tensor v_i_grad, tau_i_grad; |       auto [v_i_grad, tau_i_grad] = update_grad(i, v_i, t_i, K); | ||||||
|       std::tie(v_i_grad, tau_i_grad) = update_grad(i, v_i, t_i, K); |  | ||||||
|       input_grad.select(-1, i).copy_(v_i_grad.squeeze(-1)); |       input_grad.select(-1, i).copy_(v_i_grad.squeeze(-1)); | ||||||
|       tau_grad.select(-1, i).copy_(tau_i_grad.squeeze(-1)); |       tau_grad.select(-1, i).copy_(tau_i_grad.squeeze(-1)); | ||||||
|  |  | ||||||
| @ -5626,8 +5624,7 @@ Tensor linalg_lu_solve_LU( | |||||||
|   // gLU = gL + gU |   // gLU = gL + gU | ||||||
|  |  | ||||||
|   at::NoTF32Guard disable_tf32; |   at::NoTF32Guard disable_tf32; | ||||||
|   Tensor P, L, U; |   auto [P, L, U] = at::lu_unpack( | ||||||
|   std::tie(P, L, U) = at::lu_unpack( |  | ||||||
|       LU, pivots, /*unpack_data=*/true, /*unpack_pivots=*/left == adjoint); |       LU, pivots, /*unpack_data=*/true, /*unpack_pivots=*/left == adjoint); | ||||||
|   // TODO Optimise the order of the operations to avoid operating on large |   // TODO Optimise the order of the operations to avoid operating on large | ||||||
|   // tensors unnecessarily |   // tensors unnecessarily | ||||||
| @ -5712,8 +5709,7 @@ Tensor linalg_lu_solve_jvp( | |||||||
|     // op_3(B)^H A^{-1} = op_3(X)^H |     // op_3(B)^H A^{-1} = op_3(X)^H | ||||||
|     // We can then rewrite the formula above in terms of X as |     // We can then rewrite the formula above in terms of X as | ||||||
|     // dX = op_2(op_1(-op_3(X)^H P(LdUU^{-1} + dL)L^{-1} P^T)) + S |     // dX = op_2(op_1(-op_3(X)^H P(LdUU^{-1} + dL)L^{-1} P^T)) + S | ||||||
|     Tensor P, L, U; |     auto [P, L, U] = at::lu_unpack(LU, pivots); | ||||||
|     std::tie(P, L, U) = at::lu_unpack(LU, pivots); |  | ||||||
|     // Compute V = op_3(X)^H |     // Compute V = op_3(X)^H | ||||||
|     auto V = left ? X.mH() : X; |     auto V = left ? X.mH() : X; | ||||||
|     // Compute the inner parens LdUU^{-1} + dL |     // Compute the inner parens LdUU^{-1} + dL | ||||||
| @ -6103,9 +6099,9 @@ Tensor batch_norm_jvp( | |||||||
|   int64_t numel = 1; |   int64_t numel = 1; | ||||||
|   for (const auto dim : c10::irange(view_size.size())) { |   for (const auto dim : c10::irange(view_size.size())) { | ||||||
|     if (dim != 1) { |     if (dim != 1) { | ||||||
|       numel *= input_t.size(dim); |       numel *= input_t.size(static_cast<int64_t>(dim)); | ||||||
|       view_size[dim] = 1; |       view_size[dim] = 1; | ||||||
|       dims.push_back(dim); |       dims.push_back(static_cast<int64_t>(dim)); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|   Tensor mean_p; |   Tensor mean_p; | ||||||
| @ -6159,9 +6155,9 @@ Tensor layer_norm_jvp( | |||||||
|     if (i < view_size.size() - normalized_shape.size()) { |     if (i < view_size.size() - normalized_shape.size()) { | ||||||
|       view_size_affine[i] = 1; |       view_size_affine[i] = 1; | ||||||
|     } else { |     } else { | ||||||
|       numel *= input_t.size(i); |       numel *= input_t.size(static_cast<int64_t>(i)); | ||||||
|       view_size[i] = 1; |       view_size[i] = 1; | ||||||
|       dims.push_back(i); |       dims.push_back(static_cast<int64_t>(i)); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|   auto mean_p = saved_mean.view(view_size); |   auto mean_p = saved_mean.view(view_size); | ||||||
| @ -6478,8 +6474,7 @@ Tensor lu_factor_ex_backward( | |||||||
|     const Tensor& LU, |     const Tensor& LU, | ||||||
|     const Tensor& pivs, |     const Tensor& pivs, | ||||||
|     const bool pivot) { |     const bool pivot) { | ||||||
|   Tensor P, L, U; |   auto [P, L, U] = | ||||||
|   std::tie(P, L, U) = |  | ||||||
|       at::lu_unpack(LU, pivs, /*unpack_data=*/true, /*unpack_pivots*/ pivot); |       at::lu_unpack(LU, pivs, /*unpack_data=*/true, /*unpack_pivots*/ pivot); | ||||||
|  |  | ||||||
|   // L.shape == (..., m, k) |   // L.shape == (..., m, k) | ||||||
| @ -6565,13 +6560,9 @@ Tensor lu_factor_ex_jvp( | |||||||
|     const Tensor& LU, |     const Tensor& LU, | ||||||
|     const Tensor& pivs, |     const Tensor& pivs, | ||||||
|     const bool pivot) { |     const bool pivot) { | ||||||
|   Tensor dL, dU; |   auto [P, L, U] = | ||||||
|   { |       at::lu_unpack(LU, pivs, /*unpack_data=*/true, /*unpack_pivots=*/pivot); | ||||||
|     Tensor P, L, U; |   auto [dL, dU] = linalg_lu_jvp(dA, P, L, U, pivot); | ||||||
|     std::tie(P, L, U) = |  | ||||||
|         at::lu_unpack(LU, pivs, /*unpack_data=*/true, /*unpack_pivots=*/pivot); |  | ||||||
|     std::tie(dL, dU) = linalg_lu_jvp(dA, P, L, U, pivot); |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   auto m = dA.size(-2); |   auto m = dA.size(-2); | ||||||
|   auto n = dA.size(-1); |   auto n = dA.size(-1); | ||||||
| @ -6975,10 +6966,10 @@ mkldnn_rnn_layer_differentiable_backward( | |||||||
|       seq_length); |       seq_length); | ||||||
|   std::vector<std::tuple<Tensor, Tensor>> layer_states(seq_length + 1); |   std::vector<std::tuple<Tensor, Tensor>> layer_states(seq_length + 1); | ||||||
|   layer_states[0] = std::make_tuple(hx_, cx_tmp); |   layer_states[0] = std::make_tuple(hx_, cx_tmp); | ||||||
|   for (int seq = 1; seq < seq_length + 1; seq++) { |   for (int64_t seq = 1; seq < seq_length + 1; seq++) { | ||||||
|     auto hx = hx_prev; |     auto hx = hx_prev; | ||||||
|     auto cx = cx_prev; |     auto cx = cx_prev; | ||||||
|     int x_index = reverse ? seq_length - seq : seq - 1; |     auto x_index = reverse ? seq_length - seq : seq - 1; | ||||||
|     auto gate = at::linear(input_[x_index], weight0, bias_ih) |     auto gate = at::linear(input_[x_index], weight0, bias_ih) | ||||||
|                     .add_(at::linear(hx, weight1, bias_hh)); |                     .add_(at::linear(hx, weight1, bias_hh)); | ||||||
|     auto chunked_gates = gate.unsafe_chunk(4, 1); |     auto chunked_gates = gate.unsafe_chunk(4, 1); | ||||||
| @ -6997,8 +6988,8 @@ mkldnn_rnn_layer_differentiable_backward( | |||||||
|   Tensor dx, dWx, dWh, db, db_, dprev_h, dprev_c, dWh_, dWx_; |   Tensor dx, dWx, dWh, db, db_, dprev_h, dprev_c, dWh_, dWx_; | ||||||
|   Tensor new_grad_hy, d1, dgp, dip, dfp, dop, do_, dg, df, di, da; |   Tensor new_grad_hy, d1, dgp, dip, dfp, dop, do_, dg, df, di, da; | ||||||
|   std::vector<at::Tensor> layer_dx(seq_length); |   std::vector<at::Tensor> layer_dx(seq_length); | ||||||
|   for (int seq = seq_length - 1; seq >= 0; seq--) { |   for (int64_t seq = seq_length - 1; seq >= 0; seq--) { | ||||||
|     int x_index = reverse ? seq_length - seq - 1 : seq; |     int64_t x_index = reverse ? seq_length - seq - 1 : seq; | ||||||
|     auto i = std::get<0>(layer_gates[x_index]); |     auto i = std::get<0>(layer_gates[x_index]); | ||||||
|     auto f = std::get<1>(layer_gates[x_index]); |     auto f = std::get<1>(layer_gates[x_index]); | ||||||
|     auto g = std::get<2>(layer_gates[x_index]); |     auto g = std::get<2>(layer_gates[x_index]); | ||||||
|  | |||||||
| @ -139,23 +139,29 @@ at::Tensor pow_backward_exponent( | |||||||
|     const at::Tensor& grad, |     const at::Tensor& grad, | ||||||
|     const at::Scalar& base, |     const at::Scalar& base, | ||||||
|     const at::Tensor& exponent, |     const at::Tensor& exponent, | ||||||
|     at::Tensor result); |     const at::Tensor& result); | ||||||
| at::Tensor angle_backward(const at::Tensor& grad, const at::Tensor& self); | at::Tensor angle_backward(const at::Tensor& grad, const at::Tensor& self); | ||||||
| template <typename T> | template <typename T> | ||||||
| at::Tensor mul_tensor_backward(Tensor grad, T other, ScalarType self_st); | at::Tensor mul_tensor_backward(const Tensor& grad, T other, ScalarType self_st); | ||||||
| template <typename T> |  | ||||||
| at::Tensor div_tensor_self_backward(Tensor grad, T other, ScalarType self_st); |  | ||||||
| at::Tensor div_tensor_other_backward(Tensor grad, Tensor self, Tensor other); |  | ||||||
| template <typename T> | template <typename T> | ||||||
| at::Tensor div_tensor_self_backward( | at::Tensor div_tensor_self_backward( | ||||||
|     Tensor grad, |     const Tensor& grad, | ||||||
|  |     T other, | ||||||
|  |     ScalarType self_st); | ||||||
|  | at::Tensor div_tensor_other_backward( | ||||||
|  |     const Tensor& grad, | ||||||
|  |     const Tensor& self, | ||||||
|  |     const Tensor& other); | ||||||
|  | template <typename T> | ||||||
|  | at::Tensor div_tensor_self_backward( | ||||||
|  |     const Tensor& grad, | ||||||
|     T other, |     T other, | ||||||
|     ScalarType self_st, |     ScalarType self_st, | ||||||
|     const c10::optional<c10::string_view>& rounding_mode); |     const c10::optional<c10::string_view>& rounding_mode); | ||||||
| at::Tensor div_tensor_other_backward( | at::Tensor div_tensor_other_backward( | ||||||
|     const Tensor& grad, |     const Tensor& grad, | ||||||
|     const Tensor& self, |     const Tensor& self, | ||||||
|     Tensor other, |     const Tensor& other, | ||||||
|     const c10::optional<c10::string_view>& rounding_mode); |     const c10::optional<c10::string_view>& rounding_mode); | ||||||
| at::Tensor mvlgamma_backward( | at::Tensor mvlgamma_backward( | ||||||
|     const at::Tensor& grad, |     const at::Tensor& grad, | ||||||
|  | |||||||
| @ -290,17 +290,14 @@ const Variable& AutogradMeta::fw_grad( | |||||||
|         static_cast<const torch::autograd::DifferentiableViewMeta*>(this); |         static_cast<const torch::autograd::DifferentiableViewMeta*>(this); | ||||||
|     // This is ok to do as we ONLY modify fw_grad_ and this field is properly |     // This is ok to do as we ONLY modify fw_grad_ and this field is properly | ||||||
|     // locked in all methods |     // locked in all methods | ||||||
|     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |     if (const_view_meta->has_fw_view()) { | ||||||
|     auto this_view_meta = |       const auto& view_info = const_view_meta->get_forward_view(); | ||||||
|         const_cast<torch::autograd::DifferentiableViewMeta*>(const_view_meta); |  | ||||||
|     if (this_view_meta->has_fw_view()) { |  | ||||||
|       const auto& view_info = this_view_meta->get_forward_view(); |  | ||||||
|       const auto& base = view_info.base_; |       const auto& base = view_info.base_; | ||||||
|  |  | ||||||
|       const auto& base_val = base._fw_grad(level); |       const auto& base_val = base._fw_grad(level); | ||||||
|       if (base_val.defined()) { |       if (base_val.defined()) { | ||||||
|         // Lazy initialization of fw_grad_ |         // Lazy initialization of fw_grad_ | ||||||
|         this_view_meta->fw_grad_ = std::make_shared<ForwardGrad>(); |         const_view_meta->fw_grad_ = std::make_shared<ForwardGrad>(); | ||||||
|  |  | ||||||
|         Variable new_val; |         Variable new_val; | ||||||
|         if (view_info.has_view_fn()) { |         if (view_info.has_view_fn()) { | ||||||
| @ -310,8 +307,8 @@ const Variable& AutogradMeta::fw_grad( | |||||||
|               self.sizes(), self.strides(), self.storage_offset()); |               self.sizes(), self.strides(), self.storage_offset()); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         this_view_meta->fw_grad_->set_value(new_val, level); |         const_view_meta->fw_grad_->set_value(new_val, level); | ||||||
|         return this_view_meta->fw_grad_->value(level); |         return const_view_meta->fw_grad_->value(level); | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  | |||||||
| @ -24,11 +24,6 @@ | |||||||
| #include <utility> | #include <utility> | ||||||
| #include <vector> | #include <vector> | ||||||
|  |  | ||||||
| C10_CLANG_DIAGNOSTIC_PUSH() |  | ||||||
| #if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32") |  | ||||||
| C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32") |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| namespace torch { | namespace torch { | ||||||
| namespace autograd { | namespace autograd { | ||||||
|  |  | ||||||
| @ -177,9 +172,9 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> { | |||||||
|             name(), |             name(), | ||||||
|             c10::ArrayRef<const c10::IValue>( |             c10::ArrayRef<const c10::IValue>( | ||||||
|                 inputs_vec.data(), inputs_vec.size()), |                 inputs_vec.data(), inputs_vec.size()), | ||||||
|             sequence_nr()); |             static_cast<int64_t>(sequence_nr())); | ||||||
|       } else { |       } else { | ||||||
|         guard.before(name(), sequence_nr()); |         guard.before(name(), static_cast<int64_t>(sequence_nr())); | ||||||
|       } |       } | ||||||
|       return apply(std::move(inputs)); |       return apply(std::move(inputs)); | ||||||
|     } else { |     } else { | ||||||
| @ -203,7 +198,6 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> { | |||||||
|       c10::SymIntArrayRef shape, |       c10::SymIntArrayRef shape, | ||||||
|       bool is_tensor_subclass, |       bool is_tensor_subclass, | ||||||
|       bool is_nested) noexcept { |       bool is_nested) noexcept { | ||||||
|     // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |  | ||||||
|     uint32_t input_nr = input_metadata_.size(); |     uint32_t input_nr = input_metadata_.size(); | ||||||
|     auto meta_shape = MetadataShape{c10::in_place_type<SymIntSmallVec>, shape}; |     auto meta_shape = MetadataShape{c10::in_place_type<SymIntSmallVec>, shape}; | ||||||
|     input_metadata_.emplace_back( |     input_metadata_.emplace_back( | ||||||
| @ -212,7 +206,6 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> { | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   uint32_t add_input_metadata(const at::Tensor& t) noexcept { |   uint32_t add_input_metadata(const at::Tensor& t) noexcept { | ||||||
|     // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |  | ||||||
|     uint32_t input_nr = input_metadata_.size(); |     uint32_t input_nr = input_metadata_.size(); | ||||||
|     input_metadata_.emplace_back(t); |     input_metadata_.emplace_back(t); | ||||||
|     return input_nr; |     return input_nr; | ||||||
| @ -220,7 +213,6 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> { | |||||||
|  |  | ||||||
|   /// Adds a placeholder for an input that will not be used. |   /// Adds a placeholder for an input that will not be used. | ||||||
|   uint32_t add_input_metadata(undefined_input u) noexcept { |   uint32_t add_input_metadata(undefined_input u) noexcept { | ||||||
|     // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |  | ||||||
|     uint32_t input_nr = input_metadata_.size(); |     uint32_t input_nr = input_metadata_.size(); | ||||||
|     input_metadata_.emplace_back(); |     input_metadata_.emplace_back(); | ||||||
|     return input_nr; |     return input_nr; | ||||||
| @ -597,21 +589,18 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> { | |||||||
|  |  | ||||||
|   // Sequence number used to correlate backward nodes with forward ops in the |   // Sequence number used to correlate backward nodes with forward ops in the | ||||||
|   // profiler and provide determinism in the engine. |   // profiler and provide determinism in the engine. | ||||||
|   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) | ||||||
|   const uint64_t sequence_nr_; |   const uint64_t sequence_nr_; | ||||||
|  |  | ||||||
|   // See NOTE [ Topological Number ] |   // See NOTE [ Topological Number ] | ||||||
|   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |  | ||||||
|   uint64_t topological_nr_ = 0; |   uint64_t topological_nr_ = 0; | ||||||
|  |  | ||||||
|   // Tracks whether this node has been added as the next_edge of another node |   // Tracks whether this node has been added as the next_edge of another node | ||||||
|   // via set_next_edge(s), which always calls topological_nr() of all its |   // via set_next_edge(s), which always calls topological_nr() of all its | ||||||
|   // children See NOTE [ Topological Number ] for why we need this. |   // children See NOTE [ Topological Number ] for why we need this. | ||||||
|   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |  | ||||||
|   mutable bool has_parent_ = false; |   mutable bool has_parent_ = false; | ||||||
|  |  | ||||||
|   // Id of the thread that created the instance |   // Id of the thread that created the instance | ||||||
|   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |  | ||||||
|   uint64_t thread_id_ = 0; |   uint64_t thread_id_ = 0; | ||||||
|  |  | ||||||
|   // Note [Thread Safety on Autograd Node] |   // Note [Thread Safety on Autograd Node] | ||||||
| @ -656,14 +645,10 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> { | |||||||
|   // hooks are automatically thread safe), we rely on the user to write thread |   // hooks are automatically thread safe), we rely on the user to write thread | ||||||
|   // safe C++ hooks if they want the hook to be correctly applied in |   // safe C++ hooks if they want the hook to be correctly applied in | ||||||
|   // multithreading environment. |   // multithreading environment. | ||||||
|   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |  | ||||||
|   std::mutex mutex_; |   std::mutex mutex_; | ||||||
|  |  | ||||||
|   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |  | ||||||
|   edge_list next_edges_; |   edge_list next_edges_; | ||||||
|   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |  | ||||||
|   PyObject* pyobj_ = nullptr; // weak reference |   PyObject* pyobj_ = nullptr; // weak reference | ||||||
|   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |  | ||||||
|   std::unique_ptr<AnomalyMetadata> anomaly_metadata_ = nullptr; |   std::unique_ptr<AnomalyMetadata> anomaly_metadata_ = nullptr; | ||||||
|  |  | ||||||
|   // NOTE [Hooks ordering] |   // NOTE [Hooks ordering] | ||||||
| @ -676,15 +661,11 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> { | |||||||
|   //   even if that node won't be executed. |   //   even if that node won't be executed. | ||||||
|   // - retains_grad_hook are like tensor_pre_hooks except they are always |   // - retains_grad_hook are like tensor_pre_hooks except they are always | ||||||
|   //   ordered after all other tensor pre hooks |   //   ordered after all other tensor pre hooks | ||||||
|   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |  | ||||||
|   std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_; |   std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_; | ||||||
|   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |  | ||||||
|   std::vector<std::unique_ptr<FunctionPreHook>> tensor_pre_hooks_; |   std::vector<std::unique_ptr<FunctionPreHook>> tensor_pre_hooks_; | ||||||
|   std::unordered_map<size_t, std::unique_ptr<FunctionPreHook>> |   std::unordered_map<size_t, std::unique_ptr<FunctionPreHook>> | ||||||
|       retains_grad_hooks_; |       retains_grad_hooks_; | ||||||
|   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |  | ||||||
|   std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_; |   std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_; | ||||||
|   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |  | ||||||
|   at::SmallVector<InputMetadata, 2> input_metadata_; |   at::SmallVector<InputMetadata, 2> input_metadata_; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| @ -706,7 +687,6 @@ struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> { | |||||||
|   edge_list next_edges; |   edge_list next_edges; | ||||||
|   using IterArgs<MakeNextFunctionList>::operator(); |   using IterArgs<MakeNextFunctionList>::operator(); | ||||||
|   void operator()(const Variable& variable) { |   void operator()(const Variable& variable) { | ||||||
|     // NOLINTNEXTLINE(bugprone-branch-clone) |  | ||||||
|     if (variable.defined()) { |     if (variable.defined()) { | ||||||
|       next_edges.emplace_back(impl::gradient_edge(variable)); |       next_edges.emplace_back(impl::gradient_edge(variable)); | ||||||
|     } else { |     } else { | ||||||
| @ -714,17 +694,11 @@ struct MakeNextFunctionList : IterArgs<MakeNextFunctionList> { | |||||||
|     } |     } | ||||||
|   } |   } | ||||||
|   void operator()(const Variable* variable) { |   void operator()(const Variable* variable) { | ||||||
|     // NOLINTNEXTLINE(bugprone-branch-clone) |     operator()(*variable); | ||||||
|     if (variable->defined()) { |  | ||||||
|       next_edges.emplace_back(impl::gradient_edge(*variable)); |  | ||||||
|     } else { |  | ||||||
|       next_edges.emplace_back(); |  | ||||||
|     } |  | ||||||
|   } |   } | ||||||
|   void operator()(const c10::optional<Variable>& variable) { |   void operator()(const c10::optional<Variable>& variable) { | ||||||
|     // NOLINTNEXTLINE(bugprone-branch-clone) |     if (variable.has_value()) { | ||||||
|     if (variable.has_value() && variable->defined()) { |       operator()(*variable); | ||||||
|       next_edges.emplace_back(impl::gradient_edge(*variable)); |  | ||||||
|     } else { |     } else { | ||||||
|       next_edges.emplace_back(); |       next_edges.emplace_back(); | ||||||
|     } |     } | ||||||
| @ -783,5 +757,3 @@ struct TypeAndSize { | |||||||
|  |  | ||||||
| } // namespace autograd | } // namespace autograd | ||||||
| } // namespace torch | } // namespace torch | ||||||
|  |  | ||||||
| C10_CLANG_DIAGNOSTIC_POP() |  | ||||||
|  | |||||||
| @ -107,6 +107,7 @@ static void accumulate( | |||||||
|   //  5) The other Tensor is not a Tensor subclass (except sparse), since |   //  5) The other Tensor is not a Tensor subclass (except sparse), since | ||||||
|   //     it's hard to predict the semantics of arbitrary subclass behavior. |   //     it's hard to predict the semantics of arbitrary subclass behavior. | ||||||
|  |  | ||||||
|  |   // NOLINTNEXTLINE(bugprone-branch-clone) | ||||||
|   if (at::GradMode::is_enabled()) { |   if (at::GradMode::is_enabled()) { | ||||||
|     buffer[pos] = old_var + var; |     buffer[pos] = old_var + var; | ||||||
|   } else if ( |   } else if ( | ||||||
|  | |||||||
| @ -64,8 +64,7 @@ void PyDefaultSavedVariableHooks::push_hooks( | |||||||
| } | } | ||||||
|  |  | ||||||
| void PyDefaultSavedVariableHooks::pop_hooks() { | void PyDefaultSavedVariableHooks::pop_hooks() { | ||||||
|   PyObject *pack_hook(nullptr), *unpack_hook(nullptr); |   auto [pack_hook, unpack_hook] = at::SavedTensorDefaultHooks::get_hooks(); | ||||||
|   std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks(); |  | ||||||
|   TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr); |   TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr); | ||||||
|   if (Py_IsInitialized()) { |   if (Py_IsInitialized()) { | ||||||
|     py::gil_scoped_acquire gil; |     py::gil_scoped_acquire gil; | ||||||
| @ -76,8 +75,7 @@ void PyDefaultSavedVariableHooks::pop_hooks() { | |||||||
| } | } | ||||||
|  |  | ||||||
| std::unique_ptr<SavedVariableHooks> PyDefaultSavedVariableHooks::get_hooks() { | std::unique_ptr<SavedVariableHooks> PyDefaultSavedVariableHooks::get_hooks() { | ||||||
|   PyObject *pack_hook(nullptr), *unpack_hook(nullptr); |   auto [pack_hook, unpack_hook] = at::SavedTensorDefaultHooks::get_hooks(); | ||||||
|   std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks(); |  | ||||||
|   if (!pack_hook || !unpack_hook) { |   if (!pack_hook || !unpack_hook) { | ||||||
|     return nullptr; |     return nullptr; | ||||||
|   } |   } | ||||||
|  | |||||||
| @ -222,7 +222,7 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { | |||||||
|   //     shared by multiple Tensors. See Note [ Using ForwardGrad ] |   //     shared by multiple Tensors. See Note [ Using ForwardGrad ] | ||||||
|   // Any transition from not_initialized to initialized |   // Any transition from not_initialized to initialized | ||||||
|   // must be protected by mutex_ |   // must be protected by mutex_ | ||||||
|   std::shared_ptr<ForwardGrad> fw_grad_; |   mutable std::shared_ptr<ForwardGrad> fw_grad_; | ||||||
|  |  | ||||||
|   // The hooks_ field is actually reused by both python and cpp logic |   // The hooks_ field is actually reused by both python and cpp logic | ||||||
|   // For both cases, we have a data structure, cpp_hooks_list_ (cpp) |   // For both cases, we have a data structure, cpp_hooks_list_ (cpp) | ||||||
| @ -308,7 +308,6 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { | |||||||
|     // set_requires_grad also checks error conditions. |     // set_requires_grad also checks error conditions. | ||||||
|     if (requires_grad) { |     if (requires_grad) { | ||||||
|       TORCH_INTERNAL_ASSERT(self_impl); |       TORCH_INTERNAL_ASSERT(self_impl); | ||||||
|       // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) |  | ||||||
|       set_requires_grad(requires_grad, self_impl); |       set_requires_grad(requires_grad, self_impl); | ||||||
|     } |     } | ||||||
|     TORCH_CHECK( |     TORCH_CHECK( | ||||||
| @ -765,7 +764,6 @@ inline Variable make_variable( | |||||||
|         data.getIntrusivePtr()->unique_version()) { |         data.getIntrusivePtr()->unique_version()) { | ||||||
|       auto data_impl = data.unsafeReleaseIntrusivePtr(); |       auto data_impl = data.unsafeReleaseIntrusivePtr(); | ||||||
|       data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); |       data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); | ||||||
|       // NOLINTNEXTLINE(bugprone-branch-clone) |  | ||||||
|       if (requires_grad) { |       if (requires_grad) { | ||||||
|         data_impl->set_autograd_meta( |         data_impl->set_autograd_meta( | ||||||
|             std::make_unique<AutogradMeta>(data_impl.get(), requires_grad)); |             std::make_unique<AutogradMeta>(data_impl.get(), requires_grad)); | ||||||
| @ -777,7 +775,6 @@ inline Variable make_variable( | |||||||
|       auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( |       auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( | ||||||
|           /*version_counter=*/0, |           /*version_counter=*/0, | ||||||
|           /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); |           /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); | ||||||
|       // NOLINTNEXTLINE(bugprone-branch-clone) |  | ||||||
|       if (requires_grad) { |       if (requires_grad) { | ||||||
|         data_impl_copy->set_autograd_meta(std::make_unique<AutogradMeta>( |         data_impl_copy->set_autograd_meta(std::make_unique<AutogradMeta>( | ||||||
|             data_impl_copy.get(), requires_grad)); |             data_impl_copy.get(), requires_grad)); | ||||||
|  | |||||||
| @ -83,8 +83,7 @@ void initializeDtypes() { | |||||||
|       AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE)}; |       AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_SCALAR_TYPE)}; | ||||||
|  |  | ||||||
|   for (at::ScalarType scalarType : all_scalar_types) { |   for (at::ScalarType scalarType : all_scalar_types) { | ||||||
|     std::string primary_name, legacy_name; |     auto [primary_name, legacy_name] = getDtypeNames(scalarType); | ||||||
|     std::tie(primary_name, legacy_name) = getDtypeNames(scalarType); |  | ||||||
|     PyObject* dtype = THPDtype_New(scalarType, primary_name); |     PyObject* dtype = THPDtype_New(scalarType, primary_name); | ||||||
|     torch::registerDtypeObject((THPDtype*)dtype, scalarType); |     torch::registerDtypeObject((THPDtype*)dtype, scalarType); | ||||||
|     Py_INCREF(dtype); |     Py_INCREF(dtype); | ||||||
|  | |||||||
| @ -385,10 +385,7 @@ Tensor internal_new_from_data( | |||||||
|       at::tracer::impl::NoTracerDispatchMode tracer_guard; |       at::tracer::impl::NoTracerDispatchMode tracer_guard; | ||||||
|  |  | ||||||
|       if (isStorage(data)) { |       if (isStorage(data)) { | ||||||
|         bool is_typed_storage = false; |         auto [storage, storage_scalar_type, is_typed_storage] = | ||||||
|         ScalarType storage_scalar_type{ScalarType::Undefined}; |  | ||||||
|         Storage storage; |  | ||||||
|         std::tie(storage, storage_scalar_type, is_typed_storage) = |  | ||||||
|             createStorageGetType(data); |             createStorageGetType(data); | ||||||
|  |  | ||||||
|         TORCH_CHECK( |         TORCH_CHECK( | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user