mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 14:34:54 +08:00
Compare commits
16 Commits
whc_flight
...
flight_5
| Author | SHA1 | Date | |
|---|---|---|---|
| b33a283e9a | |||
| 7a551d81e5 | |||
| 1515a90475 | |||
| 4882ec2a91 | |||
| 972b8060bd | |||
| 3e7683ae18 | |||
| f2e9ec2dc5 | |||
| dde4324d8e | |||
| 94c079104d | |||
| a6afee6d94 | |||
| d092857531 | |||
| 6aad5e444a | |||
| c54ce9313b | |||
| 1fe59f4ef7 | |||
| e693fb2bb1 | |||
| 4fe510baf6 |
@ -1732,7 +1732,7 @@ if(BUILD_TEST)
|
||||
foreach(test_src ${Caffe2_CPU_TEST_SRCS})
|
||||
get_filename_component(test_name ${test_src} NAME_WE)
|
||||
add_executable(${test_name} "${test_src}")
|
||||
target_link_libraries(${test_name} torch_library gtest_main)
|
||||
target_link_libraries(${test_name} torch_library gtest_main stdc++)
|
||||
target_include_directories(${test_name} PRIVATE $<INSTALL_INTERFACE:include>)
|
||||
target_include_directories(${test_name} PRIVATE $<BUILD_INTERFACE:${CMAKE_BINARY_DIR}/include>)
|
||||
target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE})
|
||||
|
||||
@ -4,9 +4,10 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed._tensor.placement_types import Shard
|
||||
from torch.distributed.checkpoint._state_dict_utils import (
|
||||
from torch.distributed._state_dict_utils import (
|
||||
_check_state_dict_similarity,
|
||||
_copy_state_dict,
|
||||
_create_cpu_state_dict,
|
||||
_gather_state_dict,
|
||||
_offload_state_dict_to_cpu,
|
||||
)
|
||||
@ -115,6 +116,58 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
}
|
||||
self.assertEqual(state_dict, _gather_state_dict(dist_state_dict))
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_create_cpu_state_dict(self):
|
||||
device = torch.device("cuda")
|
||||
buffer = io.BytesIO()
|
||||
torch.save(torch.ones(10), buffer)
|
||||
buffer.seek(0)
|
||||
state_dict = {
|
||||
"tensor1": torch.arange(10, device=device),
|
||||
"tensor2": torch.ones(10, device=device),
|
||||
"non_tensor_bytes_io": copy.deepcopy(buffer),
|
||||
"non_tensor_bytes": buffer.read(),
|
||||
"step": torch.tensor(7, dtype=torch.float),
|
||||
"lr": 1.5,
|
||||
"nested": {"list": [1, 2, 3, 4]},
|
||||
}
|
||||
|
||||
def _verify(cpu_state_dict):
|
||||
# Verify the correctness of _check_state_dict_similarity()
|
||||
self.assertTrue(_check_state_dict_similarity(state_dict, cpu_state_dict))
|
||||
tensor1 = cpu_state_dict["tensor1"]
|
||||
cpu_state_dict["tensor1"] = torch.arange(11)
|
||||
self.assertFalse(_check_state_dict_similarity(state_dict, cpu_state_dict))
|
||||
cpu_state_dict["tensor1"] = tensor1
|
||||
|
||||
_copy_state_dict(state_dict, cpu_state_dict)
|
||||
|
||||
# Verify if _copy_state_dict works
|
||||
for v in cpu_state_dict.values():
|
||||
if isinstance(v, torch.Tensor):
|
||||
self.assertFalse(v.is_cuda)
|
||||
self.assertEqual(cpu_state_dict["tensor1"], torch.arange(10))
|
||||
self.assertEqual(cpu_state_dict["tensor2"], torch.ones(10))
|
||||
buffer.seek(0)
|
||||
cpu_state_dict["non_tensor_bytes_io"].seek(0)
|
||||
self.assertEqual(
|
||||
cpu_state_dict["non_tensor_bytes_io"].read(), buffer.read()
|
||||
)
|
||||
buffer.seek(0)
|
||||
self.assertEqual(cpu_state_dict["non_tensor_bytes"], buffer.read())
|
||||
self.assertEqual(cpu_state_dict["lr"], 1.5)
|
||||
self.assertEqual(cpu_state_dict["step"], 7)
|
||||
self.assertEqual(cpu_state_dict["nested"], {"list": [1, 2, 3, 4]})
|
||||
|
||||
cpu_state_dict = _create_cpu_state_dict(state_dict, pin_memory=True)
|
||||
_verify(cpu_state_dict)
|
||||
cpu_state_dict = _create_cpu_state_dict(state_dict, share_memory=True)
|
||||
_verify(cpu_state_dict)
|
||||
cpu_state_dict = _create_cpu_state_dict(
|
||||
state_dict, share_memory=True, pin_memory=True
|
||||
)
|
||||
_verify(cpu_state_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -11,6 +11,7 @@ import tempfile
|
||||
import threading
|
||||
import pickle
|
||||
import time
|
||||
import json
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta
|
||||
@ -1334,6 +1335,19 @@ class ProcessGroupNCCLTest(MultiProcessTestCase):
|
||||
self.assertEqual(tensor, original_tensor)
|
||||
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_set_process_group_desc(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
device = torch.device(f'cuda:{self.rank}')
|
||||
pg_default = self._create_process_group_nccl(store, self.opts(), device_id=device)
|
||||
self.assertEqual(pg_default.group_desc, "default_pg")
|
||||
pg_1 = c10d.new_group([0, 1], group_desc="test_purpose")
|
||||
self.assertEqual(pg_1.group_desc, "test_purpose")
|
||||
pg_2 = c10d.new_group([0, 1])
|
||||
self.assertEqual(pg_2.group_desc, "undefined")
|
||||
|
||||
|
||||
class DistributedDataParallelTest(
|
||||
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
|
||||
):
|
||||
@ -3637,11 +3651,18 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
ver = t['version']
|
||||
self.assertEqual(ver, "1.0")
|
||||
self.assertEqual(ver, "1.1")
|
||||
t = t['entries']
|
||||
self.assertEqual(len(t), 2)
|
||||
last = t[-1]
|
||||
self.assertEqual(last['process_group'], ('0', 'default_pg'))
|
||||
self.assertEqual(last['state'], 'completed')
|
||||
s = last['time_discovered_started_ns']
|
||||
f = last['time_discovered_completed_ns']
|
||||
self.assertIsNotNone(f)
|
||||
if timing_enabled:
|
||||
self.assertIsNotNone(s)
|
||||
self.assertTrue(s <= f)
|
||||
self.assertIn('test_c10d_nccl.py', str(last['frames']))
|
||||
self.assertEqual(last['input_sizes'], ((3, 4),))
|
||||
self.assertEqual(last['output_sizes'], ((3, 4),))
|
||||
@ -3718,6 +3739,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
self.assertEqual(len(t), 10)
|
||||
first = t[0]
|
||||
last = t[-1]
|
||||
self.assertEqual(last['profiling_name'], 'nccl:all_reduce')
|
||||
self.assertEqual(last['state'], 'completed')
|
||||
self.assertIn('test_c10d_nccl.py', str(last['frames']))
|
||||
self.assertEqual(last['input_sizes'], ((3, 4),))
|
||||
@ -3750,6 +3772,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
e.synchronize()
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
t = t['entries']
|
||||
self.assertEqual(t[-1]['profiling_name'], 'nccl:all_reduce')
|
||||
if self.rank == 0:
|
||||
self.assertEqual(t[-1]['seq_id'], 1)
|
||||
self.assertEqual(t[-1]['state'], 'completed')
|
||||
@ -3792,12 +3815,14 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
time.sleep(5)
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
t = t['entries']
|
||||
self.assertEqual(t[-1]['profiling_name'], 'nccl:all_reduce')
|
||||
if self.rank == 0:
|
||||
self.assertEqual(t[-1]['seq_id'], 1)
|
||||
self.assertEqual(t[-1]['state'], 'completed')
|
||||
else:
|
||||
self.assertEqual(t[-1]['seq_id'], 2)
|
||||
self.assertEqual(t[-1]['state'], self.started_or_scheduled(timing_enabled))
|
||||
self.assertIsNone(t[-1]['time_discovered_completed_ns'])
|
||||
# this will eventually cause the missing rank 0
|
||||
# to continue which will unblock the non-zero ranks
|
||||
self.parent.send('next')
|
||||
@ -3851,9 +3876,10 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase):
|
||||
@skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_timeout_dumps(self, timing_enabled):
|
||||
# We need to completely disable the coordinated timeout dump to avoid rank 0
|
||||
# also timeout so that we set the check frequency to be very large (25 min).
|
||||
os.environ['TORCH_NCCL_COORD_CHECK_MILSEC'] = '1500000'
|
||||
# dump on heartbeatmonitor thread
|
||||
os.environ['TORCH_NCCL_COORD_CHECK_MILSEC'] = '1000'
|
||||
# need rank0 to crash before looking for its output file
|
||||
os.environ['TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC'] = '1'
|
||||
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
# wait for rank0 to crash before looking for its output file
|
||||
@ -3938,6 +3964,60 @@ class NCCLTraceTestTimeoutDumpOnStuckRanks(NCCLTraceTestDumpOnTimeoutBase):
|
||||
# getting the global signal to dump the debugging info.
|
||||
time.sleep(600)
|
||||
|
||||
class NcclErrorDumpTest(NCCLTraceTestBase):
|
||||
def _wait_process(self, rank, timeout):
|
||||
try:
|
||||
self.processes[rank].join(timeout)
|
||||
return self.processes[rank].exitcode
|
||||
except TimeoutError:
|
||||
return None
|
||||
|
||||
def _check_return_codes(self, elapsed_time):
|
||||
# the base test infra assumes processes exit with matching return codes,
|
||||
# but we want rank0 to abort with exception and rank1 to exit with exit 1
|
||||
self.assertEqual(self.processes[0].exitcode, -6)
|
||||
self.assertEqual(self.processes[1].exitcode, 1)
|
||||
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_rocm
|
||||
def test_nccl_errors_dump(self):
|
||||
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = '1000'
|
||||
os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = '1'
|
||||
# need rank0 to dump before abort
|
||||
os.environ['TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC'] = '5'
|
||||
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
# wait for both rank0 and 1 to crash before looking for dump
|
||||
self.assertEqual(self._wait_process(0, timeout=90), -6)
|
||||
self.assertEqual(self._wait_process(1, timeout=90), 1)
|
||||
# verify that the trace file exists for rank0
|
||||
self.assertTrue(os.path.exists(self._trace_name(rank=0)))
|
||||
return
|
||||
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
process_group = c10d.ProcessGroupNCCL(
|
||||
store,
|
||||
self.rank,
|
||||
self.world_size,
|
||||
timeout=timedelta(seconds=10),
|
||||
)
|
||||
process_group.allreduce(torch.rand(10).cuda(self.rank))
|
||||
if self.rank == 0:
|
||||
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
|
||||
# expect an error to be raised
|
||||
with self.assertRaisesRegex(dist.DistBackendError, ""):
|
||||
# Block the current stream on the NCCL stream
|
||||
work.wait()
|
||||
# Run some GPU operations
|
||||
a = torch.rand(10).cuda(self.rank)
|
||||
elif self.rank == 1:
|
||||
# Clean up structures (ex: files for FileStore before going down)
|
||||
del process_group
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert (
|
||||
|
||||
@ -463,6 +463,7 @@ class ProcessGroup:
|
||||
backend: Optional[ProcessGroup],
|
||||
) -> None: ...
|
||||
def _set_group_name(self, name: str) -> None: ...
|
||||
def _set_group_desc(self, desc: str) -> None: ...
|
||||
def name(self) -> str: ...
|
||||
def _has_hooks(self) -> bool: ...
|
||||
def _wait_for_pending_works(self) -> None: ...
|
||||
@ -471,6 +472,10 @@ class ProcessGroup:
|
||||
def bound_device_id(self) -> Optional[torch.device]: ...
|
||||
@bound_device_id.setter
|
||||
def bound_device_id(self, device: Optional[torch.device]) -> None: ...
|
||||
@property
|
||||
def group_name(self) -> str: ...
|
||||
@property
|
||||
def group_desc(self) -> str: ...
|
||||
|
||||
class ProcessGroupRoundRobin(ProcessGroup): ...
|
||||
|
||||
|
||||
@ -369,6 +369,14 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
return pg_name_;
|
||||
}
|
||||
|
||||
void setGroupDesc(const std::string& desc) {
|
||||
pg_desc_ = desc;
|
||||
}
|
||||
|
||||
const std::string& getGroupDesc() const {
|
||||
return pg_desc_;
|
||||
}
|
||||
|
||||
// See similar functions in ProcessGroup.hpp for context.
|
||||
c10::optional<at::Device> getBoundDeviceId() const {
|
||||
return bound_device_id_;
|
||||
@ -399,6 +407,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
// remains the same across use of this process group.
|
||||
DebugLevel dist_debug_level_;
|
||||
std::string pg_name_;
|
||||
std::string pg_desc_;
|
||||
|
||||
std::function<void(std::shared_ptr<WorkInfo>)> onCompletionHook_;
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <nccl.h>
|
||||
@ -282,6 +283,18 @@ class NCCLComm {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(IS_NCCL_EXP) && defined(NCCL_COMM_DUMP)
|
||||
std::unordered_map<std::string, std::string> ncclCommDump() {
|
||||
std::unordered_map<std::string, std::string> dump;
|
||||
if (isAborted()) {
|
||||
LOG(INFO) << "Communicator was aborted before trying to dump its state.";
|
||||
return dump;
|
||||
}
|
||||
C10D_NCCL_CHECK(::ncclCommDump(ncclComm_, dump), c10::nullopt);
|
||||
return dump;
|
||||
}
|
||||
#endif
|
||||
|
||||
ncclUniqueId getNcclId() {
|
||||
return ncclId_;
|
||||
}
|
||||
@ -337,6 +350,9 @@ class NCCLComm {
|
||||
// Set true failure reason if provided by ProcessGroupNCCL (e.g. work
|
||||
// timeout)
|
||||
commFailureReason_ = commFailureReason;
|
||||
LOG(INFO) << "Aborting ncclComm_ " << ncclComm_ << " with reason: "
|
||||
<< (commFailureReason ? *commFailureReason
|
||||
: "No abort reason provided.");
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_);
|
||||
#else
|
||||
@ -436,6 +452,8 @@ class NCCLComm {
|
||||
#endif
|
||||
}
|
||||
|
||||
friend class ProcessGroupNCCL;
|
||||
|
||||
protected:
|
||||
ncclComm_t ncclComm_;
|
||||
// Unique nccl_id for this communicator.
|
||||
|
||||
@ -165,6 +165,18 @@ void ProcessGroup::setGroupName(const std::string& name) {
|
||||
}
|
||||
}
|
||||
|
||||
const std::string& ProcessGroup::getGroupDesc() const {
|
||||
return pg_desc_;
|
||||
}
|
||||
|
||||
void ProcessGroup::setGroupDesc(const std::string& name) {
|
||||
pg_desc_ = name;
|
||||
// Also set the group desc for all backends
|
||||
for (auto& kv : deviceTypeToBackend_) {
|
||||
kv.second->setGroupDesc(name);
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroup::enableCollectivesTiming() {
|
||||
for (auto& kv : deviceTypeToBackend_) {
|
||||
kv.second->enableCollectivesTiming();
|
||||
|
||||
@ -694,6 +694,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
|
||||
const std::string& getGroupName() const;
|
||||
void setGroupName(const std::string& name);
|
||||
const std::string& getGroupDesc() const;
|
||||
void setGroupDesc(const std::string& name);
|
||||
void enableCollectivesTiming();
|
||||
|
||||
void release_resources() override;
|
||||
@ -724,6 +726,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const int size_;
|
||||
const c10::intrusive_ptr<Options> options_;
|
||||
const BackendType backendType_;
|
||||
std::string pg_desc_;
|
||||
|
||||
// Debug level setting. It is parsed once when ProcessGroup is constructed and
|
||||
// remains the same across use of this process group.
|
||||
|
||||
@ -4,13 +4,6 @@
|
||||
#include <mutex>
|
||||
#include <sstream>
|
||||
|
||||
#if defined(__linux__)
|
||||
#include <fcntl.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#ifdef USE_C10D_NCCL
|
||||
|
||||
#include <exception>
|
||||
@ -301,6 +294,9 @@ inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) {
|
||||
static std::unordered_map<std::shared_ptr<NCCLComm>, int> ncclCommDevIdxMap;
|
||||
static std::mutex ncclCommDevIdxMapMutex;
|
||||
static bool allocatorHooksAttached = false;
|
||||
|
||||
std::atomic<bool> ProcessGroupNCCL::shouldDump_(false);
|
||||
|
||||
void cacheAllocatorRegisterHook(
|
||||
const c10::cuda::CUDACachingAllocator::TraceEntry& te) {
|
||||
// Register after SEGMENT_ALLOC
|
||||
@ -337,9 +333,34 @@ void cacheAllocatorDeregisterHook(
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(IS_NCCL_EXP) && defined(NCCL_COMM_DUMP)
|
||||
std::string dump_nccl_trace() {
|
||||
return NCCLTraceBuffer::get()->dump();
|
||||
std::unordered_map<
|
||||
std::string /* ncclUniqueID */,
|
||||
std::unordered_map<std::string, std::string> /* dump from this comm */>
|
||||
ncclDumpMap;
|
||||
// dump_nccl_trace is only called from the default PG (uid_=0), but we want to
|
||||
// dump from all comms so we need to iterate over ncclCommDevIdxMap, which
|
||||
// is static
|
||||
std::vector<std::shared_ptr<NCCLComm>> allNCCLComms;
|
||||
// within the critical section, we don't want to dump while holding the lock
|
||||
// as dump might hang
|
||||
ncclCommDevIdxMapMutex.lock();
|
||||
for (auto& [ncclComm, _] : ncclCommDevIdxMap) {
|
||||
allNCCLComms.push_back(ncclComm);
|
||||
}
|
||||
ncclCommDevIdxMapMutex.unlock();
|
||||
for (auto& ncclComm : allNCCLComms) {
|
||||
std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId());
|
||||
ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump();
|
||||
}
|
||||
return NCCLTraceBuffer::get()->dump(ncclDumpMap);
|
||||
}
|
||||
#else
|
||||
std::string dump_nccl_trace() {
|
||||
return NCCLTraceBuffer::get()->dump(c10::nullopt);
|
||||
}
|
||||
#endif
|
||||
|
||||
c10::optional<std::function<std::string()>>& get_cpp_trace_dumper() {
|
||||
static c10::optional<std::function<std::string()>> dumper(c10::nullopt);
|
||||
@ -744,13 +765,17 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
ValueError,
|
||||
at::cuda::getNumGPUs() != 0,
|
||||
"ProcessGroupNCCL is only supported with GPUs, no GPUs found!");
|
||||
this->setGroupName(options_->group_name);
|
||||
logPrefix_ = createLogPrefix();
|
||||
blockingWait_ = getCvarBool(TORCH_NCCL_BLOCKING_WAIT, false);
|
||||
asyncErrorHandling_ = static_cast<ErrorHandlingMode>(
|
||||
getCvarInt(TORCH_NCCL_ASYNC_ERROR_HANDLING, 3 /*SkipCleanUp*/));
|
||||
desyncDebug_ = getCvarBool(TORCH_NCCL_DESYNC_DEBUG, false) ||
|
||||
(dist_debug_level_ >= DebugLevel::Detail);
|
||||
dumpOnTimeout_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, false) ||
|
||||
// TODO, we should either deprecate TORCH_NCCL_DUMP_ON_TIMEOUT
|
||||
// or change its name to reflect that dump happens on exception including
|
||||
// both timeout and other errors.
|
||||
dumpOnException_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, false) ||
|
||||
(dist_debug_level_ >= DebugLevel::Detail);
|
||||
heartbeat_ = 1ULL;
|
||||
monitorThreadEnabled_.store(getCvarBool(TORCH_NCCL_ENABLE_MONITORING, true));
|
||||
@ -827,7 +852,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
<< "NCCL version: " << getNcclVersion() << ", size: " << size
|
||||
<< ", global rank: " << globalRank()
|
||||
<< ", TORCH_NCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_
|
||||
<< ", TORCH_NCCL_DUMP_ON_TIMEOUT: " << dumpOnTimeout_
|
||||
<< ", TORCH_NCCL_DUMP_ON_TIMEOUT: " << dumpOnException_
|
||||
<< ", TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: "
|
||||
<< waitTimeoutDumpInMilSec_
|
||||
<< ", TORCH_NCCL_DESYNC_DEBUG: " << desyncDebug_
|
||||
@ -848,7 +873,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
<< ", TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_
|
||||
<< ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << ncclTraceBufferSize_
|
||||
<< ", TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_
|
||||
<< ", ID=" << this->getID();
|
||||
<< ", PG Name: " << options_->group_name;
|
||||
|
||||
RECORD_PARAM_COMMS(
|
||||
0, // seq
|
||||
@ -1086,6 +1111,55 @@ void ProcessGroupNCCL::waitForDumpOrTimeout(
|
||||
std::this_thread::sleep_until(wakeUpTime);
|
||||
}
|
||||
|
||||
// WHC - pulled this from
|
||||
// https://github.com/pytorch/pytorch/commit/893dcac068f13542b1e00e3e55bca4530ab412cb
|
||||
// to help cherry-pick go through. did not cherry-pick entirety of the PR that
|
||||
// provided this new util function.
|
||||
void ProcessGroupNCCL::waitForFutureOrTimeout(
|
||||
std::future<bool>& fut,
|
||||
const std::chrono::milliseconds& timeOutMilSec,
|
||||
const std::string& futDescription) {
|
||||
TORCH_CHECK(fut.valid(), "Expected a valid future");
|
||||
std::future_status status = fut.wait_for(timeOutMilSec);
|
||||
if (status == std::future_status::ready) {
|
||||
// Calling .get() will re-raise any exception from the future, and we don't
|
||||
// care about the retval
|
||||
try {
|
||||
bool result = fut.get();
|
||||
if (result) {
|
||||
LOG(INFO) << logPrefix()
|
||||
<< "future is successfully executed for: " << futDescription;
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
C10_THROW_ERROR(
|
||||
DistBackendError,
|
||||
c10::str(
|
||||
logPrefix(),
|
||||
"Exception thrown when waitng for future ",
|
||||
futDescription,
|
||||
": ",
|
||||
e.what()));
|
||||
} catch (...) {
|
||||
C10_THROW_ERROR(
|
||||
DistBackendError,
|
||||
c10::str(
|
||||
logPrefix(),
|
||||
"Unknown exception thrown when waitng for future ",
|
||||
futDescription));
|
||||
}
|
||||
} else {
|
||||
C10_THROW_ERROR(
|
||||
DistBackendError,
|
||||
c10::str(
|
||||
logPrefix(),
|
||||
"Future for ",
|
||||
futDescription,
|
||||
" timed out after ",
|
||||
timeOutMilSec.count(),
|
||||
" ms"));
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroupNCCL::abortCommsFromMap(
|
||||
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>&
|
||||
ncclCommsMap,
|
||||
@ -1097,6 +1171,8 @@ void ProcessGroupNCCL::abortCommsFromMap(
|
||||
auto& ncclComms = it.second;
|
||||
|
||||
for (const auto& ncclComm : ncclComms) {
|
||||
LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ "
|
||||
<< ncclComm->ncclComm_ << " on CUDA device: " << devName;
|
||||
ncclComm->ncclCommAbort(abortReason);
|
||||
}
|
||||
// Note that we don't remove the aborted communicators from the
|
||||
@ -1117,8 +1193,9 @@ void ProcessGroupNCCL::abortCommsFromMap(
|
||||
}
|
||||
}
|
||||
|
||||
LOG(INFO) << logPrefix() << "] Destroyed " << ncclComms.size()
|
||||
<< "communicators on CUDA device: " << devName
|
||||
LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroyed "
|
||||
<< ncclComms.size()
|
||||
<< " communicators on CUDA device: " << devName
|
||||
<< " with stream: " << streamId;
|
||||
}
|
||||
}
|
||||
@ -1226,12 +1303,18 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
uint64_t heartBeatCounter = 0ULL;
|
||||
std::string errorMsg;
|
||||
std::string exitMsg;
|
||||
bool checkTimeoutSignal = (dumpOnTimeout_ && uid_ == 0);
|
||||
int monitorPollInterval = checkTimeoutSignal ? coordCheckIntervalMilSec_
|
||||
: heartbeatTimeoutInSec_ * 1000;
|
||||
bool checkDumpSignal = (dumpOnException_ && uid_ == 0);
|
||||
int monitorPollInterval = checkDumpSignal ? coordCheckIntervalMilSec_
|
||||
: heartbeatTimeoutInSec_ * 1000;
|
||||
auto lastTimePollStore = std::chrono::steady_clock::now();
|
||||
auto lastTimeHeartBeatCheck = std::chrono::steady_clock::now();
|
||||
std::future<bool> asyncDebugDump;
|
||||
c10::optional<DumpPipe> dumpPipe = c10::nullopt;
|
||||
if (uid_ == 0) {
|
||||
// DumpPipe is one per-trainer process, and its convenient to name them
|
||||
// after 'global' ranks in the system, So we assume processgroup (uid)==0 is
|
||||
// the global PG and has globally unique rank ids across trainers.
|
||||
dumpPipe.emplace(rank_);
|
||||
}
|
||||
while (true) {
|
||||
// This won't have any lock since this lock is only used here.
|
||||
// Please be aware that mutex `monitorMutex_` should not be used
|
||||
@ -1254,7 +1337,28 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
// to see if any PG on any rank observed a timeout and signaled peers to
|
||||
// dump debugging info, and we avoid hammering the TCPStore from all PGs on
|
||||
// the same rank.
|
||||
if (checkTimeoutSignal) {
|
||||
if (checkDumpSignal) {
|
||||
// There are two scenarios where monitor thread will dump on timeout:
|
||||
// 1. The local rank is the first to observe a timeout.shouldDump_ will be
|
||||
// set to true.
|
||||
// 2. other ranks detected the timeout and signal the local rank to dump
|
||||
// In addtion, monitor threads will dump if watchdog threads has no
|
||||
// heartbeat or dumpPipe is not empty.
|
||||
if (shouldDump_.load()) {
|
||||
errorMsg = c10::str(
|
||||
logPrefix(),
|
||||
"Received a dump signal from this local rank and will ",
|
||||
"start to dump the debug info.");
|
||||
exitMsg = c10::str(
|
||||
"ProcessGroupNCCL's watchdog detected an exception from the local rank. ",
|
||||
"This is most likely caused by incorrect usages of collectives, e.g., wrong ",
|
||||
"sizes used across ranks, the order of collectives is not same for all ranks ",
|
||||
"or the scheduled collective, for some reason, didn't run. Additionally, ",
|
||||
"this can be caused by GIL deadlock or other reasons such as network errors or ",
|
||||
"bugs in the communications library (e.g. NCCL), etc. We tried our best to ",
|
||||
"dump the debug info into the storage to help you debug the issue.");
|
||||
break;
|
||||
}
|
||||
// We poll store to see if some ranks have flagged a timeout when
|
||||
// we haven't polled for `heartbeat_timeout` seconds and there haven't
|
||||
// any work added or removed for `watchdog_timeout` seconds.
|
||||
@ -1263,13 +1367,28 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
computeDeltaMS(lastTimePollStore, currentTime) >=
|
||||
coordCheckIntervalMilSec_) {
|
||||
lastTimePollStore = currentTime;
|
||||
if (globalStore_->check({std::string(TIMEOUT_DUMP)})) {
|
||||
if (globalStore_->check({std::string(EXCEPTION_DUMP)})) {
|
||||
int timeOutRank = -1;
|
||||
shouldDump_.store(true);
|
||||
try {
|
||||
auto vec = globalStore_->get(std::string(EXCEPTION_DUMP));
|
||||
TORCH_CHECK_WITH(
|
||||
DistBackendError,
|
||||
vec.size() == sizeof(int),
|
||||
"Invalid size for the timeout rank ID");
|
||||
std::memcpy(&timeOutRank, vec.data(), vec.size());
|
||||
} catch (const std::exception& e) {
|
||||
LOG(ERROR)
|
||||
<< "Failed to get timeout rank ID from the global store.";
|
||||
}
|
||||
errorMsg = c10::str(
|
||||
logPrefix(),
|
||||
"Received a global timeout from another rank and will ",
|
||||
"start to dump the debug info.");
|
||||
"Received a global dump signal from rank and will ",
|
||||
"start to dump the debug info. ");
|
||||
exitMsg = c10::str(
|
||||
"ProcessGroupNCCL's watchdog detected a collective timeout and notified current rank. ",
|
||||
"ProcessGroupNCCL's watchdog detected a dump signal from rank ",
|
||||
timeOutRank,
|
||||
" and notified the current rank. ",
|
||||
"This is most likely caused by incorrect usages of collectives, e.g., wrong ",
|
||||
"sizes used across ranks, the order of collectives is not same for all ranks ",
|
||||
"or the scheduled collective, for some reason, didn't run. Additionally, ",
|
||||
@ -1289,6 +1408,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
if (heartbeat != heartBeatCounter) {
|
||||
heartBeatCounter = heartbeat;
|
||||
} else {
|
||||
shouldDump_.store(true);
|
||||
// No heartbeat increase detected and timeout.
|
||||
errorMsg = c10::str(
|
||||
logPrefix(),
|
||||
@ -1310,6 +1430,14 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// process a request to dump the trace. only PG uid 0 will respond to dump
|
||||
// requests, but this is fine since all PG's feed into the same flight
|
||||
// recorder and dump. After dump, the training should continue.
|
||||
if (dumpPipe.has_value() && dumpPipe->shouldDump()) {
|
||||
// best effort dump, not waiting for the dump here
|
||||
std::future<bool> fut = std::async(
|
||||
std::launch::async, [this]() { return this->dumpDebuggingInfo(); });
|
||||
}
|
||||
}
|
||||
LOG(ERROR) << errorMsg;
|
||||
|
||||
@ -1318,10 +1446,16 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
LOG(INFO) << "Dumping c++ stacktraces: " << cpp_dumper.value()();
|
||||
}
|
||||
|
||||
auto wakeUpTime = getWakeupTime(waitTimeoutDumpInMilSec_);
|
||||
// Store debug info to storage if no other thread does it. (By default to
|
||||
// local disk)
|
||||
asyncDebugDump = launchAsyncDebugDump();
|
||||
std::future<bool> asyncDebugDump = std::async(
|
||||
std::launch::async, [this]() { return this->dumpDebuggingInfo(); });
|
||||
|
||||
// wait for the dump until timeout
|
||||
waitForFutureOrTimeout(
|
||||
asyncDebugDump,
|
||||
std::chrono::milliseconds(waitTimeoutDumpInMilSec_),
|
||||
"Flight recorder dump in heartbeatMonitor");
|
||||
|
||||
if (get_gil_checker() != nullptr) {
|
||||
auto fut = launchAsyncGilCheck();
|
||||
@ -1348,7 +1482,8 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
// Case two: desync might be slow or get stuck. Or we get stuck in
|
||||
// destructors, we will sleep for some time before calling std::abort() to
|
||||
// kill the whole process.
|
||||
if ((terminateProcessGroup_.load() || collectiveDebugInfoMode_.load()) &&
|
||||
if ((terminateProcessGroup_.load() || collectiveDebugInfoMode_.load() ||
|
||||
shouldDump_.load()) &&
|
||||
!terminateHeartbeatMonitorThread_.load()) {
|
||||
// Leave another two mins for desync report generation or process group
|
||||
// destroy.
|
||||
@ -1364,7 +1499,6 @@ void ProcessGroupNCCL::heartbeatMonitor() {
|
||||
// We already log completion inside the thread, so it may not be necessary to
|
||||
// check the return value here. We mainly use a future so we can exit early
|
||||
// if done.
|
||||
waitForDumpOrTimeout(asyncDebugDump, wakeUpTime);
|
||||
|
||||
if (!terminateHeartbeatMonitorThread_.load()) {
|
||||
// Create a error message reported from MonitorThread, so
|
||||
@ -1445,61 +1579,12 @@ std::string ProcessGroupNCCL::getNCCLWatchdogDebugInfo() {
|
||||
return retrieveDesyncReport(store_, "NCCL", rank_, size_);
|
||||
}
|
||||
|
||||
#if defined(__linux__)
|
||||
struct DumpPipe {
|
||||
DumpPipe(int rank) {
|
||||
std::string fileStem =
|
||||
getCvarString({"TORCH_NCCL_DEBUG_INFO_PIPE_FILE"}, "");
|
||||
if (fileStem.empty() ||
|
||||
getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0) <= 0) {
|
||||
return;
|
||||
}
|
||||
TORCH_CHECK(!fileStem.empty(), "TORCH_NCCL_DEBUG_INFO_TEMP_FILE is empty");
|
||||
std::string filename = c10::str(fileStem, rank, ".pipe");
|
||||
TORCH_CHECK(
|
||||
unlink(filename.c_str()) != -1 || errno == ENOENT,
|
||||
"Error removing existing named pipe ",
|
||||
filename);
|
||||
TORCH_CHECK(
|
||||
mkfifo(filename.c_str(), 0666) != -1,
|
||||
"Error creating named pipe ",
|
||||
filename);
|
||||
fd_ = open(filename.c_str(), O_RDONLY | O_NONBLOCK);
|
||||
LOG(INFO) << "Pipe file " << filename
|
||||
<< " has been opened, write to it to trigger NCCL Debug Dump.";
|
||||
TORCH_CHECK(fd_ != -1, "Error opening named pipe ", filename);
|
||||
}
|
||||
bool shouldDump() {
|
||||
if (fd_ == -1) {
|
||||
return false;
|
||||
}
|
||||
char buf[128];
|
||||
// non-blocking from O_NONBLOCK above.
|
||||
// Ignore EINTR because we already will poll this
|
||||
// again later.
|
||||
ssize_t bytesRead = read(fd_, &buf, 128);
|
||||
return bytesRead > 0;
|
||||
}
|
||||
~DumpPipe() {
|
||||
if (fd_ != -1) {
|
||||
close(fd_);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int fd_ = -1;
|
||||
};
|
||||
#else
|
||||
struct DumpPipe {
|
||||
DumpPipe(int rank) {}
|
||||
bool shouldDump() {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
std::string ProcessGroupNCCL::createLogPrefix() const {
|
||||
return c10::str("[PG ", uid_, " Rank ", rank_, "] ");
|
||||
if (!pg_desc_.empty() && pg_desc_ != "undefined") {
|
||||
return c10::str("[PG ", pg_name_, " (", pg_desc_, ") Rank ", rank_, "] ");
|
||||
} else {
|
||||
return c10::str("[PG ", pg_name_, " Rank ", rank_, "] ");
|
||||
}
|
||||
}
|
||||
|
||||
const std::string& ProcessGroupNCCL::logPrefix() const {
|
||||
@ -1514,17 +1599,8 @@ const int& ProcessGroupNCCL::globalRank() const {
|
||||
void ProcessGroupNCCL::watchdogHandler() {
|
||||
bool done = false;
|
||||
lastWorkListUpdateTime_ = std::chrono::steady_clock::now();
|
||||
c10::optional<std::future<bool>> optAsyncDebugDump;
|
||||
|
||||
std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList;
|
||||
|
||||
c10::optional<DumpPipe> dumpPipe = c10::nullopt;
|
||||
if (uid_ == 0) {
|
||||
// DumpPipe is one per-trainer process, and its convenient to name them
|
||||
// after 'global' ranks in the system, So we assume processgroup (uid)==0 is
|
||||
// the global PG and has globally unique rank ids across trainers.
|
||||
dumpPipe.emplace(rank_);
|
||||
}
|
||||
while (!done || !terminateProcessGroup_.load()) {
|
||||
std::unique_lock<std::mutex> lock(workMetaListMutex_);
|
||||
// We busy-poll the work vector every kWatchdogThreadSleepMillis
|
||||
@ -1544,6 +1620,28 @@ void ProcessGroupNCCL::watchdogHandler() {
|
||||
|
||||
// If work hits an exception (either an error or timeout)
|
||||
if (work.exception()) {
|
||||
// try to dump flight records if exception happens.
|
||||
// Flight recorder behavior should be independent of desync Debug
|
||||
if (dumpOnException_) {
|
||||
try {
|
||||
auto rank = globalRank();
|
||||
auto vec = std::vector<uint8_t>(
|
||||
reinterpret_cast<uint8_t*>(&rank),
|
||||
reinterpret_cast<uint8_t*>(&rank) + sizeof(rank));
|
||||
globalStore_->set(std::string(EXCEPTION_DUMP), vec);
|
||||
// signal the monitor thread to start dumping
|
||||
shouldDump_.store(true);
|
||||
// This sleep is used to give time for dumping before throwing
|
||||
// exception
|
||||
std::this_thread::sleep_for(
|
||||
std::chrono::seconds(heartbeatTimeoutInSec_));
|
||||
} catch (const std::exception& e) {
|
||||
LOG(ERROR) << logPrefix()
|
||||
<< "Failed to set dump signal in tcpstore. "
|
||||
<< "Error: " << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
if (SHOULD_CLEAN_UP(asyncErrorHandling_)) {
|
||||
// Abort work and corresponding communicators
|
||||
work.abort();
|
||||
@ -1554,40 +1652,22 @@ void ProcessGroupNCCL::watchdogHandler() {
|
||||
|
||||
// Report desync state in case of timeout
|
||||
if (timedOut) {
|
||||
try {
|
||||
if (desyncDebug_ || dumpOnTimeout_) {
|
||||
// Set shutdown mode, so the heartbeat monitor thread will not
|
||||
// abort process immediately.
|
||||
if (desyncDebug_) {
|
||||
try {
|
||||
collectiveDebugInfoMode_.store(true);
|
||||
std::vector<uint8_t> vec(1);
|
||||
globalStore_->set(std::string(TIMEOUT_DUMP), vec);
|
||||
}
|
||||
|
||||
auto wakeUpTime = getWakeupTime(waitTimeoutDumpInMilSec_);
|
||||
if (dumpOnTimeout_ && !optAsyncDebugDump) {
|
||||
// Store debug info to storage. (By default to local disk)
|
||||
optAsyncDebugDump = launchAsyncDebugDump();
|
||||
}
|
||||
|
||||
if (desyncDebug_) {
|
||||
auto desyncMsg = getNCCLWatchdogDebugInfo();
|
||||
LOG(ERROR) << logPrefix() << desyncMsg;
|
||||
} catch (const std::exception& e) {
|
||||
LOG(ERROR)
|
||||
<< logPrefix()
|
||||
<< "Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. "
|
||||
<< " Please file an issue. Error: " << e.what();
|
||||
} catch (...) {
|
||||
LOG(ERROR)
|
||||
<< logPrefix()
|
||||
<< "Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error."
|
||||
<< " Please file an issue.";
|
||||
}
|
||||
|
||||
if (dumpOnTimeout_) {
|
||||
// Store debug info to storage. (By default to local disk)
|
||||
waitForDumpOrTimeout(*optAsyncDebugDump, wakeUpTime);
|
||||
}
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
LOG(ERROR) << logPrefix()
|
||||
<< "Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. "
|
||||
<< " Please file an issue. Error: " << e.what();
|
||||
} catch (...) {
|
||||
LOG(ERROR)
|
||||
<< logPrefix()
|
||||
<< "Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error."
|
||||
<< " Please file an issue.";
|
||||
}
|
||||
}
|
||||
// Throw exception
|
||||
@ -1630,12 +1710,6 @@ void ProcessGroupNCCL::watchdogHandler() {
|
||||
// in case processing is slowed down (but not hung) by cuda api contention
|
||||
heartbeat_++;
|
||||
}
|
||||
// process a request to dump the trace. only PG uid 0 will respond to dump
|
||||
// requests, but this is fine since all PG's feed into the same flight
|
||||
// recorder and dump.
|
||||
if (dumpPipe.has_value() && dumpPipe->shouldDump()) {
|
||||
launchAsyncDebugDump();
|
||||
}
|
||||
done = workMetaList_.empty();
|
||||
}
|
||||
}
|
||||
@ -1882,6 +1956,15 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
|
||||
// Create the unique NCCL ID and broadcast it
|
||||
ncclUniqueId ncclID;
|
||||
|
||||
// reset log prefix to include group_desc
|
||||
logPrefix_ = createLogPrefix();
|
||||
|
||||
#ifdef NCCL_COMM_DESCRIPTION
|
||||
// Pass process group name and description to NCCL communicator
|
||||
std::string commDesc = pg_desc_ + ':' + pg_name_;
|
||||
options_->config.commDesc = strdup(commDesc.c_str());
|
||||
#endif
|
||||
|
||||
// For batch_isend_irecv, ncclGroupStart() would be called upfront
|
||||
bool batchP2P = ncclActiveGroupCounter_ > 0;
|
||||
bool singleP2POp = isP2POp(opType, batchP2P);
|
||||
@ -2004,6 +2087,13 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
|
||||
}
|
||||
#endif
|
||||
|
||||
for (const auto i : c10::irange(devices.size())) {
|
||||
int deviceIndex = devices[i].index();
|
||||
LOG(INFO) << logPrefix() << "ProcessGroupNCCL created ncclComm_ "
|
||||
<< ncclComms[i]->ncclComm_ << " on CUDA device: " << deviceIndex;
|
||||
}
|
||||
logPrefix_ = createLogPrefix(); // reset log prefix to include group_desc
|
||||
|
||||
// At this point NCCL should have been initialized, hence we can accurately
|
||||
// get the env value even if NCCL sets it by reading from nccl.conf file
|
||||
LOG(INFO) << logPrefix()
|
||||
@ -2052,18 +2142,17 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
|
||||
segmentInfo.total_size);
|
||||
}
|
||||
}
|
||||
|
||||
// Record the mapping between ncclComm and device index so that later
|
||||
// register hook can register a newly allocated segment to communicators
|
||||
// on the same device.
|
||||
// NOTE: we need remove the communicator from this map when it is
|
||||
// destroyed, otherwise may register onto an invalid communicator.
|
||||
ncclCommDevIdxMapMutex.lock();
|
||||
for (const auto i : c10::irange(devices.size())) {
|
||||
ncclCommDevIdxMap.emplace(ncclComms[i], devices[i].index());
|
||||
}
|
||||
ncclCommDevIdxMapMutex.unlock();
|
||||
}
|
||||
// Record the mapping between ncclComm and device index so that later
|
||||
// register hook can register a newly allocated segment to communicators
|
||||
// on the same device.
|
||||
// NOTE: we need remove the communicator from this map when it is
|
||||
// destroyed, otherwise may register onto an invalid communicator.
|
||||
ncclCommDevIdxMapMutex.lock();
|
||||
for (const auto i : c10::irange(devices.size())) {
|
||||
ncclCommDevIdxMap.emplace(ncclComms[i], devices[i].index());
|
||||
}
|
||||
ncclCommDevIdxMapMutex.unlock();
|
||||
}
|
||||
|
||||
it = devNCCLCommMap_.find(devicesKey);
|
||||
@ -2284,8 +2373,10 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
|
||||
enableTiming_.load());
|
||||
r->trace_id_ = NCCLTraceBuffer::get()->record(
|
||||
uid_,
|
||||
pg_name_,
|
||||
seq_,
|
||||
profilingTitle,
|
||||
// create a string copy of profilingTitle
|
||||
profilingTitle ? profilingTitle : "",
|
||||
inputs,
|
||||
outputs,
|
||||
r->ncclStartEvents_.get(),
|
||||
|
||||
@ -1,5 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(__linux__)
|
||||
#include <fcntl.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#ifdef USE_C10D_NCCL
|
||||
|
||||
#include <atomic>
|
||||
@ -95,7 +102,7 @@ static std::vector<std::string> TORCH_NCCL_COORD_CHECK_MILSEC = {
|
||||
|
||||
constexpr const char* NCCL_BACKEND_NAME = "nccl";
|
||||
|
||||
constexpr const char* TIMEOUT_DUMP = "timeout_dump";
|
||||
constexpr const char* EXCEPTION_DUMP = "exception_dump";
|
||||
|
||||
constexpr auto kProcessGroupNCCLDefaultTimeout =
|
||||
std::chrono::milliseconds(10 * 60 * 1000);
|
||||
@ -134,6 +141,59 @@ static std::vector<std::string> TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK =
|
||||
{"TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK",
|
||||
"NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"};
|
||||
|
||||
#if defined(__linux__)
|
||||
struct DumpPipe {
|
||||
DumpPipe(int rank) {
|
||||
std::string fileStem =
|
||||
getCvarString({"TORCH_NCCL_DEBUG_INFO_PIPE_FILE"}, "");
|
||||
if (fileStem.empty() ||
|
||||
getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0) <= 0) {
|
||||
return;
|
||||
}
|
||||
TORCH_CHECK(!fileStem.empty(), "TORCH_NCCL_DEBUG_INFO_TEMP_FILE is empty");
|
||||
std::string filename = c10::str(fileStem, rank, ".pipe");
|
||||
TORCH_CHECK(
|
||||
unlink(filename.c_str()) != -1 || errno == ENOENT,
|
||||
"Error removing existing named pipe ",
|
||||
filename);
|
||||
TORCH_CHECK(
|
||||
mkfifo(filename.c_str(), 0666) != -1,
|
||||
"Error creating named pipe ",
|
||||
filename);
|
||||
fd_ = open(filename.c_str(), O_RDONLY | O_NONBLOCK);
|
||||
LOG(INFO) << "Pipe file " << filename
|
||||
<< " has been opened, write to it to trigger NCCL Debug Dump.";
|
||||
TORCH_CHECK(fd_ != -1, "Error opening named pipe ", filename);
|
||||
}
|
||||
bool shouldDump() {
|
||||
if (fd_ == -1) {
|
||||
return false;
|
||||
}
|
||||
char buf[128];
|
||||
// non-blocking from O_NONBLOCK above.
|
||||
// Ignore EINTR because we already will poll this
|
||||
// again later.
|
||||
ssize_t bytesRead = read(fd_, &buf, 128);
|
||||
return bytesRead > 0;
|
||||
}
|
||||
~DumpPipe() {
|
||||
if (fd_ != -1) {
|
||||
close(fd_);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int fd_ = -1;
|
||||
};
|
||||
#else
|
||||
struct DumpPipe {
|
||||
DumpPipe(int rank) {}
|
||||
bool shouldDump() {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
// ProcessGroupNCCL implements NCCL bindings for c10d.
|
||||
//
|
||||
// All functions of the class are expected to be called in the same order
|
||||
@ -384,6 +444,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// via `ncclCommSplit`
|
||||
std::shared_ptr<ProcessGroupNCCL> split_from;
|
||||
int64_t split_color{0};
|
||||
std::string group_name;
|
||||
};
|
||||
|
||||
// If you wish to create multiple process groups, each with a potentially
|
||||
@ -761,6 +822,12 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
const std::chrono::time_point<std::chrono::steady_clock>& wakeUpTime,
|
||||
size_t timeout_sec = 30);
|
||||
|
||||
// A helper function to wait for a future to complete or timeout.
|
||||
void waitForFutureOrTimeout(
|
||||
std::future<bool>& fut,
|
||||
const std::chrono::milliseconds& timeOutMilSec,
|
||||
const std::string& futDescription);
|
||||
|
||||
// When watchdog timeout, this function will be called and return debug info
|
||||
// for users. For now we only get information from retrieveDesyncReport.
|
||||
// We are working on enabling more useful debug information for watchdog
|
||||
@ -878,6 +945,15 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// Whether there are hooks pending to be fired
|
||||
std::atomic<bool> hasPendingHooks_;
|
||||
|
||||
// This is the signal from watchdog threads to indicate whether the monitor
|
||||
// thread should dump. Making it static so that it is accessiable from all the
|
||||
// PGs. With this flag, monitor thread would dump debug info under any one of
|
||||
// the 3 conditions: 1: this flag is set to true by the watchdog thread when
|
||||
// it detects a timeout. 2: timeout signal is received from
|
||||
// other ranks through tcpstore 3: no heartbeat of watchdog Note that only the
|
||||
// monitor thread from PG0 should dump the debug info and only once
|
||||
static std::atomic<bool> shouldDump_;
|
||||
|
||||
// Mutex to Guard workMetaList_
|
||||
std::mutex workMetaListMutex_;
|
||||
|
||||
@ -962,8 +1038,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// Whether or not to enable timeout root cause analysis.
|
||||
bool desyncDebug_;
|
||||
|
||||
// Whether or not to dump debug info on timeout
|
||||
bool dumpOnTimeout_;
|
||||
// Whether or not to dump debug info on exception including both watchdog
|
||||
// timeout and nccl errors.
|
||||
bool dumpOnException_;
|
||||
|
||||
// Whether or not to create start CUDAEvent and enable timing for start
|
||||
// and end events. Note that enableTiming_ is always true if desyncDebug_
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
#include <string>
|
||||
#include <system_error>
|
||||
#include <vector>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
/* Trace Utils Related to TORCH_NCCL_DESYNC_DEBUG */
|
||||
@ -352,6 +351,18 @@ inline c10::List<c10::IValue> new_list() {
|
||||
return c10::List<c10::IValue>(c10::AnyType::get());
|
||||
}
|
||||
|
||||
inline std::string ranks_str(const std::vector<uint64_t>& ranks) {
|
||||
std::string str;
|
||||
for (const auto& rank : ranks) {
|
||||
if (str.empty()) {
|
||||
str = std::to_string(rank);
|
||||
} else {
|
||||
str += ", " + std::to_string(rank);
|
||||
}
|
||||
}
|
||||
return c10::str("[", str, "]");
|
||||
}
|
||||
|
||||
struct NCCLTraceBuffer {
|
||||
static NCCLTraceBuffer* get() {
|
||||
// intentionally leak on exit
|
||||
@ -371,8 +382,9 @@ struct NCCLTraceBuffer {
|
||||
// buffer this entry will be located to
|
||||
// update state information
|
||||
size_t pg_id_;
|
||||
std::string pg_name_;
|
||||
size_t seq_id_; // as tracked by the process group
|
||||
const char* profiling_name_;
|
||||
std::string profiling_name_;
|
||||
|
||||
std::shared_ptr<torch::CapturedTraceback> traceback_;
|
||||
// we borrow pointers to start_ and end_ so we can query the state
|
||||
@ -385,7 +397,16 @@ struct NCCLTraceBuffer {
|
||||
c10::time_t time_created_;
|
||||
c10::optional<float> duration_;
|
||||
|
||||
const char* state_ = "scheduled";
|
||||
// timestamp when our CPU threads discovered that the kernel started.
|
||||
// will always be _after_ it actually started, and can be very late
|
||||
// if the watchdog thread got stuck on CUDA APIs.
|
||||
c10::optional<c10::time_t> time_discovered_started_;
|
||||
|
||||
// timestamp when our CPU threads discovered that the kernel completed.
|
||||
// will always be _after_ it actually complated, and can be the same time
|
||||
// as the discovery of the start if the watchdog thread is stuck on CUDA
|
||||
// APIs
|
||||
c10::optional<c10::time_t> time_discovered_completed_;
|
||||
|
||||
// size information for input/output tensors
|
||||
c10::SmallVector<int, 4> input_dims_;
|
||||
@ -405,8 +426,9 @@ struct NCCLTraceBuffer {
|
||||
|
||||
c10::optional<size_t> record(
|
||||
size_t pg_id,
|
||||
const std::string& pg_name,
|
||||
size_t seq_id,
|
||||
const char* profiling_name,
|
||||
std::string profiling_name,
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
const std::vector<at::Tensor>& outputs,
|
||||
EventList* start,
|
||||
@ -421,8 +443,9 @@ struct NCCLTraceBuffer {
|
||||
auto te = Entry{
|
||||
id_,
|
||||
pg_id,
|
||||
pg_name,
|
||||
seq_id,
|
||||
profiling_name == nullptr ? "" : profiling_name,
|
||||
std::move(profiling_name),
|
||||
std::move(traceback),
|
||||
std::move(start),
|
||||
std::move(end),
|
||||
@ -460,8 +483,8 @@ struct NCCLTraceBuffer {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (started) {
|
||||
r.state_ = "started";
|
||||
if (started && !r.time_discovered_started_) {
|
||||
r.time_discovered_started_ = c10::getTime();
|
||||
}
|
||||
}
|
||||
if (r.end_ != nullptr) {
|
||||
@ -472,8 +495,8 @@ struct NCCLTraceBuffer {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (completed) {
|
||||
r.state_ = "completed";
|
||||
if (completed && !r.time_discovered_completed_) {
|
||||
r.time_discovered_completed_ = c10::getTime();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -516,15 +539,15 @@ struct NCCLTraceBuffer {
|
||||
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
|
||||
auto& entry = entries_.at(*id % max_entries_);
|
||||
if (entry.id_ == *id) {
|
||||
update_state(entry);
|
||||
Entry* entry = &entries_.at(*id % max_entries_);
|
||||
if (entry->id_ == *id) {
|
||||
update_state(*entry);
|
||||
|
||||
if (compute_duration) {
|
||||
can_compute_duration = strcmp(entry.state_, "completed") == 0 &&
|
||||
entry.start_ && entry.end_;
|
||||
startEvents = entry.start_;
|
||||
endEvents = entry.end_;
|
||||
can_compute_duration = entry->time_discovered_completed_.has_value() &&
|
||||
entry->start_ && entry->end_;
|
||||
startEvents = entry->start_;
|
||||
endEvents = entry->end_;
|
||||
}
|
||||
}
|
||||
|
||||
@ -536,33 +559,38 @@ struct NCCLTraceBuffer {
|
||||
duration = getDurationFromFirstEvent(*startEvents, *endEvents);
|
||||
guard.lock();
|
||||
|
||||
// Refresh the entry ref, see if it has been overwritten
|
||||
entry = entries_.at(*id % max_entries_);
|
||||
if (entry.id_ != *id) {
|
||||
// Refresh the entry pointer, see if the entry has been overwritten
|
||||
entry = &entries_.at(*id % max_entries_);
|
||||
if (entry->id_ != *id) {
|
||||
LOG(INFO)
|
||||
<< "retire_id abandoned for id " << *id
|
||||
<< ", event was overwritten while waiting to compute duration.";
|
||||
return;
|
||||
}
|
||||
if (duration.has_value()) {
|
||||
entry.duration_ = duration.value();
|
||||
entry->duration_ = duration.value();
|
||||
}
|
||||
}
|
||||
|
||||
entry.retired_ = true;
|
||||
entry.start_ = entry.end_ = nullptr;
|
||||
entry->retired_ = true;
|
||||
entry->start_ = entry->end_ = nullptr;
|
||||
}
|
||||
|
||||
std::string dump() {
|
||||
std::string dump(
|
||||
const c10::optional<std::unordered_map<
|
||||
std::string,
|
||||
std::unordered_map<std::string, std::string>>>& ncclDumpMap) {
|
||||
auto result = dump_entries();
|
||||
auto entries = new_list();
|
||||
c10::IValue entries_key = "entries";
|
||||
c10::IValue nccl_comm_key = "nccl_comm_state";
|
||||
c10::IValue version_key = "version";
|
||||
// Update whenever changing contents or formatting of the dump
|
||||
// (minor when adding fields, major when changing existing fields)
|
||||
c10::IValue version_val = "1.0";
|
||||
c10::IValue version_val = "1.1";
|
||||
|
||||
c10::IValue pg_id_key = "pg_id";
|
||||
c10::IValue pg_name_key = "process_group";
|
||||
c10::IValue seq_id_key = "seq_id";
|
||||
c10::IValue profiling_name_key = "profiling_name";
|
||||
c10::IValue input_sizes_key = "input_sizes";
|
||||
@ -576,6 +604,8 @@ struct NCCLTraceBuffer {
|
||||
c10::IValue name_key = "name";
|
||||
c10::IValue filename_key = "filename";
|
||||
c10::IValue retired_key = "retired";
|
||||
c10::IValue time_discovered_started_key = "time_discovered_started_ns";
|
||||
c10::IValue time_discovered_completed_key = "time_discovered_completed_ns";
|
||||
|
||||
std::vector<torch::CapturedTraceback*> tracebacks;
|
||||
for (auto& e : result) {
|
||||
@ -596,6 +626,7 @@ struct NCCLTraceBuffer {
|
||||
auto& tb = stracebacks.tracebacks.at(i);
|
||||
auto dict = new_dict();
|
||||
dict.insert(pg_id_key, int64_t(e.pg_id_));
|
||||
dict.insert(pg_name_key, e.pg_name_);
|
||||
dict.insert(seq_id_key, int64_t(e.seq_id_));
|
||||
dict.insert(profiling_name_key, e.profiling_name_);
|
||||
dict.insert(time_created_key, int64_t(e.time_created_));
|
||||
@ -619,7 +650,24 @@ struct NCCLTraceBuffer {
|
||||
|
||||
dict.insert(input_sizes_key, read_sizes(e.input_dims_));
|
||||
dict.insert(output_sizes_key, read_sizes(e.output_dims_));
|
||||
dict.insert(state_key, e.state_);
|
||||
if (e.time_discovered_completed_.has_value()) {
|
||||
dict.insert(state_key, "completed");
|
||||
} else if (e.time_discovered_started_.has_value()) {
|
||||
dict.insert(state_key, "started");
|
||||
} else {
|
||||
dict.insert(state_key, "scheduled");
|
||||
}
|
||||
|
||||
dict.insert(
|
||||
time_discovered_started_key,
|
||||
e.time_discovered_started_.has_value()
|
||||
? int64_t(*e.time_discovered_started_)
|
||||
: c10::IValue());
|
||||
dict.insert(
|
||||
time_discovered_completed_key,
|
||||
e.time_discovered_completed_.has_value()
|
||||
? int64_t(*e.time_discovered_completed_)
|
||||
: c10::IValue());
|
||||
dict.insert(retired_key, e.retired_);
|
||||
|
||||
auto frames = new_list();
|
||||
@ -629,10 +677,24 @@ struct NCCLTraceBuffer {
|
||||
dict.insert(frames_key, frames);
|
||||
entries.push_back(dict);
|
||||
}
|
||||
// convert ncclDumpMap into a dictionary
|
||||
auto per_comm_dict = new_dict();
|
||||
if (ncclDumpMap.has_value()) {
|
||||
for (const auto& [ncclId, ncclDump] : ncclDumpMap.value()) {
|
||||
auto inner_dict = new_dict();
|
||||
for (const auto& [key, value] : ncclDump) {
|
||||
inner_dict.insert(key, value);
|
||||
}
|
||||
per_comm_dict.insert(ncclId, inner_dict);
|
||||
}
|
||||
}
|
||||
|
||||
auto dict = new_dict();
|
||||
dict.insert(entries_key, entries);
|
||||
dict.insert(version_key, version_val);
|
||||
if (per_comm_dict.size() > 0) {
|
||||
dict.insert(nccl_comm_key, per_comm_dict);
|
||||
}
|
||||
|
||||
return pickle_str(dict);
|
||||
}
|
||||
|
||||
@ -1863,6 +1863,15 @@ Arguments:
|
||||
"group_name",
|
||||
&::c10d::ProcessGroup::getGroupName,
|
||||
"(Gets this process group name. It's cluster unique)")
|
||||
.def(
|
||||
"_set_group_desc",
|
||||
&::c10d::ProcessGroup::setGroupDesc,
|
||||
py::call_guard<py::gil_scoped_acquire>(),
|
||||
"Sets the process group description. This is an internal C10D method, do not use.")
|
||||
.def_property_readonly(
|
||||
"group_desc",
|
||||
&::c10d::ProcessGroup::getGroupDesc,
|
||||
"Gets this process group description")
|
||||
.def_property(
|
||||
"bound_device_id",
|
||||
&::c10d::ProcessGroup::getBoundDeviceId,
|
||||
@ -2443,7 +2452,9 @@ Example::
|
||||
.def_readwrite(
|
||||
"split_from", &::c10d::ProcessGroupNCCL::Options::split_from)
|
||||
.def_readwrite(
|
||||
"split_color", &::c10d::ProcessGroupNCCL::Options::split_color);
|
||||
"split_color", &::c10d::ProcessGroupNCCL::Options::split_color)
|
||||
.def_readwrite(
|
||||
"group_name", &::c10d::ProcessGroupNCCL::Options::group_name);
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
448
torch/distributed/_state_dict_utils.py
Normal file
448
torch/distributed/_state_dict_utils.py
Normal file
@ -0,0 +1,448 @@
|
||||
import io
|
||||
import math
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||
|
||||
if dist.is_available() or TYPE_CHECKING:
|
||||
from torch.distributed import distributed_c10d
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
from torch.distributed._tensor import DTensor, Replicate
|
||||
|
||||
|
||||
def _identity_func(
|
||||
obj: torch.Tensor,
|
||||
pg: Optional[dist.ProcessGroup],
|
||||
device: Optional[torch.device],
|
||||
companion_obj: Any,
|
||||
) -> torch.Tensor:
|
||||
return obj
|
||||
|
||||
|
||||
def _all_gather_sharded_tensor(
|
||||
sharded_tensor: "ShardedTensor",
|
||||
pg: Optional[dist.ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
if pg is None:
|
||||
pg = distributed_c10d._get_default_group()
|
||||
world_size = dist.get_world_size(pg)
|
||||
shards = sharded_tensor.local_shards()
|
||||
dim_0_size = sharded_tensor.size()[0] # type: ignore[index]
|
||||
tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr]
|
||||
chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size
|
||||
pg_device = (
|
||||
distributed_c10d._get_pg_default_device(pg) if device is None else device
|
||||
)
|
||||
if shards:
|
||||
local_tensor = shards[0].tensor.flatten()
|
||||
if local_tensor.device.type != pg_device.type:
|
||||
local_tensor = local_tensor.to(pg_device)
|
||||
num_padding = chunk_size - local_tensor.numel()
|
||||
if num_padding > 0:
|
||||
local_tensor = F.pad(local_tensor, [0, num_padding])
|
||||
else:
|
||||
local_tensor = torch.zeros(
|
||||
chunk_size, dtype=sharded_tensor.dtype, device=pg_device
|
||||
)
|
||||
|
||||
tensor = torch.empty(
|
||||
chunk_size * world_size,
|
||||
dtype=local_tensor.dtype,
|
||||
device=pg_device,
|
||||
)
|
||||
dist.all_gather_into_tensor(tensor, local_tensor, group=pg)
|
||||
|
||||
tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
|
||||
return tensor
|
||||
|
||||
|
||||
class CompanionMismatch(Exception):
|
||||
...
|
||||
|
||||
|
||||
def _iterate_state_dict(
|
||||
iter_object: Any,
|
||||
sharded_tensor_func: Callable,
|
||||
dtensor_func: Callable,
|
||||
tensor_func: Callable,
|
||||
*,
|
||||
pg: Optional[dist.ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
cpu_offload: bool = False,
|
||||
companion_obj: Any = None,
|
||||
ranks_only: Tuple[int, ...] = tuple(),
|
||||
type_check: bool = True,
|
||||
non_blocking: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""Iterate through the state dict, applying the given functions to each tensor type.
|
||||
|
||||
Args:
|
||||
iter_object (Any): the target state_dict.
|
||||
sharded_tensor_func (Callable): the function to apply to ShardedTensor
|
||||
dtensor_func (Callable): the function to apply to DTensor
|
||||
tensor_func (Callable): the function to apply to Tensor
|
||||
pg (Optional[dist.ProcessGroup]): process group passed to tensor functions
|
||||
device (Optional[torch.device]): device passed to tensor functions
|
||||
cpu_offload (bool): whether to offload the tensors to CPU memory. This option is ignored
|
||||
if a companion_obj is supplied.
|
||||
companion_obj (Any): A companion object to the state dict. If this object
|
||||
is supplied, we attempt to copy the tensor to the companion object.
|
||||
ranks_only (Tuple[int, ...]): if this tuple is empty, all ranks will
|
||||
have the same state_dicts. Otherwise only ranks that in ``ranks_only``
|
||||
have the same state_dicts. Other ranks will get empty state_dicts.
|
||||
type_check (bool): check if the instance data type is a supported type
|
||||
that can be saved by DCP. The current supported data types are
|
||||
torch.Tensor, DTensor, int, float, str, list, dict, None.
|
||||
non_blocking (bool): whether to use non-blocking copy when copying to the companion object.
|
||||
"""
|
||||
# TODO: should we use pytree?
|
||||
cpu_device = torch.device("cpu")
|
||||
if isinstance(iter_object, ShardedTensor):
|
||||
ret = sharded_tensor_func(iter_object, pg, device, companion_obj)
|
||||
elif isinstance(iter_object, DTensor):
|
||||
ret = dtensor_func(iter_object, pg, device, companion_obj)
|
||||
elif isinstance(iter_object, torch.Tensor):
|
||||
ret = tensor_func(iter_object, pg, device, companion_obj)
|
||||
elif (
|
||||
isinstance(iter_object, (int, float, str, bytes, io.BytesIO))
|
||||
or iter_object is None
|
||||
):
|
||||
ret = iter_object
|
||||
elif isinstance(iter_object, dict):
|
||||
if companion_obj is not None and (
|
||||
not isinstance(companion_obj, dict)
|
||||
or set(companion_obj.keys()) != set(iter_object.keys())
|
||||
):
|
||||
raise CompanionMismatch()
|
||||
|
||||
ret = {
|
||||
key: _iterate_state_dict(
|
||||
value,
|
||||
sharded_tensor_func,
|
||||
dtensor_func,
|
||||
tensor_func,
|
||||
pg=pg,
|
||||
device=device,
|
||||
cpu_offload=cpu_offload,
|
||||
companion_obj=companion_obj[key] if companion_obj is not None else None,
|
||||
ranks_only=ranks_only,
|
||||
type_check=type_check,
|
||||
non_blocking=non_blocking,
|
||||
)
|
||||
for key, value in iter_object.items()
|
||||
}
|
||||
elif isinstance(iter_object, (list, tuple)):
|
||||
if companion_obj is not None and (
|
||||
not isinstance(companion_obj, (list, tuple))
|
||||
or len(companion_obj) != len(iter_object)
|
||||
):
|
||||
raise CompanionMismatch()
|
||||
|
||||
ret = [
|
||||
_iterate_state_dict(
|
||||
v,
|
||||
sharded_tensor_func,
|
||||
dtensor_func,
|
||||
tensor_func,
|
||||
pg=pg,
|
||||
device=device,
|
||||
cpu_offload=cpu_offload,
|
||||
companion_obj=companion_obj[idx] if companion_obj is not None else None,
|
||||
ranks_only=ranks_only,
|
||||
type_check=type_check,
|
||||
non_blocking=non_blocking,
|
||||
)
|
||||
for idx, v in enumerate(iter_object)
|
||||
]
|
||||
if isinstance(iter_object, tuple):
|
||||
ret = tuple(ret)
|
||||
elif not type_check:
|
||||
ret = iter_object
|
||||
else:
|
||||
raise ValueError(f"Unexpected value type {type(iter_object)}")
|
||||
|
||||
if not ranks_only or dist.get_rank(pg) in ranks_only:
|
||||
if isinstance(ret, torch.Tensor):
|
||||
if cpu_offload and companion_obj is None:
|
||||
ret = ret.to(cpu_device)
|
||||
|
||||
if companion_obj is not None:
|
||||
# TODO: support DTensor
|
||||
companion_obj.copy_(ret, non_blocking=non_blocking)
|
||||
ret = companion_obj
|
||||
else:
|
||||
ret = {} if isinstance(ret, dict) else None
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def _gather_state_dict(
|
||||
state_dict: Dict[str, Any],
|
||||
*,
|
||||
pg: Optional[dist.ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
cpu_offload: bool = False,
|
||||
ranks_only: Tuple[int, ...] = tuple(),
|
||||
type_check: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Given a state_dict, this API gathers all the ShardedTensors or DTensors in
|
||||
the state_dict.
|
||||
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Any]): the target sharded state_dict.
|
||||
pg (Optional[dist.ProcessGroup]): the process group that is used to
|
||||
gather ShardedTensor. Note that gathering a DTensor will use
|
||||
the DeviceMesh. So this argument will be ignored when gathering a
|
||||
DTensor.
|
||||
device: (Optional[torch.device]): the device that is used to
|
||||
perform allgather for ShardedTensor. Note that gathering a DTensor
|
||||
will use the DeviceMesh. So this argument will be ignored when
|
||||
gathering a DTensor.
|
||||
cpu_offload (bool): whether to offload the tensors to CPU memory. The
|
||||
default value is False.
|
||||
ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
|
||||
have the same state_dicts. Otherwise only ranks that in ``ranks_only``
|
||||
have the same state_dicts. Other ranks will get empty state_dicts.
|
||||
type_check: (bool): check if the instance data type is a supported type
|
||||
that can be saved by DCP. The current supported data types are
|
||||
torch.Tensor, DTensor, int, float, str, list, dict, None.
|
||||
|
||||
Returns:
|
||||
The gathered state dictionary.
|
||||
"""
|
||||
|
||||
def sharded_tensor_func(value, pg, device, companion_obj):
|
||||
# ShardedTensor does not seem to record the original device type.
|
||||
# So if the tensor is moved to CPU, we won't know the original type.
|
||||
# As a result, we have to rely on the user to tell us the correct one.
|
||||
cpu_device = torch.device("cpu")
|
||||
output_tensor = _all_gather_sharded_tensor(value, pg, device)
|
||||
local_shard_device = (
|
||||
value.local_shards()[0].tensor.device
|
||||
if value.local_shards()
|
||||
else cpu_device
|
||||
)
|
||||
if output_tensor.device != local_shard_device:
|
||||
value = output_tensor.to(local_shard_device)
|
||||
else:
|
||||
value = output_tensor
|
||||
return value
|
||||
|
||||
def dtensor_func(value, pg, device, companion_obj):
|
||||
if value.device != value.device_mesh.device_type:
|
||||
value = value.to(value.device_mesh.device_type)
|
||||
# FSDP all_gather: [Shard(0)] -> [Replicate()]
|
||||
# HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
|
||||
# 2D FSDP + TP all_gather:
|
||||
# - [Shard(0), Shard(n)] -> [Replicate(), Replicate()]
|
||||
# - [Shard(0), Replicate()] -> [Replicate(), Replicate()]
|
||||
placements = [Replicate() for _ in value.placements]
|
||||
value = value.redistribute(
|
||||
device_mesh=value.device_mesh,
|
||||
placements=placements,
|
||||
)
|
||||
# Call `wait()` to force the tensor to be synchronous with respect
|
||||
# to the main stream.
|
||||
# See the discussion in https://github.com/pytorch/pytorch/pull/117799.
|
||||
value = value.to_local()
|
||||
if isinstance(value, AsyncCollectiveTensor):
|
||||
value = value.wait()
|
||||
return value
|
||||
|
||||
return _iterate_state_dict(
|
||||
state_dict,
|
||||
sharded_tensor_func,
|
||||
dtensor_func,
|
||||
_identity_func,
|
||||
pg=pg,
|
||||
device=device,
|
||||
cpu_offload=cpu_offload,
|
||||
ranks_only=ranks_only,
|
||||
type_check=type_check,
|
||||
)
|
||||
|
||||
|
||||
def _offload_state_dict_to_cpu(
|
||||
state_dict: Dict[str, Any],
|
||||
*,
|
||||
ranks_only: Tuple[int, ...] = tuple(),
|
||||
type_check: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Given a state_dict, this API offload all the tensors to CPU memory.
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Any]): the target state_dict.
|
||||
pg (Optional[dist.ProcessGroup]): the process group that is used to
|
||||
gather ShardedTensor. Note that gathering a DTensor will use
|
||||
the DeviceMesh. So this argument will be ignored when gathering a
|
||||
DTensor.
|
||||
ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
|
||||
have the same state_dicts. Otherwise only ranks that in ``ranks_only``
|
||||
have the same state_dicts. Other ranks will get empty state_dicts.
|
||||
type_check: (bool): check if the instance data type is a supported type
|
||||
that can be saved by DCP. The current supported data types are
|
||||
torch.Tensor, DTensor, int, float, str, list, dict, None.
|
||||
|
||||
Returns:
|
||||
The gathered state dictionary.
|
||||
"""
|
||||
|
||||
ret = _iterate_state_dict(
|
||||
state_dict,
|
||||
_identity_func,
|
||||
_identity_func,
|
||||
_identity_func,
|
||||
pg=None,
|
||||
device=None,
|
||||
cpu_offload=True,
|
||||
ranks_only=ranks_only,
|
||||
type_check=type_check,
|
||||
)
|
||||
return ret
|
||||
|
||||
|
||||
def _copy_state_dict(
|
||||
state_dict: Dict[str, Any],
|
||||
copy_state_dict: Dict[str, Any],
|
||||
non_blocking: bool = False,
|
||||
):
|
||||
"""
|
||||
Copies all tensors in a given state dict into a different state_dict with the
|
||||
same structure.
|
||||
|
||||
.. warning::
|
||||
It is expected by this function that state_dict and copy_state_dict share
|
||||
the same structure and data types.
|
||||
|
||||
.. warning::
|
||||
The current supported data types are
|
||||
torch.Tensor, DTensor, int, float, str, list, dict, None.
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Any]): the target state_dict.
|
||||
copy_state_dict (Dict[str, Any]):
|
||||
The state dict we are copying into. This state_dict must have exactly
|
||||
the same structure as the source `state_dict`.
|
||||
non_blocking: (bool): Whether copy ops should be performed asynchronously
|
||||
"""
|
||||
|
||||
_iterate_state_dict(
|
||||
state_dict,
|
||||
_identity_func,
|
||||
_identity_func,
|
||||
_identity_func,
|
||||
pg=None,
|
||||
device=None,
|
||||
cpu_offload=False,
|
||||
ranks_only=tuple(),
|
||||
companion_obj=copy_state_dict,
|
||||
type_check=True,
|
||||
non_blocking=non_blocking,
|
||||
)
|
||||
|
||||
|
||||
def _create_cpu_state_dict(
|
||||
state_dict: Dict[str, Any], pin_memory: bool = False, share_memory: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Given a state_dict, create another state_dict with the same structure and elements.
|
||||
However, all tensors in the returned state_dict are new tensors on CPU. These
|
||||
tensors can be placed on pin_memory or share_memory based on the provided arguments.
|
||||
|
||||
.. warning::
|
||||
Setting both `pin_memory` and `share_memory` to True significantly increases the
|
||||
latency of this method because of the nuances which require us to register memory
|
||||
as pinned directly as opposed to relying on the pin_memory cache allocator. This
|
||||
option should only be used for long lived tensors which are required to be shared.
|
||||
This is not the case as long as at least one of `pin_memory` or `share_memory` is
|
||||
set to False.
|
||||
|
||||
"""
|
||||
|
||||
def tensor_func(
|
||||
obj: torch.Tensor,
|
||||
pg: Optional[dist.ProcessGroup],
|
||||
device: Optional[torch.device],
|
||||
_: Any,
|
||||
) -> torch.Tensor:
|
||||
if len(obj.size()) == 0:
|
||||
return torch.tensor(0, dtype=obj.dtype)
|
||||
|
||||
if share_memory:
|
||||
t = torch.empty(*tuple(obj.size()), dtype=obj.dtype).share_memory_()
|
||||
if pin_memory:
|
||||
succ = torch.cuda.cudart().cudaHostRegister(
|
||||
t.data_ptr(),
|
||||
t.numel() * t.element_size(),
|
||||
1, # lines up with 'cudaHostRegisterPortable'
|
||||
)
|
||||
assert (
|
||||
succ == 0
|
||||
), f"Pinning shared memory failed with error-code: {succ}"
|
||||
return t
|
||||
elif pin_memory:
|
||||
return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory()
|
||||
else:
|
||||
return torch.empty(*tuple(obj.size()), dtype=obj.dtype)
|
||||
|
||||
ret = _iterate_state_dict(
|
||||
state_dict,
|
||||
_identity_func,
|
||||
_identity_func,
|
||||
tensor_func,
|
||||
pg=None,
|
||||
device=None,
|
||||
cpu_offload=False,
|
||||
ranks_only=tuple(),
|
||||
type_check=False,
|
||||
)
|
||||
return ret
|
||||
|
||||
|
||||
def _check_state_dict_similarity(
|
||||
state_dict: Dict[str, Any],
|
||||
compared_state_dict: Dict[str, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Given two state_dicts, check if the structures are the same. And
|
||||
if a [key, tensor] pair exist in one state_dict there must be
|
||||
the a corresponding pait, [key, other_tensor], in the other state_dict,
|
||||
where tensor and other_tensor have the same size and dtype.
|
||||
|
||||
Return the check result.
|
||||
"""
|
||||
|
||||
def tensor_func(
|
||||
obj: torch.Tensor,
|
||||
pg: Optional[dist.ProcessGroup],
|
||||
device: Optional[torch.device],
|
||||
companion_obj: Any,
|
||||
) -> torch.Tensor:
|
||||
if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size():
|
||||
raise CompanionMismatch()
|
||||
return obj
|
||||
|
||||
try:
|
||||
_iterate_state_dict(
|
||||
state_dict,
|
||||
_identity_func,
|
||||
_identity_func,
|
||||
tensor_func,
|
||||
pg=None,
|
||||
device=None,
|
||||
cpu_offload=False,
|
||||
ranks_only=tuple(),
|
||||
companion_obj=compared_state_dict,
|
||||
type_check=False,
|
||||
)
|
||||
except CompanionMismatch:
|
||||
return False
|
||||
|
||||
return True
|
||||
@ -1171,7 +1171,7 @@ def init_process_group(
|
||||
)
|
||||
|
||||
default_pg, _ = _new_process_group_helper(
|
||||
-1, -1, [], backend, None, group_name, timeout=timeout
|
||||
-1, -1, [], backend, None, group_name, timeout=timeout, group_desc="default_pg"
|
||||
)
|
||||
_update_default_pg(default_pg)
|
||||
else:
|
||||
@ -1197,6 +1197,7 @@ def init_process_group(
|
||||
pg_options=pg_options,
|
||||
timeout=timeout,
|
||||
device_id=device_id,
|
||||
group_desc="default_pg"
|
||||
)
|
||||
_update_default_pg(default_pg)
|
||||
|
||||
@ -1257,6 +1258,7 @@ def _new_process_group_helper(
|
||||
timeout=None,
|
||||
pg_tag=None,
|
||||
device_id=None,
|
||||
group_desc=None,
|
||||
):
|
||||
"""
|
||||
Create a new distributed process group.
|
||||
@ -1289,6 +1291,8 @@ def _new_process_group_helper(
|
||||
_, prefix_store = _world.pg_map[existing_group]
|
||||
return existing_group, prefix_store
|
||||
|
||||
group_desc = "undefined" if group_desc is None else group_desc
|
||||
|
||||
# The list of group ranks is empty if we're creating the default group.
|
||||
is_default_group = len(global_ranks_in_group) == 0
|
||||
|
||||
@ -1375,6 +1379,7 @@ def _new_process_group_helper(
|
||||
if split_from:
|
||||
pg_options.split_from = split_from
|
||||
pg_options.split_color = _process_group_color(global_ranks_in_group)
|
||||
pg_options.group_name = group_name
|
||||
backend_class = ProcessGroupNCCL(
|
||||
backend_prefix_store, group_rank, group_size, pg_options)
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
@ -1461,9 +1466,11 @@ def _new_process_group_helper(
|
||||
|
||||
# update global state
|
||||
assert group_name is not None
|
||||
assert group_desc is not None
|
||||
_world.pg_map[pg] = (backend, prefix_store)
|
||||
_world.pg_names[pg] = group_name
|
||||
pg._set_group_name(group_name)
|
||||
pg._set_group_desc(group_desc)
|
||||
|
||||
_world.pg_backend_config[pg] = str(backend_config)
|
||||
# "" is the default tag for user PGs
|
||||
@ -3614,7 +3621,7 @@ def _get_backend_from_str(backend: Optional[str] = None) -> Backend:
|
||||
|
||||
|
||||
@_time_logger
|
||||
def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local_synchronization=False):
|
||||
def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local_synchronization=False, group_desc=None):
|
||||
"""
|
||||
Create a new distributed group.
|
||||
|
||||
@ -3655,6 +3662,7 @@ def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local
|
||||
barrier at the end of the process group creation. This is different
|
||||
in that non-member ranks don't need to call into API and don't
|
||||
join the barrier.
|
||||
group_desc (str, optional): a string to describe the process group.
|
||||
|
||||
Returns:
|
||||
A handle of distributed group that can be given to collective calls or None if the rank is not part of ``ranks``.
|
||||
@ -3669,7 +3677,15 @@ def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local
|
||||
multiple overlaping process groups. To avoid that, make sure all ranks follow the
|
||||
same global creation order.
|
||||
"""
|
||||
return _new_group_with_tag(ranks, timeout, backend, pg_options, None, use_local_synchronization=use_local_synchronization)
|
||||
return _new_group_with_tag(
|
||||
ranks,
|
||||
timeout,
|
||||
backend,
|
||||
pg_options,
|
||||
None,
|
||||
use_local_synchronization=use_local_synchronization,
|
||||
group_desc=group_desc,
|
||||
)
|
||||
|
||||
def _new_group_with_tag(
|
||||
ranks=None,
|
||||
@ -3677,7 +3693,8 @@ def _new_group_with_tag(
|
||||
backend=None,
|
||||
pg_options=None,
|
||||
pg_tag=None,
|
||||
use_local_synchronization=False
|
||||
use_local_synchronization=False,
|
||||
group_desc=None
|
||||
):
|
||||
"""
|
||||
Variant of ``new_group`` that exposes tag creation.
|
||||
@ -3749,7 +3766,8 @@ def _new_group_with_tag(
|
||||
group_name,
|
||||
pg_options=pg_options,
|
||||
timeout=timeout,
|
||||
pg_tag=pg_tag
|
||||
pg_tag=pg_tag,
|
||||
group_desc=group_desc
|
||||
)
|
||||
|
||||
# Create the global rank to group rank mapping
|
||||
@ -3789,6 +3807,7 @@ def new_subgroups(
|
||||
timeout=None,
|
||||
backend=None,
|
||||
pg_options=None,
|
||||
group_desc=None,
|
||||
):
|
||||
"""
|
||||
Create subgroups of equal size.
|
||||
@ -3841,6 +3860,8 @@ def new_subgroups(
|
||||
the construction of specific process groups. i.e. for the ``nccl``
|
||||
backend, ``is_high_priority_stream`` can be specified so that
|
||||
process group can pick up high priority cuda streams.
|
||||
group_desc (str, optional): A string describing the group. Each subgroup will
|
||||
inherit its group_desc
|
||||
|
||||
Returns:
|
||||
The subgroup containing the current rank, and all the subgroups used for cleanup.
|
||||
@ -3886,6 +3907,7 @@ def new_subgroups(
|
||||
timeout=timeout,
|
||||
backend=backend,
|
||||
pg_options=pg_options,
|
||||
group_desc=group_desc,
|
||||
)
|
||||
subgroups.append(subgroup)
|
||||
|
||||
@ -3905,6 +3927,7 @@ def new_subgroups_by_enumeration(
|
||||
timeout=None,
|
||||
backend=None,
|
||||
pg_options=None,
|
||||
group_desc=None,
|
||||
):
|
||||
"""
|
||||
Create subgroups by dividing the global world.
|
||||
@ -3945,6 +3968,8 @@ def new_subgroups_by_enumeration(
|
||||
the construction of specific process groups. i.e. for the ``nccl``
|
||||
backend, ``is_high_priority_stream`` can be specified so that
|
||||
process group can pick up high priority cuda streams.
|
||||
group_desc (str, optional): A string describing the group. Each subgroup will
|
||||
inherit its group_desc.
|
||||
|
||||
Returns:
|
||||
The subgroup containing the current rank, and all the subgroups used for cleanup.
|
||||
@ -3973,6 +3998,7 @@ def new_subgroups_by_enumeration(
|
||||
timeout=timeout,
|
||||
backend=backend,
|
||||
pg_options=pg_options,
|
||||
group_desc=group_desc,
|
||||
)
|
||||
subgroups.append(subgroup)
|
||||
my_rank = get_rank()
|
||||
|
||||
@ -28,7 +28,7 @@ def tail_logfile(
|
||||
return
|
||||
time.sleep(interval_sec)
|
||||
|
||||
with open(file) as fp:
|
||||
with open(file, errors="replace") as fp:
|
||||
while True:
|
||||
line = fp.readline()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user