Apply clang-format to distributed/c10d folder (#107140)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107140
Approved by: https://github.com/H-Huang
This commit is contained in:
Shen Li
2023-08-14 06:59:47 -07:00
committed by PyTorch MergeBot
parent 858b465d74
commit dd6319198d
28 changed files with 370 additions and 307 deletions

View File

@ -11,9 +11,9 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
#include <torch/csrc/distributed/c10d/Work.hpp>
#include <torch/csrc/distributed/c10d/Types.hpp> #include <torch/csrc/distributed/c10d/Types.hpp>
#include <torch/csrc/distributed/c10d/Utils.hpp> #include <torch/csrc/distributed/c10d/Utils.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
#include <torch/csrc/distributed/c10d/debug.h> #include <torch/csrc/distributed/c10d/debug.h>
#include <torch/csrc/distributed/c10d/sequence_num.hpp> #include <torch/csrc/distributed/c10d/sequence_num.hpp>
@ -24,7 +24,6 @@ namespace c10d {
class TORCH_API Backend : public torch::CustomClassHolder { class TORCH_API Backend : public torch::CustomClassHolder {
public: public:
// Backend Options is a base struct that defines the basic options // Backend Options is a base struct that defines the basic options
// when constructing a Backend. Each Backend subclass should // when constructing a Backend. Each Backend subclass should
// extend this struct and define its options if it wants to provide more // extend this struct and define its options if it wants to provide more
@ -62,13 +61,17 @@ class TORCH_API Backend : public torch::CustomClassHolder {
virtual void startCoalescing() { virtual void startCoalescing() {
TORCH_CHECK( TORCH_CHECK(
false, false,
c10::str("Backend ", getBackendName(), " does not implement startCoalescing")); c10::str(
"Backend ",
getBackendName(),
" does not implement startCoalescing"));
} }
virtual c10::intrusive_ptr<Work> endCoalescing() { virtual c10::intrusive_ptr<Work> endCoalescing() {
TORCH_CHECK( TORCH_CHECK(
false, false,
c10::str("Backend ", getBackendName(), " does not implement endCoalescing")); c10::str(
"Backend ", getBackendName(), " does not implement endCoalescing"));
} }
// Subclasses must override this method to return the backend name // Subclasses must override this method to return the backend name
@ -215,8 +218,8 @@ class TORCH_API Backend : public torch::CustomClassHolder {
} }
// This function is a coalesced version of `reduce_scatter_tensor` (currently // This function is a coalesced version of `reduce_scatter_tensor` (currently
// still named as `_reduce_scatter_base`). Each tensor in the vector corresponds to // still named as `_reduce_scatter_base`). Each tensor in the vector
// an input/output of one `reduce_scatter_tensor` operation. // corresponds to an input/output of one `reduce_scatter_tensor` operation.
virtual c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced( virtual c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced(
std::vector<at::Tensor>& /* outputs */, std::vector<at::Tensor>& /* outputs */,
std::vector<at::Tensor>& /* inputs */, std::vector<at::Tensor>& /* inputs */,
@ -293,7 +296,8 @@ class TORCH_API Backend : public torch::CustomClassHolder {
int /* dstRank */, int /* dstRank */,
int /* tag */) { int /* tag */) {
TORCH_CHECK( TORCH_CHECK(
false, c10::str("Backend ", getBackendName(), " does not support send")); false,
c10::str("Backend ", getBackendName(), " does not support send"));
} }
virtual c10::intrusive_ptr<Work> recv( virtual c10::intrusive_ptr<Work> recv(
@ -301,7 +305,8 @@ class TORCH_API Backend : public torch::CustomClassHolder {
int /* srcRank */, int /* srcRank */,
int /* tag */) { int /* tag */) {
TORCH_CHECK( TORCH_CHECK(
false, c10::str("Backend ", getBackendName(), " does not support recv")); false,
c10::str("Backend ", getBackendName(), " does not support recv"));
} }
virtual c10::intrusive_ptr<Work> recvAnysource( virtual c10::intrusive_ptr<Work> recvAnysource(

View File

@ -11,7 +11,7 @@ namespace c10d {
class TORCH_API FileStore : public Store { class TORCH_API FileStore : public Store {
public: public:
explicit FileStore(std::string path, int numWorkers); explicit FileStore(std::string path, int numWorkers);
~FileStore() override; ~FileStore() override;

View File

@ -39,15 +39,15 @@ class TORCH_API HashStore : public Store {
bool deleteKey(const std::string& key) override; bool deleteKey(const std::string& key) override;
void append( void append(const std::string& key, const std::vector<uint8_t>& value)
const std::string& key, override;
const std::vector<uint8_t>& value) override;
std::vector<std::vector<uint8_t>> multiGet(const std::vector<std::string>& keys) override; std::vector<std::vector<uint8_t>> multiGet(
const std::vector<std::string>& keys) override;
void multiSet( void multiSet(
const std::vector<std::string>& keys, const std::vector<std::string>& keys,
const std::vector<std::vector<uint8_t>>& values) override; const std::vector<std::vector<uint8_t>>& values) override;
// Returns true if this store support append, multiGet and multiSet // Returns true if this store support append, multiGet and multiSet
bool hasExtendedApi() const override; bool hasExtendedApi() const override;

View File

@ -8,9 +8,9 @@
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <nccl.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <c10/util/Optional.h> #include <c10/util/Optional.h>
#include <nccl.h>
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR >= 14) (NCCL_MINOR >= 14)
@ -46,20 +46,22 @@
#define ENABLE_NCCL_P2P_SUPPORT #define ENABLE_NCCL_P2P_SUPPORT
#endif #endif
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && (NCCL_MINOR >= 11) #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR >= 11)
#define ENABLE_NCCL_PREMUL_SUM_SUPPORT #define ENABLE_NCCL_PREMUL_SUM_SUPPORT
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
#define ENABLE_NCCL_PREMUL_SUM_SUPPORT #define ENABLE_NCCL_PREMUL_SUM_SUPPORT
#endif #endif
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && (NCCL_MINOR >= 17) #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR >= 17)
#define NCCL_HAS_COMM_CTA_CGA #define NCCL_HAS_COMM_CTA_CGA
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
#define NCCL_HAS_COMM_CTA_CGA #define NCCL_HAS_COMM_CTA_CGA
#endif #endif
// Macro to throw on a non-successful NCCL return value. // Macro to throw on a non-successful NCCL return value.
#define C10D_NCCL_CHECK(cmd, failureReason) \ #define C10D_NCCL_CHECK(cmd, failureReason) \
do { \ do { \
ncclResult_t result = cmd; \ ncclResult_t result = cmd; \
if (result != ncclSuccess) { \ if (result != ncclSuccess) { \
@ -71,57 +73,63 @@
} while (0) } while (0)
// Macro to throw on a non-successful NCCL return value, non-blocking. // Macro to throw on a non-successful NCCL return value, non-blocking.
#define C10D_NCCL_CHECK_TIMEOUT(cmd, comm, failureReason) \ #define C10D_NCCL_CHECK_TIMEOUT(cmd, comm, failureReason) \
ncclResult_t result = cmd; \ ncclResult_t result = cmd; \
auto startTimepoint = std::chrono::steady_clock::now(); \ auto startTimepoint = std::chrono::steady_clock::now(); \
while (result == ncclInProgress) { \ while (result == ncclInProgress) { \
if (nccl_nonblocking_timeout() > 0) { \ if (nccl_nonblocking_timeout() > 0) { \
auto currentTimepoint = std::chrono::steady_clock::now(); \ auto currentTimepoint = std::chrono::steady_clock::now(); \
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(currentTimepoint - startTimepoint).count(); \ auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>( \
if (timeElapsed > nccl_nonblocking_timeout()) { \ currentTimepoint - startTimepoint) \
std::string err = "NCCL timeout in: " + std::string(__FILE__) + ":" + \ .count(); \
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \ if (timeElapsed > nccl_nonblocking_timeout()) { \
"\n" + getNcclErrorDetailStr(result, failureReason); \ std::string err = "NCCL timeout in: " + std::string(__FILE__) + ":" + \
TORCH_CHECK_WITH(DistBackendError, false, err); \ std::to_string(__LINE__) + ", " + \
} \ ncclGetErrorWithVersion(result) + "\n" + \
} \ getNcclErrorDetailStr(result, failureReason); \
ncclCommGetAsyncError(comm, &result); \ TORCH_CHECK_WITH(DistBackendError, false, err); \
} \ } \
if (result != ncclSuccess) { \ } \
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ ncclCommGetAsyncError(comm, &result); \
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \ } \
"\n" + getNcclErrorDetailStr(result, failureReason); \ if (result != ncclSuccess) { \
TORCH_CHECK_WITH(DistBackendError, false, err); \ std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \
"\n" + getNcclErrorDetailStr(result, failureReason); \
TORCH_CHECK_WITH(DistBackendError, false, err); \
} }
#define C10D_NCCL_CHECK_TIMEOUT_GROUPEND(cmd, comms_, failureReason) \ #define C10D_NCCL_CHECK_TIMEOUT_GROUPEND(cmd, comms_, failureReason) \
ncclResult_t state = cmd; \ ncclResult_t state = cmd; \
auto startTimepoint = std::chrono::steady_clock::now(); \ auto startTimepoint = std::chrono::steady_clock::now(); \
if (state == ncclInProgress) { \ if (state == ncclInProgress) { \
for (const auto i : c10::irange(comms_.size())) { \ for (const auto i : c10::irange(comms_.size())) { \
do { \ do { \
if (nccl_nonblocking_timeout() > 0) { \ if (nccl_nonblocking_timeout() > 0) { \
auto currentTimepoint = std::chrono::steady_clock::now(); \ auto currentTimepoint = std::chrono::steady_clock::now(); \
auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>(currentTimepoint - startTimepoint).count(); \ auto timeElapsed = std::chrono::duration_cast<std::chrono::seconds>( \
if (timeElapsed > nccl_nonblocking_timeout()) { \ currentTimepoint - startTimepoint) \
std::string err = "NCCL timeout in: " + std::string(__FILE__) + ":" + \ .count(); \
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) + \ if (timeElapsed > nccl_nonblocking_timeout()) { \
"\n" + getNcclErrorDetailStr(state, failureReason); \ std::string err = "NCCL timeout in: " + std::string(__FILE__) + \
TORCH_CHECK_WITH(DistBackendError, false, err); \ ":" + std::to_string(__LINE__) + ", " + \
} \ ncclGetErrorWithVersion(state) + "\n" + \
} \ getNcclErrorDetailStr(state, failureReason); \
ncclCommGetAsyncError(comms_[i]->getNcclComm(), &state); \ TORCH_CHECK_WITH(DistBackendError, false, err); \
} while (state == ncclInProgress); \ } \
if (state != ncclSuccess) { \ } \
break; /* fall through to failed case */ \ ncclCommGetAsyncError(comms_[i]->getNcclComm(), &state); \
} \ } while (state == ncclInProgress); \
} \ if (state != ncclSuccess) { \
} \ break; /* fall through to failed case */ \
if (state != ncclSuccess) { \ } \
std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ } \
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) + \ } \
"\n" + getNcclErrorDetailStr(state, failureReason); \ if (state != ncclSuccess) { \
TORCH_CHECK_WITH(DistBackendError, false, err); \ std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) + \
"\n" + getNcclErrorDetailStr(state, failureReason); \
TORCH_CHECK_WITH(DistBackendError, false, err); \
} }
// Macro to print and abort on a non-successful NCCL return value. // Macro to print and abort on a non-successful NCCL return value.
@ -150,8 +158,8 @@ int nccl_nonblocking_timeout();
// Provides additional detail into NCCL error codes based on when these are // Provides additional detail into NCCL error codes based on when these are
// thrown in the NCCL codebase. // thrown in the NCCL codebase.
std::string getNcclErrorDetailStr( std::string getNcclErrorDetailStr(
ncclResult_t error, ncclResult_t error,
c10::optional<std::string> processGroupFailureReason = c10::nullopt); c10::optional<std::string> processGroupFailureReason = c10::nullopt);
// RAII wrapper for NCCL communicator // RAII wrapper for NCCL communicator
class NCCLComm { class NCCLComm {
@ -186,7 +194,8 @@ class NCCLComm {
ncclUniqueId commId) { ncclUniqueId commId) {
auto comm = std::make_shared<NCCLComm>(); auto comm = std::make_shared<NCCLComm>();
C10D_NCCL_CHECK( C10D_NCCL_CHECK(
ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), c10::nullopt); ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank),
c10::nullopt);
comm->ncclId_ = commId; comm->ncclId_ = commId;
comm->rank_ = rank; comm->rank_ = rank;
return comm; return comm;
@ -202,10 +211,15 @@ class NCCLComm {
if (nccl_use_nonblocking()) { if (nccl_use_nonblocking()) {
config.blocking = 0; config.blocking = 0;
C10D_NCCL_CHECK_TIMEOUT( C10D_NCCL_CHECK_TIMEOUT(
ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), comm->ncclComm_, c10::nullopt); ncclCommInitRankConfig(
&(comm->ncclComm_), numRanks, commId, rank, &config),
comm->ncclComm_,
c10::nullopt);
} else { } else {
C10D_NCCL_CHECK( C10D_NCCL_CHECK(
ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), c10::nullopt); ncclCommInitRankConfig(
&(comm->ncclComm_), numRanks, commId, rank, &config),
c10::nullopt);
} }
comm->ncclId_ = commId; comm->ncclId_ = commId;
comm->rank_ = rank; comm->rank_ = rank;
@ -257,7 +271,7 @@ class NCCLComm {
C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_); C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_);
#else #else
C10D_NCCL_CHECK_TIMEOUT( C10D_NCCL_CHECK_TIMEOUT(
::ncclCommAbort(ncclComm_), ncclComm_, commFailureReason_); ::ncclCommAbort(ncclComm_), ncclComm_, commFailureReason_);
#endif #endif
aborted_ = true; aborted_ = true;
ncclComm_ = nullptr; ncclComm_ = nullptr;
@ -283,7 +297,8 @@ class NCCLComm {
if (ncclAsyncErr_ != ncclSuccess) { if (ncclAsyncErr_ != ncclSuccess) {
return ncclAsyncErr_; return ncclAsyncErr_;
} }
C10D_NCCL_CHECK(ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_); C10D_NCCL_CHECK(
ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_);
return ncclAsyncErr_; return ncclAsyncErr_;
#else #else
// Always return success, if error checks are disabled. // Always return success, if error checks are disabled.
@ -309,8 +324,8 @@ class NCCLComm {
struct ncclRedOpRAII { struct ncclRedOpRAII {
ncclRedOpRAII() = default; ncclRedOpRAII() = default;
ncclRedOpRAII(ncclRedOp_t op) : op_(op) {} ncclRedOpRAII(ncclRedOp_t op) : op_(op) {}
ncclRedOpRAII(ncclRedOp_t op, ncclComm_t comm) : ncclRedOpRAII(ncclRedOp_t op, ncclComm_t comm)
op_(op), comm_(comm), premul_sum_(true) {} : op_(op), comm_(comm), premul_sum_(true) {}
ncclRedOpRAII(const ncclRedOpRAII&) = delete; ncclRedOpRAII(const ncclRedOpRAII&) = delete;
ncclRedOpRAII& operator=(const ncclRedOpRAII&) = delete; ncclRedOpRAII& operator=(const ncclRedOpRAII&) = delete;
ncclRedOpRAII(ncclRedOpRAII&& tmp) : ncclRedOpRAII() { ncclRedOpRAII(ncclRedOpRAII&& tmp) : ncclRedOpRAII() {
@ -325,13 +340,14 @@ struct ncclRedOpRAII {
} }
} }
#endif #endif
operator ncclRedOp_t() const { return op_; } operator ncclRedOp_t() const {
return op_;
}
ncclRedOp_t op_; ncclRedOp_t op_;
ncclComm_t comm_; ncclComm_t comm_;
bool premul_sum_ = false; bool premul_sum_ = false;
}; };
} // namespace c10d } // namespace c10d
#endif // USE_C10D_NCCL #endif // USE_C10D_NCCL

View File

@ -1,29 +1,27 @@
#pragma once #pragma once
#include <string> #include <ATen/core/ivalue.h>
#include <vector> #include <ATen/record_function.h>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
#include <c10/util/ThreadLocalDebugInfo.h> #include <c10/util/ThreadLocalDebugInfo.h>
#include <ATen/record_function.h> #include <string>
#include <ATen/core/ivalue.h> #include <vector>
namespace torch { namespace torch {
extern TORCH_API const std::string kParamCommsCallName; extern TORCH_API const std::string kParamCommsCallName;
class TORCH_API ParamCommsDebugInfo class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
: public c10::DebugInfoBase {
public: public:
ParamCommsDebugInfo() = default; ParamCommsDebugInfo() = default;
ParamCommsDebugInfo( ParamCommsDebugInfo(
int rank, int rank,
std::string&& colName, std::string&& colName,
int inSize, int inSize,
int outSize, int outSize,
at::ScalarType dType, at::ScalarType dType,
std::vector<int64_t> inSplitSizes, std::vector<int64_t> inSplitSizes,
std::vector<int64_t> outSplitSizes); std::vector<int64_t> outSplitSizes);
~ParamCommsDebugInfo() override = default; ~ParamCommsDebugInfo() override = default;
@ -80,7 +78,7 @@ class TORCH_API ParamCommsDebugInfo
c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \ c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
std::initializer_list<const c10::IValue> paramList = { \ std::initializer_list<const c10::IValue> paramList = { \
c10::IValue(seq), \ c10::IValue(seq), \
c10::IValue(pg_ptr), \ c10::IValue(pg_ptr), \
rank, \ rank, \
colName, \ colName, \
inSplitSizes, \ inSplitSizes, \
@ -91,8 +89,8 @@ class TORCH_API ParamCommsDebugInfo
#define RECORD_PARAM_COMMS_DATA( \ #define RECORD_PARAM_COMMS_DATA( \
seq, \ seq, \
pg_ptr, \ pg_ptr, \
InputTensors, \ InputTensors, \
OutputTensors, \ OutputTensors, \
rank, \ rank, \
colName, \ colName, \
inSize, \ inSize, \
@ -104,7 +102,7 @@ class TORCH_API ParamCommsDebugInfo
rank, colName, inSize, outSize, dType, inSplitSizes, outSplitSizes); \ rank, colName, inSize, outSize, dType, inSplitSizes, outSplitSizes); \
c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \ c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
std::initializer_list<const c10::IValue> paramList = { \ std::initializer_list<const c10::IValue> paramList = { \
c10::IValue(InputTensors), \ c10::IValue(InputTensors), \
c10::IValue(seq), \ c10::IValue(seq), \
c10::IValue(pg_ptr), \ c10::IValue(pg_ptr), \
rank, \ rank, \

View File

@ -7,9 +7,7 @@ namespace c10d {
class TORCH_API PrefixStore : public Store { class TORCH_API PrefixStore : public Store {
public: public:
explicit PrefixStore( explicit PrefixStore(std::string prefix, c10::intrusive_ptr<Store> store);
std::string prefix,
c10::intrusive_ptr<Store> store);
~PrefixStore() override = default; ~PrefixStore() override = default;
@ -42,20 +40,19 @@ class TORCH_API PrefixStore : public Store {
void setTimeout(const std::chrono::milliseconds& timeout) override; void setTimeout(const std::chrono::milliseconds& timeout) override;
void append( void append(const std::string& key, const std::vector<uint8_t>& value)
const std::string& key, override;
const std::vector<uint8_t>& value) override;
std::vector<std::vector<uint8_t>> multiGet(const std::vector<std::string>& keys) override; std::vector<std::vector<uint8_t>> multiGet(
const std::vector<std::string>& keys) override;
void multiSet( void multiSet(
const std::vector<std::string>& keys, const std::vector<std::string>& keys,
const std::vector<std::vector<uint8_t>>& values) override; const std::vector<std::vector<uint8_t>>& values) override;
// Returns true if this store support append, multiGet and multiSet // Returns true if this store support append, multiGet and multiSet
bool hasExtendedApi() const override; bool hasExtendedApi() const override;
c10::intrusive_ptr<Store> getUnderlyingStore(); c10::intrusive_ptr<Store> getUnderlyingStore();
protected: protected:

View File

@ -366,14 +366,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
at::Tensor& outputBuffer, at::Tensor& outputBuffer,
at::Tensor& inputBuffer, at::Tensor& inputBuffer,
const ReduceScatterOptions& opts = ReduceScatterOptions()) { const ReduceScatterOptions& opts = ReduceScatterOptions()) {
static auto op = c10::Dispatcher::singleton() static auto op =
.findSchemaOrThrow("c10d::_reduce_scatter_base_", "") c10::Dispatcher::singleton()
.typed<std::tuple<at::Tensor, c10::intrusive_ptr<Work>>( .findSchemaOrThrow("c10d::_reduce_scatter_base_", "")
at::Tensor&, .typed<std::tuple<at::Tensor, c10::intrusive_ptr<Work>>(
at::Tensor&, at::Tensor&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&, at::Tensor&,
const c10::intrusive_ptr<::c10d::ReduceOp>&, const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t)>(); const c10::intrusive_ptr<::c10d::ReduceOp>&,
int64_t)>();
return std::get<1>(op.call( return std::get<1>(op.call(
outputBuffer, outputBuffer,
inputBuffer, inputBuffer,
@ -383,8 +384,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
} }
// This function is a coalesced version of `reduce_scatter_tensor` (currently // This function is a coalesced version of `reduce_scatter_tensor` (currently
// still named as `_reduce_scatter_base`). Each tensor in the vector corresponds to // still named as `_reduce_scatter_base`). Each tensor in the vector
// an input/output of one `reduce_scatter_tensor` operation. // corresponds to an input/output of one `reduce_scatter_tensor` operation.
virtual c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced( virtual c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced(
std::vector<at::Tensor>& outputTensors, std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors, std::vector<at::Tensor>& inputTensors,
@ -435,13 +436,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
std::vector<at::Tensor>& outputTensors, std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors, std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts = AllToAllOptions()) { const AllToAllOptions& opts = AllToAllOptions()) {
static auto op = c10::Dispatcher::singleton() static auto op =
.findSchemaOrThrow("c10d::alltoall_", "") c10::Dispatcher::singleton()
.typed<std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>( .findSchemaOrThrow("c10d::alltoall_", "")
const at::TensorList&, .typed<
const at::TensorList&, std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
const c10::intrusive_ptr<::c10d::ProcessGroup>&, const at::TensorList&,
int64_t)>(); const at::TensorList&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
int64_t)>();
return std::get<1>(op.call( return std::get<1>(op.call(
outputTensors, outputTensors,
inputTensors, inputTensors,
@ -570,8 +573,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
if (device.has_value()) { if (device.has_value()) {
// set device tensor from argument // set device tensor from argument
tensor = at::empty( tensor = at::empty(
{1}, {1}, at::TensorOptions().device(device.value()).dtype(at::kByte));
at::TensorOptions().device(device.value()).dtype(at::kByte));
} else if (backendType_ == c10d::ProcessGroup::BackendType::NCCL) { } else if (backendType_ == c10d::ProcessGroup::BackendType::NCCL) {
// set cuda tensor // set cuda tensor
tensor = at::empty( tensor = at::empty(

View File

@ -9,7 +9,6 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <gloo/rendezvous/store.h>
#include <gloo/algorithm.h> #include <gloo/algorithm.h>
#include <gloo/common/error.h> #include <gloo/common/error.h>
#include <gloo/context.h> #include <gloo/context.h>
@ -66,14 +65,15 @@ class TORCH_API ProcessGroupGloo : public Backend {
// operations using the new AsyncWork base class. Over time we will port // operations using the new AsyncWork base class. Over time we will port
// all operations and perform needed cleanup. // all operations and perform needed cleanup.
// //
// FIXME: This probably should be called WorkGloo since the work is executed in sync mode // FIXME: This probably should be called WorkGloo since the work is executed
// by a background thread. // in sync mode by a background thread.
class TORCH_API AsyncWork : public Work { class TORCH_API AsyncWork : public Work {
public: public:
explicit AsyncWork( explicit AsyncWork(
std::vector<std::vector<at::Tensor>> outputTensors, std::vector<std::vector<at::Tensor>> outputTensors,
const char* profilingTitle = nullptr, const char* profilingTitle = nullptr,
const c10::optional<std::vector<at::Tensor>>& inputTensors = c10::nullopt); const c10::optional<std::vector<at::Tensor>>& inputTensors =
c10::nullopt);
~AsyncWork() override = default; ~AsyncWork() override = default;
@ -129,40 +129,44 @@ class TORCH_API ProcessGroupGloo : public Backend {
} }
void wait( void wait(
const std::vector<std::string>& keys, const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) override { const std::chrono::milliseconds& timeout) override {
store_->wait(keys, timeout); store_->wait(keys, timeout);
} }
#ifdef GLOO_STORE_HAS_STORE_V2 #ifdef GLOO_STORE_HAS_STORE_V2
bool has_v2_support() override { bool has_v2_support() override {
return store_->hasExtendedApi(); return store_->hasExtendedApi();
}
std::vector<std::vector<char>> multi_get(const std::vector<std::string>& keys) override {
std::vector<std::vector<char>> res;
for(auto& value : store_->multiGet(keys)) {
res.emplace_back(std::vector<char>(value.begin(), value.end()));
} }
return res;
}
void multi_set(const std::vector<std::string>& keys, const std::vector<std::vector<char>>& values) override { std::vector<std::vector<char>> multi_get(
std::vector<std::vector<uint8_t>> u_values; const std::vector<std::string>& keys) override {
for(auto& value : values) { std::vector<std::vector<char>> res;
u_values.emplace_back(std::vector<uint8_t>(value.begin(), value.end())); for (auto& value : store_->multiGet(keys)) {
res.emplace_back(std::vector<char>(value.begin(), value.end()));
}
return res;
} }
store_->multiSet(keys, u_values);
}
void append(const std::string& key, const std::vector<char>& value) override { void multi_set(
std::vector<uint8_t> tmp(value.begin(), value.end()); const std::vector<std::string>& keys,
return store_->append(key, tmp); const std::vector<std::vector<char>>& values) override {
} std::vector<std::vector<uint8_t>> u_values;
for (auto& value : values) {
u_values.emplace_back(std::vector<uint8_t>(value.begin(), value.end()));
}
store_->multiSet(keys, u_values);
}
int64_t add(const std::string& key, int64_t value) override { void append(const std::string& key, const std::vector<char>& value)
return store_->add(key, value); override {
} std::vector<uint8_t> tmp(value.begin(), value.end());
return store_->append(key, tmp);
}
int64_t add(const std::string& key, int64_t value) override {
return store_->add(key, value);
}
#endif #endif
protected: protected:
@ -247,10 +251,10 @@ class TORCH_API ProcessGroupGloo : public Backend {
// Create ProcessGroupGloo instance. // Create ProcessGroupGloo instance.
static c10::intrusive_ptr<ProcessGroupGloo> createProcessGroupGloo( static c10::intrusive_ptr<ProcessGroupGloo> createProcessGroupGloo(
const c10::intrusive_ptr<Store>& store, const c10::intrusive_ptr<Store>& store,
int rank, int rank,
int size, int size,
std::chrono::milliseconds timeout); std::chrono::milliseconds timeout);
explicit ProcessGroupGloo( explicit ProcessGroupGloo(
const c10::intrusive_ptr<Store>& store, const c10::intrusive_ptr<Store>& store,

View File

@ -33,8 +33,7 @@ struct WorkEntry {
std::vector<at::Tensor>* srcPtr, std::vector<at::Tensor>* srcPtr,
std::vector<at::Tensor>* dstPtr, std::vector<at::Tensor>* dstPtr,
std::function<void(std::unique_ptr<WorkEntry>&)> run) std::function<void(std::unique_ptr<WorkEntry>&)> run)
: dst(dstPtr ? *dstPtr : std::vector<at::Tensor>()), : dst(dstPtr ? *dstPtr : std::vector<at::Tensor>()), run(std::move(run)) {
run(std::move(run)) {
if (srcPtr) { if (srcPtr) {
src = *srcPtr; src = *srcPtr;
} }
@ -72,8 +71,8 @@ struct WorkEntry {
// group. In other words, no more than 1 process group can be created globally. // group. In other words, no more than 1 process group can be created globally.
// //
// If you would like to use multiple ProcessGroupMPI, it requires your MPI // If you would like to use multiple ProcessGroupMPI, it requires your MPI
// implementation to have a thread support value of MPI_THREAD_MULTIPLE, that is, // implementation to have a thread support value of MPI_THREAD_MULTIPLE, that
// multiple threads may call MPI, with no restriction. // is, multiple threads may call MPI, with no restriction.
// //
// Also note that ProcessGroupMPI only supports a single Tensor operation. In // Also note that ProcessGroupMPI only supports a single Tensor operation. In
// other words, the size of the input Tensor vector should always be 1. // other words, the size of the input Tensor vector should always be 1.
@ -244,7 +243,8 @@ class TORCH_API ProcessGroupMPI : public Backend {
c10::intrusive_ptr<Work> enqueue( c10::intrusive_ptr<Work> enqueue(
std::unique_ptr<WorkEntry> entry, std::unique_ptr<WorkEntry> entry,
const char* profilingTitle = nullptr, const char* profilingTitle = nullptr,
const c10::optional<std::vector<at::Tensor>>& inputTensors = c10::nullopt); const c10::optional<std::vector<at::Tensor>>& inputTensors =
c10::nullopt);
bool stop_; bool stop_;

View File

@ -9,8 +9,8 @@
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/Backend.hpp> #include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp> #include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/UCCForNCCL.hpp> #include <torch/csrc/distributed/c10d/UCCForNCCL.hpp>
@ -46,11 +46,17 @@ constexpr const char* NCCL_BACKEND_NAME = "nccl";
// NoHandling: do not handle asynchronous NCCL errors // NoHandling: do not handle asynchronous NCCL errors
// TearDown: tear down process upon error, see `WorkNCCL::handleException` // TearDown: tear down process upon error, see `WorkNCCL::handleException`
// CleanUpOnly: just clean up collectives and abort communicators without tearing down process // CleanUpOnly: just clean up collectives and abort communicators without
// SkipCleanUp: (this is a temporary option and can be removed in future) tear // tearing down process SkipCleanUp: (this is a temporary option and can be
// down process without cleaning up NCCL communicators. This should be used as a // removed in future) tear down process without cleaning up NCCL communicators.
// last resort in case `ncclCommAbort` itself is hanging // This should be used as a last resort in case `ncclCommAbort` itself is
enum ErrorHandlingMode { NoHandling = 0, TearDown = 1, CleanUpOnly = 2, SkipCleanUp = 3 }; // hanging
enum ErrorHandlingMode {
NoHandling = 0,
TearDown = 1,
CleanUpOnly = 2,
SkipCleanUp = 3
};
#define SHOULD_CLEAN_UP(a) (a != NoHandling && a != SkipCleanUp) #define SHOULD_CLEAN_UP(a) (a != NoHandling && a != SkipCleanUp)
@ -62,7 +68,8 @@ enum ErrorHandlingMode { NoHandling = 0, TearDown = 1, CleanUpOnly = 2, SkipClea
// Instead, it stashes live references to those tensors until after // Instead, it stashes live references to those tensors until after
// user-facing streams are synced with comm streams. // user-facing streams are synced with comm streams.
// See stashed_for_allocator_safety_ below. // See stashed_for_allocator_safety_ below.
constexpr const char* TORCH_NCCL_AVOID_RECORD_STREAMS = "TORCH_NCCL_AVOID_RECORD_STREAMS"; constexpr const char* TORCH_NCCL_AVOID_RECORD_STREAMS =
"TORCH_NCCL_AVOID_RECORD_STREAMS";
// ProcessGroupNCCL implements NCCL bindings for c10d. // ProcessGroupNCCL implements NCCL bindings for c10d.
// //
@ -101,8 +108,7 @@ constexpr const char* TORCH_NCCL_AVOID_RECORD_STREAMS = "TORCH_NCCL_AVOID_RECORD
// // Now continue on other work in the current stream. // // Now continue on other work in the current stream.
class TORCH_API ProcessGroupNCCL : public Backend { class TORCH_API ProcessGroupNCCL : public Backend {
public: public:
class WorkNCCL : public Work, class WorkNCCL : public Work, public std::enable_shared_from_this<WorkNCCL> {
public std::enable_shared_from_this<WorkNCCL> {
public: public:
// Constructor takes a list of CUDA devices // Constructor takes a list of CUDA devices
WorkNCCL( WorkNCCL(
@ -159,7 +165,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Helper function that returns True if the WorkNCCL object has timed out // Helper function that returns True if the WorkNCCL object has timed out
// and False otherwise. // and False otherwise.
// In case of timeout, set exception on the WorkNCCL object. // In case of timeout, set exception on the WorkNCCL object.
bool checkTimeout(c10::optional<std::chrono::milliseconds> timeout = c10::nullopt); bool checkTimeout(
c10::optional<std::chrono::milliseconds> timeout = c10::nullopt);
std::vector<at::Tensor> result() override; std::vector<at::Tensor> result() override;
@ -281,8 +288,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
struct Options : Backend::Options { struct Options : Backend::Options {
// NOTE: timeout in ProcessGroupNCCL::Options denote the timeout for // NOTE: timeout in ProcessGroupNCCL::Options denote the timeout for
// operations. This is only used when blockingWait_ is enabled. // operations. This is only used when blockingWait_ is enabled.
explicit Options( explicit Options(bool is_high_priority_stream = false);
bool is_high_priority_stream = false);
// return intrusive_ptr of the object // return intrusive_ptr of the object
static c10::intrusive_ptr<Options> create( static c10::intrusive_ptr<Options> create(
@ -337,7 +343,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
} }
const std::string getBackendName() const override { const std::string getBackendName() const override {
return std::string(NCCL_BACKEND_NAME); return std::string(NCCL_BACKEND_NAME);
} }
void startCoalescing() override; void startCoalescing() override;
@ -456,7 +462,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::vector<at::Tensor>& tensors, std::vector<at::Tensor>& tensors,
int tag) override; int tag) override;
// Agrees on an initial sequence number for the whole group by having rank 0 // Agrees on an initial sequence number for the whole group by having rank 0
// create it and broadcast it to other ranks using the store. // create it and broadcast it to other ranks using the store.
void setSequenceNumberForGroup() override; void setSequenceNumberForGroup() override;
@ -497,7 +503,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
std::vector<at::Device> devices, std::vector<at::Device> devices,
int rank, int rank,
OpType opType, OpType opType,
const char* profilingTitle=nullptr, const char* profilingTitle = nullptr,
const c10::optional<std::vector<at::Tensor>>& inputs = c10::nullopt); const c10::optional<std::vector<at::Tensor>>& inputs = c10::nullopt);
virtual c10::intrusive_ptr<ProcessGroupNCCL::CoalescedWorkNCCL> virtual c10::intrusive_ptr<ProcessGroupNCCL::CoalescedWorkNCCL>
@ -583,7 +589,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
void destroyNCCLComms(const std::string& devNCCLCommMapKey); void destroyNCCLComms(const std::string& devNCCLCommMapKey);
// Watchdog's inside loop. // Watchdog's inside loop.
// Takes care of cleaning up completed work, and aborting upon failure or timeout. // Takes care of cleaning up completed work, and aborting upon failure or
// timeout.
void workCleanupLoop(); void workCleanupLoop();
// Desync debug helper // Desync debug helper

View File

@ -28,7 +28,7 @@ class TORCH_API ProcessGroupRoundRobin final : public ProcessGroup {
~ProcessGroupRoundRobin() override; ~ProcessGroupRoundRobin() override;
const std::string getBackendName() const override { const std::string getBackendName() const override {
return std::string(ROUND_ROBIN_BACKEND_NAME); return std::string(ROUND_ROBIN_BACKEND_NAME);
} }
c10::intrusive_ptr<Work> broadcast( c10::intrusive_ptr<Work> broadcast(

View File

@ -110,7 +110,7 @@ class TORCH_API ProcessGroupWrapper : public Backend {
c10::intrusive_ptr<Work> barrier( c10::intrusive_ptr<Work> barrier(
const BarrierOptions& opts = BarrierOptions()) override; const BarrierOptions& opts = BarrierOptions()) override;
c10::intrusive_ptr<Work> _reduce_scatter_base( c10::intrusive_ptr<Work> _reduce_scatter_base(
at::Tensor& outputBuffer, at::Tensor& outputBuffer,
at::Tensor& inputBuffer, at::Tensor& inputBuffer,
const ReduceScatterOptions& opts) override; const ReduceScatterOptions& opts) override;

View File

@ -1,8 +1,8 @@
#pragma once #pragma once
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp> #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/jit/python/pybind_utils.h> #include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/pybind.h>
namespace c10d { namespace c10d {
@ -25,19 +25,21 @@ class PyProcessGroup : public ProcessGroup {
} }
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override { c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
// We cannot use PYBIND11_OVERRIDE because: // We cannot use PYBIND11_OVERRIDE because:
// 1. We have to >MANUALLY< unwrap the PyFutureWrapper and // 1. We have to >MANUALLY< unwrap the PyFutureWrapper and
// 2. The python name is get_future // 2. The python name is get_future
pybind11::gil_scoped_acquire gil; pybind11::gil_scoped_acquire gil;
auto override = pybind11::get_override(static_cast<const Work *>(this), "get_future"); auto override =
pybind11::get_override(static_cast<const Work*>(this), "get_future");
if (override) { if (override) {
py::object o = override(); py::object o = override();
auto futWrapper = o.cast<std::shared_ptr<torch::jit::PythonFutureWrapper>>(); auto futWrapper =
return futWrapper->fut; o.cast<std::shared_ptr<torch::jit::PythonFutureWrapper>>();
} return futWrapper->fut;
}
return Work::getFuture(); return Work::getFuture();
} }
}; };

View File

@ -70,25 +70,23 @@ class TORCH_API Store : public torch::CustomClassHolder {
virtual void setTimeout(const std::chrono::milliseconds& timeout); virtual void setTimeout(const std::chrono::milliseconds& timeout);
// watchKey() is deprecated and no longer supported. // watchKey() is deprecated and no longer supported.
virtual void watchKey( virtual void watchKey(
const std::string& /* unused */, const std::string& /* unused */,
WatchKeyCallback /* unused */) { WatchKeyCallback /* unused */) {
TORCH_CHECK( TORCH_CHECK(false, "watchKey is deprecated, no implementation support it.");
false,
"watchKey is deprecated, no implementation support it.");
} }
virtual void append( virtual void append(
const std::string& key, const std::string& key,
const std::vector<uint8_t>& value); const std::vector<uint8_t>& value);
virtual std::vector<std::vector<uint8_t>> multiGet(const std::vector<std::string>& keys); virtual std::vector<std::vector<uint8_t>> multiGet(
const std::vector<std::string>& keys);
virtual void multiSet( virtual void multiSet(
const std::vector<std::string>& keys, const std::vector<std::string>& keys,
const std::vector<std::vector<uint8_t>>& values); const std::vector<std::vector<uint8_t>>& values);
// Returns true if this store support append, multiGet and multiSet // Returns true if this store support append, multiGet and multiSet
virtual bool hasExtendedApi() const; virtual bool hasExtendedApi() const;

View File

@ -79,15 +79,15 @@ class TORCH_API TCPStore : public Store {
const std::vector<std::string>& keys, const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) override; const std::chrono::milliseconds& timeout) override;
void append( void append(const std::string& key, const std::vector<uint8_t>& value)
const std::string& key, override;
const std::vector<uint8_t>& value) override;
std::vector<std::vector<uint8_t>> multiGet(const std::vector<std::string>& keys) override; std::vector<std::vector<uint8_t>> multiGet(
const std::vector<std::string>& keys) override;
void multiSet( void multiSet(
const std::vector<std::string>& keys, const std::vector<std::string>& keys,
const std::vector<std::vector<uint8_t>>& values) override; const std::vector<std::vector<uint8_t>>& values) override;
bool hasExtendedApi() const override; bool hasExtendedApi() const override;

View File

@ -4,8 +4,8 @@
#include <thread> #include <thread>
#include <vector> #include <vector>
#include <torch/csrc/distributed/c10d/socket.h>
#include <torch/csrc/distributed/c10d/TCPStore.hpp> #include <torch/csrc/distributed/c10d/TCPStore.hpp>
#include <torch/csrc/distributed/c10d/socket.h>
#ifdef _WIN32 #ifdef _WIN32
#include <io.h> #include <io.h>
@ -49,18 +49,24 @@ class BackgroundThread {
void start(); void start();
bool stop_requested(); bool stop_requested();
protected: protected:
void dispose(); void dispose();
virtual void run() = 0; virtual void run() = 0;
virtual void stop() = 0; virtual void stop() = 0;
bool is_running() { return is_running_.load(); } bool is_running() {
return is_running_.load();
}
private: private:
std::atomic<bool> is_running_; std::atomic<bool> is_running_;
std::thread daemonThread_{}; std::thread daemonThread_{};
}; };
std::unique_ptr<BackgroundThread> create_tcpstore_backend(const TCPStoreOptions& opts); std::unique_ptr<BackgroundThread> create_tcpstore_backend(
std::unique_ptr<BackgroundThread> create_libuv_tcpstore_backend(const TCPStoreOptions& opts); const TCPStoreOptions& opts);
std::unique_ptr<BackgroundThread> create_libuv_tcpstore_backend(
const TCPStoreOptions& opts);
bool is_libuv_tcpstore_backend_available(); bool is_libuv_tcpstore_backend_available();
} // namespace detail } // namespace detail

View File

@ -5,8 +5,8 @@
#include <chrono> #include <chrono>
#include <cstdint> #include <cstdint>
#include <ATen/core/ivalue.h>
#include <ATen/core/Tensor.h> #include <ATen/core/Tensor.h>
#include <ATen/core/ivalue.h>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
#include <c10/util/intrusive_ptr.h> #include <c10/util/intrusive_ptr.h>
@ -50,12 +50,13 @@ struct TORCH_API ReduceOp : torch::CustomClassHolder {
ReduceOp(RedOpType op) : op_(op) { ReduceOp(RedOpType op) : op_(op) {
TORCH_INTERNAL_ASSERT( TORCH_INTERNAL_ASSERT(
op_ != PREMUL_SUM, op_ != PREMUL_SUM,
"Use `torch.distributed._make_nccl_premul_sum` to create an instance of ReduceOp with PREMUL_SUM" "Use `torch.distributed._make_nccl_premul_sum` to create an instance of ReduceOp with PREMUL_SUM");
);
} }
ReduceOp(RedOpType op, c10::intrusive_ptr<_SupplementBase> optional_supplement) { ReduceOp(
RedOpType op,
c10::intrusive_ptr<_SupplementBase> optional_supplement) {
if (optional_supplement.get()) { if (optional_supplement.get()) {
op_ = op; op_ = op;
} else { } else {
@ -63,10 +64,10 @@ struct TORCH_API ReduceOp : torch::CustomClassHolder {
} }
} }
// The heap resource supplement_, if it exists, is managed by a c10::intrusive_ptr, // The heap resource supplement_, if it exists, is managed by a
// so constructors and operator= can be simple // c10::intrusive_ptr, so constructors and operator= can be simple
ReduceOp(const ReduceOp& other) : ReduceOp(const ReduceOp& other)
op_(other.op_), supplement_(other.supplement_) {} : op_(other.op_), supplement_(other.supplement_) {}
const ReduceOp& operator=(const ReduceOp& other) { const ReduceOp& operator=(const ReduceOp& other) {
op_ = other.op_; op_ = other.op_;
@ -74,7 +75,9 @@ struct TORCH_API ReduceOp : torch::CustomClassHolder {
return *this; return *this;
} }
operator RedOpType() const { return op_; } operator RedOpType() const {
return op_;
}
bool operator==(const std::uint8_t other) { bool operator==(const std::uint8_t other) {
TORCH_INTERNAL_ASSERT(other < 9, "Invalid other op value"); TORCH_INTERNAL_ASSERT(other < 9, "Invalid other op value");
@ -101,7 +104,8 @@ struct TORCH_API ReduceOp : torch::CustomClassHolder {
c10::intrusive_ptr<_SupplementBase> supplement_; c10::intrusive_ptr<_SupplementBase> supplement_;
}; };
template<typename T> ReduceOp makeNCCLPreMulSum(const T& factor) { template <typename T>
ReduceOp makeNCCLPreMulSum(const T& factor) {
ReduceOp rop; ReduceOp rop;
rop.op_ = ReduceOp::PREMUL_SUM; rop.op_ = ReduceOp::PREMUL_SUM;
rop.supplement_ = c10::make_intrusive<NCCLPreMulSumSupplement>(factor); rop.supplement_ = c10::make_intrusive<NCCLPreMulSumSupplement>(factor);

View File

@ -1,25 +1,27 @@
#pragma once #pragma once
#include <string>
#include <vector>
#include <cassert> #include <cassert>
#include <memory> #include <memory>
#include <string>
#include <vector>
#include <ATen/DynamicLibrary.h> #include <ATen/DynamicLibrary.h>
namespace c10d { namespace c10d {
inline std::shared_ptr<at::DynamicLibrary> loadTorchUCC() { inline std::shared_ptr<at::DynamicLibrary> loadTorchUCC() {
const char *path = std::getenv("TORCH_UCC_LIBRARY_PATH"); const char* path = std::getenv("TORCH_UCC_LIBRARY_PATH");
if (path != nullptr) { if (path != nullptr) {
try { try {
return std::make_shared<at::DynamicLibrary>(path); return std::make_shared<at::DynamicLibrary>(path);
} catch (const c10::DynamicLibraryError &e) { } catch (const c10::DynamicLibraryError& e) {
TORCH_WARN("TORCH_UCC_LIBRARY_PATH is set, " TORCH_WARN(
"but the loading of torch_ucc.so failed with:", e.msg()); "TORCH_UCC_LIBRARY_PATH is set, "
"but the loading of torch_ucc.so failed with:",
e.msg());
} }
} }
return nullptr; return nullptr;
} }
} // namespace c10d } // namespace c10d

View File

@ -11,20 +11,20 @@ namespace c10d {
// Macro to generate the error message on a non-successful UCC return value. // Macro to generate the error message on a non-successful UCC return value.
#define TORCH_UCC_GET_ERROR_MSG(_err, _error_msg, _result) \ #define TORCH_UCC_GET_ERROR_MSG(_err, _error_msg, _result) \
do { \ do { \
_err = c10::str( \ _err = c10::str( \
"[", \ "[", \
std::string(__FILE__), \ std::string(__FILE__), \
":", \ ":", \
std::to_string(__LINE__), \ std::to_string(__LINE__), \
"] ", \ "] ", \
logger->getLogPrefix(), \ logger->getLogPrefix(), \
_error_msg, \ _error_msg, \
", error code ", \ ", error code ", \
_result, \ _result, \
": ", \ ": ", \
ucc_status_string(_result), \ ucc_status_string(_result), \
", system error code ", \ ", system error code ", \
errno); \ errno); \
} while (0) } while (0)
// Macro to throw on a non-successful UCC return value. // Macro to throw on a non-successful UCC return value.

View File

@ -7,12 +7,14 @@ namespace tcputil {
#define CONNECT_SOCKET_OFFSET 2 #define CONNECT_SOCKET_OFFSET 2
inline int poll(struct pollfd *fds, unsigned long nfds, int timeout) { inline int poll(struct pollfd* fds, unsigned long nfds, int timeout) {
return ::poll(fds, nfds, timeout); return ::poll(fds, nfds, timeout);
} }
inline void addPollfd(std::vector<struct pollfd> &fds, int socket, inline void addPollfd(
short events) { std::vector<struct pollfd>& fds,
int socket,
short events) {
fds.push_back({.fd = socket, .events = events}); fds.push_back({.fd = socket, .events = events});
} }

View File

@ -35,7 +35,8 @@ namespace c10d {
TORCH_API std::string parse_env(const char* env_var_name); TORCH_API std::string parse_env(const char* env_var_name);
// Retrieve tensor shapes from a given tensor. // Retrieve tensor shapes from a given tensor.
TORCH_API std::vector<at::Tensor> getTensorShapes(const std::vector<at::Tensor>& tensors); TORCH_API std::vector<at::Tensor> getTensorShapes(
const std::vector<at::Tensor>& tensors);
// Use -2 to represent unset state of env vars // Use -2 to represent unset state of env vars
#define C10D_ENV_NOT_SET -2 #define C10D_ENV_NOT_SET -2
@ -73,7 +74,9 @@ inline void assertSameType(
} }
} }
inline std::vector<std::string> split(char separator, const std::string& string) { inline std::vector<std::string> split(
char separator,
const std::string& string) {
std::vector<std::string> pieces; std::vector<std::string> pieces;
std::stringstream ss(string); std::stringstream ss(string);
std::string item; std::string item;
@ -90,7 +93,8 @@ inline int parseEnvVarInt(const char* envVarName) {
try { try {
val = std::stoi(stringValue); val = std::stoi(stringValue);
} catch (std::exception& e) { } catch (std::exception& e) {
TORCH_CHECK(false, TORCH_CHECK(
false,
"Invalid value for environment variable: " + std::string(envVarName)); "Invalid value for environment variable: " + std::string(envVarName));
} }
return val; return val;
@ -98,7 +102,9 @@ inline int parseEnvVarInt(const char* envVarName) {
return C10D_ENV_NOT_SET; return C10D_ENV_NOT_SET;
} }
inline const char* parseEnvVarString(const char* envVarName, const char* default_val) { inline const char* parseEnvVarString(
const char* envVarName,
const char* default_val) {
const char* val = std::getenv(envVarName); const char* val = std::getenv(envVarName);
if (val == nullptr) { if (val == nullptr) {
val = default_val; val = default_val;
@ -107,22 +113,23 @@ inline const char* parseEnvVarString(const char* envVarName, const char* default
} }
inline int parseEnvVarIntDefault(const char* envVarName, int defaultVal) { inline int parseEnvVarIntDefault(const char* envVarName, int defaultVal) {
int val = parseEnvVarInt(envVarName); int val = parseEnvVarInt(envVarName);
if (val == C10D_ENV_NOT_SET) if (val == C10D_ENV_NOT_SET)
return defaultVal; return defaultVal;
return val; return val;
} }
inline bool parseEnvVarFlag(const char* envVarName) { inline bool parseEnvVarFlag(const char* envVarName) {
int val = parseEnvVarInt(envVarName); int val = parseEnvVarInt(envVarName);
if (val == 1) { if (val == 1) {
return true; return true;
} else if (val == 0 || val == C10D_ENV_NOT_SET) { } else if (val == 0 || val == C10D_ENV_NOT_SET) {
return false;
}
TORCH_CHECK(false,
"Invalid value for environment variable: " + std::string(envVarName));
return false; return false;
}
TORCH_CHECK(
false,
"Invalid value for environment variable: " + std::string(envVarName));
return false;
} }
inline void assertSameSizes( inline void assertSameSizes(
@ -466,7 +473,7 @@ size_t computeLengthsAndOffsets(
equal_splits = true; equal_splits = true;
split_size = tensor.size(0) / group_size; split_size = tensor.size(0) / group_size;
} }
for(const auto i : c10::irange(group_size)) { for (const auto i : c10::irange(group_size)) {
size_t length = row_size * (equal_splits ? split_size : split_sizes[i]); size_t length = row_size * (equal_splits ? split_size : split_sizes[i]);
(*lengths)[i] = length; (*lengths)[i] = length;
(*offsets)[i] = offset; (*offsets)[i] = offset;
@ -483,7 +490,7 @@ size_t computeLengthsAndOffsets(
std::vector<T>* offsets) { std::vector<T>* offsets) {
size_t group_size = lengths->size(); size_t group_size = lengths->size();
size_t offset = 0; size_t offset = 0;
for(const auto i : c10::irange(group_size)) { for (const auto i : c10::irange(group_size)) {
size_t length = tensors[i].numel(); size_t length = tensors[i].numel();
(*lengths)[i] = length; (*lengths)[i] = length;
(*offsets)[i] = offset; (*offsets)[i] = offset;
@ -514,7 +521,7 @@ using SizeType = uint64_t;
continue; \ continue; \
} else if ( \ } else if ( \
errno_local == WSAETIMEDOUT || errno_local == WSAEWOULDBLOCK) { \ errno_local == WSAETIMEDOUT || errno_local == WSAEWOULDBLOCK) { \
TORCH_CHECK(false, "Socket Timeout"); \ TORCH_CHECK(false, "Socket Timeout"); \
} else { \ } else { \
throw std::system_error(errno_local, std::system_category()); \ throw std::system_error(errno_local, std::system_category()); \
} \ } \
@ -531,7 +538,7 @@ using SizeType = uint64_t;
if (errno == EINTR) { \ if (errno == EINTR) { \
continue; \ continue; \
} else if (errno == EAGAIN || errno == EWOULDBLOCK) { \ } else if (errno == EAGAIN || errno == EWOULDBLOCK) { \
TORCH_CHECK(false, "Socket Timeout"); \ TORCH_CHECK(false, "Socket Timeout"); \
} else { \ } else { \
throw std::system_error(errno, std::system_category()); \ throw std::system_error(errno, std::system_category()); \
} \ } \

View File

@ -7,12 +7,14 @@ namespace tcputil {
#define CONNECT_SOCKET_OFFSET 1 #define CONNECT_SOCKET_OFFSET 1
inline int poll(struct pollfd *fdArray, unsigned long fds, int timeout) { inline int poll(struct pollfd* fdArray, unsigned long fds, int timeout) {
return WSAPoll(fdArray, fds, timeout); return WSAPoll(fdArray, fds, timeout);
} }
inline void addPollfd(std::vector<struct pollfd> &fds, int socket, inline void addPollfd(
short events) { std::vector<struct pollfd>& fds,
int socket,
short events) {
fds.push_back({(SOCKET)socket, events}); fds.push_back({(SOCKET)socket, events});
} }

View File

@ -111,8 +111,7 @@ class TORCH_API CommHookInterface {
// Returns the resulting tensor once the communication hook result is // Returns the resulting tensor once the communication hook result is
// ready. The resulting tensor will then be copied to the grads of // ready. The resulting tensor will then be copied to the grads of
// individual parameters. // individual parameters.
virtual at::Tensor parseHookResult( virtual at::Tensor parseHookResult(const c10::IValue& result) = 0;
const c10::IValue& result) = 0;
}; };
namespace detail { namespace detail {

View File

@ -10,7 +10,8 @@ enum class BuiltinCommHookType {
FP16_COMPRESS = 2, FP16_COMPRESS = 2,
}; };
class AllReduceCommHook : public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> { class AllReduceCommHook
: public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> {
public: public:
explicit AllReduceCommHook(const c10::intrusive_ptr<ProcessGroup>& state) explicit AllReduceCommHook(const c10::intrusive_ptr<ProcessGroup>& state)
: CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {} : CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {}
@ -20,7 +21,8 @@ class AllReduceCommHook : public CppCommHookInterface<c10::intrusive_ptr<Process
c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override; c10::intrusive_ptr<c10::ivalue::Future> runHook(GradBucket& bucket) override;
}; };
class FP16CompressCommHook : public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> { class FP16CompressCommHook
: public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> {
public: public:
explicit FP16CompressCommHook(const c10::intrusive_ptr<ProcessGroup>& state) explicit FP16CompressCommHook(const c10::intrusive_ptr<ProcessGroup>& state)
: CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {} : CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {}
@ -32,12 +34,14 @@ class FP16CompressCommHook : public CppCommHookInterface<c10::intrusive_ptr<Proc
// Almost same as AllReduceCommHook, but without division inside the hook. // Almost same as AllReduceCommHook, but without division inside the hook.
// This enables the optimization of fusing copy and division and saves one scan // This enables the optimization of fusing copy and division and saves one scan
// over all the input parameters, when no communication hook is provided by the user. // over all the input parameters, when no communication hook is provided by the
// Only used internally and not released as a public built-in communication hook. // user. Only used internally and not released as a public built-in
// communication hook.
class _AllReduceBySumCommHook class _AllReduceBySumCommHook
: public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> { : public CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>> {
public: public:
explicit _AllReduceBySumCommHook(const c10::intrusive_ptr<ProcessGroup>& state) explicit _AllReduceBySumCommHook(
const c10::intrusive_ptr<ProcessGroup>& state)
: CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {} : CppCommHookInterface<c10::intrusive_ptr<ProcessGroup>>(state) {}
~_AllReduceBySumCommHook() override = default; ~_AllReduceBySumCommHook() override = default;

View File

@ -16,8 +16,7 @@ class TORCH_API Logger {
int output_device, int output_device,
bool broadcast_buffers, bool broadcast_buffers,
bool has_sync_bn, bool has_sync_bn,
bool static_graph bool static_graph);
);
void set_static_graph(); void set_static_graph();
@ -62,11 +61,7 @@ class TORCH_API Logger {
Timer::Event end_event); Timer::Event end_event);
// Set the absolute time of the event that has been recorded in reducer. // Set the absolute time of the event that has been recorded in reducer.
void set_event_time( void set_event_time(int64_t& event_time, Timer& timer, Timer::Event event);
int64_t& event_time,
Timer& timer,
Timer::Event event
);
// Set stats that can be collected only during // Set stats that can be collected only during
// training loop. It is called at the beginning of forward call // training loop. It is called at the beginning of forward call
// to record the run time stats of sampled iterations that previously ran. // to record the run time stats of sampled iterations that previously ran.
@ -97,7 +92,6 @@ class TORCH_API Logger {
// optimization. // optimization.
void log_if_graph_static(bool is_static); void log_if_graph_static(bool is_static);
private: private:
// ddp_logging_data_ is used to hold all the ddp related logging // ddp_logging_data_ is used to hold all the ddp related logging
// data fields. // data fields.

View File

@ -1,25 +1,25 @@
#pragma once #pragma once
#include <c10/core/ScalarType.h>
#include <atomic> #include <atomic>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <tuple> #include <tuple>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <c10/core/ScalarType.h>
#include <ATen/core/ivalue_inl.h> #include <ATen/core/ivalue_inl.h>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
#include <c10/util/intrusive_ptr.h> #include <c10/util/intrusive_ptr.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp> #include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/Utils.hpp> #include <torch/csrc/distributed/c10d/Utils.hpp>
#include <torch/csrc/distributed/c10d/comm.hpp> #include <torch/csrc/distributed/c10d/comm.hpp>
#include <torch/csrc/distributed/c10d/debug.h> #include <torch/csrc/distributed/c10d/debug.h>
#include <torch/csrc/distributed/c10d/reducer_timer.hpp>
#include <torch/csrc/distributed/c10d/default_comm_hooks.hpp> #include <torch/csrc/distributed/c10d/default_comm_hooks.hpp>
#include <torch/csrc/autograd/function.h> #include <torch/csrc/distributed/c10d/reducer_timer.hpp>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/variable.h>
#ifndef _WIN32 #ifndef _WIN32
#include <torch/csrc/distributed/autograd/context/context.h> #include <torch/csrc/distributed/autograd/context/context.h>
#endif #endif
@ -101,7 +101,9 @@ class TORCH_API Reducer {
// Informs reducer that optimizer is running in backward, so gradients // Informs reducer that optimizer is running in backward, so gradients
// don't need to be copied from buckets as the optimizer would've already // don't need to be copied from buckets as the optimizer would've already
// been applied. // been applied.
void set_optimizer_in_backward() { optim_in_backward_ = true; }; void set_optimizer_in_backward() {
optim_in_backward_ = true;
};
// Runs allreduce or installed communication hook given GradBucket instance. // Runs allreduce or installed communication hook given GradBucket instance.
c10::intrusive_ptr<c10::ivalue::Future> run_comm_hook( c10::intrusive_ptr<c10::ivalue::Future> run_comm_hook(
@ -109,7 +111,7 @@ class TORCH_API Reducer {
// Runs default allreduce hook. // Runs default allreduce hook.
c10::intrusive_ptr<c10::ivalue::Future> run_allreduce_hook( c10::intrusive_ptr<c10::ivalue::Future> run_allreduce_hook(
GradBucket& grad_bucket); GradBucket& grad_bucket);
// Returns gradient buckets in sequential order of buckets_. This is the order // Returns gradient buckets in sequential order of buckets_. This is the order
// in which buckets are reduced across processes. If return_zero_tensors=true, // in which buckets are reduced across processes. If return_zero_tensors=true,
@ -133,8 +135,8 @@ class TORCH_API Reducer {
void setSparseMetadata(std::map<std::string, at::Tensor>& metadata); void setSparseMetadata(std::map<std::string, at::Tensor>& metadata);
// Install futures that should be awaited at end of backwards. Currently these // Install futures that should be awaited at end of backwards. Currently these
// are only used by user-defined custom buffer reduction hooks, but can be generalized // are only used by user-defined custom buffer reduction hooks, but can be
// to any user-originating futures that need to be awaited. // generalized to any user-originating futures that need to be awaited.
void install_futures(c10::List<c10::intrusive_ptr<c10::ivalue::Future>> futs); void install_futures(c10::List<c10::intrusive_ptr<c10::ivalue::Future>> futs);
// Returns true if we should rebuild buckets, else false. We only rebuild // Returns true if we should rebuild buckets, else false. We only rebuild
@ -183,7 +185,8 @@ class TORCH_API Reducer {
// Removes autograd hooks registered by the Reducer on the model parameters. // Removes autograd hooks registered by the Reducer on the model parameters.
void remove_autograd_hooks(); void remove_autograd_hooks();
// Checks whether or not the reducer has finalized the current backward iteration. // Checks whether or not the reducer has finalized the current backward
// iteration.
void check_finalized(); void check_finalized();
protected: protected:
@ -248,9 +251,10 @@ class TORCH_API Reducer {
// Weak pointer to associated DDP logger. // Weak pointer to associated DDP logger.
std::weak_ptr<c10d::Logger> logger_; std::weak_ptr<c10d::Logger> logger_;
// List of futures installed by Reducer::install_futures that should be awaited // List of futures installed by Reducer::install_futures that should be
// at the end of backwards pass. // awaited at the end of backwards pass.
c10::optional<c10::List<c10::intrusive_ptr<c10::ivalue::Future>>> installed_futures_{c10::nullopt}; c10::optional<c10::List<c10::intrusive_ptr<c10::ivalue::Future>>>
installed_futures_{c10::nullopt};
// Mixed precision parameter dtype for bucket type checking. // Mixed precision parameter dtype for bucket type checking.
c10::optional<c10::ScalarType> mixed_precision_param_dtype_{c10::nullopt}; c10::optional<c10::ScalarType> mixed_precision_param_dtype_{c10::nullopt};
@ -273,7 +277,8 @@ class TORCH_API Reducer {
// bucket_index is a key to cache after buckets are rebuilt, after which this // bucket_index is a key to cache after buckets are rebuilt, after which this
// mapping never changes. // mapping never changes.
std::vector<at::Tensor> get_variables_for_bucket( std::vector<at::Tensor> get_variables_for_bucket(
size_t bucket_index, const Bucket& bucket) const; size_t bucket_index,
const Bucket& bucket) const;
// Asserts that the reduction for the previous iteration has finished before // Asserts that the reduction for the previous iteration has finished before
// rebuilding buckets or kicking off the next one. // rebuilding buckets or kicking off the next one.
@ -385,7 +390,6 @@ class TORCH_API Reducer {
// done on different CUDA streams. We record an event for every copy // done on different CUDA streams. We record an event for every copy
// so that we can synchronize with them prior to kicking off the reduction. // so that we can synchronize with them prior to kicking off the reduction.
// std::vector<at::cuda::CUDAEvent> events; // std::vector<at::cuda::CUDAEvent> events;
}; };
std::vector<Bucket> buckets_; std::vector<Bucket> buckets_;
@ -401,7 +405,9 @@ class TORCH_API Reducer {
VariableLocator() = default; VariableLocator() = default;
VariableLocator(size_t bucket_index_, size_t intra_bucket_index_) : bucket_index(bucket_index_), intra_bucket_index(intra_bucket_index_) {} VariableLocator(size_t bucket_index_, size_t intra_bucket_index_)
: bucket_index(bucket_index_),
intra_bucket_index(intra_bucket_index_) {}
}; };
// Map the index of a variable to its location in the bucket structure. // Map the index of a variable to its location in the bucket structure.
@ -409,10 +415,12 @@ class TORCH_API Reducer {
// track the number of iterations to synchronize grads in training so far. // track the number of iterations to synchronize grads in training so far.
long num_iterations_; long num_iterations_;
// track distinct iteration of backward call. This is distinct from num_iterations_, // track distinct iteration of backward call. This is distinct from
// for example in the case of multiple forward before backward. // num_iterations_, for example in the case of multiple forward before
// backward.
long num_bwd_calls_; long num_bwd_calls_;
// whether the first autograd hook for a distinct backward pass has been called. // whether the first autograd hook for a distinct backward pass has been
// called.
bool first_autograd_hook_called_; bool first_autograd_hook_called_;
// track the number of buckets that have been ready for // track the number of buckets that have been ready for
// communication calls like allReduce or communication hooks. // communication calls like allReduce or communication hooks.
@ -543,7 +551,8 @@ class TORCH_API Reducer {
// Cached bucket index to model parameter mapping. Populated after buckets // Cached bucket index to model parameter mapping. Populated after buckets
// are rebuilt after which this mapping is static. // are rebuilt after which this mapping is static.
mutable std::unordered_map<size_t, std::vector<at::Tensor>> cached_variables_for_bucket_; mutable std::unordered_map<size_t, std::vector<at::Tensor>>
cached_variables_for_bucket_;
bool optim_in_backward_{false}; bool optim_in_backward_{false};
friend class Logger; friend class Logger;

View File

@ -71,5 +71,10 @@ class TORCH_API Timer {
} }
}; };
TORCH_DECLARE_TYPED_REGISTRY(TimerRegistry, c10::DeviceType, Timer, std::unique_ptr, c10::Device); TORCH_DECLARE_TYPED_REGISTRY(
TimerRegistry,
c10::DeviceType,
Timer,
std::unique_ptr,
c10::Device);
} // namespace c10d } // namespace c10d

View File

@ -1,15 +1,15 @@
#pragma once #pragma once
#include <vector>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
#include <c10/util/Optional.h> #include <c10/util/Optional.h>
#include <c10/util/irange.h> #include <c10/util/irange.h>
#include <vector>
namespace c10d { namespace c10d {
const int kUnsetSeqNum = 0; const int kUnsetSeqNum = 0;
namespace { namespace {
constexpr int kByteOffset = 8; constexpr int kByteOffset = 8;
} }
// Converts from int to char vec to write in store // Converts from int to char vec to write in store