Enable clang-tidy on torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp (#143806)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143806
Approved by: https://github.com/kwen2501
This commit is contained in:
cyy
2025-01-24 12:22:13 +00:00
committed by PyTorch MergeBot
parent f08b9bc7e4
commit 6a35d9aaa4
10 changed files with 44 additions and 54 deletions

View File

@ -261,7 +261,6 @@ exclude_patterns = [
'torch/csrc/api/include/torch/linalg.h', 'torch/csrc/api/include/torch/linalg.h',
'torch/csrc/autograd/generated/**', 'torch/csrc/autograd/generated/**',
'torch/csrc/distributed/**/*.cu', 'torch/csrc/distributed/**/*.cu',
'torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp',
'torch/csrc/distributed/c10d/WinSockUtils.hpp', 'torch/csrc/distributed/c10d/WinSockUtils.hpp',
'torch/csrc/distributed/c10d/quantization/quantization_gpu.h', 'torch/csrc/distributed/c10d/quantization/quantization_gpu.h',
'torch/csrc/dynamo/eval_frame.h', 'torch/csrc/dynamo/eval_frame.h',

View File

@ -170,6 +170,10 @@ class File {
} }
SYSASSERT(fd_, "open(" + path + ")"); SYSASSERT(fd_, "open(" + path + ")");
} }
File(const File&) = delete;
File& operator=(const File&) = delete;
File(File&&) noexcept = delete;
File& operator=(File&&) noexcept = delete;
~File() { ~File() {
::close(fd_); ::close(fd_);

View File

@ -523,8 +523,6 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
exception_ = w.exception_; exception_ = w.exception_;
} }
ProcessGroupNCCL::WorkNCCL::~WorkNCCL() = default;
bool ProcessGroupNCCL::WorkNCCL::isCompleted() { bool ProcessGroupNCCL::WorkNCCL::isCompleted() {
if (!ncclComm_->isAborted()) { if (!ncclComm_->isAborted()) {
checkAndSetException(); checkAndSetException();
@ -1011,7 +1009,8 @@ ProcessGroupNCCL::ProcessGroupNCCL(
if (options_->global_ranks_in_group.empty()) { if (options_->global_ranks_in_group.empty()) {
this->globalRankStart = 0; this->globalRankStart = 0;
} else { } else {
this->globalRankStart = options_->global_ranks_in_group[0]; this->globalRankStart =
static_cast<int>(options_->global_ranks_in_group[0]);
} }
if (options_->global_ranks_in_group.empty()) { if (options_->global_ranks_in_group.empty()) {
@ -1033,8 +1032,9 @@ ProcessGroupNCCL::ProcessGroupNCCL(
} }
if (ranksAreStrided) { if (ranksAreStrided) {
this->globalRankStride = options_->global_ranks_in_group[1] - this->globalRankStride = static_cast<int>(
options_->global_ranks_in_group[0]; options_->global_ranks_in_group[1] -
options_->global_ranks_in_group[0]);
} else { } else {
this->globalRankStride = -1; this->globalRankStride = -1;
} }
@ -1087,7 +1087,7 @@ bool ProcessGroupNCCL::useNonblocking() {
} }
// 2nd priority: Respect the environment variable // 2nd priority: Respect the environment variable
else if (nbEnv.has_value()) { else if (nbEnv.has_value()) {
useNonblocking_ = nbEnv.value(); useNonblocking_ = nbEnv;
} }
// 3rd priority: automatically use nonblocking if we are in eager init mode // 3rd priority: automatically use nonblocking if we are in eager init mode
else if (getBoundDeviceId()) { else if (getBoundDeviceId()) {
@ -1711,7 +1711,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
} }
if (computeDeltaMS(lastTimeHeartBeatCheck, currentTime) >= if (computeDeltaMS(lastTimeHeartBeatCheck, currentTime) >=
heartbeatTimeoutInSec_ * 1000) { heartbeatTimeoutInSec_ * 1000l) {
// Check the heart beat of watchdog thread. // Check the heart beat of watchdog thread.
lastTimeHeartBeatCheck = currentTime; lastTimeHeartBeatCheck = currentTime;
auto heartbeat = heartbeat_.load(); auto heartbeat = heartbeat_.load();
@ -2122,21 +2122,24 @@ void ProcessGroupNCCL::watchdogHandler() {
kWorkStatusUpdatePeriodMs) { kWorkStatusUpdatePeriodMs) {
::c10d::C10dLoggingData data; ::c10d::C10dLoggingData data;
// logging integers // logging integers
data.integers["pg_id"] = local_id_; data.integers["pg_id"] = static_cast<int64_t>(local_id_);
data.integers["rank"] = rank_; data.integers["rank"] = rank_;
data.integers["global_rank"] = globalRank(); data.integers["global_rank"] = globalRank();
data.integers["last_enqueued_work"] = pgStatus_->lastEnqueuedSeq; data.integers["last_enqueued_work"] = pgStatus_->lastEnqueuedSeq;
data.integers["last_started_work"] = pgStatus_->lastStartedSeq; data.integers["last_started_work"] = pgStatus_->lastStartedSeq;
data.integers["last_completed_work"] = pgStatus_->lastCompletedSeq; data.integers["last_completed_work"] = pgStatus_->lastCompletedSeq;
data.integers["last_enqueued_numel_in"] = pgStatus_->lastEnqueuedNumelIn; data.integers["last_enqueued_numel_in"] =
static_cast<int64_t>(pgStatus_->lastEnqueuedNumelIn);
data.integers["last_enqueued_numel_out"] = data.integers["last_enqueued_numel_out"] =
pgStatus_->lastEnqueuedNumelOut; static_cast<int64_t>(pgStatus_->lastEnqueuedNumelOut);
data.integers["last_completed_numel_in"] = data.integers["last_completed_numel_in"] =
pgStatus_->lastCompletedNumelIn; static_cast<int64_t>(pgStatus_->lastCompletedNumelIn);
data.integers["last_completed_numel_out"] = data.integers["last_completed_numel_out"] =
pgStatus_->lastCompletedNumelOut; static_cast<int64_t>(pgStatus_->lastCompletedNumelOut);
data.integers["last_started_numel_in"] = pgStatus_->lastStartedNumelIn; data.integers["last_started_numel_in"] =
data.integers["last_started_numel_out"] = pgStatus_->lastStartedNumelOut; static_cast<int64_t>(pgStatus_->lastStartedNumelIn);
data.integers["last_started_numel_out"] =
static_cast<int64_t>(pgStatus_->lastStartedNumelOut);
// logging strings // logging strings
data.strings["last_enqueued_work_name"] = pgStatus_->lastEnqueuedWorkName; data.strings["last_enqueued_work_name"] = pgStatus_->lastEnqueuedWorkName;
data.strings["last_started_work_name"] = pgStatus_->lastStartedWorkName; data.strings["last_started_work_name"] = pgStatus_->lastStartedWorkName;
@ -2686,6 +2689,7 @@ std::shared_ptr<NCCLComm> ProcessGroupNCCL::initNCCLComm(
segmentInfo.device == device.index(), segmentInfo.device == device.index(),
"Mismatch between CUDA memory segment device and current device"); "Mismatch between CUDA memory segment device and current device");
ncclComm->registerSegment( ncclComm->registerSegment(
// NOLINTNEXTLINE(performance-no-int-to-ptr)
reinterpret_cast<void*>(segmentInfo.address), reinterpret_cast<void*>(segmentInfo.address),
segmentInfo.total_size); segmentInfo.total_size);
} }
@ -2911,7 +2915,7 @@ void ProcessGroupNCCL::workEnqueue(
// get deadlock. Here we enqueue work without outputs_. // get deadlock. Here we enqueue work without outputs_.
workMetaList_.emplace_back(*work); workMetaList_.emplace_back(*work);
// update the PG status related to the last enqueued work // update the PG status related to the last enqueued work
pgStatus_->lastEnqueuedSeq = work->seq_; pgStatus_->lastEnqueuedSeq = static_cast<int64_t>(work->seq_);
pgStatus_->lastEnqueuedWorkName = opTypeToString(work->opType_); pgStatus_->lastEnqueuedWorkName = opTypeToString(work->opType_);
pgStatus_->lastEnqueuedNumelIn = work->numelIn_; pgStatus_->lastEnqueuedNumelIn = work->numelIn_;
pgStatus_->lastEnqueuedNumelOut = work->numelOut_; pgStatus_->lastEnqueuedNumelOut = work->numelOut_;
@ -3860,7 +3864,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce(
TORCH_CHECK( TORCH_CHECK(
!isFloat8Type(tensor.scalar_type()), !isFloat8Type(tensor.scalar_type()),
"Float8 dtypes are not currenlty supported for NCCL reductions"); "Float8 dtypes are not currenlty supported for NCCL reductions");
// @lint-ignore CLANGTIDY
RECORD_PARAM_COMMS_DATA( RECORD_PARAM_COMMS_DATA(
std::make_tuple( std::make_tuple(
static_cast<int64_t>(seqCollective_) + 1, static_cast<int64_t>(seqCollective_) + 1,
@ -3891,7 +3894,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_coalesced(
!isFloat8Type(tensors.back().scalar_type()), !isFloat8Type(tensors.back().scalar_type()),
"Float8 dtypes are not currenlty supported for NCCL reductions"); "Float8 dtypes are not currenlty supported for NCCL reductions");
// @lint-ignore CLANGTIDY
RECORD_PARAM_COMMS_DATA( RECORD_PARAM_COMMS_DATA(
std::make_tuple( std::make_tuple(
static_cast<int64_t>(seqCollective_) + 1, static_cast<int64_t>(seqCollective_) + 1,
@ -3946,7 +3948,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::broadcast(
} }
check_gpu_single_tensor(tensor); check_gpu_single_tensor(tensor);
// @lint-ignore CLANGTIDY
RECORD_PARAM_COMMS_DATA( RECORD_PARAM_COMMS_DATA(
std::make_tuple( std::make_tuple(
static_cast<int64_t>(seqCollective_) + 1, static_cast<int64_t>(seqCollective_) + 1,
@ -3982,7 +3983,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::broadcast(
input.data_ptr(), input.data_ptr(),
input.numel(), input.numel(),
getNcclDataType(input.scalar_type()), getNcclDataType(input.scalar_type()),
root, static_cast<int>(root),
comm, comm,
stream.stream()); stream.stream());
}, },
@ -4022,7 +4023,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_broadcast_oop(
output.data_ptr(), output.data_ptr(),
input.numel(), input.numel(),
getNcclDataType(input.scalar_type()), getNcclDataType(input.scalar_type()),
root, static_cast<int>(root),
comm, comm,
stream.stream()); stream.stream());
}, },
@ -4036,7 +4037,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce(
std::vector<at::Tensor>& tensors, std::vector<at::Tensor>& tensors,
const ReduceOptions& opts) { const ReduceOptions& opts) {
TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
// @lint-ignore CLANGTIDY
auto tensor = tensors.back(); auto tensor = tensors.back();
if (tensor.is_complex()) { if (tensor.is_complex()) {
TORCH_CHECK( TORCH_CHECK(
@ -4083,7 +4083,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce(
input.numel(), input.numel(),
ncclDataType, ncclDataType,
ncclReduceOp, ncclReduceOp,
root, static_cast<int>(root),
comm, comm,
stream.stream()); stream.stream());
}, },
@ -4137,10 +4137,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
std::vector<at::Tensor>& inputTensors, std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts) { const AllgatherOptions& opts) {
TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
// @lint-ignore CLANGTIDY
auto inputTensor = inputTensors.back(); auto inputTensor = inputTensors.back();
check_gpu_single_tensor(inputTensor); check_gpu_single_tensor(inputTensor);
// @lint-ignore CLANGTIDY
auto outputTensors_ = outputTensors.back(); auto outputTensors_ = outputTensors.back();
RECORD_PARAM_COMMS_DATA( RECORD_PARAM_COMMS_DATA(
@ -4209,7 +4207,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
c10::cuda::CUDACachingAllocator::recordStream( c10::cuda::CUDACachingAllocator::recordStream(
outputTensors_[j].storage().data_ptr(), ncclStream); outputTensors_[j].storage().data_ptr(), ncclStream);
} }
outputTensors_[j].copy_(outputFlattened[j], true); outputTensors_[j].copy_(
outputFlattened[static_cast<int64_t>(j)], true);
} }
}, },
OpType::ALLGATHER, OpType::ALLGATHER,
@ -4217,11 +4216,10 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
} else { } else {
const auto num_reduces = outputTensors_.size(); const auto num_reduces = outputTensors_.size();
startCoalescing(); startCoalescing();
for (const int i : c10::irange(num_reduces)) { for (const int64_t i : c10::irange(static_cast<int64_t>(num_reduces))) {
auto& output = outputTensors_[i]; auto& output = outputTensors_[i];
auto& input = (i == rank_) ? inputTensor : output; auto& input = (i == rank_) ? inputTensor : output;
auto broadcastOpts = BroadcastOptions{ auto broadcastOpts = BroadcastOptions{i, int64_t(0), opts.timeout};
static_cast<int64_t>(i), static_cast<int64_t>(0), opts.timeout};
_broadcast_oop(output, input, broadcastOpts); _broadcast_oop(output, input, broadcastOpts);
} }
auto work = endCoalescing(OpType::ALLGATHER); auto work = endCoalescing(OpType::ALLGATHER);
@ -4242,7 +4240,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather_into_tensor_coalesced(
std::vector<at::Tensor>& outputs, std::vector<at::Tensor>& outputs,
std::vector<at::Tensor>& inputs, std::vector<at::Tensor>& inputs,
const AllgatherOptions& opts) { const AllgatherOptions& opts) {
// @lint-ignore CLANGTIDY
RECORD_PARAM_COMMS_DATA( RECORD_PARAM_COMMS_DATA(
std::make_tuple( std::make_tuple(
static_cast<int64_t>(seqCollective_) + 1, static_cast<int64_t>(seqCollective_) + 1,
@ -4286,10 +4283,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
std::vector<std::vector<at::Tensor>>& inputTensors, std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts) { const ReduceScatterOptions& opts) {
TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
// @lint-ignore CLANGTIDY
auto outputTensor = outputTensors.back(); auto outputTensor = outputTensors.back();
check_gpu_single_tensor(outputTensor); check_gpu_single_tensor(outputTensor);
// @lint-ignore CLANGTIDY
auto inputTensors_ = inputTensors.back(); auto inputTensors_ = inputTensors.back();
TORCH_CHECK( TORCH_CHECK(
!isFloat8Type(outputTensor.scalar_type()), !isFloat8Type(outputTensor.scalar_type()),
@ -4364,7 +4359,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
c10::cuda::CUDACachingAllocator::recordStream( c10::cuda::CUDACachingAllocator::recordStream(
inputTensors_[j].storage().data_ptr(), ncclStream); inputTensors_[j].storage().data_ptr(), ncclStream);
} }
inputFlattened[j].copy_(inputTensors_[j], true); inputFlattened[static_cast<int64_t>(j)].copy_(
inputTensors_[j], true);
} }
}, },
[&](at::cuda::CUDAStream&, [&](at::cuda::CUDAStream&,
@ -4374,7 +4370,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
} else { } else {
const auto num_reduces = inputTensors_.size(); const auto num_reduces = inputTensors_.size();
startCoalescing(); startCoalescing();
for (const int i : c10::irange(num_reduces)) { for (const int i : c10::irange(static_cast<int>(num_reduces))) {
auto& input = inputTensors_[i]; auto& input = inputTensors_[i];
auto& output = (i == rank_) ? outputTensor : input; auto& output = (i == rank_) ? outputTensor : input;
auto reduceOpts = ReduceOptions{ auto reduceOpts = ReduceOptions{
@ -4404,7 +4400,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
"input tensor must be the same size as output size times world size"); "input tensor must be the same size as output size times world size");
} }
// @lint-ignore CLANGTIDY
const auto& tensor = outputTensor; const auto& tensor = outputTensor;
TORCH_CHECK( TORCH_CHECK(
!isFloat8Type(tensor.scalar_type()), !isFloat8Type(tensor.scalar_type()),
@ -4474,7 +4469,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
!isFloat8Type(inputs.back().scalar_type()), !isFloat8Type(inputs.back().scalar_type()),
"Float8 dtypes are not currenlty supported for NCCL reductions"); "Float8 dtypes are not currenlty supported for NCCL reductions");
// @lint-ignore CLANGTIDY
RECORD_PARAM_COMMS_DATA( RECORD_PARAM_COMMS_DATA(
std::make_tuple( std::make_tuple(
static_cast<int64_t>(seqCollective_) + 1, static_cast<int64_t>(seqCollective_) + 1,
@ -4539,13 +4533,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::barrier(const BarrierOptions& opts) {
this->getSize()); // worldSize this->getSize()); // worldSize
// Device to use for barrier // Device to use for barrier
int barDevIdx = -1; c10::DeviceIndex barDevIdx = -1;
// Select device to use for barrier // Select device to use for barrier
// 1st choice: Use user defined GPU device ids if provided // 1st choice: Use user defined GPU device ids if provided
if (!opts.device_ids.empty()) { if (!opts.device_ids.empty()) {
// Use the first device id because PG NCCL is single-device now // Use the first device id because PG NCCL is single-device now
barDevIdx = opts.device_ids[0]; barDevIdx = static_cast<c10::DeviceIndex>(opts.device_ids[0]);
} else if (getBoundDeviceId()) { } else if (getBoundDeviceId()) {
// 2nd choice: Use the bound GPU device id if available. // 2nd choice: Use the bound GPU device id if available.
// Bounded device id can be passed to `init_process_group`. // Bounded device id can be passed to `init_process_group`.
@ -4562,12 +4556,12 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::barrier(const BarrierOptions& opts) {
// Note: it is better to use global rank because the group-local rank can be // Note: it is better to use global rank because the group-local rank can be
// offset wrt the device id if intra-node GPUs are sharded into multiple // offset wrt the device id if intra-node GPUs are sharded into multiple
// dimensions. // dimensions.
barDevIdx = static_cast<int16_t>(globalRank() % localDeviceCount_); barDevIdx = static_cast<c10::DeviceIndex>(globalRank() % localDeviceCount_);
LOG(WARNING) LOG(WARNING)
<< logPrefix() << logPrefix()
<< c10::str( << c10::str(
" using GPU ", " using GPU ",
barDevIdx, static_cast<int>(barDevIdx),
" to perform barrier as devices used by this process are currently unknown. ", " to perform barrier as devices used by this process are currently unknown. ",
"This can potentially cause a hang if this rank to GPU mapping is incorrect. ", "This can potentially cause a hang if this rank to GPU mapping is incorrect. ",
"Specify device_ids in barrier() to force use of a particular device, ", "Specify device_ids in barrier() to force use of a particular device, ",
@ -4578,8 +4572,7 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::barrier(const BarrierOptions& opts) {
ValueError, ValueError,
barDevIdx >= 0, barDevIdx >= 0,
"Failed to infer a GPU device id to perform barrier. "); "Failed to infer a GPU device id to perform barrier. ");
auto barDevice = at::Device( auto barDevice = at::Device(at::DeviceType::CUDA, barDevIdx);
at::DeviceType::CUDA, static_cast<c10::DeviceIndex>(barDevIdx));
// Create a dummy tensor on the device // Create a dummy tensor on the device
// Note: we use zeros() instead of empty() to prevent barrier from triggering // Note: we use zeros() instead of empty() to prevent barrier from triggering
@ -4776,7 +4769,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::send(
int dstRank, int dstRank,
int /* unused */) { int /* unused */) {
TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
// @lint-ignore CLANGTIDY
auto tensor = tensors.back(); auto tensor = tensors.back();
check_gpu_single_tensor(tensor, true); check_gpu_single_tensor(tensor, true);
@ -4825,7 +4817,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::recv(
int srcRank, int srcRank,
int /* unused */) { int /* unused */) {
TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
// @lint-ignore CLANGTIDY
auto tensor = tensors.back(); auto tensor = tensors.back();
check_gpu_single_tensor(tensor, true); check_gpu_single_tensor(tensor, true);
@ -4904,7 +4895,6 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::gather(
assertRootRank(invalidArgument, opts.rootRank, size_); assertRootRank(invalidArgument, opts.rootRank, size_);
TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
// @lint-ignore CLANGTIDY
auto inputTensor = inputTensors.back(); auto inputTensor = inputTensors.back();
std::vector<at::Tensor> outputs; std::vector<at::Tensor> outputs;

View File

@ -286,7 +286,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// destructs outputs_ tensors who are view tensors in autograd graph. // destructs outputs_ tensors who are view tensors in autograd graph.
WorkNCCL(const WorkNCCL& w); WorkNCCL(const WorkNCCL& w);
~WorkNCCL() override; ~WorkNCCL() override = default;
// Checks if the NCCL kernel has started to execute. // Checks if the NCCL kernel has started to execute.
bool isStarted(); bool isStarted();
@ -1161,7 +1161,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::unordered_map<std::string, at::cuda::CUDAEvent> ncclEvents_; std::unordered_map<std::string, at::cuda::CUDAEvent> ncclEvents_;
// Device Indexes used for all collectives in this group // Device Indexes used for all collectives in this group
std::set<int> usedDeviceIdxs_; std::set<c10::DeviceIndex> usedDeviceIdxs_;
// Flag to denote if a coalescing groupStart/groupEnd block is active // Flag to denote if a coalescing groupStart/groupEnd block is active
int coalescing_state_ = 0; int coalescing_state_ = 0;

View File

@ -71,6 +71,7 @@ struct TORCH_API ReduceOp : torch::CustomClassHolder {
ReduceOp(ReduceOp&& other) = default; ReduceOp(ReduceOp&& other) = default;
ReduceOp& operator=(ReduceOp&& other) = default; ReduceOp& operator=(ReduceOp&& other) = default;
~ReduceOp() override = default;
operator RedOpType() const { operator RedOpType() const {
return op_; return op_;

View File

@ -6,8 +6,6 @@
#pragma once #pragma once
#include <stdexcept>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>

View File

@ -548,7 +548,7 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
py::init( py::init(
[](std::vector<at::Tensor> params, [](std::vector<at::Tensor> params,
std::vector<std::vector<size_t>> bucket_indices, std::vector<std::vector<size_t>> bucket_indices,
std::vector<size_t> per_bucket_size_limits, const std::vector<size_t>& per_bucket_size_limits,
c10::intrusive_ptr<::c10d::ProcessGroup> process_group, c10::intrusive_ptr<::c10d::ProcessGroup> process_group,
std::vector<bool> expect_sparse_gradients, std::vector<bool> expect_sparse_gradients,
int64_t bucket_bytes_cap, int64_t bucket_bytes_cap,
@ -563,7 +563,6 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
return std::make_unique<::c10d::Reducer>( return std::make_unique<::c10d::Reducer>(
std::move(params), std::move(params),
std::move(bucket_indices), std::move(bucket_indices),
std::move(per_bucket_size_limits),
std::move(process_group), std::move(process_group),
std::move(expect_sparse_gradients), std::move(expect_sparse_gradients),
bucket_bytes_cap, bucket_bytes_cap,
@ -2939,7 +2938,7 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
py::gil_scoped_release nogil{}; py::gil_scoped_release nogil{};
return c10::make_intrusive<::c10d::ProcessGroupNCCL>( return c10::make_intrusive<::c10d::ProcessGroupNCCL>(
store, rank, size, options); store, rank, size, std::move(options));
}), }),
py::arg("store"), py::arg("store"),
py::arg("rank"), py::arg("rank"),

View File

@ -19,6 +19,7 @@ enum class LogLevel { Trace, Debug, Info, Warning, Error };
TORCH_API bool isLogLevelEnabled(LogLevel level) noexcept; TORCH_API bool isLogLevelEnabled(LogLevel level) noexcept;
template <typename... T> template <typename... T>
// NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
std::string formatLogMessage(fmt::string_view fmt, T&&... args) { std::string formatLogMessage(fmt::string_view fmt, T&&... args) {
return fmt::vformat(fmt, fmt::make_format_args(args...)); return fmt::vformat(fmt, fmt::make_format_args(args...));
} }

View File

@ -90,7 +90,6 @@ std::vector<at::Tensor> extractTensors(const c10::IValue& result) {
Reducer::Reducer( Reducer::Reducer(
std::vector<at::Tensor> params, std::vector<at::Tensor> params,
std::vector<std::vector<size_t>> bucket_indices, std::vector<std::vector<size_t>> bucket_indices,
const std::vector<size_t>& per_bucket_size_limits,
c10::intrusive_ptr<c10d::ProcessGroup> process_group, c10::intrusive_ptr<c10d::ProcessGroup> process_group,
std::vector<bool> expect_sparse_gradients, std::vector<bool> expect_sparse_gradients,
int64_t bucket_bytes_cap, int64_t bucket_bytes_cap,

View File

@ -51,7 +51,6 @@ class TORCH_API Reducer {
explicit Reducer( explicit Reducer(
std::vector<at::Tensor> params, std::vector<at::Tensor> params,
std::vector<std::vector<size_t>> bucket_indices, std::vector<std::vector<size_t>> bucket_indices,
const std::vector<size_t>& per_bucket_size_limits,
c10::intrusive_ptr<c10d::ProcessGroup> process_group, c10::intrusive_ptr<c10d::ProcessGroup> process_group,
std::vector<bool> expect_sparse_gradients, std::vector<bool> expect_sparse_gradients,
int64_t bucket_bytes_cap, int64_t bucket_bytes_cap,