[structural binding][10/N] Replace std::tie with structural binding (#130784)

Follows  #130404

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130784
Approved by: https://github.com/malfet
This commit is contained in:
cyy
2024-07-16 10:28:14 +00:00
committed by PyTorch MergeBot
parent 747b38c131
commit 168e41009b
10 changed files with 17 additions and 36 deletions

View File

@ -489,8 +489,7 @@ void inline apply_grad_input(scalar_in* buffer_ptr, scalar_out* gin, int64_t siz
int64_t d = 0;
for (; d < size - (size % bVec::size()); d += bVec::size()) {
bVec gin_bvec = bVec::loadu(gin + d);
fVec gin_fvec0, gin_fvec1;
std::tie(gin_fvec0, gin_fvec1) = convert_to_float<scalar_out>(gin_bvec);
auto [gin_fvec0, gin_fvec1] = convert_to_float<scalar_out>(gin_bvec);
gin_fvec0 += fVec::loadu(buffer_ptr + d);
gin_fvec1 += fVec::loadu(buffer_ptr + d + fVec::size());
fVec(0).store(buffer_ptr + d);

View File

@ -421,11 +421,11 @@ struct ComputeLocation<scalar_t, GridSamplerPadding::Reflection, align_corners>
inline std::pair<Vec, Vec> apply_get_grad(const Vec &in) const {
auto [res, grad_refl] = reflect_coordinates_get_grad(unnormalize(in));
Vec grad_clip, grad(scaling_factor);
Vec grad(scaling_factor);
grad = grad_refl * grad;
std::tie(res, grad_clip) = clip_coordinates_get_grad(res);
auto [res2, grad_clip] = clip_coordinates_get_grad(res);
grad = grad_clip & grad;
return std::make_pair(res, grad);
return std::make_pair(res2, grad);
}
};

View File

@ -994,14 +994,13 @@ struct HelperInterpBase {
double scale = area_pixel_compute_scale<double>(
input_size, output_size, align_corners, opt_scale);
std::vector<Tensor> indices_weights;
double wt_max;
std::tie(indices_weights, interp_size, wt_max) = HelperInterpBase::_compute_index_ranges_weights<double, aa_filter_fn_t, sizeof(int16_t)>(
auto [indices_weights, aligned_interp_size, wt_max] = HelperInterpBase::_compute_index_ranges_weights<double, aa_filter_fn_t, sizeof(int16_t)>(
input_size, output_size, stride, ndims, reshape_dim, scale, interp_size, aa_filter_fn, antialias, align_corners);
interp_size = aligned_interp_size;
// Rescale float weights to int16 and compute weights precision
auto weights_f64 = indices_weights[3];
double * data_f64 = weights_f64.data_ptr<double>();
double * data_f64 = weights_f64. template data_ptr<double>();
unsigned int weights_precision = 0;
for (weights_precision = 0; weights_precision < 22; ++weights_precision) {
@ -1012,7 +1011,6 @@ struct HelperInterpBase {
// Rescale float values to int16
int16_t * data_i16 = (int16_t *) data_f64;
auto aligned_interp_size = interp_size;
if (align_i32) {
// We should respect int32 alignment as we will load int16 data as int32

View File

@ -362,8 +362,7 @@ static void histogram_select_outer_bin_edges_kernel(const Tensor& input,
const int64_t N,
std::vector<double>& leftmost_edges,
std::vector<double>& rightmost_edges) {
Tensor min, max;
std::tie(min, max) = at::aminmax(input, 0);
auto [min, max] = at::aminmax(input, 0);
for (const auto i : c10::irange(N)) {
leftmost_edges[i] = min[i].item().to<double>();

View File

@ -448,13 +448,10 @@ static Tensor& addmm_out_mps_impl(const Tensor& bias,
string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" +
std::to_string(beta.toDouble()) + ":" + std::to_string(alpha.toDouble());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* selfTensor = nil;
MPSGraphTensor* otherTensor = nil;
MPSGraphTensor* productTensor = nil;
MPSGraphTensor* biasTensor = mpsGraphRankedPlaceHolder(mpsGraph, *bias_);
// TODO: Use alpha and beta here with fill_.Scalar and mul
std::tie(selfTensor, otherTensor, productTensor) = do_mm(mpsGraph, self, other);
auto [selfTensor, otherTensor, productTensor] = do_mm(mpsGraph, self, other);
auto productTimesAlphaTensor = productTensor;
if (alpha.toDouble() != 1.0) {

View File

@ -415,8 +415,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _batch_norm_with_update_mps(const Ten
Tensor& running_var,
double momentum,
double eps) {
Tensor output, save_mean, save_var;
std::tie(output, save_mean, save_var) =
auto [output, save_mean, save_var] =
batch_norm_mps(input, weight_opt, bias_opt, running_mean, running_var, /*train*/ true, momentum, eps);
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, save_mean, save_var, reserve);

View File

@ -323,8 +323,7 @@ void _sparse_binary_op_intersection_kernel_impl(
//
// intersection_count and intersection_first_idx are used to form indices at which
// intersection values are selected.
Tensor intersection_count, intersection_first_idx;
std::tie(intersection_count, intersection_first_idx) = [&]() -> std::tuple<Tensor, Tensor> {
auto [intersection_count, intersection_first_idx] = [&]() -> std::tuple<Tensor, Tensor> {
const auto source_nnz = source._nnz();
auto intersection_buffer = at::empty({2, source_nnz}, sorted_hash.options());
auto intersection_count = intersection_buffer.select(0, 0);

View File

@ -479,9 +479,8 @@ Tensor reduce_sparse_csr_dim0_cuda_template(const Tensor& sparse, ReductionOp ro
Tensor values = sparse.values();
auto ncols = sparse.size(1);
auto nnz = col_indices.numel();
Tensor new_col_indices;
std::tie(new_col_indices, std::ignore) = at::_unique(col_indices, true, false);
auto new_col_indices = std::get<0>(at::_unique(col_indices, true, false));
auto new_nnz = new_col_indices.numel();
Tensor new_crow_indices = at::tensor(ArrayRef<int64_t>{0, new_nnz}, col_indices.options());

View File

@ -26,8 +26,7 @@ TEST(GraphExecutorTest, Basic_CUDA) {
auto stack = createStack({input, hx, cx, w_ih, w_hh});
executor.run(stack);
ASSERT_EQ(stack.size(), 2);
at::Tensor r0, r1;
std::tie(r0, r1) = lstm(input, hx, cx, w_ih, w_hh);
auto [r0, r1] = lstm(input, hx, cx, w_ih, w_hh);
ASSERT_TRUE(almostEqual(stack[0].toTensor(), r0));
ASSERT_TRUE(almostEqual(stack[1].toTensor(), r1));
}

View File

@ -24,9 +24,7 @@ TEST(TensorpipeSerialize, Base) {
c10::make_intrusive<torch::distributed::rpc::Message>(
std::move(payload), std::move(tensors), mtype);
sendingRpcMessage->setId(mId);
tensorpipe::Message sendingTpMessage;
torch::distributed::rpc::TensorpipeWriteBuffers sendingTpBuffers;
std::tie(sendingTpMessage, sendingTpBuffers) =
auto [sendingTpMessage, sendingTpBuffers] =
torch::distributed::rpc::tensorpipeSerialize(
std::move(sendingRpcMessage), {}, {});
@ -58,9 +56,7 @@ TEST(TensorpipeSerialize, Base) {
// Mimic readDescriptor() callback:
// - Allocate buffers
// - Fill pointers in tensorpipe message
tensorpipe::Allocation recvingTpAllocation;
torch::distributed::rpc::TensorpipeReadBuffers recvingTpBuffers;
std::tie(recvingTpAllocation, recvingTpBuffers) =
auto [recvingTpAllocation, recvingTpBuffers] =
torch::distributed::rpc::tensorpipeAllocate(recvingTpDescriptor, {});
// Mimic tensorpipe data transfer
@ -117,9 +113,7 @@ TEST(TensorpipeSerialize, RecopySparseTensors) {
c10::make_intrusive<torch::distributed::rpc::Message>(
std::move(payload), std::move(tensors), mtype);
tensorpipe::Message sendingTpMessage;
torch::distributed::rpc::TensorpipeWriteBuffers tpBuffers;
std::tie(sendingTpMessage, tpBuffers) =
auto [sendingTpMessage, tpBuffers] =
torch::distributed::rpc::tensorpipeSerialize(
std::move(sendingRpcMessage), {}, {});
@ -150,9 +144,7 @@ TEST(TensorpipeSerialize, NoDeleterTensors) {
c10::make_intrusive<torch::distributed::rpc::Message>(
std::move(payload), std::move(tensors), mtype);
tensorpipe::Message sendingTpMessage;
torch::distributed::rpc::TensorpipeWriteBuffers tpBuffers;
std::tie(sendingTpMessage, tpBuffers) =
auto [sendingTpMessage, tpBuffers] =
torch::distributed::rpc::tensorpipeSerialize(
std::move(sendingRpcMessage), {}, {});