Enable clang-tidy on torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp (#143806)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143806
Approved by: https://github.com/kwen2501
This commit is contained in:
cyy
2025-01-24 12:22:13 +00:00
committed by PyTorch MergeBot
parent f08b9bc7e4
commit 6a35d9aaa4
10 changed files with 44 additions and 54 deletions

View File

@ -261,7 +261,6 @@ exclude_patterns = [
'torch/csrc/api/include/torch/linalg.h',
'torch/csrc/autograd/generated/**',
'torch/csrc/distributed/**/*.cu',
'torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp',
'torch/csrc/distributed/c10d/WinSockUtils.hpp',
'torch/csrc/distributed/c10d/quantization/quantization_gpu.h',
'torch/csrc/dynamo/eval_frame.h',

View File

@ -170,6 +170,10 @@ class File {
}
SYSASSERT(fd_, "open(" + path + ")");
}
File(const File&) = delete;
File& operator=(const File&) = delete;
File(File&&) noexcept = delete;
File& operator=(File&&) noexcept = delete;
~File() {
::close(fd_);

View File

@ -523,8 +523,6 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
exception_ = w.exception_;
}
ProcessGroupNCCL::WorkNCCL::~WorkNCCL() = default;
bool ProcessGroupNCCL::WorkNCCL::isCompleted() {
if (!ncclComm_->isAborted()) {
checkAndSetException();
@ -1011,7 +1009,8 @@ ProcessGroupNCCL::ProcessGroupNCCL(
if (options_->global_ranks_in_group.empty()) {
this->globalRankStart = 0;
} else {
this->globalRankStart = options_->global_ranks_in_group[0];
this->globalRankStart =
static_cast<int>(options_->global_ranks_in_group[0]);
}
if (options_->global_ranks_in_group.empty()) {
@ -1033,8 +1032,9 @@ ProcessGroupNCCL::ProcessGroupNCCL(
}
if (ranksAreStrided) {
this->globalRankStride = options_->global_ranks_in_group[1] -
options_->global_ranks_in_group[0];
this->globalRankStride = static_cast<int>(
options_->global_ranks_in_group[1] -
options_->global_ranks_in_group[0]);
} else {
this->globalRankStride = -1;
}
@ -1087,7 +1087,7 @@ bool ProcessGroupNCCL::useNonblocking() {
}
// 2nd priority: Respect the environment variable
else if (nbEnv.has_value()) {
useNonblocking_ = nbEnv.value();
useNonblocking_ = nbEnv;
}
// 3rd priority: automatically use nonblocking if we are in eager init mode
else if (getBoundDeviceId()) {
@ -1711,7 +1711,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
}
if (computeDeltaMS(lastTimeHeartBeatCheck, currentTime) >=
heartbeatTimeoutInSec_ * 1000) {
heartbeatTimeoutInSec_ * 1000l) {
// Check the heart beat of watchdog thread.
lastTimeHeartBeatCheck = currentTime;
auto heartbeat = heartbeat_.load();
@ -2122,21 +2122,24 @@ void ProcessGroupNCCL::watchdogHandler() {
kWorkStatusUpdatePeriodMs) {
::c10d::C10dLoggingData data;
// logging integers
data.integers["pg_id"] = local_id_;
data.integers["pg_id"] = static_cast<int64_t>(local_id_);
data.integers["rank"] = rank_;
data.integers["global_rank"] = globalRank();
data.integers["last_enqueued_work"] = pgStatus_->lastEnqueuedSeq;
data.integers["last_started_work"] = pgStatus_->lastStartedSeq;
data.integers["last_completed_work"] = pgStatus_->lastCompletedSeq;
data.integers["last_enqueued_numel_in"] = pgStatus_->lastEnqueuedNumelIn;
data.integers["last_enqueued_numel_in"] =
static_cast<int64_t>(pgStatus_->lastEnqueuedNumelIn);
data.integers["last_enqueued_numel_out"] =
pgStatus_->lastEnqueuedNumelOut;
static_cast<int64_t>(pgStatus_->lastEnqueuedNumelOut);
data.integers["last_completed_numel_in"] =
pgStatus_->lastCompletedNumelIn;
static_cast<int64_t>(pgStatus_->lastCompletedNumelIn);
data.integers["last_completed_numel_out"] =
pgStatus_->lastCompletedNumelOut;
data.integers["last_started_numel_in"] = pgStatus_->lastStartedNumelIn;
data.integers["last_started_numel_out"] = pgStatus_->lastStartedNumelOut;
static_cast<int64_t>(pgStatus_->lastCompletedNumelOut);
data.integers["last_started_numel_in"] =
static_cast<int64_t>(pgStatus_->lastStartedNumelIn);
data.integers["last_started_numel_out"] =
static_cast<int64_t>(pgStatus_->lastStartedNumelOut);
// logging strings
data.strings["last_enqueued_work_name"] = pgStatus_->lastEnqueuedWorkName;
data.strings["last_started_work_name"] = pgStatus_->lastStartedWorkName;
@ -2686,6 +2689,7 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::initNCCLComm(
segmentInfo.device == device.index(),
"Mismatch between CUDA memory segment device and current device");
ncclComm->registerSegment(
// NOLINTNEXTLINE(performance-no-int-to-ptr)
reinterpret_cast<void*>(segmentInfo.address),
segmentInfo.total_size);
}
@ -2911,7 +2915,7 @@ void ProcessGroupNCCL::workEnqueue(
// get deadlock. Here we enqueue work without outputs_.
workMetaList_.emplace_back(*work);
// update the PG status related to the last enqueued work
pgStatus_->lastEnqueuedSeq = work->seq_;
pgStatus_->lastEnqueuedSeq = static_cast<int64_t>(work->seq_);
pgStatus_->lastEnqueuedWorkName = opTypeToString(work->opType_);
pgStatus_->lastEnqueuedNumelIn = work->numelIn_;
pgStatus_->lastEnqueuedNumelOut = work->numelOut_;
@ -3860,7 +3864,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce(
TORCH_CHECK(
!isFloat8Type(tensor.scalar_type()),
"Float8 dtypes are not currenlty supported for NCCL reductions");
// @lint-ignore CLANGTIDY
RECORD_PARAM_COMMS_DATA(
std::make_tuple(
static_cast<int64_t>(seqCollective_) + 1,
@ -3891,7 +3894,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_coalesced(
!isFloat8Type(tensors.back().scalar_type()),
"Float8 dtypes are not currenlty supported for NCCL reductions");
// @lint-ignore CLANGTIDY
RECORD_PARAM_COMMS_DATA(
std::make_tuple(
static_cast<int64_t>(seqCollective_) + 1,
@ -3946,7 +3948,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::broadcast(
}
check_gpu_single_tensor(tensor);
// @lint-ignore CLANGTIDY
RECORD_PARAM_COMMS_DATA(
std::make_tuple(
static_cast<int64_t>(seqCollective_) + 1,
@ -3982,7 +3983,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::broadcast(
input.data_ptr(),
input.numel(),
getNcclDataType(input.scalar_type()),
root,
static_cast<int>(root),
comm,
stream.stream());
},
@ -4022,7 +4023,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_broadcast_oop(
output.data_ptr(),
input.numel(),
getNcclDataType(input.scalar_type()),
root,
static_cast<int>(root),
comm,
stream.stream());
},
@ -4036,7 +4037,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts) {
TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
// @lint-ignore CLANGTIDY
auto tensor = tensors.back();
if (tensor.is_complex()) {
TORCH_CHECK(
@ -4083,7 +4083,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce(
input.numel(),
ncclDataType,
ncclReduceOp,
root,
static_cast<int>(root),
comm,
stream.stream());
},
@ -4137,10 +4137,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts) {
TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
// @lint-ignore CLANGTIDY
auto inputTensor = inputTensors.back();
check_gpu_single_tensor(inputTensor);
// @lint-ignore CLANGTIDY
auto outputTensors_ = outputTensors.back();
RECORD_PARAM_COMMS_DATA(
@ -4209,7 +4207,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
c10::cuda::CUDACachingAllocator::recordStream(
outputTensors_[j].storage().data_ptr(), ncclStream);
}
outputTensors_[j].copy_(outputFlattened[j], true);
outputTensors_[j].copy_(
outputFlattened[static_cast<int64_t>(j)], true);
}
},
OpType::ALLGATHER,
@ -4217,11 +4216,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
} else {
const auto num_reduces = outputTensors_.size();
startCoalescing();
for (const int i : c10::irange(num_reduces)) {
for (const int64_t i : c10::irange(static_cast<int64_t>(num_reduces))) {
auto& output = outputTensors_[i];
auto& input = (i == rank_) ? inputTensor : output;
auto broadcastOpts = BroadcastOptions{
static_cast<int64_t>(i), static_cast<int64_t>(0), opts.timeout};
auto broadcastOpts = BroadcastOptions{i, int64_t(0), opts.timeout};
_broadcast_oop(output, input, broadcastOpts);
}
auto work = endCoalescing(OpType::ALLGATHER);
@ -4242,7 +4240,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather_into_tensor_coalesced(
std::vector<at::Tensor>& outputs,
std::vector<at::Tensor>& inputs,
const AllgatherOptions& opts) {
// @lint-ignore CLANGTIDY
RECORD_PARAM_COMMS_DATA(
std::make_tuple(
static_cast<int64_t>(seqCollective_) + 1,
@ -4286,10 +4283,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts) {
TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
// @lint-ignore CLANGTIDY
auto outputTensor = outputTensors.back();
check_gpu_single_tensor(outputTensor);
// @lint-ignore CLANGTIDY
auto inputTensors_ = inputTensors.back();
TORCH_CHECK(
!isFloat8Type(outputTensor.scalar_type()),
@ -4364,7 +4359,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
c10::cuda::CUDACachingAllocator::recordStream(
inputTensors_[j].storage().data_ptr(), ncclStream);
}
inputFlattened[j].copy_(inputTensors_[j], true);
inputFlattened[static_cast<int64_t>(j)].copy_(
inputTensors_[j], true);
}
},
[&](at::cuda::CUDAStream&,
@ -4374,7 +4370,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
} else {
const auto num_reduces = inputTensors_.size();
startCoalescing();
for (const int i : c10::irange(num_reduces)) {
for (const int i : c10::irange(static_cast<int>(num_reduces))) {
auto& input = inputTensors_[i];
auto& output = (i == rank_) ? outputTensor : input;
auto reduceOpts = ReduceOptions{
@ -4404,7 +4400,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
"input tensor must be the same size as output size times world size");
}
// @lint-ignore CLANGTIDY
const auto& tensor = outputTensor;
TORCH_CHECK(
!isFloat8Type(tensor.scalar_type()),
@ -4474,7 +4469,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
!isFloat8Type(inputs.back().scalar_type()),
"Float8 dtypes are not currenlty supported for NCCL reductions");
// @lint-ignore CLANGTIDY
RECORD_PARAM_COMMS_DATA(
std::make_tuple(
static_cast<int64_t>(seqCollective_) + 1,
@ -4539,13 +4533,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::barrier(const BarrierOptions& opts) {
this->getSize()); // worldSize
// Device to use for barrier
int barDevIdx = -1;
c10::DeviceIndex barDevIdx = -1;
// Select device to use for barrier
// 1st choice: Use user defined GPU device ids if provided
if (!opts.device_ids.empty()) {
// Use the first device id because PG NCCL is single-device now
barDevIdx = opts.device_ids[0];
barDevIdx = static_cast<c10::DeviceIndex>(opts.device_ids[0]);
} else if (getBoundDeviceId()) {
// 2nd choice: Use the bound GPU device id if available.
// Bounded device id can be passed to `init_process_group`.
@ -4562,12 +4556,12 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::barrier(const BarrierOptions& opts) {
// Note: it is better to use global rank because the group-local rank can be
// offset wrt the device id if intra-node GPUs are sharded into multiple
// dimensions.
barDevIdx = static_cast<int16_t>(globalRank() % localDeviceCount_);
barDevIdx = static_cast<c10::DeviceIndex>(globalRank() % localDeviceCount_);
LOG(WARNING)
<< logPrefix()
<< c10::str(
" using GPU ",
barDevIdx,
static_cast<int>(barDevIdx),
" to perform barrier as devices used by this process are currently unknown. ",
"This can potentially cause a hang if this rank to GPU mapping is incorrect. ",
"Specify device_ids in barrier() to force use of a particular device, ",
@ -4578,8 +4572,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::barrier(const BarrierOptions& opts) {
ValueError,
barDevIdx >= 0,
"Failed to infer a GPU device id to perform barrier. ");
auto barDevice = at::Device(
at::DeviceType::CUDA, static_cast<c10::DeviceIndex>(barDevIdx));
auto barDevice = at::Device(at::DeviceType::CUDA, barDevIdx);
// Create a dummy tensor on the device
// Note: we use zeros() instead of empty() to prevent barrier from triggering
@ -4776,7 +4769,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::send(
int dstRank,
int /* unused */) {
TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
// @lint-ignore CLANGTIDY
auto tensor = tensors.back();
check_gpu_single_tensor(tensor, true);
@ -4825,7 +4817,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::recv(
int srcRank,
int /* unused */) {
TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
// @lint-ignore CLANGTIDY
auto tensor = tensors.back();
check_gpu_single_tensor(tensor, true);
@ -4904,7 +4895,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::gather(
assertRootRank(invalidArgument, opts.rootRank, size_);
TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
// @lint-ignore CLANGTIDY
auto inputTensor = inputTensors.back();
std::vector<at::Tensor> outputs;

View File

@ -286,7 +286,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// destructs outputs_ tensors who are view tensors in autograd graph.
WorkNCCL(const WorkNCCL& w);
~WorkNCCL() override;
~WorkNCCL() override = default;
// Checks if the NCCL kernel has started to execute.
bool isStarted();
@ -1161,7 +1161,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::unordered_map<std::string, at::cuda::CUDAEvent> ncclEvents_;
// Device Indexes used for all collectives in this group
std::set<int> usedDeviceIdxs_;
std::set<c10::DeviceIndex> usedDeviceIdxs_;
// Flag to denote if a coalescing groupStart/groupEnd block is active
int coalescing_state_ = 0;

View File

@ -71,6 +71,7 @@ struct TORCH_API ReduceOp : torch::CustomClassHolder {
ReduceOp(ReduceOp&& other) = default;
ReduceOp& operator=(ReduceOp&& other) = default;
~ReduceOp() override = default;
operator RedOpType() const {
return op_;

View File

@ -6,8 +6,6 @@
#pragma once
#include <stdexcept>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>

View File

@ -548,7 +548,7 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
py::init(
[](std::vector<at::Tensor> params,
std::vector<std::vector<size_t>> bucket_indices,
std::vector<size_t> per_bucket_size_limits,
const std::vector<size_t>& per_bucket_size_limits,
c10::intrusive_ptr<::c10d::ProcessGroup> process_group,
std::vector<bool> expect_sparse_gradients,
int64_t bucket_bytes_cap,
@ -563,7 +563,6 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
return std::make_unique<::c10d::Reducer>(
std::move(params),
std::move(bucket_indices),
std::move(per_bucket_size_limits),
std::move(process_group),
std::move(expect_sparse_gradients),
bucket_bytes_cap,
@ -2939,7 +2938,7 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
py::gil_scoped_release nogil{};
return c10::make_intrusive<::c10d::ProcessGroupNCCL>(
store, rank, size, options);
store, rank, size, std::move(options));
}),
py::arg("store"),
py::arg("rank"),

View File

@ -19,6 +19,7 @@ enum class LogLevel { Trace, Debug, Info, Warning, Error };
TORCH_API bool isLogLevelEnabled(LogLevel level) noexcept;
template <typename... T>
// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
std::string formatLogMessage(fmt::string_view fmt, T&&... args) {
return fmt::vformat(fmt, fmt::make_format_args(args...));
}

View File

@ -90,7 +90,6 @@ std::vector<at::Tensor> extractTensors(const c10::IValue& result) {
Reducer::Reducer(
std::vector<at::Tensor> params,
std::vector<std::vector<size_t>> bucket_indices,
const std::vector<size_t>& per_bucket_size_limits,
c10::intrusive_ptr<c10d::ProcessGroup> process_group,
std::vector<bool> expect_sparse_gradients,
int64_t bucket_bytes_cap,

View File

@ -51,7 +51,6 @@ class TORCH_API Reducer {
explicit Reducer(
std::vector<at::Tensor> params,
std::vector<std::vector<size_t>> bucket_indices,
const std::vector<size_t>& per_bucket_size_limits,
c10::intrusive_ptr<c10d::ProcessGroup> process_group,
std::vector<bool> expect_sparse_gradients,
int64_t bucket_bytes_cap,