mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 23:04:52 +08:00
Compare commits
55 Commits
Author | SHA1 | Date | |
---|---|---|---|
b33a283e9a | |||
7a551d81e5 | |||
1515a90475 | |||
4882ec2a91 | |||
972b8060bd | |||
3e7683ae18 | |||
f2e9ec2dc5 | |||
dde4324d8e | |||
94c079104d | |||
a6afee6d94 | |||
d092857531 | |||
6aad5e444a | |||
c54ce9313b | |||
1fe59f4ef7 | |||
e693fb2bb1 | |||
4fe510baf6 | |||
7c507b78c4 | |||
0019901601 | |||
18be18535b | |||
2729367313 | |||
33537aae24 | |||
dcdb1337dd | |||
9cf0f2bd59 | |||
1d2e877c05 | |||
f27b979b0c | |||
f30d6047ad | |||
75311510ef | |||
dbd6094d05 | |||
397b9d47e9 | |||
36a01a8ab9 | |||
ee336cf58a | |||
a9e2e745d7 | |||
ab4df89eea | |||
9d02ebe876 | |||
b61e01cce9 | |||
f7ce61ba53 | |||
e7bae15ab1 | |||
e71b422908 | |||
389940ce60 | |||
b2237a7c85 | |||
ef5dfe3f3e | |||
e303dc3c08 | |||
265efad2de | |||
60f0455905 | |||
4898313791 | |||
f4da9adf6b | |||
8f7f35273e | |||
44ec9612ed | |||
4d3bea2b29 | |||
0bcdddc3c1 | |||
28b6220312 | |||
210b7b65e2 | |||
4da10b5cd3 | |||
f09763814f | |||
80923ed5a6 |
@ -1732,7 +1732,7 @@ if(BUILD_TEST)
|
||||
foreach(test_src ${Caffe2_CPU_TEST_SRCS})
|
||||
get_filename_component(test_name ${test_src} NAME_WE)
|
||||
add_executable(${test_name} "${test_src}")
|
||||
target_link_libraries(${test_name} torch_library gtest_main)
|
||||
target_link_libraries(${test_name} torch_library gtest_main stdc++)
|
||||
target_include_directories(${test_name} PRIVATE $<INSTALL_INTERFACE:include>)
|
||||
target_include_directories(${test_name} PRIVATE $<BUILD_INTERFACE:${CMAKE_BINARY_DIR}/include>)
|
||||
target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE})
|
||||
|
@ -371,7 +371,8 @@ std::string readTraceFromFile(const std::string& filename, size_t size) {
|
||||
// Extend the nested class outside the parent class
|
||||
class TestDebugInfoWriter : public c10d::DebugInfoWriter {
|
||||
public:
|
||||
TestDebugInfoWriter() : DebugInfoWriter(0) {}
|
||||
TestDebugInfoWriter(std::string namePrefix)
|
||||
: DebugInfoWriter(namePrefix, 0) {}
|
||||
|
||||
void write(const std::string& ncclTrace) override {
|
||||
traces_.assign(ncclTrace.begin(), ncclTrace.end());
|
||||
@ -415,10 +416,12 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
|
||||
// The storer here is very similar to the fallback storer.
|
||||
// The only difference is that we are storing traces also in memory for
|
||||
// validation.
|
||||
std::string fileNamePrefix = c10d::getCvarString(
|
||||
{"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
|
||||
std::unique_ptr<TestDebugInfoWriter> wrterForTestPtr =
|
||||
std::make_unique<TestDebugInfoWriter>();
|
||||
std::make_unique<TestDebugInfoWriter>(fileNamePrefix);
|
||||
std::vector<uint8_t>& traces = wrterForTestPtr->getTraces();
|
||||
pg.registerDebugInfoWriter(std::move(wrterForTestPtr));
|
||||
c10d::DebugInfoWriter::registerWriter(std::move(wrterForTestPtr));
|
||||
|
||||
// Normal collective case.
|
||||
auto work = pg.allreduce(tensors_);
|
||||
@ -449,6 +452,9 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
|
||||
class ProcessGroupNCCLWatchdogTimeoutTest : public ProcessGroupNCCLErrorsTest {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
// TODO (kwen2501)
|
||||
GTEST_SKIP() << "Skipping tests under ProcessGroupNCCLWatchdogTimeoutTest; "
|
||||
<< "will rewrite them after refactoring Work queues.";
|
||||
ProcessGroupNCCLErrorsTest::SetUp();
|
||||
std::string timeInterval = std::to_string(heartBeatIntervalInSec);
|
||||
ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0);
|
||||
|
@ -4,9 +4,10 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._functional_collectives as funcol
|
||||
|
||||
from torch.distributed._tensor import DTensor
|
||||
from torch.distributed._tensor.placement_types import Shard
|
||||
from torch.distributed.checkpoint._state_dict_utils import (
|
||||
from torch.distributed._state_dict_utils import (
|
||||
_check_state_dict_similarity,
|
||||
_copy_state_dict,
|
||||
_create_cpu_state_dict,
|
||||
_gather_state_dict,
|
||||
_offload_state_dict_to_cpu,
|
||||
)
|
||||
@ -115,6 +116,58 @@ class TestStateDictUtils(DTensorTestBase):
|
||||
}
|
||||
self.assertEqual(state_dict, _gather_state_dict(dist_state_dict))
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_create_cpu_state_dict(self):
|
||||
device = torch.device("cuda")
|
||||
buffer = io.BytesIO()
|
||||
torch.save(torch.ones(10), buffer)
|
||||
buffer.seek(0)
|
||||
state_dict = {
|
||||
"tensor1": torch.arange(10, device=device),
|
||||
"tensor2": torch.ones(10, device=device),
|
||||
"non_tensor_bytes_io": copy.deepcopy(buffer),
|
||||
"non_tensor_bytes": buffer.read(),
|
||||
"step": torch.tensor(7, dtype=torch.float),
|
||||
"lr": 1.5,
|
||||
"nested": {"list": [1, 2, 3, 4]},
|
||||
}
|
||||
|
||||
def _verify(cpu_state_dict):
|
||||
# Verify the correctness of _check_state_dict_similarity()
|
||||
self.assertTrue(_check_state_dict_similarity(state_dict, cpu_state_dict))
|
||||
tensor1 = cpu_state_dict["tensor1"]
|
||||
cpu_state_dict["tensor1"] = torch.arange(11)
|
||||
self.assertFalse(_check_state_dict_similarity(state_dict, cpu_state_dict))
|
||||
cpu_state_dict["tensor1"] = tensor1
|
||||
|
||||
_copy_state_dict(state_dict, cpu_state_dict)
|
||||
|
||||
# Verify if _copy_state_dict works
|
||||
for v in cpu_state_dict.values():
|
||||
if isinstance(v, torch.Tensor):
|
||||
self.assertFalse(v.is_cuda)
|
||||
self.assertEqual(cpu_state_dict["tensor1"], torch.arange(10))
|
||||
self.assertEqual(cpu_state_dict["tensor2"], torch.ones(10))
|
||||
buffer.seek(0)
|
||||
cpu_state_dict["non_tensor_bytes_io"].seek(0)
|
||||
self.assertEqual(
|
||||
cpu_state_dict["non_tensor_bytes_io"].read(), buffer.read()
|
||||
)
|
||||
buffer.seek(0)
|
||||
self.assertEqual(cpu_state_dict["non_tensor_bytes"], buffer.read())
|
||||
self.assertEqual(cpu_state_dict["lr"], 1.5)
|
||||
self.assertEqual(cpu_state_dict["step"], 7)
|
||||
self.assertEqual(cpu_state_dict["nested"], {"list": [1, 2, 3, 4]})
|
||||
|
||||
cpu_state_dict = _create_cpu_state_dict(state_dict, pin_memory=True)
|
||||
_verify(cpu_state_dict)
|
||||
cpu_state_dict = _create_cpu_state_dict(state_dict, share_memory=True)
|
||||
_verify(cpu_state_dict)
|
||||
cpu_state_dict = _create_cpu_state_dict(
|
||||
state_dict, share_memory=True, pin_memory=True
|
||||
)
|
||||
_verify(cpu_state_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -11,6 +11,7 @@ import tempfile
|
||||
import threading
|
||||
import pickle
|
||||
import time
|
||||
import json
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta
|
||||
@ -1334,6 +1335,19 @@ class ProcessGroupNCCLTest(MultiProcessTestCase):
|
||||
self.assertEqual(tensor, original_tensor)
|
||||
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_set_process_group_desc(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
device = torch.device(f'cuda:{self.rank}')
|
||||
pg_default = self._create_process_group_nccl(store, self.opts(), device_id=device)
|
||||
self.assertEqual(pg_default.group_desc, "default_pg")
|
||||
pg_1 = c10d.new_group([0, 1], group_desc="test_purpose")
|
||||
self.assertEqual(pg_1.group_desc, "test_purpose")
|
||||
pg_2 = c10d.new_group([0, 1])
|
||||
self.assertEqual(pg_2.group_desc, "undefined")
|
||||
|
||||
|
||||
class DistributedDataParallelTest(
|
||||
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
|
||||
):
|
||||
@ -3542,11 +3556,12 @@ class SparseCollective(MultiProcessTestCase):
|
||||
class NCCLTraceTestBase(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
os.environ["TORCH_NCCL_ENABLE_TIMING"] = '0'
|
||||
os.environ["TORCH_NCCL_ENABLE_TIMING"] = '0' # see 'timing_enabled' parametrized tests
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = '10'
|
||||
os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = '1'
|
||||
self.tempdir = tempfile.TemporaryDirectory()
|
||||
os.environ["TORCH_NCCL_DEBUG_INFO_TEMP_FILE"] = self._trace_basename()
|
||||
os.environ["TORCH_NCCL_DEBUG_INFO_PIPE_FILE"] = self._trace_basename()
|
||||
self._spawn_processes()
|
||||
|
||||
@classmethod
|
||||
@ -3617,28 +3632,50 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
|
||||
def test_short(self):
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_short(self, timing_enabled):
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
return
|
||||
pg = self._create_process_group_nccl()
|
||||
if timing_enabled:
|
||||
pg._enable_collectives_timing()
|
||||
device = self.local_device
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
for i in range(2):
|
||||
f = pg.allreduce(a)
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
|
||||
# gah ok so now the duration_ms is populated best-effort since it can only happen outside "dump()" api
|
||||
time.sleep(1)
|
||||
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
ver = t['version']
|
||||
self.assertEqual(ver, "1.1")
|
||||
t = t['entries']
|
||||
self.assertEqual(len(t), 2)
|
||||
last = t[-1]
|
||||
self.assertEqual(last['process_group'], ('0', 'default_pg'))
|
||||
self.assertEqual(last['state'], 'completed')
|
||||
s = last['time_discovered_started_ns']
|
||||
f = last['time_discovered_completed_ns']
|
||||
self.assertIsNotNone(f)
|
||||
if timing_enabled:
|
||||
self.assertIsNotNone(s)
|
||||
self.assertTrue(s <= f)
|
||||
self.assertIn('test_c10d_nccl.py', str(last['frames']))
|
||||
self.assertEqual(last['input_sizes'], ((3, 4),))
|
||||
self.assertEqual(last['output_sizes'], ((3, 4),))
|
||||
self.assertEqual(last['seq_id'], 2)
|
||||
now = datetime.now()
|
||||
event_created_time = datetime.fromtimestamp(last['time_created_us'] / 1000000)
|
||||
event_created_time = datetime.fromtimestamp(last['time_created_ns'] / 1000000000)
|
||||
before_test = now - timedelta(minutes=1)
|
||||
self.assertTrue(before_test < event_created_time < now)
|
||||
if timing_enabled:
|
||||
# very loose bounds, measured 0.036 ms on devgpu
|
||||
self.assertTrue(0 < last['duration_ms'] < 100)
|
||||
else:
|
||||
self.assertTrue("duration_ms" not in last)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
|
||||
@ -3698,9 +3735,11 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
f.wait()
|
||||
torch.cuda.synchronize(device=device)
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
t = t['entries']
|
||||
self.assertEqual(len(t), 10)
|
||||
first = t[0]
|
||||
last = t[-1]
|
||||
self.assertEqual(last['profiling_name'], 'nccl:all_reduce')
|
||||
self.assertEqual(last['state'], 'completed')
|
||||
self.assertIn('test_c10d_nccl.py', str(last['frames']))
|
||||
self.assertEqual(last['input_sizes'], ((3, 4),))
|
||||
@ -3732,6 +3771,8 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
pg.allreduce(a).wait()
|
||||
e.synchronize()
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
t = t['entries']
|
||||
self.assertEqual(t[-1]['profiling_name'], 'nccl:all_reduce')
|
||||
if self.rank == 0:
|
||||
self.assertEqual(t[-1]['seq_id'], 1)
|
||||
self.assertEqual(t[-1]['state'], 'completed')
|
||||
@ -3773,12 +3814,15 @@ class NCCLTraceTest(NCCLTraceTestBase):
|
||||
# give the other thread some time to fill the cuda buffer
|
||||
time.sleep(5)
|
||||
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
|
||||
t = t['entries']
|
||||
self.assertEqual(t[-1]['profiling_name'], 'nccl:all_reduce')
|
||||
if self.rank == 0:
|
||||
self.assertEqual(t[-1]['seq_id'], 1)
|
||||
self.assertEqual(t[-1]['state'], 'completed')
|
||||
else:
|
||||
self.assertEqual(t[-1]['seq_id'], 2)
|
||||
self.assertEqual(t[-1]['state'], self.started_or_scheduled(timing_enabled))
|
||||
self.assertIsNone(t[-1]['time_discovered_completed_ns'])
|
||||
# this will eventually cause the missing rank 0
|
||||
# to continue which will unblock the non-zero ranks
|
||||
self.parent.send('next')
|
||||
@ -3832,6 +3876,10 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase):
|
||||
@skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
|
||||
@parametrize("timing_enabled", [True, False])
|
||||
def test_timeout_dumps(self, timing_enabled):
|
||||
# dump on heartbeatmonitor thread
|
||||
os.environ['TORCH_NCCL_COORD_CHECK_MILSEC'] = '1000'
|
||||
# need rank0 to crash before looking for its output file
|
||||
os.environ['TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC'] = '1'
|
||||
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
# wait for rank0 to crash before looking for its output file
|
||||
@ -3839,6 +3887,7 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase):
|
||||
self.assertEqual(self._wait_process(0, timeout=90), -6)
|
||||
with open(self._trace_name(rank=0), 'rb') as f:
|
||||
t = pickle.load(f)
|
||||
t = t['entries']
|
||||
self.assertEqual(len(t), 2)
|
||||
self.assertEqual(t[0]['seq_id'], 1)
|
||||
self.assertEqual(t[0]['state'], 'completed')
|
||||
@ -3868,7 +3917,7 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase):
|
||||
instantiate_parametrized_tests(NCCLTraceTestDumpOnTimeout)
|
||||
instantiate_parametrized_tests(NCCLTraceTest)
|
||||
|
||||
class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase):
|
||||
class NCCLTraceTestTimeoutDumpOnStuckRanks(NCCLTraceTestDumpOnTimeoutBase):
|
||||
def _check_return_codes(self, elapsed_time):
|
||||
# the base test infra assumes processes exit with matching return codes,
|
||||
# but we want rank0 to abort and rank1 to exit cleanly in this test
|
||||
@ -3877,7 +3926,7 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase):
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
|
||||
def test_timeout_dumps_on_idle_ranks(self):
|
||||
def test_timeout_dumps_on_stuck_ranks(self):
|
||||
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
# wait for both rank0 and 1 to crash before looking for both ranks' output
|
||||
@ -3888,20 +3937,17 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase):
|
||||
self.assertTrue(os.path.exists(self._trace_name(rank=0)))
|
||||
with open(self._trace_name(rank=0), 'rb') as f:
|
||||
t = pickle.load(f)
|
||||
t = t['entries']
|
||||
self.assertEqual(len(t), 2)
|
||||
with open(self._trace_name(rank=1), 'rb') as f:
|
||||
t = pickle.load(f)
|
||||
t = t['entries']
|
||||
self.assertEqual(len(t), 1)
|
||||
self.assertEqual(t[0]['seq_id'], 1)
|
||||
self.assertEqual(t[0]['state'], 'completed')
|
||||
return
|
||||
|
||||
# Set heartbeat timeout to a shorter one (default timeout is 2 min).
|
||||
os.environ[
|
||||
"TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"
|
||||
] = f"{NCCLTraceTestDumpOnTimeoutBase.timeout_sec * 2}"
|
||||
pg = self._create_process_group_nccl()
|
||||
|
||||
device = self.local_device
|
||||
with torch.cuda.device(device):
|
||||
a = torch.full((3, 4), float(self.rank), device=device)
|
||||
@ -3910,12 +3956,68 @@ class NCCLTraceTestTimeoutDumpOnIdleRanks(NCCLTraceTestDumpOnTimeoutBase):
|
||||
if self.rank == 0:
|
||||
pg.allreduce(a).wait()
|
||||
|
||||
# rank 0 will crash before it passes the sync, but rank1 will exit quickly and cleanly
|
||||
# rank 0 will get stuck, timeout and then signal a timeout to all ranks.
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Force rank 1 to idle so that it also gets debug info dump triggered.
|
||||
if self.rank == 1:
|
||||
time.sleep(6)
|
||||
# Force rank 1 to idle so that it will eventually timeout as well after
|
||||
# getting the global signal to dump the debugging info.
|
||||
time.sleep(600)
|
||||
|
||||
class NcclErrorDumpTest(NCCLTraceTestBase):
|
||||
def _wait_process(self, rank, timeout):
|
||||
try:
|
||||
self.processes[rank].join(timeout)
|
||||
return self.processes[rank].exitcode
|
||||
except TimeoutError:
|
||||
return None
|
||||
|
||||
def _check_return_codes(self, elapsed_time):
|
||||
# the base test infra assumes processes exit with matching return codes,
|
||||
# but we want rank0 to abort with exception and rank1 to exit with exit 1
|
||||
self.assertEqual(self.processes[0].exitcode, -6)
|
||||
self.assertEqual(self.processes[1].exitcode, 1)
|
||||
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_rocm
|
||||
def test_nccl_errors_dump(self):
|
||||
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
|
||||
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = '1000'
|
||||
os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = '1'
|
||||
# need rank0 to dump before abort
|
||||
os.environ['TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC'] = '5'
|
||||
|
||||
if self.rank == self.MAIN_PROCESS_RANK:
|
||||
# wait for both rank0 and 1 to crash before looking for dump
|
||||
self.assertEqual(self._wait_process(0, timeout=90), -6)
|
||||
self.assertEqual(self._wait_process(1, timeout=90), 1)
|
||||
# verify that the trace file exists for rank0
|
||||
self.assertTrue(os.path.exists(self._trace_name(rank=0)))
|
||||
return
|
||||
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
process_group = c10d.ProcessGroupNCCL(
|
||||
store,
|
||||
self.rank,
|
||||
self.world_size,
|
||||
timeout=timedelta(seconds=10),
|
||||
)
|
||||
process_group.allreduce(torch.rand(10).cuda(self.rank))
|
||||
if self.rank == 0:
|
||||
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
|
||||
# expect an error to be raised
|
||||
with self.assertRaisesRegex(dist.DistBackendError, ""):
|
||||
# Block the current stream on the NCCL stream
|
||||
work.wait()
|
||||
# Run some GPU operations
|
||||
a = torch.rand(10).cuda(self.rank)
|
||||
elif self.rank == 1:
|
||||
# Clean up structures (ex: files for FileStore before going down)
|
||||
del process_group
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert (
|
||||
|
@ -71,7 +71,7 @@ class StoreTestBase:
|
||||
def _create_store(self, i):
|
||||
raise RuntimeError("not implemented")
|
||||
|
||||
def _test_set_get(self, fs):
|
||||
def _test_set_get_check(self, fs):
|
||||
fs.add("key", 1)
|
||||
fs.add("key", 2)
|
||||
fs.add("key", 3)
|
||||
@ -90,14 +90,16 @@ class StoreTestBase:
|
||||
self.assertEqual(b"value1", fs.get("key1"))
|
||||
self.assertEqual(b"value2", fs.get("key2"))
|
||||
self.assertEqual(b"21", fs.get("key3"))
|
||||
self.assertTrue(fs.check(["key3"]))
|
||||
self.assertFalse(fs.check(["Randomkey3"]))
|
||||
|
||||
fs.set("-key3", "7")
|
||||
self.assertEqual(b"7", fs.get("-key3"))
|
||||
fs.delete_key("-key3")
|
||||
self.assertEqual(fs.num_keys(), self.num_keys_total)
|
||||
|
||||
def test_set_get(self):
|
||||
self._test_set_get(self._create_store())
|
||||
def test_set_get_check(self):
|
||||
self._test_set_get_check(self._create_store())
|
||||
|
||||
def _test_compare_set(self, store):
|
||||
missing_key_result = store.compare_set("cs_key0", "wrong_old_value", "new_value0")
|
||||
@ -441,6 +443,12 @@ class PrefixTCPStoreTest(TestCase, StoreTestBase):
|
||||
def num_keys_total(self):
|
||||
return 6
|
||||
|
||||
def test_underlying_non_prefix_store(self):
|
||||
store = self._create_store()
|
||||
wrapped_store = dist.PrefixStore(self.prefix, dist.PrefixStore(self.prefix, store))
|
||||
self.assertEqual(self.tcpstore, store._underlying_non_prefix_store)
|
||||
self.assertEqual(self.tcpstore, wrapped_store._underlying_non_prefix_store)
|
||||
|
||||
class MyPythonStore(dist.Store):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -463,6 +463,7 @@ class ProcessGroup:
|
||||
backend: Optional[ProcessGroup],
|
||||
) -> None: ...
|
||||
def _set_group_name(self, name: str) -> None: ...
|
||||
def _set_group_desc(self, desc: str) -> None: ...
|
||||
def name(self) -> str: ...
|
||||
def _has_hooks(self) -> bool: ...
|
||||
def _wait_for_pending_works(self) -> None: ...
|
||||
@ -471,6 +472,10 @@ class ProcessGroup:
|
||||
def bound_device_id(self) -> Optional[torch.device]: ...
|
||||
@bound_device_id.setter
|
||||
def bound_device_id(self, device: Optional[torch.device]) -> None: ...
|
||||
@property
|
||||
def group_name(self) -> str: ...
|
||||
@property
|
||||
def group_desc(self) -> str: ...
|
||||
|
||||
class ProcessGroupRoundRobin(ProcessGroup): ...
|
||||
|
||||
|
@ -369,6 +369,14 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
return pg_name_;
|
||||
}
|
||||
|
||||
void setGroupDesc(const std::string& desc) {
|
||||
pg_desc_ = desc;
|
||||
}
|
||||
|
||||
const std::string& getGroupDesc() const {
|
||||
return pg_desc_;
|
||||
}
|
||||
|
||||
// See similar functions in ProcessGroup.hpp for context.
|
||||
c10::optional<at::Device> getBoundDeviceId() const {
|
||||
return bound_device_id_;
|
||||
@ -399,6 +407,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
// remains the same across use of this process group.
|
||||
DebugLevel dist_debug_level_;
|
||||
std::string pg_name_;
|
||||
std::string pg_desc_;
|
||||
|
||||
std::function<void(std::shared_ptr<WorkInfo>)> onCompletionHook_;
|
||||
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <nccl.h>
|
||||
@ -175,14 +176,29 @@ std::string getNcclErrorDetailStr(
|
||||
c10::optional<std::string> processGroupFailureReason = c10::nullopt);
|
||||
|
||||
// Write NCCL debug info to local disk or any storage users define.
|
||||
// There are some constrains we set for the debug info writer:
|
||||
// 1. The writer should only be registered once.
|
||||
// 2. Once registered, users cannot change it including un-register.
|
||||
// 3. It is recommended to register the customized writer in the trainer setup,
|
||||
// If users don't register before calling launchAsyncDebugDump, then users
|
||||
// lose the chance to register (and the default writer will be
|
||||
// auto-registered).
|
||||
class TORCH_API DebugInfoWriter {
|
||||
public:
|
||||
DebugInfoWriter(int rank);
|
||||
virtual ~DebugInfoWriter();
|
||||
virtual void write(const std::string& ncclTrace);
|
||||
static DebugInfoWriter& getWriter(int rank);
|
||||
static void registerWriter(std::unique_ptr<DebugInfoWriter> writer);
|
||||
|
||||
protected:
|
||||
DebugInfoWriter(std::string namePrefix, int rank) {
|
||||
filename_ = c10::str(namePrefix, rank);
|
||||
}
|
||||
std::string filename_;
|
||||
|
||||
private:
|
||||
static std::unique_ptr<DebugInfoWriter> writer_;
|
||||
static std::atomic<bool> hasWriterRegistered_;
|
||||
};
|
||||
|
||||
// RAII wrapper for NCCL communicator
|
||||
@ -267,6 +283,18 @@ class NCCLComm {
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(IS_NCCL_EXP) && defined(NCCL_COMM_DUMP)
|
||||
std::unordered_map<std::string, std::string> ncclCommDump() {
|
||||
std::unordered_map<std::string, std::string> dump;
|
||||
if (isAborted()) {
|
||||
LOG(INFO) << "Communicator was aborted before trying to dump its state.";
|
||||
return dump;
|
||||
}
|
||||
C10D_NCCL_CHECK(::ncclCommDump(ncclComm_, dump), c10::nullopt);
|
||||
return dump;
|
||||
}
|
||||
#endif
|
||||
|
||||
ncclUniqueId getNcclId() {
|
||||
return ncclId_;
|
||||
}
|
||||
@ -322,6 +350,9 @@ class NCCLComm {
|
||||
// Set true failure reason if provided by ProcessGroupNCCL (e.g. work
|
||||
// timeout)
|
||||
commFailureReason_ = commFailureReason;
|
||||
LOG(INFO) << "Aborting ncclComm_ " << ncclComm_ << " with reason: "
|
||||
<< (commFailureReason ? *commFailureReason
|
||||
: "No abort reason provided.");
|
||||
#ifndef NCCL_HAS_COMM_NONBLOCKING
|
||||
C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_);
|
||||
#else
|
||||
@ -421,6 +452,8 @@ class NCCLComm {
|
||||
#endif
|
||||
}
|
||||
|
||||
friend class ProcessGroupNCCL;
|
||||
|
||||
protected:
|
||||
ncclComm_t ncclComm_;
|
||||
// Unique nccl_id for this communicator.
|
||||
|
@ -108,4 +108,22 @@ c10::intrusive_ptr<Store> PrefixStore::getUnderlyingStore() {
|
||||
return store_;
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Store> PrefixStore::getUnderlyingNonPrefixStore() {
|
||||
c10::intrusive_ptr<Store> store = store_;
|
||||
|
||||
while (store) {
|
||||
// Attempt to dynamically cast to PrefixStore
|
||||
PrefixStore* asPrefixStore = dynamic_cast<PrefixStore*>(store.get());
|
||||
if (asPrefixStore) {
|
||||
store = asPrefixStore->getUnderlyingStore();
|
||||
} else {
|
||||
break; // We've reached a non-PrefixStore
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
store != nullptr, "Underlying Non-PrefixStore shouldn't be null.");
|
||||
return store;
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
@ -53,6 +53,9 @@ class TORCH_API PrefixStore : public Store {
|
||||
|
||||
c10::intrusive_ptr<Store> getUnderlyingStore();
|
||||
|
||||
// Recursively to fetch the store before layers of wrapping with PrefixStore.
|
||||
c10::intrusive_ptr<Store> getUnderlyingNonPrefixStore();
|
||||
|
||||
protected:
|
||||
std::string prefix_;
|
||||
c10::intrusive_ptr<Store> store_;
|
||||
|
@ -165,6 +165,18 @@ void ProcessGroup::setGroupName(const std::string& name) {
|
||||
}
|
||||
}
|
||||
|
||||
const std::string& ProcessGroup::getGroupDesc() const {
|
||||
return pg_desc_;
|
||||
}
|
||||
|
||||
void ProcessGroup::setGroupDesc(const std::string& name) {
|
||||
pg_desc_ = name;
|
||||
// Also set the group desc for all backends
|
||||
for (auto& kv : deviceTypeToBackend_) {
|
||||
kv.second->setGroupDesc(name);
|
||||
}
|
||||
}
|
||||
|
||||
void ProcessGroup::enableCollectivesTiming() {
|
||||
for (auto& kv : deviceTypeToBackend_) {
|
||||
kv.second->enableCollectivesTiming();
|
||||
|
@ -694,6 +694,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
|
||||
const std::string& getGroupName() const;
|
||||
void setGroupName(const std::string& name);
|
||||
const std::string& getGroupDesc() const;
|
||||
void setGroupDesc(const std::string& name);
|
||||
void enableCollectivesTiming();
|
||||
|
||||
void release_resources() override;
|
||||
@ -724,6 +726,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
|
||||
const int size_;
|
||||
const c10::intrusive_ptr<Options> options_;
|
||||
const BackendType backendType_;
|
||||
std::string pg_desc_;
|
||||
|
||||
// Debug level setting. It is parsed once when ProcessGroup is constructed and
|
||||
// remains the same across use of this process group.
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,7 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(__linux__)
|
||||
#include <fcntl.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#ifdef USE_C10D_NCCL
|
||||
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <future>
|
||||
#include <iostream>
|
||||
@ -12,6 +20,7 @@
|
||||
|
||||
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
||||
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
||||
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
||||
#include <torch/csrc/distributed/c10d/Store.hpp>
|
||||
|
||||
#include <ATen/DynamicLibrary.h>
|
||||
@ -26,52 +35,74 @@
|
||||
#include <torch/custom_class.h>
|
||||
|
||||
namespace c10d {
|
||||
// Environment variable which controls whether we perform a NCCL healt check
|
||||
// Control whether we perform a NCCL health check or not
|
||||
// which ensures communicators are healthy at the beginning of init.
|
||||
static std::vector<std::string> TORCH_ENABLE_NCCL_HEALTH_CHECK = {
|
||||
"TORCH_ENABLE_NCCL_HEALTH_CHECK",
|
||||
"ENABLE_NCCL_HEALTH_CHECK"};
|
||||
|
||||
// Environment variable which controls whether or not wait() is blocking or
|
||||
// non-blocking.
|
||||
// Control whether or not wait() is blocking or non-blocking.
|
||||
static std::vector<std::string> TORCH_NCCL_BLOCKING_WAIT = {
|
||||
"TORCH_NCCL_BLOCKING_WAIT",
|
||||
"NCCL_BLOCKING_WAIT"};
|
||||
|
||||
// Environment variable which controls whether or not we perform Async Error
|
||||
// Handling with NCCL.
|
||||
// Control whether or not we perform Async Error Handling with NCCL.
|
||||
static std::vector<std::string> TORCH_NCCL_ASYNC_ERROR_HANDLING = {
|
||||
"TORCH_NCCL_ASYNC_ERROR_HANDLING",
|
||||
"NCCL_ASYNC_ERROR_HANDLING"};
|
||||
|
||||
// Environment Variable to control whether dumping debug info on watchdog
|
||||
// Control whether dumping debug info on watchdog
|
||||
// timeout is enabled. This variable must be set together with
|
||||
// TORCH_NCCL_ENABLE_MONITORING=1 and TORCH_NCCL_TRACE_BUFFER_SIZE > 0.
|
||||
static std::vector<std::string> TORCH_NCCL_DUMP_ON_TIMEOUT = {
|
||||
"TORCH_NCCL_DUMP_ON_TIMEOUT"};
|
||||
|
||||
// Environment Variable to control whether Desync Debug is enabled.
|
||||
// This variable must be set together with TORCH_NCCL_ASYNC_ERROR_HANDLING.
|
||||
// Control whether Desync Debug is enabled. This variable must be set
|
||||
// together with TORCH_NCCL_ASYNC_ERROR_HANDLING.
|
||||
static std::vector<std::string> TORCH_NCCL_DESYNC_DEBUG = {
|
||||
"TORCH_NCCL_DESYNC_DEBUG",
|
||||
"NCCL_DESYNC_DEBUG"};
|
||||
|
||||
// Enable recording start-events for all ProcessGroupNCCL collectives, and
|
||||
// compute accurate collective timing per-collective. (Note: end-events are
|
||||
// recorded by default. Turn on this flag can increase chances of a watchdog
|
||||
// hang due to performing a CUDA event query which eventually calls
|
||||
// cudaEventElapsedTime() API.
|
||||
static std::vector<std::string> TORCH_NCCL_ENABLE_TIMING = {
|
||||
"TORCH_NCCL_ENABLE_TIMING",
|
||||
"NCCL_ENABLE_TIMING"};
|
||||
|
||||
// Enable monitoring thread which aborts the process when the ProcessGroupNCCL
|
||||
// Watchdog thread gets stuck and no heartbeat is detected after
|
||||
// TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC. This can happen due to calling CUDA/NCCL
|
||||
// APIs that may hang. It is Useful to prevent jobs being stuck for a prolonged
|
||||
// time than necessary tying up cluster resources.
|
||||
static std::vector<std::string> TORCH_NCCL_ENABLE_MONITORING = {
|
||||
"TORCH_NCCL_ENABLE_MONITORING"};
|
||||
|
||||
// Control the watchdog heartbeat timeout period after which the monitoring
|
||||
// thread will abort the process.
|
||||
static std::vector<std::string> TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC = {
|
||||
"TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"};
|
||||
|
||||
// The maximum number of events we store in the flight recorder's ring buffer.
|
||||
// (One event could be the start or end of a collective, for example).
|
||||
static std::vector<std::string> TORCH_NCCL_TRACE_BUFFER_SIZE = {
|
||||
"TORCH_NCCL_TRACE_BUFFER_SIZE"};
|
||||
|
||||
// Control how much extra time we will wait for dumping the debugging info
|
||||
// before we exit and throws timeout exception.
|
||||
static std::vector<std::string> TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC = {
|
||||
"TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC"};
|
||||
|
||||
// Control the interval inside the watchdog thread to check the coordinated
|
||||
// signal from other ranks, e.g. to dump the debugging information.
|
||||
static std::vector<std::string> TORCH_NCCL_COORD_CHECK_MILSEC = {
|
||||
"TORCH_NCCL_COORD_CHECK_MILSEC"};
|
||||
|
||||
constexpr const char* NCCL_BACKEND_NAME = "nccl";
|
||||
|
||||
constexpr const char* TIMEOUT_DUMP = "timeout_dump";
|
||||
constexpr const char* EXCEPTION_DUMP = "exception_dump";
|
||||
|
||||
constexpr auto kProcessGroupNCCLDefaultTimeout =
|
||||
std::chrono::milliseconds(10 * 60 * 1000);
|
||||
@ -110,6 +141,59 @@ static std::vector<std::string> TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK =
|
||||
{"TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK",
|
||||
"NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"};
|
||||
|
||||
#if defined(__linux__)
|
||||
struct DumpPipe {
|
||||
DumpPipe(int rank) {
|
||||
std::string fileStem =
|
||||
getCvarString({"TORCH_NCCL_DEBUG_INFO_PIPE_FILE"}, "");
|
||||
if (fileStem.empty() ||
|
||||
getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0) <= 0) {
|
||||
return;
|
||||
}
|
||||
TORCH_CHECK(!fileStem.empty(), "TORCH_NCCL_DEBUG_INFO_TEMP_FILE is empty");
|
||||
std::string filename = c10::str(fileStem, rank, ".pipe");
|
||||
TORCH_CHECK(
|
||||
unlink(filename.c_str()) != -1 || errno == ENOENT,
|
||||
"Error removing existing named pipe ",
|
||||
filename);
|
||||
TORCH_CHECK(
|
||||
mkfifo(filename.c_str(), 0666) != -1,
|
||||
"Error creating named pipe ",
|
||||
filename);
|
||||
fd_ = open(filename.c_str(), O_RDONLY | O_NONBLOCK);
|
||||
LOG(INFO) << "Pipe file " << filename
|
||||
<< " has been opened, write to it to trigger NCCL Debug Dump.";
|
||||
TORCH_CHECK(fd_ != -1, "Error opening named pipe ", filename);
|
||||
}
|
||||
bool shouldDump() {
|
||||
if (fd_ == -1) {
|
||||
return false;
|
||||
}
|
||||
char buf[128];
|
||||
// non-blocking from O_NONBLOCK above.
|
||||
// Ignore EINTR because we already will poll this
|
||||
// again later.
|
||||
ssize_t bytesRead = read(fd_, &buf, 128);
|
||||
return bytesRead > 0;
|
||||
}
|
||||
~DumpPipe() {
|
||||
if (fd_ != -1) {
|
||||
close(fd_);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int fd_ = -1;
|
||||
};
|
||||
#else
|
||||
struct DumpPipe {
|
||||
DumpPipe(int rank) {}
|
||||
bool shouldDump() {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
// ProcessGroupNCCL implements NCCL bindings for c10d.
|
||||
//
|
||||
// All functions of the class are expected to be called in the same order
|
||||
@ -205,6 +289,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
uint64_t getSequencenumber() const override;
|
||||
|
||||
const std::string& logPrefix() const;
|
||||
|
||||
// Helper function that sets an exception_ptr on the WorkNCCL object.
|
||||
void setException(std::exception_ptr exception_ptr);
|
||||
|
||||
@ -358,6 +444,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// via `ncclCommSplit`
|
||||
std::shared_ptr<ProcessGroupNCCL> split_from;
|
||||
int64_t split_color{0};
|
||||
std::string group_name;
|
||||
};
|
||||
|
||||
// If you wish to create multiple process groups, each with a potentially
|
||||
@ -540,8 +627,11 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
void enableCollectivesTiming() override;
|
||||
|
||||
// Provide an API for users to define their own ways to store NCCL debug info.
|
||||
void registerDebugInfoWriter(std::unique_ptr<DebugInfoWriter> writer);
|
||||
// Helper function for iteratively aborting communicators in the provided map
|
||||
void abortCommsFromMap(
|
||||
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>&
|
||||
ncclCommsMap,
|
||||
c10::optional<std::string> abortReason);
|
||||
|
||||
// Provides an API to abort the ProcessGroup (similar to ncclCommAbort)
|
||||
// instead of relying on ProcessGroupNCCL destructor.
|
||||
@ -694,6 +784,19 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// Desync debug helper
|
||||
void logWorkEnd(WorkNCCL& work);
|
||||
|
||||
// Generates a prefix that is unique to this process group and rank, for
|
||||
// disambiguating logs
|
||||
std::string createLogPrefix() const;
|
||||
|
||||
// Returns the unique prefix created in createLogPrefix
|
||||
const std::string& logPrefix() const;
|
||||
|
||||
// Returns the global rank of the device. This function assumes that users
|
||||
// always create a default global process group(PG) which includes all
|
||||
// devices. It is called in the constructor of ProcessGroupNCCL, so it always
|
||||
// return the rank_ of the the very first PG created, aka, default global PG.
|
||||
const int& globalRank() const;
|
||||
|
||||
protected:
|
||||
// Function that runs as part of a separate thread aside from watchdog
|
||||
// thread because we need to check the heartbeat from watchdog thread
|
||||
@ -712,6 +815,19 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// for dump completion.
|
||||
std::future<bool> launchAsyncDebugDump();
|
||||
|
||||
// Helper to wait up to the specified timeout and then abandon the dump.
|
||||
// Logs on timeout, and asserts the future's status is as expected.
|
||||
void waitForDumpOrTimeout(
|
||||
std::future<bool>& fut,
|
||||
const std::chrono::time_point<std::chrono::steady_clock>& wakeUpTime,
|
||||
size_t timeout_sec = 30);
|
||||
|
||||
// A helper function to wait for a future to complete or timeout.
|
||||
void waitForFutureOrTimeout(
|
||||
std::future<bool>& fut,
|
||||
const std::chrono::milliseconds& timeOutMilSec,
|
||||
const std::string& futDescription);
|
||||
|
||||
// When watchdog timeout, this function will be called and return debug info
|
||||
// for users. For now we only get information from retrieveDesyncReport.
|
||||
// We are working on enabling more useful debug information for watchdog
|
||||
@ -720,9 +836,16 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
static const int64_t kWatchdogThreadSleepMillis;
|
||||
|
||||
// The store is used to broadcast the NCCL unique ID of rank 0.
|
||||
// The store is used to broadcast the NCCL unique ID of rank 0. This store
|
||||
// comes with prefix and it is different across ProcessGroup NCCL instances
|
||||
// (aka, different ProcessGroups).
|
||||
c10::intrusive_ptr<Store> store_;
|
||||
|
||||
// Reference to the store without prefix so that keys are same across all
|
||||
// ProcessGroup NCCL instances and (key, value) pairs written to the store are
|
||||
// global.
|
||||
c10::intrusive_ptr<Store> globalStore_;
|
||||
|
||||
bool storeError_{false};
|
||||
|
||||
const c10::intrusive_ptr<Options> options_;
|
||||
@ -781,11 +904,18 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
std::mutex mutex_;
|
||||
|
||||
// Heartbeat of watchdog thread.
|
||||
uint64_t heartbeat_;
|
||||
std::atomic_uint64_t heartbeat_;
|
||||
|
||||
// The time interval used for deciding whether there is no watchdog heartbeat.
|
||||
int heartbeatTimeoutInSec_;
|
||||
|
||||
// Extra time of sleep when waiting for timeout dump to finish.
|
||||
int waitTimeoutDumpInMilSec_;
|
||||
|
||||
// Interval of check coordinated signals in ProcessGroupNCCL from other ranks
|
||||
// e.g., trigger the dump of the debugging info for timeout when notified.
|
||||
int coordCheckIntervalMilSec_;
|
||||
|
||||
// Size of ring buffer where we store NCCL Traces for debugging.
|
||||
int ncclTraceBufferSize_;
|
||||
|
||||
@ -815,6 +945,15 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// Whether there are hooks pending to be fired
|
||||
std::atomic<bool> hasPendingHooks_;
|
||||
|
||||
// This is the signal from watchdog threads to indicate whether the monitor
|
||||
// thread should dump. Making it static so that it is accessiable from all the
|
||||
// PGs. With this flag, monitor thread would dump debug info under any one of
|
||||
// the 3 conditions: 1: this flag is set to true by the watchdog thread when
|
||||
// it detects a timeout. 2: timeout signal is received from
|
||||
// other ranks through tcpstore 3: no heartbeat of watchdog Note that only the
|
||||
// monitor thread from PG0 should dump the debug info and only once
|
||||
static std::atomic<bool> shouldDump_;
|
||||
|
||||
// Mutex to Guard workMetaList_
|
||||
std::mutex workMetaListMutex_;
|
||||
|
||||
@ -823,9 +962,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
bool writeDebugInfo_ = false;
|
||||
|
||||
// Mutex to Guard the check of writeDebugInfo_
|
||||
std::mutex writeDebugInfoMutex_;
|
||||
|
||||
// Condition Variable for watchdog thread sleep
|
||||
std::condition_variable workMetaListCV_;
|
||||
|
||||
@ -902,8 +1038,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
// Whether or not to enable timeout root cause analysis.
|
||||
bool desyncDebug_;
|
||||
|
||||
// Whether or not to dump debug info on timeout
|
||||
bool dumpOnTimeout_;
|
||||
// Whether or not to dump debug info on exception including both watchdog
|
||||
// timeout and nccl errors.
|
||||
bool dumpOnException_;
|
||||
|
||||
// Whether or not to create start CUDAEvent and enable timing for start
|
||||
// and end events. Note that enableTiming_ is always true if desyncDebug_
|
||||
@ -929,14 +1066,25 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
std::exception_ptr watchDogException_ = nullptr;
|
||||
|
||||
// The callback function to store NCCL debug info.
|
||||
std::unique_ptr<DebugInfoWriter> debugInfoWriter_ = nullptr;
|
||||
|
||||
size_t uid_;
|
||||
|
||||
std::string logPrefix_;
|
||||
};
|
||||
|
||||
TORCH_API std::string dump_nccl_trace();
|
||||
|
||||
// Gets a mutable reference to a global optional function. Heartbeat Monitor
|
||||
// will query this function and if available, call it to dump traces. Inside
|
||||
// fbcode, we store a function here that uses an internal tool for process
|
||||
// tracing
|
||||
TORCH_API c10::optional<std::function<std::string()>>& get_cpp_trace_dumper();
|
||||
|
||||
// Similar to get_cpp_trace_dumper, this stores a function defined in
|
||||
// torch-python layer that lets us check whether the GIL can be acquired,
|
||||
// helpful for instrumenting in cases where a hang was observed.
|
||||
typedef bool (*gil_checker_t)();
|
||||
|
||||
TORCH_API gil_checker_t& get_gil_checker();
|
||||
} // namespace c10d
|
||||
|
||||
#endif // USE_C10D_NCCL
|
||||
|
@ -13,7 +13,6 @@
|
||||
#include <string>
|
||||
#include <system_error>
|
||||
#include <vector>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
/* Trace Utils Related to TORCH_NCCL_DESYNC_DEBUG */
|
||||
@ -269,10 +268,20 @@ inline std::string retrieveDesyncReport(
|
||||
|
||||
#ifdef USE_C10D_NCCL
|
||||
|
||||
DebugInfoWriter::DebugInfoWriter(int rank) {
|
||||
std::string fileName = getCvarString(
|
||||
{"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
|
||||
filename_ = c10::str(fileName, rank);
|
||||
/* Helper used by work::getDuration() and nccl flight recorder */
|
||||
float getDurationFromFirstEvent(
|
||||
const std::vector<at::cuda::CUDAEvent>& ncclStartEvents,
|
||||
const std::vector<at::cuda::CUDAEvent>& ncclEndEvents) {
|
||||
TORCH_CHECK(
|
||||
ncclStartEvents.size() == 1,
|
||||
"getDuration only works for single device per ProcessGroup, but found multiple start events.");
|
||||
TORCH_CHECK(
|
||||
ncclEndEvents.size() == 1,
|
||||
"getDuration only works for single device per ProcessGroup, but found multiple end events.");
|
||||
TORCH_CHECK(
|
||||
ncclEndEvents[0].query(),
|
||||
"getDuration can only be called after work is succeeded.")
|
||||
return ncclStartEvents[0].elapsed_time(ncclEndEvents[0]);
|
||||
}
|
||||
|
||||
DebugInfoWriter::~DebugInfoWriter() = default;
|
||||
@ -293,6 +302,31 @@ void DebugInfoWriter::write(const std::string& ncclTrace) {
|
||||
LOG(INFO) << "Finished writing NCCLPG debug info to " << filename_;
|
||||
}
|
||||
|
||||
DebugInfoWriter& DebugInfoWriter::getWriter(int rank) {
|
||||
if (writer_ == nullptr) {
|
||||
std::string fileNamePrefix = getCvarString(
|
||||
{"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
|
||||
// Using std::unique_ptr here to auto-delete the writer object
|
||||
// when the pointer itself is destroyed.
|
||||
std::unique_ptr<DebugInfoWriter> writerPtr(
|
||||
new DebugInfoWriter(fileNamePrefix, rank));
|
||||
DebugInfoWriter::registerWriter(std::move(writerPtr));
|
||||
}
|
||||
return *writer_;
|
||||
}
|
||||
|
||||
void DebugInfoWriter::registerWriter(std::unique_ptr<DebugInfoWriter> writer) {
|
||||
TORCH_CHECK_WITH(
|
||||
DistBackendError,
|
||||
hasWriterRegistered_.load() == false,
|
||||
"debugInfoWriter already registered");
|
||||
hasWriterRegistered_.store(true);
|
||||
writer_ = std::move(writer);
|
||||
}
|
||||
|
||||
std::unique_ptr<DebugInfoWriter> DebugInfoWriter::writer_ = nullptr;
|
||||
std::atomic<bool> DebugInfoWriter::hasWriterRegistered_(false);
|
||||
|
||||
inline std::string pickle_str(const c10::IValue& v) {
|
||||
std::vector<char> result;
|
||||
{
|
||||
@ -317,6 +351,18 @@ inline c10::List<c10::IValue> new_list() {
|
||||
return c10::List<c10::IValue>(c10::AnyType::get());
|
||||
}
|
||||
|
||||
inline std::string ranks_str(const std::vector<uint64_t>& ranks) {
|
||||
std::string str;
|
||||
for (const auto& rank : ranks) {
|
||||
if (str.empty()) {
|
||||
str = std::to_string(rank);
|
||||
} else {
|
||||
str += ", " + std::to_string(rank);
|
||||
}
|
||||
}
|
||||
return c10::str("[", str, "]");
|
||||
}
|
||||
|
||||
struct NCCLTraceBuffer {
|
||||
static NCCLTraceBuffer* get() {
|
||||
// intentionally leak on exit
|
||||
@ -336,11 +382,12 @@ struct NCCLTraceBuffer {
|
||||
// buffer this entry will be located to
|
||||
// update state information
|
||||
size_t pg_id_;
|
||||
std::string pg_name_;
|
||||
size_t seq_id_; // as tracked by the process group
|
||||
const char* profiling_name_;
|
||||
std::string profiling_name_;
|
||||
|
||||
std::shared_ptr<torch::CapturedTraceback> traceback_;
|
||||
// we borrow pointser to start_ and end_ so we can query the state
|
||||
// we borrow pointers to start_ and end_ so we can query the state
|
||||
// on reporting. However, once the event is completed, the call
|
||||
// to `complete` will clear these.
|
||||
EventList *start_, *end_;
|
||||
@ -348,8 +395,18 @@ struct NCCLTraceBuffer {
|
||||
// timestamp when the entry was created, likely close to the time the work
|
||||
// was 'enqueued'- not necessarily started
|
||||
c10::time_t time_created_;
|
||||
c10::optional<float> duration_;
|
||||
|
||||
const char* state_ = "scheduled";
|
||||
// timestamp when our CPU threads discovered that the kernel started.
|
||||
// will always be _after_ it actually started, and can be very late
|
||||
// if the watchdog thread got stuck on CUDA APIs.
|
||||
c10::optional<c10::time_t> time_discovered_started_;
|
||||
|
||||
// timestamp when our CPU threads discovered that the kernel completed.
|
||||
// will always be _after_ it actually complated, and can be the same time
|
||||
// as the discovery of the start if the watchdog thread is stuck on CUDA
|
||||
// APIs
|
||||
c10::optional<c10::time_t> time_discovered_completed_;
|
||||
|
||||
// size information for input/output tensors
|
||||
c10::SmallVector<int, 4> input_dims_;
|
||||
@ -369,8 +426,9 @@ struct NCCLTraceBuffer {
|
||||
|
||||
c10::optional<size_t> record(
|
||||
size_t pg_id,
|
||||
const std::string& pg_name,
|
||||
size_t seq_id,
|
||||
const char* profiling_name,
|
||||
std::string profiling_name,
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
const std::vector<at::Tensor>& outputs,
|
||||
EventList* start,
|
||||
@ -385,8 +443,9 @@ struct NCCLTraceBuffer {
|
||||
auto te = Entry{
|
||||
id_,
|
||||
pg_id,
|
||||
pg_name,
|
||||
seq_id,
|
||||
profiling_name,
|
||||
std::move(profiling_name),
|
||||
std::move(traceback),
|
||||
std::move(start),
|
||||
std::move(end),
|
||||
@ -424,8 +483,8 @@ struct NCCLTraceBuffer {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (started) {
|
||||
r.state_ = "started";
|
||||
if (started && !r.time_discovered_started_) {
|
||||
r.time_discovered_started_ = c10::getTime();
|
||||
}
|
||||
}
|
||||
if (r.end_ != nullptr) {
|
||||
@ -436,8 +495,8 @@ struct NCCLTraceBuffer {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (completed) {
|
||||
r.state_ = "completed";
|
||||
if (completed && !r.time_discovered_completed_) {
|
||||
r.time_discovered_completed_ = c10::getTime();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -456,35 +515,97 @@ struct NCCLTraceBuffer {
|
||||
return result;
|
||||
}
|
||||
|
||||
void retire_id(c10::optional<size_t> id) {
|
||||
/*
|
||||
Mark an Event as completed and free its events.
|
||||
|
||||
This is called by the watchdog thread, and is asynchronous from the
|
||||
perspective of the main thread.
|
||||
|
||||
compute_duration defaults to true since retire_id is only called in the
|
||||
watchdog thread, which is currently a place we call cuda APIs which may hang,
|
||||
but care should be taken to avoid computing duration in any function that must
|
||||
never hang. (timing must also be enabled for compute_duration - see
|
||||
TORCH_NCCL_ENABLE_TIMING).
|
||||
*/
|
||||
void retire_id(c10::optional<size_t> id, bool compute_duration = true) {
|
||||
if (!enabled_ || !id) {
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::mutex> guard(mutex_);
|
||||
auto& entry = entries_.at(*id % max_entries_);
|
||||
if (entry.id_ == *id) {
|
||||
update_state(entry);
|
||||
entry.retired_ = true;
|
||||
entry.start_ = entry.end_ = nullptr;
|
||||
|
||||
bool can_compute_duration = false;
|
||||
EventList* startEvents = nullptr;
|
||||
EventList* endEvents = nullptr;
|
||||
c10::optional<float> duration = c10::nullopt;
|
||||
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
|
||||
Entry* entry = &entries_.at(*id % max_entries_);
|
||||
if (entry->id_ == *id) {
|
||||
update_state(*entry);
|
||||
|
||||
if (compute_duration) {
|
||||
can_compute_duration = entry->time_discovered_completed_.has_value() &&
|
||||
entry->start_ && entry->end_;
|
||||
startEvents = entry->start_;
|
||||
endEvents = entry->end_;
|
||||
}
|
||||
}
|
||||
|
||||
if (can_compute_duration) {
|
||||
// Compute duration without without holding the lock, because
|
||||
// cudaEventDuration() can hang, and we need to acquire the lock before we
|
||||
// can dump(), which we never want to block.
|
||||
guard.unlock();
|
||||
duration = getDurationFromFirstEvent(*startEvents, *endEvents);
|
||||
guard.lock();
|
||||
|
||||
// Refresh the entry pointer, see if the entry has been overwritten
|
||||
entry = &entries_.at(*id % max_entries_);
|
||||
if (entry->id_ != *id) {
|
||||
LOG(INFO)
|
||||
<< "retire_id abandoned for id " << *id
|
||||
<< ", event was overwritten while waiting to compute duration.";
|
||||
return;
|
||||
}
|
||||
if (duration.has_value()) {
|
||||
entry->duration_ = duration.value();
|
||||
}
|
||||
}
|
||||
|
||||
entry->retired_ = true;
|
||||
entry->start_ = entry->end_ = nullptr;
|
||||
}
|
||||
|
||||
std::string dump() {
|
||||
std::string dump(
|
||||
const c10::optional<std::unordered_map<
|
||||
std::string,
|
||||
std::unordered_map<std::string, std::string>>>& ncclDumpMap) {
|
||||
auto result = dump_entries();
|
||||
auto entries = new_list();
|
||||
c10::IValue pg_id_s = "pg_id";
|
||||
c10::IValue seq_id_s = "seq_id";
|
||||
c10::IValue profiling_name_s = "profiling_name";
|
||||
c10::IValue input_sizes_s = "input_sizes";
|
||||
c10::IValue output_sizes_s = "output_sizes";
|
||||
c10::IValue time_created_s = "time_created_us";
|
||||
c10::IValue entries_key = "entries";
|
||||
c10::IValue nccl_comm_key = "nccl_comm_state";
|
||||
c10::IValue version_key = "version";
|
||||
// Update whenever changing contents or formatting of the dump
|
||||
// (minor when adding fields, major when changing existing fields)
|
||||
c10::IValue version_val = "1.1";
|
||||
|
||||
c10::IValue frames_s = "frames";
|
||||
c10::IValue state_s = "state";
|
||||
c10::IValue line_s = "line";
|
||||
c10::IValue name_s = "name";
|
||||
c10::IValue filename_s = "filename";
|
||||
c10::IValue retired_s = "retired";
|
||||
c10::IValue pg_id_key = "pg_id";
|
||||
c10::IValue pg_name_key = "process_group";
|
||||
c10::IValue seq_id_key = "seq_id";
|
||||
c10::IValue profiling_name_key = "profiling_name";
|
||||
c10::IValue input_sizes_key = "input_sizes";
|
||||
c10::IValue output_sizes_key = "output_sizes";
|
||||
c10::IValue time_created_key = "time_created_ns";
|
||||
c10::IValue duration_key = "duration_ms";
|
||||
|
||||
c10::IValue frames_key = "frames";
|
||||
c10::IValue state_key = "state";
|
||||
c10::IValue line_key = "line";
|
||||
c10::IValue name_key = "name";
|
||||
c10::IValue filename_key = "filename";
|
||||
c10::IValue retired_key = "retired";
|
||||
c10::IValue time_discovered_started_key = "time_discovered_started_ns";
|
||||
c10::IValue time_discovered_completed_key = "time_discovered_completed_ns";
|
||||
|
||||
std::vector<torch::CapturedTraceback*> tracebacks;
|
||||
for (auto& e : result) {
|
||||
@ -494,9 +615,9 @@ struct NCCLTraceBuffer {
|
||||
std::vector<c10::IValue> all_frames;
|
||||
for (const auto& f : stracebacks.all_frames) {
|
||||
auto d = new_dict();
|
||||
d.insert(name_s, f.funcname);
|
||||
d.insert(filename_s, f.filename);
|
||||
d.insert(line_s, int64_t(f.lineno));
|
||||
d.insert(name_key, f.funcname);
|
||||
d.insert(filename_key, f.filename);
|
||||
d.insert(line_key, int64_t(f.lineno));
|
||||
all_frames.emplace_back(std::move(d));
|
||||
}
|
||||
|
||||
@ -504,10 +625,14 @@ struct NCCLTraceBuffer {
|
||||
auto& e = result.at(i);
|
||||
auto& tb = stracebacks.tracebacks.at(i);
|
||||
auto dict = new_dict();
|
||||
dict.insert(pg_id_s, int64_t(e.pg_id_));
|
||||
dict.insert(seq_id_s, int64_t(e.seq_id_));
|
||||
dict.insert(profiling_name_s, e.profiling_name_);
|
||||
dict.insert(time_created_s, int64_t(e.time_created_ / 1000));
|
||||
dict.insert(pg_id_key, int64_t(e.pg_id_));
|
||||
dict.insert(pg_name_key, e.pg_name_);
|
||||
dict.insert(seq_id_key, int64_t(e.seq_id_));
|
||||
dict.insert(profiling_name_key, e.profiling_name_);
|
||||
dict.insert(time_created_key, int64_t(e.time_created_));
|
||||
if (e.duration_) {
|
||||
dict.insert(duration_key, *e.duration_);
|
||||
}
|
||||
|
||||
auto it = e.sizes_.begin();
|
||||
auto read_sizes = [&](const c10::SmallVector<int, 4>& dims) {
|
||||
@ -523,19 +648,55 @@ struct NCCLTraceBuffer {
|
||||
return sizes;
|
||||
};
|
||||
|
||||
dict.insert(input_sizes_s, read_sizes(e.input_dims_));
|
||||
dict.insert(output_sizes_s, read_sizes(e.output_dims_));
|
||||
dict.insert(state_s, e.state_);
|
||||
dict.insert(retired_s, e.retired_);
|
||||
dict.insert(input_sizes_key, read_sizes(e.input_dims_));
|
||||
dict.insert(output_sizes_key, read_sizes(e.output_dims_));
|
||||
if (e.time_discovered_completed_.has_value()) {
|
||||
dict.insert(state_key, "completed");
|
||||
} else if (e.time_discovered_started_.has_value()) {
|
||||
dict.insert(state_key, "started");
|
||||
} else {
|
||||
dict.insert(state_key, "scheduled");
|
||||
}
|
||||
|
||||
dict.insert(
|
||||
time_discovered_started_key,
|
||||
e.time_discovered_started_.has_value()
|
||||
? int64_t(*e.time_discovered_started_)
|
||||
: c10::IValue());
|
||||
dict.insert(
|
||||
time_discovered_completed_key,
|
||||
e.time_discovered_completed_.has_value()
|
||||
? int64_t(*e.time_discovered_completed_)
|
||||
: c10::IValue());
|
||||
dict.insert(retired_key, e.retired_);
|
||||
|
||||
auto frames = new_list();
|
||||
for (int64_t frame : tb) {
|
||||
frames.push_back(all_frames.at(frame));
|
||||
}
|
||||
dict.insert(frames_s, frames);
|
||||
dict.insert(frames_key, frames);
|
||||
entries.push_back(dict);
|
||||
}
|
||||
return pickle_str(entries);
|
||||
// convert ncclDumpMap into a dictionary
|
||||
auto per_comm_dict = new_dict();
|
||||
if (ncclDumpMap.has_value()) {
|
||||
for (const auto& [ncclId, ncclDump] : ncclDumpMap.value()) {
|
||||
auto inner_dict = new_dict();
|
||||
for (const auto& [key, value] : ncclDump) {
|
||||
inner_dict.insert(key, value);
|
||||
}
|
||||
per_comm_dict.insert(ncclId, inner_dict);
|
||||
}
|
||||
}
|
||||
|
||||
auto dict = new_dict();
|
||||
dict.insert(entries_key, entries);
|
||||
dict.insert(version_key, version_val);
|
||||
if (per_comm_dict.size() > 0) {
|
||||
dict.insert(nccl_comm_key, per_comm_dict);
|
||||
}
|
||||
|
||||
return pickle_str(dict);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -50,6 +50,35 @@
|
||||
|
||||
namespace {
|
||||
|
||||
#ifdef USE_C10D_NCCL
|
||||
|
||||
bool acquire_gil() {
|
||||
// basically if this function can acquire the gil, it will return quickly.
|
||||
// if not, it will hang forever. The idea is to call this from a thread
|
||||
// wrapped in a future, and then check the future after a timeout, to
|
||||
// determine whether we're facing gil contention.
|
||||
if (Py_IsInitialized()) {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
return true;
|
||||
}
|
||||
|
||||
// If we end up here, its probably still a "pass" from the perspective of
|
||||
// checking whether python is stuck. but currently we don't check the return
|
||||
// value of this function anyway, just check whether it returned quickly vs
|
||||
// timing out. Taking a long time is the main sign of trouble. Fast return
|
||||
// with true or with false is both OK from the perspective of debugging python
|
||||
// hangs.
|
||||
return false;
|
||||
}
|
||||
|
||||
bool registerGilChecker() {
|
||||
c10d::get_gil_checker() = &acquire_gil;
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool registered = registerGilChecker();
|
||||
#endif // USE_C10D_NCCL
|
||||
|
||||
// Wrapper to ensure GIL is released before destructing ProcessGroupGloo
|
||||
// TODO: move this somewhere more generally useful
|
||||
template <typename T>
|
||||
@ -1033,6 +1062,29 @@ Example::
|
||||
>>> store.add("first_key", 6)
|
||||
>>> # Should return 7
|
||||
>>> store.get("first_key")
|
||||
)")
|
||||
.def(
|
||||
"check",
|
||||
&::c10d::Store::check,
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(
|
||||
The call to check whether a given list of ``keys`` have value stored in
|
||||
the store. This call immediately returns in normal cases but still suffers
|
||||
from some edge deadlock cases, e.g, calling check after TCPStore has been destroyed.
|
||||
Calling :meth:`~torch.distributed.store.check` with a list of keys that
|
||||
one wants to check whether stored in the store or not.
|
||||
|
||||
Arguments:
|
||||
keys (lisr[str]): The keys to query whether stored in the store.
|
||||
|
||||
Example::
|
||||
>>> import torch.distributed as dist
|
||||
>>> from datetime import timedelta
|
||||
>>> # Using TCPStore as an example, other store types can also be used
|
||||
>>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
|
||||
>>> store.add("first_key", 1)
|
||||
>>> # Should return 7
|
||||
>>> store.check(["first_key"])
|
||||
)")
|
||||
.def(
|
||||
"delete_key",
|
||||
@ -1404,7 +1456,11 @@ Arguments:
|
||||
.def_property_readonly(
|
||||
"underlying_store",
|
||||
&::c10d::PrefixStore::getUnderlyingStore,
|
||||
R"(Gets the underlying store object that PrefixStore wraps around.)");
|
||||
R"(Gets the underlying store object that PrefixStore wraps around.)")
|
||||
.def_property_readonly(
|
||||
"_underlying_non_prefix_store",
|
||||
&::c10d::PrefixStore::getUnderlyingNonPrefixStore,
|
||||
R"(Recursively to get the store before layers of wrapping with PrefixStore.)");
|
||||
|
||||
auto processGroup =
|
||||
py::class_<
|
||||
@ -1807,6 +1863,15 @@ Arguments:
|
||||
"group_name",
|
||||
&::c10d::ProcessGroup::getGroupName,
|
||||
"(Gets this process group name. It's cluster unique)")
|
||||
.def(
|
||||
"_set_group_desc",
|
||||
&::c10d::ProcessGroup::setGroupDesc,
|
||||
py::call_guard<py::gil_scoped_acquire>(),
|
||||
"Sets the process group description. This is an internal C10D method, do not use.")
|
||||
.def_property_readonly(
|
||||
"group_desc",
|
||||
&::c10d::ProcessGroup::getGroupDesc,
|
||||
"Gets this process group description")
|
||||
.def_property(
|
||||
"bound_device_id",
|
||||
&::c10d::ProcessGroup::getBoundDeviceId,
|
||||
@ -2387,7 +2452,9 @@ Example::
|
||||
.def_readwrite(
|
||||
"split_from", &::c10d::ProcessGroupNCCL::Options::split_from)
|
||||
.def_readwrite(
|
||||
"split_color", &::c10d::ProcessGroupNCCL::Options::split_color);
|
||||
"split_color", &::c10d::ProcessGroupNCCL::Options::split_color)
|
||||
.def_readwrite(
|
||||
"group_name", &::c10d::ProcessGroupNCCL::Options::group_name);
|
||||
|
||||
#endif
|
||||
|
||||
|
448
torch/distributed/_state_dict_utils.py
Normal file
448
torch/distributed/_state_dict_utils.py
Normal file
@ -0,0 +1,448 @@
|
||||
import io
|
||||
import math
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||
|
||||
if dist.is_available() or TYPE_CHECKING:
|
||||
from torch.distributed import distributed_c10d
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
from torch.distributed._tensor import DTensor, Replicate
|
||||
|
||||
|
||||
def _identity_func(
|
||||
obj: torch.Tensor,
|
||||
pg: Optional[dist.ProcessGroup],
|
||||
device: Optional[torch.device],
|
||||
companion_obj: Any,
|
||||
) -> torch.Tensor:
|
||||
return obj
|
||||
|
||||
|
||||
def _all_gather_sharded_tensor(
|
||||
sharded_tensor: "ShardedTensor",
|
||||
pg: Optional[dist.ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
if pg is None:
|
||||
pg = distributed_c10d._get_default_group()
|
||||
world_size = dist.get_world_size(pg)
|
||||
shards = sharded_tensor.local_shards()
|
||||
dim_0_size = sharded_tensor.size()[0] # type: ignore[index]
|
||||
tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr]
|
||||
chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size
|
||||
pg_device = (
|
||||
distributed_c10d._get_pg_default_device(pg) if device is None else device
|
||||
)
|
||||
if shards:
|
||||
local_tensor = shards[0].tensor.flatten()
|
||||
if local_tensor.device.type != pg_device.type:
|
||||
local_tensor = local_tensor.to(pg_device)
|
||||
num_padding = chunk_size - local_tensor.numel()
|
||||
if num_padding > 0:
|
||||
local_tensor = F.pad(local_tensor, [0, num_padding])
|
||||
else:
|
||||
local_tensor = torch.zeros(
|
||||
chunk_size, dtype=sharded_tensor.dtype, device=pg_device
|
||||
)
|
||||
|
||||
tensor = torch.empty(
|
||||
chunk_size * world_size,
|
||||
dtype=local_tensor.dtype,
|
||||
device=pg_device,
|
||||
)
|
||||
dist.all_gather_into_tensor(tensor, local_tensor, group=pg)
|
||||
|
||||
tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
|
||||
return tensor
|
||||
|
||||
|
||||
class CompanionMismatch(Exception):
|
||||
...
|
||||
|
||||
|
||||
def _iterate_state_dict(
|
||||
iter_object: Any,
|
||||
sharded_tensor_func: Callable,
|
||||
dtensor_func: Callable,
|
||||
tensor_func: Callable,
|
||||
*,
|
||||
pg: Optional[dist.ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
cpu_offload: bool = False,
|
||||
companion_obj: Any = None,
|
||||
ranks_only: Tuple[int, ...] = tuple(),
|
||||
type_check: bool = True,
|
||||
non_blocking: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""Iterate through the state dict, applying the given functions to each tensor type.
|
||||
|
||||
Args:
|
||||
iter_object (Any): the target state_dict.
|
||||
sharded_tensor_func (Callable): the function to apply to ShardedTensor
|
||||
dtensor_func (Callable): the function to apply to DTensor
|
||||
tensor_func (Callable): the function to apply to Tensor
|
||||
pg (Optional[dist.ProcessGroup]): process group passed to tensor functions
|
||||
device (Optional[torch.device]): device passed to tensor functions
|
||||
cpu_offload (bool): whether to offload the tensors to CPU memory. This option is ignored
|
||||
if a companion_obj is supplied.
|
||||
companion_obj (Any): A companion object to the state dict. If this object
|
||||
is supplied, we attempt to copy the tensor to the companion object.
|
||||
ranks_only (Tuple[int, ...]): if this tuple is empty, all ranks will
|
||||
have the same state_dicts. Otherwise only ranks that in ``ranks_only``
|
||||
have the same state_dicts. Other ranks will get empty state_dicts.
|
||||
type_check (bool): check if the instance data type is a supported type
|
||||
that can be saved by DCP. The current supported data types are
|
||||
torch.Tensor, DTensor, int, float, str, list, dict, None.
|
||||
non_blocking (bool): whether to use non-blocking copy when copying to the companion object.
|
||||
"""
|
||||
# TODO: should we use pytree?
|
||||
cpu_device = torch.device("cpu")
|
||||
if isinstance(iter_object, ShardedTensor):
|
||||
ret = sharded_tensor_func(iter_object, pg, device, companion_obj)
|
||||
elif isinstance(iter_object, DTensor):
|
||||
ret = dtensor_func(iter_object, pg, device, companion_obj)
|
||||
elif isinstance(iter_object, torch.Tensor):
|
||||
ret = tensor_func(iter_object, pg, device, companion_obj)
|
||||
elif (
|
||||
isinstance(iter_object, (int, float, str, bytes, io.BytesIO))
|
||||
or iter_object is None
|
||||
):
|
||||
ret = iter_object
|
||||
elif isinstance(iter_object, dict):
|
||||
if companion_obj is not None and (
|
||||
not isinstance(companion_obj, dict)
|
||||
or set(companion_obj.keys()) != set(iter_object.keys())
|
||||
):
|
||||
raise CompanionMismatch()
|
||||
|
||||
ret = {
|
||||
key: _iterate_state_dict(
|
||||
value,
|
||||
sharded_tensor_func,
|
||||
dtensor_func,
|
||||
tensor_func,
|
||||
pg=pg,
|
||||
device=device,
|
||||
cpu_offload=cpu_offload,
|
||||
companion_obj=companion_obj[key] if companion_obj is not None else None,
|
||||
ranks_only=ranks_only,
|
||||
type_check=type_check,
|
||||
non_blocking=non_blocking,
|
||||
)
|
||||
for key, value in iter_object.items()
|
||||
}
|
||||
elif isinstance(iter_object, (list, tuple)):
|
||||
if companion_obj is not None and (
|
||||
not isinstance(companion_obj, (list, tuple))
|
||||
or len(companion_obj) != len(iter_object)
|
||||
):
|
||||
raise CompanionMismatch()
|
||||
|
||||
ret = [
|
||||
_iterate_state_dict(
|
||||
v,
|
||||
sharded_tensor_func,
|
||||
dtensor_func,
|
||||
tensor_func,
|
||||
pg=pg,
|
||||
device=device,
|
||||
cpu_offload=cpu_offload,
|
||||
companion_obj=companion_obj[idx] if companion_obj is not None else None,
|
||||
ranks_only=ranks_only,
|
||||
type_check=type_check,
|
||||
non_blocking=non_blocking,
|
||||
)
|
||||
for idx, v in enumerate(iter_object)
|
||||
]
|
||||
if isinstance(iter_object, tuple):
|
||||
ret = tuple(ret)
|
||||
elif not type_check:
|
||||
ret = iter_object
|
||||
else:
|
||||
raise ValueError(f"Unexpected value type {type(iter_object)}")
|
||||
|
||||
if not ranks_only or dist.get_rank(pg) in ranks_only:
|
||||
if isinstance(ret, torch.Tensor):
|
||||
if cpu_offload and companion_obj is None:
|
||||
ret = ret.to(cpu_device)
|
||||
|
||||
if companion_obj is not None:
|
||||
# TODO: support DTensor
|
||||
companion_obj.copy_(ret, non_blocking=non_blocking)
|
||||
ret = companion_obj
|
||||
else:
|
||||
ret = {} if isinstance(ret, dict) else None
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def _gather_state_dict(
|
||||
state_dict: Dict[str, Any],
|
||||
*,
|
||||
pg: Optional[dist.ProcessGroup] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
cpu_offload: bool = False,
|
||||
ranks_only: Tuple[int, ...] = tuple(),
|
||||
type_check: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Given a state_dict, this API gathers all the ShardedTensors or DTensors in
|
||||
the state_dict.
|
||||
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Any]): the target sharded state_dict.
|
||||
pg (Optional[dist.ProcessGroup]): the process group that is used to
|
||||
gather ShardedTensor. Note that gathering a DTensor will use
|
||||
the DeviceMesh. So this argument will be ignored when gathering a
|
||||
DTensor.
|
||||
device: (Optional[torch.device]): the device that is used to
|
||||
perform allgather for ShardedTensor. Note that gathering a DTensor
|
||||
will use the DeviceMesh. So this argument will be ignored when
|
||||
gathering a DTensor.
|
||||
cpu_offload (bool): whether to offload the tensors to CPU memory. The
|
||||
default value is False.
|
||||
ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
|
||||
have the same state_dicts. Otherwise only ranks that in ``ranks_only``
|
||||
have the same state_dicts. Other ranks will get empty state_dicts.
|
||||
type_check: (bool): check if the instance data type is a supported type
|
||||
that can be saved by DCP. The current supported data types are
|
||||
torch.Tensor, DTensor, int, float, str, list, dict, None.
|
||||
|
||||
Returns:
|
||||
The gathered state dictionary.
|
||||
"""
|
||||
|
||||
def sharded_tensor_func(value, pg, device, companion_obj):
|
||||
# ShardedTensor does not seem to record the original device type.
|
||||
# So if the tensor is moved to CPU, we won't know the original type.
|
||||
# As a result, we have to rely on the user to tell us the correct one.
|
||||
cpu_device = torch.device("cpu")
|
||||
output_tensor = _all_gather_sharded_tensor(value, pg, device)
|
||||
local_shard_device = (
|
||||
value.local_shards()[0].tensor.device
|
||||
if value.local_shards()
|
||||
else cpu_device
|
||||
)
|
||||
if output_tensor.device != local_shard_device:
|
||||
value = output_tensor.to(local_shard_device)
|
||||
else:
|
||||
value = output_tensor
|
||||
return value
|
||||
|
||||
def dtensor_func(value, pg, device, companion_obj):
|
||||
if value.device != value.device_mesh.device_type:
|
||||
value = value.to(value.device_mesh.device_type)
|
||||
# FSDP all_gather: [Shard(0)] -> [Replicate()]
|
||||
# HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
|
||||
# 2D FSDP + TP all_gather:
|
||||
# - [Shard(0), Shard(n)] -> [Replicate(), Replicate()]
|
||||
# - [Shard(0), Replicate()] -> [Replicate(), Replicate()]
|
||||
placements = [Replicate() for _ in value.placements]
|
||||
value = value.redistribute(
|
||||
device_mesh=value.device_mesh,
|
||||
placements=placements,
|
||||
)
|
||||
# Call `wait()` to force the tensor to be synchronous with respect
|
||||
# to the main stream.
|
||||
# See the discussion in https://github.com/pytorch/pytorch/pull/117799.
|
||||
value = value.to_local()
|
||||
if isinstance(value, AsyncCollectiveTensor):
|
||||
value = value.wait()
|
||||
return value
|
||||
|
||||
return _iterate_state_dict(
|
||||
state_dict,
|
||||
sharded_tensor_func,
|
||||
dtensor_func,
|
||||
_identity_func,
|
||||
pg=pg,
|
||||
device=device,
|
||||
cpu_offload=cpu_offload,
|
||||
ranks_only=ranks_only,
|
||||
type_check=type_check,
|
||||
)
|
||||
|
||||
|
||||
def _offload_state_dict_to_cpu(
|
||||
state_dict: Dict[str, Any],
|
||||
*,
|
||||
ranks_only: Tuple[int, ...] = tuple(),
|
||||
type_check: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Given a state_dict, this API offload all the tensors to CPU memory.
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Any]): the target state_dict.
|
||||
pg (Optional[dist.ProcessGroup]): the process group that is used to
|
||||
gather ShardedTensor. Note that gathering a DTensor will use
|
||||
the DeviceMesh. So this argument will be ignored when gathering a
|
||||
DTensor.
|
||||
ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
|
||||
have the same state_dicts. Otherwise only ranks that in ``ranks_only``
|
||||
have the same state_dicts. Other ranks will get empty state_dicts.
|
||||
type_check: (bool): check if the instance data type is a supported type
|
||||
that can be saved by DCP. The current supported data types are
|
||||
torch.Tensor, DTensor, int, float, str, list, dict, None.
|
||||
|
||||
Returns:
|
||||
The gathered state dictionary.
|
||||
"""
|
||||
|
||||
ret = _iterate_state_dict(
|
||||
state_dict,
|
||||
_identity_func,
|
||||
_identity_func,
|
||||
_identity_func,
|
||||
pg=None,
|
||||
device=None,
|
||||
cpu_offload=True,
|
||||
ranks_only=ranks_only,
|
||||
type_check=type_check,
|
||||
)
|
||||
return ret
|
||||
|
||||
|
||||
def _copy_state_dict(
|
||||
state_dict: Dict[str, Any],
|
||||
copy_state_dict: Dict[str, Any],
|
||||
non_blocking: bool = False,
|
||||
):
|
||||
"""
|
||||
Copies all tensors in a given state dict into a different state_dict with the
|
||||
same structure.
|
||||
|
||||
.. warning::
|
||||
It is expected by this function that state_dict and copy_state_dict share
|
||||
the same structure and data types.
|
||||
|
||||
.. warning::
|
||||
The current supported data types are
|
||||
torch.Tensor, DTensor, int, float, str, list, dict, None.
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Any]): the target state_dict.
|
||||
copy_state_dict (Dict[str, Any]):
|
||||
The state dict we are copying into. This state_dict must have exactly
|
||||
the same structure as the source `state_dict`.
|
||||
non_blocking: (bool): Whether copy ops should be performed asynchronously
|
||||
"""
|
||||
|
||||
_iterate_state_dict(
|
||||
state_dict,
|
||||
_identity_func,
|
||||
_identity_func,
|
||||
_identity_func,
|
||||
pg=None,
|
||||
device=None,
|
||||
cpu_offload=False,
|
||||
ranks_only=tuple(),
|
||||
companion_obj=copy_state_dict,
|
||||
type_check=True,
|
||||
non_blocking=non_blocking,
|
||||
)
|
||||
|
||||
|
||||
def _create_cpu_state_dict(
|
||||
state_dict: Dict[str, Any], pin_memory: bool = False, share_memory: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Given a state_dict, create another state_dict with the same structure and elements.
|
||||
However, all tensors in the returned state_dict are new tensors on CPU. These
|
||||
tensors can be placed on pin_memory or share_memory based on the provided arguments.
|
||||
|
||||
.. warning::
|
||||
Setting both `pin_memory` and `share_memory` to True significantly increases the
|
||||
latency of this method because of the nuances which require us to register memory
|
||||
as pinned directly as opposed to relying on the pin_memory cache allocator. This
|
||||
option should only be used for long lived tensors which are required to be shared.
|
||||
This is not the case as long as at least one of `pin_memory` or `share_memory` is
|
||||
set to False.
|
||||
|
||||
"""
|
||||
|
||||
def tensor_func(
|
||||
obj: torch.Tensor,
|
||||
pg: Optional[dist.ProcessGroup],
|
||||
device: Optional[torch.device],
|
||||
_: Any,
|
||||
) -> torch.Tensor:
|
||||
if len(obj.size()) == 0:
|
||||
return torch.tensor(0, dtype=obj.dtype)
|
||||
|
||||
if share_memory:
|
||||
t = torch.empty(*tuple(obj.size()), dtype=obj.dtype).share_memory_()
|
||||
if pin_memory:
|
||||
succ = torch.cuda.cudart().cudaHostRegister(
|
||||
t.data_ptr(),
|
||||
t.numel() * t.element_size(),
|
||||
1, # lines up with 'cudaHostRegisterPortable'
|
||||
)
|
||||
assert (
|
||||
succ == 0
|
||||
), f"Pinning shared memory failed with error-code: {succ}"
|
||||
return t
|
||||
elif pin_memory:
|
||||
return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory()
|
||||
else:
|
||||
return torch.empty(*tuple(obj.size()), dtype=obj.dtype)
|
||||
|
||||
ret = _iterate_state_dict(
|
||||
state_dict,
|
||||
_identity_func,
|
||||
_identity_func,
|
||||
tensor_func,
|
||||
pg=None,
|
||||
device=None,
|
||||
cpu_offload=False,
|
||||
ranks_only=tuple(),
|
||||
type_check=False,
|
||||
)
|
||||
return ret
|
||||
|
||||
|
||||
def _check_state_dict_similarity(
|
||||
state_dict: Dict[str, Any],
|
||||
compared_state_dict: Dict[str, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Given two state_dicts, check if the structures are the same. And
|
||||
if a [key, tensor] pair exist in one state_dict there must be
|
||||
the a corresponding pait, [key, other_tensor], in the other state_dict,
|
||||
where tensor and other_tensor have the same size and dtype.
|
||||
|
||||
Return the check result.
|
||||
"""
|
||||
|
||||
def tensor_func(
|
||||
obj: torch.Tensor,
|
||||
pg: Optional[dist.ProcessGroup],
|
||||
device: Optional[torch.device],
|
||||
companion_obj: Any,
|
||||
) -> torch.Tensor:
|
||||
if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size():
|
||||
raise CompanionMismatch()
|
||||
return obj
|
||||
|
||||
try:
|
||||
_iterate_state_dict(
|
||||
state_dict,
|
||||
_identity_func,
|
||||
_identity_func,
|
||||
tensor_func,
|
||||
pg=None,
|
||||
device=None,
|
||||
cpu_offload=False,
|
||||
ranks_only=tuple(),
|
||||
companion_obj=compared_state_dict,
|
||||
type_check=False,
|
||||
)
|
||||
except CompanionMismatch:
|
||||
return False
|
||||
|
||||
return True
|
@ -1171,7 +1171,7 @@ def init_process_group(
|
||||
)
|
||||
|
||||
default_pg, _ = _new_process_group_helper(
|
||||
-1, -1, [], backend, None, group_name, timeout=timeout
|
||||
-1, -1, [], backend, None, group_name, timeout=timeout, group_desc="default_pg"
|
||||
)
|
||||
_update_default_pg(default_pg)
|
||||
else:
|
||||
@ -1197,6 +1197,7 @@ def init_process_group(
|
||||
pg_options=pg_options,
|
||||
timeout=timeout,
|
||||
device_id=device_id,
|
||||
group_desc="default_pg"
|
||||
)
|
||||
_update_default_pg(default_pg)
|
||||
|
||||
@ -1257,6 +1258,7 @@ def _new_process_group_helper(
|
||||
timeout=None,
|
||||
pg_tag=None,
|
||||
device_id=None,
|
||||
group_desc=None,
|
||||
):
|
||||
"""
|
||||
Create a new distributed process group.
|
||||
@ -1289,6 +1291,8 @@ def _new_process_group_helper(
|
||||
_, prefix_store = _world.pg_map[existing_group]
|
||||
return existing_group, prefix_store
|
||||
|
||||
group_desc = "undefined" if group_desc is None else group_desc
|
||||
|
||||
# The list of group ranks is empty if we're creating the default group.
|
||||
is_default_group = len(global_ranks_in_group) == 0
|
||||
|
||||
@ -1375,6 +1379,7 @@ def _new_process_group_helper(
|
||||
if split_from:
|
||||
pg_options.split_from = split_from
|
||||
pg_options.split_color = _process_group_color(global_ranks_in_group)
|
||||
pg_options.group_name = group_name
|
||||
backend_class = ProcessGroupNCCL(
|
||||
backend_prefix_store, group_rank, group_size, pg_options)
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
@ -1461,9 +1466,11 @@ def _new_process_group_helper(
|
||||
|
||||
# update global state
|
||||
assert group_name is not None
|
||||
assert group_desc is not None
|
||||
_world.pg_map[pg] = (backend, prefix_store)
|
||||
_world.pg_names[pg] = group_name
|
||||
pg._set_group_name(group_name)
|
||||
pg._set_group_desc(group_desc)
|
||||
|
||||
_world.pg_backend_config[pg] = str(backend_config)
|
||||
# "" is the default tag for user PGs
|
||||
@ -3614,7 +3621,7 @@ def _get_backend_from_str(backend: Optional[str] = None) -> Backend:
|
||||
|
||||
|
||||
@_time_logger
|
||||
def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local_synchronization=False):
|
||||
def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local_synchronization=False, group_desc=None):
|
||||
"""
|
||||
Create a new distributed group.
|
||||
|
||||
@ -3655,6 +3662,7 @@ def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local
|
||||
barrier at the end of the process group creation. This is different
|
||||
in that non-member ranks don't need to call into API and don't
|
||||
join the barrier.
|
||||
group_desc (str, optional): a string to describe the process group.
|
||||
|
||||
Returns:
|
||||
A handle of distributed group that can be given to collective calls or None if the rank is not part of ``ranks``.
|
||||
@ -3669,7 +3677,15 @@ def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local
|
||||
multiple overlaping process groups. To avoid that, make sure all ranks follow the
|
||||
same global creation order.
|
||||
"""
|
||||
return _new_group_with_tag(ranks, timeout, backend, pg_options, None, use_local_synchronization=use_local_synchronization)
|
||||
return _new_group_with_tag(
|
||||
ranks,
|
||||
timeout,
|
||||
backend,
|
||||
pg_options,
|
||||
None,
|
||||
use_local_synchronization=use_local_synchronization,
|
||||
group_desc=group_desc,
|
||||
)
|
||||
|
||||
def _new_group_with_tag(
|
||||
ranks=None,
|
||||
@ -3677,7 +3693,8 @@ def _new_group_with_tag(
|
||||
backend=None,
|
||||
pg_options=None,
|
||||
pg_tag=None,
|
||||
use_local_synchronization=False
|
||||
use_local_synchronization=False,
|
||||
group_desc=None
|
||||
):
|
||||
"""
|
||||
Variant of ``new_group`` that exposes tag creation.
|
||||
@ -3749,7 +3766,8 @@ def _new_group_with_tag(
|
||||
group_name,
|
||||
pg_options=pg_options,
|
||||
timeout=timeout,
|
||||
pg_tag=pg_tag
|
||||
pg_tag=pg_tag,
|
||||
group_desc=group_desc
|
||||
)
|
||||
|
||||
# Create the global rank to group rank mapping
|
||||
@ -3789,6 +3807,7 @@ def new_subgroups(
|
||||
timeout=None,
|
||||
backend=None,
|
||||
pg_options=None,
|
||||
group_desc=None,
|
||||
):
|
||||
"""
|
||||
Create subgroups of equal size.
|
||||
@ -3841,6 +3860,8 @@ def new_subgroups(
|
||||
the construction of specific process groups. i.e. for the ``nccl``
|
||||
backend, ``is_high_priority_stream`` can be specified so that
|
||||
process group can pick up high priority cuda streams.
|
||||
group_desc (str, optional): A string describing the group. Each subgroup will
|
||||
inherit its group_desc
|
||||
|
||||
Returns:
|
||||
The subgroup containing the current rank, and all the subgroups used for cleanup.
|
||||
@ -3886,6 +3907,7 @@ def new_subgroups(
|
||||
timeout=timeout,
|
||||
backend=backend,
|
||||
pg_options=pg_options,
|
||||
group_desc=group_desc,
|
||||
)
|
||||
subgroups.append(subgroup)
|
||||
|
||||
@ -3905,6 +3927,7 @@ def new_subgroups_by_enumeration(
|
||||
timeout=None,
|
||||
backend=None,
|
||||
pg_options=None,
|
||||
group_desc=None,
|
||||
):
|
||||
"""
|
||||
Create subgroups by dividing the global world.
|
||||
@ -3945,6 +3968,8 @@ def new_subgroups_by_enumeration(
|
||||
the construction of specific process groups. i.e. for the ``nccl``
|
||||
backend, ``is_high_priority_stream`` can be specified so that
|
||||
process group can pick up high priority cuda streams.
|
||||
group_desc (str, optional): A string describing the group. Each subgroup will
|
||||
inherit its group_desc.
|
||||
|
||||
Returns:
|
||||
The subgroup containing the current rank, and all the subgroups used for cleanup.
|
||||
@ -3973,6 +3998,7 @@ def new_subgroups_by_enumeration(
|
||||
timeout=timeout,
|
||||
backend=backend,
|
||||
pg_options=pg_options,
|
||||
group_desc=group_desc,
|
||||
)
|
||||
subgroups.append(subgroup)
|
||||
my_rank = get_rank()
|
||||
|
@ -28,7 +28,7 @@ def tail_logfile(
|
||||
return
|
||||
time.sleep(interval_sec)
|
||||
|
||||
with open(file) as fp:
|
||||
with open(file, errors="replace") as fp:
|
||||
while True:
|
||||
line = fp.readline()
|
||||
|
||||
|
Reference in New Issue
Block a user