Revert "[Distributed] [2/N] Fix clang-tidy warnings in torch/csrc/distributed/c10d (#122892)"

This reverts commit 0ba16ffd35af3eb56da4892cc5387c5e8ac864bb.

Reverted https://github.com/pytorch/pytorch/pull/122892 on behalf of https://github.com/atalman due to broke cuda tests ([comment](https://github.com/pytorch/pytorch/pull/122892#issuecomment-2037207036))
This commit is contained in:
PyTorch MergeBot
2024-04-04 13:22:22 +00:00
parent 6890333e3d
commit 54801e6fd6
3 changed files with 34 additions and 37 deletions

View File

@ -1,3 +1,5 @@
#include <shared_mutex>
#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <c10/core/DispatchKey.h>
@ -5,7 +7,6 @@
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/RankLocal.hpp>
#include <utility>
namespace {
@ -13,7 +14,7 @@ class WorkRegistry {
public:
void register_work(
const at::Tensor& tensor,
const c10::intrusive_ptr<c10d::Work>& work) {
c10::intrusive_ptr<c10d::Work> work) {
const auto storage = tensor.storage().getWeakStorageImpl();
std::unique_lock lock(lock_);
auto [it, inserted] = registry_.emplace(storage, work);
@ -49,8 +50,8 @@ class WorkRegistry {
"is invoked on all tensors returned from c10d_functional collective "
"ops before they are used.");
}
for (auto& it : registry_) {
it.second.release();
for (auto it = registry_.begin(); it != registry_.end(); ++it) {
it->second.release();
}
}
@ -66,7 +67,7 @@ static WorkRegistry process_registry;
void register_work(
const at::Tensor& tensor,
const c10::intrusive_ptr<c10d::Work>& work) {
c10::intrusive_ptr<c10d::Work> work) {
if (c10d::get_thread_isolation_mode()) {
c10d::RankLocal<WorkRegistry>::get().register_work(tensor, work);
} else {
@ -104,8 +105,8 @@ c10d::ReduceOp to_reduce_op(const std::string& reduce_op) {
at::Tensor& all_reduce_(
at::Tensor& input,
const std::string& reduce_op,
const std::string& group_name) {
std::string reduce_op,
std::string group_name) {
c10d::AllreduceOptions opts;
opts.reduceOp = to_reduce_op(reduce_op);
@ -118,16 +119,16 @@ at::Tensor& all_reduce_(
at::Tensor all_reduce(
const at::Tensor& input,
const std::string& reduce_op,
const std::string& group_name) {
std::string reduce_op,
std::string group_name) {
auto output = input.clone(at::MemoryFormat::Contiguous);
return all_reduce_(output, reduce_op, group_name);
}
std::vector<at::Tensor> all_reduce_coalesced_(
std::vector<at::Tensor> inputs,
const std::string& reduce_op,
const std::string& group_name) {
std::string reduce_op,
std::string group_name) {
c10d::AllreduceCoalescedOptions opts;
opts.reduceOp = to_reduce_op(reduce_op);
@ -140,9 +141,9 @@ std::vector<at::Tensor> all_reduce_coalesced_(
}
std::vector<at::Tensor> all_reduce_coalesced(
const std::vector<at::Tensor>& inputs,
const std::string& reduce_op,
const std::string& group_name) {
std::vector<at::Tensor> inputs,
std::string reduce_op,
std::string group_name) {
std::vector<at::Tensor> outputs;
outputs.reserve(inputs.size());
for (const auto& tensor : inputs) {
@ -164,9 +165,8 @@ at::Tensor allocate_all_gather_output(
std::vector<at::Tensor> all_gather_into_tensor_coalesced(
std::vector<at::Tensor> inputs,
int64_t group_size,
const std::string& group_name) {
std::string group_name) {
std::vector<at::Tensor> outputs;
outputs.reserve(inputs.size());
for (const auto& tensor : inputs) {
outputs.push_back(allocate_all_gather_output(tensor, group_size));
}
@ -183,7 +183,7 @@ std::vector<at::Tensor> all_gather_into_tensor_coalesced(
at::Tensor all_gather_into_tensor(
const at::Tensor& input,
int64_t group_size,
const std::string& group_name) {
std::string group_name) {
std::vector<at::Tensor> inputs{input};
return all_gather_into_tensor_coalesced(inputs, group_size, group_name)[0];
}
@ -205,13 +205,12 @@ at::Tensor allocate_reduce_scatter_output(
std::vector<at::Tensor> reduce_scatter_tensor_coalesced(
std::vector<at::Tensor> inputs,
const std::string& reduce_op,
std::string reduce_op,
int64_t group_size,
const std::string& group_name) {
std::string group_name) {
c10d::ReduceScatterOptions opts;
opts.reduceOp = to_reduce_op(reduce_op);
std::vector<at::Tensor> outputs;
outputs.reserve(inputs.size());
for (const auto& tensor : inputs) {
outputs.push_back(allocate_reduce_scatter_output(tensor, group_size));
}
@ -227,9 +226,9 @@ std::vector<at::Tensor> reduce_scatter_tensor_coalesced(
at::Tensor reduce_scatter_tensor(
const at::Tensor& input,
const std::string& reduce_op,
std::string reduce_op,
int64_t group_size,
const std::string& group_name) {
std::string group_name) {
std::vector<at::Tensor> inputs{input};
return reduce_scatter_tensor_coalesced(
inputs, reduce_op, group_size, group_name)[0];
@ -239,10 +238,10 @@ at::Tensor all_to_all_single(
const at::Tensor& input,
std::vector<int64_t> output_split_sizes,
std::vector<int64_t> input_split_sizes,
const std::string& group_name) {
std::string group_name) {
std::vector<int64_t> output_sizes = input.sizes().vec();
output_sizes[0] = std::accumulate(
output_split_sizes.begin(), output_split_sizes.end(), int64_t(0));
output_sizes[0] =
std::accumulate(output_split_sizes.begin(), output_split_sizes.end(), 0);
auto output = input.new_empty(output_sizes);
auto group = c10d::resolve_process_group(group_name);
@ -255,10 +254,7 @@ at::Tensor all_to_all_single(
return output;
}
at::Tensor& broadcast_(
at::Tensor& input,
int64_t src,
const std::string& group_name) {
at::Tensor& broadcast_(at::Tensor& input, int64_t src, std::string group_name) {
c10d::BroadcastOptions opts;
opts.rootRank = src;
std::vector<at::Tensor> inputs{input};
@ -272,7 +268,7 @@ at::Tensor& broadcast_(
at::Tensor broadcast(
const at::Tensor& input,
int64_t src,
const std::string& group_name) {
std::string group_name) {
auto output = input.clone(at::MemoryFormat::Contiguous);
return broadcast_(output, src, group_name);
}

View File

@ -20,7 +20,6 @@
#include <torch/csrc/autograd/utils/lambda_post_hook.h>
#include <torch/csrc/distributed/c10d/comm.hpp>
#include <torch/csrc/distributed/c10d/logger.hpp>
#include <utility>
namespace c10d {
namespace {
@ -90,7 +89,7 @@ std::vector<at::Tensor> extractTensors(const c10::IValue& result) {
Reducer::Reducer(
std::vector<at::Tensor> params,
std::vector<std::vector<size_t>> bucket_indices,
const std::vector<size_t>& per_bucket_size_limits,
std::vector<size_t> per_bucket_size_limits,
c10::intrusive_ptr<c10d::ProcessGroup> process_group,
std::vector<bool> expect_sparse_gradients,
int64_t bucket_bytes_cap,
@ -450,7 +449,7 @@ void Reducer::mark_variable_ready_sparse(size_t variable_index) {
if (sparse_metadata_) {
grad = grad.coalesce();
REDUCER_CHECK(
!param_names_.empty(), logger_, "No parameter names were found");
param_names_.size() != 0, logger_, "No parameter names were found");
std::string& param_name = param_names_[variable_index];
auto iter = sparse_metadata_->find(param_name);
REDUCER_CHECK(
@ -633,7 +632,7 @@ void Reducer::delay_all_reduce() {
}
void Reducer::set_logger(std::weak_ptr<c10d::Logger> logger) {
logger_ = std::move(logger);
logger_ = logger;
}
// The function `autograd_hook` is called after the gradient for a

View File

@ -51,7 +51,7 @@ class TORCH_API Reducer {
explicit Reducer(
std::vector<at::Tensor> params,
std::vector<std::vector<size_t>> bucket_indices,
const std::vector<size_t>& per_bucket_size_limits,
std::vector<size_t> per_bucket_size_limits,
c10::intrusive_ptr<c10d::ProcessGroup> process_group,
std::vector<bool> expect_sparse_gradients,
int64_t bucket_bytes_cap,
@ -303,9 +303,11 @@ class TORCH_API Reducer {
using GradCallback = std::function<bool(at::Tensor&)>;
#ifndef _WIN32
static_assert(
std::is_same_v<
std::is_same<
GradCallback,
torch::distributed::autograd::DistAutogradContext::GradCallback>);
torch::distributed::autograd::DistAutogradContext::GradCallback>::
value,
"");
#endif
void runGradCallbackForVariable(at::Tensor& variable, GradCallback&& cb);