Compare commits

..

18 Commits

Author SHA1 Message Date
173368cb63 CUDAEvent::elapsed_time could accidentally initialize a non-used GPU (#122538)
This sets the device before call cudaEventElapsedTime to avoid the case
where the "cudaGetCurrentDevice" device would be initialized even though
neither event is on that device.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122538
Approved by: https://github.com/shuqiangzhang, https://github.com/wconstab
2024-04-22 15:39:28 -07:00
b86edd97d6 [nccl-pg] print broadcast ncclunique id duration (#123963)
Summary: Print NCCL PG broadcast nccl unique id duration for measurement.

Differential Revision: D56048059

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123963
Approved by: https://github.com/wconstab
2024-04-16 17:03:25 -07:00
b33a283e9a [nccl-pg] Pass pg name and desc to NCCL communicator (#124149)
Summary:
Pass Process Group Name and Desc to NCCL communicator in order to access pg information in NCCL layer.
The information is passed as commDesc string(i.e. "<pg_desc>:<pg_name>")
Function only valid when NCCL_COMM_DESCRIPTION is defined.

Differential Revision: D55703310

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124149
Approved by: https://github.com/shuqiangzhang
2024-04-16 15:08:38 -07:00
7a551d81e5 [c10d/nccl-pg] allow user to pass process group description (#123472)
Summary:
We need a way to allow user set a customized description for a process group, e.g. FSDP, PP.

Here are several use cases of user specified group_desc:
- Logging: we can easily match a log line and understand what's this collective/pg is used to.
- Pytorch traces (e.g. Kineto, Execution Trace) can benefit from the PG desc since trace analysis, benchmarks will be able to easily differentiate PG purpose like FSDP, PP.
- Lower layer collectives(e.g. NCCL) debug: we will be able to expose PG desc to NCCL communicator so NCCL layer operations can be easily correlated to a PG.

Solution: Add a group_desc field to c10d

Differential Revision: D55781850

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123472
Approved by: https://github.com/kwen2501
2024-04-16 15:08:38 -07:00
1515a90475 [DCP] Adds ability to create a CPU state dict that is both shared and pinned (#122338)
[DCP] Adds ability to create a CPU state dict that is both shared and pinned, as well as a new utility specific to copying the state dict

https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY.html#group__CUDART__MEMORY_1ge8d5c17670f16ac4fc8fcb4181cb490c

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122338
Approved by: https://github.com/fegin
2024-04-16 15:08:22 -07:00
4882ec2a91 Pass and record process_group_name when creating ProcessGroupNCCL (#123117)
Summary:
Pass python c10d group_name to c++ ProcessGroupNCCL so that the pg name will be consistent across different layers.
Also record pg_name in flight recorder entry.

Differential Revision: D55597200

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123117
Approved by: https://github.com/wconstab
2024-04-16 13:48:35 -07:00
972b8060bd [c10d] make monitorThread sleep when we try to dump (#123788)
Summary:
We seperated the FR dump logic from the desync debug logic,
so we no longer set collectiveDebugInfoMode_ to true when we just need FR
dump. That's why monitor thread did not sleep and try to kill the
process without waiting for the dump.

The fix is simple, we should sleep whenever shouldDump_ is true
Test Plan:
Existing unit tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123788
Approved by: https://github.com/wconstab
2024-04-11 09:19:15 -07:00
3e7683ae18 [c10d] dump on any exception (timeout + nccl error) (#123023)
Summary:
Existing flight recorder dumping logic is: dump only on timeout, but not
on NCCL error. This resulted in the faulty ranks missing dumps when NCCL
error happens.

So in this PR, we revise the logic of dump such that records are dumped
when any exception is detected. Exception could be 1. NCCL async errors.
2. watchdog timeout

Also the existing code tends to mix the logic of flight recorder dump
and desync debug, which is no desirable. We only dump the desync debug
report only when timeout is detected.
Test Plan:
Added a new unit test to trigger nccl error and dump, and make sure the
dump is triggered by the error.

Also existing dump on timeout tests should still pass.

sqzhang_1) [sqzhang@devgpu009.cln1 ~/pytorch (84bf9d4c)]$ python
test/distributed/test_c10d_nccl.py NcclErrorDumpTest
NCCL version 2.19.3+cuda12.0
[E329 19:15:11.775879730 ProcessGroupNCCL.cpp:565] [Rank 0] Watchdog
caught collective operation timeout: WorkNCCL(SeqNum=2,
OpType=ALLREDUCE, NumelIn=10, NumelOut=10, Timeout(ms)=10000) ran for
10028 milliseconds before timing out.
[E329 19:15:11.777459894 ProcessGroupNCCL.cpp:1561] [PG 0 Rank 0]
Exception hit in NCCL work: 2
[E329 19:15:12.660717323 ProcessGroupNCCL.cpp:1332] [PG 0 Rank 0]
Received a timeout signal from this local rank and will start to dump
the debug info. Last enqueued NCCL work: 2, last completed NCCL work: 1.
[E329 19:15:12.660932242 ProcessGroupNCCL.cpp:1167] [PG 0 Rank 0]
ProcessGroupNCCL preparing to dump debug info.
[E329 19:15:12.661192990 ProcessGroupNCCL.cpp:1174] [PG 0 Rank 0]
ProcessGroupNCCL dumping nccl trace to /tmp/tmp06psqil3/trace_0
[F329 19:15:12.661485601 ProcessGroupNCCL.cpp:1185] [PG 0 Rank 0] [PG 0
Rank 0] ProcessGroupNCCL's watchdog detected a collective timeout from
the local rank. This is most likely caused by incorrect usages of
collectives, e.g., wrong sizes used across ranks, the order of
collectives is not same for all ranks or the scheduled collective, for
some reason, didn't run. Additionally, this can be caused by GIL
deadlock or other reasons such as network errors or bugs in the
communications library (e.g. NCCL), etc. We tried our best to dump the
debug info into the storage to help you debug the issue.

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123023
Approved by: https://github.com/wconstab
2024-04-02 15:41:15 -07:00
f2e9ec2dc5 [c10d] dump from one and only one thread (PG0's monitor thread) (#120893)
Summary:
When there are multiple PGs in a process and a hardware failure happens,
we found that multiple PGs/ threads in the same
process are competing to dump the same records at the same time. The
affects the reliability of dumps.

In this PR, we will try to make the change such that only one thread/PG
could dump: PG0's monitor thread. We use a static variable to indicate
that something (e.g., collective timeout) has triggered the dump
locally.

monitor thread would dump debug info under any one of the 3 conditions:
1: this static variable is set to true by the watchdog thread when it detects
a timeout or pipe dump signal
2: timeout signal is received from other ranks through tcpstore
3: no heartbeat of watchdog
Test Plan:
python test/distributed/test_c10d_nccl.py -k
test_timeout_dumps_on_stuck_ranks

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120893
Approved by: https://github.com/wconstab
2024-04-02 15:36:05 -07:00
dde4324d8e [NCCL PG] Enable ncclCommDevIdxMap unconditionally (#122049)
Differential Revision: D54993977

The initial purpose of ncclCommDevIdxMap is to support NCCL zero copy algorithms. Therefore, it is only enabled (with its values filled) if useTensorRegisterAllocatorHook_ is set to true. However, now we rely on it to support dumping NCCL information in a single PG. So we need it to be always available, regardless of whether we enabled useTensorRegisterAllocatorHook_.
Move the code of filling ncclCommDevIdxMap out of if (useTensorRegisterAllocatorHook_) statement.

See diff

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122049
Approved by: https://github.com/shuqiangzhang
2024-03-26 17:14:06 -07:00
94c079104d [c10d] fix the macro definition of NCCL_COMM_DUMP (#120502)
Summary:
Only if both macros are defined, should we dump the comm dump,
otherwise, use the original definition.

The previous implementation missed the function definition when IS_NCCL_EXP is defined but NCCL_COMM_DUMP is not defined

Test Plan:
Build and unit test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120502
Approved by: https://github.com/dsjohns2, https://github.com/Skylion007
2024-03-26 14:09:00 -07:00
a6afee6d94 [c10d][flight recorder] dump additinal NCCL debug info (#120063)
Summary:
This PR is mainly about flight recorder side of changes that takes a
map of maps as input, and dump it as picklable. Also add functions that
should be compiled only when NCCL_COMM_DUMP is defined
Test Plan:
Integration tests with NCCL would be done later, here we only do the
c10d side of dump test, aka,NCCLTraceTest

Testing the dump function is a bit tricky as we don't have
existing C++ unit tests for them. So we still use the Python NCCLTraceTest with
the python binding of _dump_nccl_trace(), we manually fed the
dump_nccl_trace with a map of test info, and assert the pickle result and
print the converted python dict:
```
(sqzhang_1) [sqzhang@devgpu009.cln1 ~/pytorch (main)]$  python
test/distributed/test_c10d_nccl.py NCCLTraceTest
NCCL version 2.19.3+cuda12.0
[rank0]:[E ProcessGroupNCCL.cpp:1200] [PG 0 Rank 0] ProcessGroupNCCL
preparing to dump debug info.
.NCCL version 2.19.3+cuda12.0
.NCCL version 2.19.3+cuda12.0
{'ncclID2': {'Key2': 'Value2', 'Key1': 'Value1'}, 'ncclID1': {'Key2':
'Value2', 'Key1': 'Value1'}}
{'ncclID2': {'Key2': 'Value2', 'Key1': 'Value1'}, 'ncclID1': {'Key2':
'Value2', 'Key1': 'Value1'}}
.NCCL version 2.19.3+cuda12.0
{'ncclID2': {'Key2': 'Value2', 'Key1': 'Value1'}, 'ncclID1': {'Key2':
'Value2', 'Key1': 'Value1'}}
{'ncclID2': {'Key2': 'Value2', 'Key1': 'Value1'}, 'ncclID1': {'Key2':
'Value2', 'Key1': 'Value1'}}
.NCCL version 2.19.3+cuda12.0
.NCCL version 2.19.3+cuda12.0
.NCCL version 2.19.3+cuda12.0
.NCCL version 2.19.3+cuda12.0
.
----------------------------------------------------------------------
Ran 8 tests in 95.761s
OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120063
Approved by: https://github.com/wconstab
2024-03-26 14:08:19 -07:00
d092857531 [Caffe2 CPU tests] Update CMakeLists.txt 2024-02-24 12:18:10 -08:00
6aad5e444a Fix missing MAST log when there is Unicode non-decodable text in logs (#119298)
Summary:
## Issue
When there is Unicode non-decodable text in logs, `tail_logger` will stop working afterwards, i.e. f527390102

In the example, the process stopped producing Python logs after 17:20:21 untill the job finished
```
[0]:I0201 17:20:21.338000 3429 gen_ai/genie_projects/llm/metaformers/reward_model_score.py:335] Progress: 118 batches out of 512 total batches. 23.05 % | (gpu mem: 25.8GB, free CPU mem: 1387.8GB)
I0201 17:39:14 Stopping twtask-main.service with Service Result: [success] Exit Code: [exited] Exit Status: [0]
```
At the end, `UnicodeDecodeError` was thrown at the end with no call stack.

## Fix
Use `errors="replace"` to avoid throwing exception when `UnicodeDecodeError` happens.

Test Plan: f528854819

Differential Revision: D53483644

Co-authored-by: Jack Zhang <jackzh@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119298
Approved by: https://github.com/XilunWu
2024-02-24 12:16:39 -08:00
c54ce9313b [c10d][flight recorder] store a copy of string in entry (#119837)
Summary:
Previously, we just store the char pointer in entry, the string is a
temp object and will be destructed when we want to dump/access it.

A quick fix is to store a copy of the string, but without changing the
upstream char*.

An alternative is to change every profilingTitle into std:string, this
however would needs comprehensive overhall of the code up to the
c10d::work layer above workNCCL and RecordFunction etc.

We chose the first option for this change

Resolve #119808

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119837
Approved by: https://github.com/zdevito, https://github.com/wconstab
2024-02-14 11:38:10 -08:00
1fe59f4ef7 [c10d][flight recorder] remove unintended assignment of entry (#119748)
Summary:
auto& entry = entries_.at(*id % max_entries_);
entry = entries_.at(*id % max_entries_);
The above line of code has unintended consequence of invoking copy/assignment
of entry objects as ref itself cannot be re-assigned.

Also what could cause the crash is that the entry ref could become invalid if entries_ are
resized by other threads. and this could result in 'copy to a garbage
location'. The fix is to use a pointer which can be re-assigned after
re-acquiring the lock

Tests: python test/distributed/test_c10d_nccl.py NCCLTraceTest

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119748
Approved by: https://github.com/wconstab, https://github.com/fegin
2024-02-14 11:38:10 -08:00
e693fb2bb1 [nccl flight recorder] record time we discover start and complete (#119249)
Some APIs like ncclCommAbort can cause nccl kernels to finish even if
they were previously stuck. Because we can gather the trace buffer after
those calls, we can end up seeing some collectives marked completed eventhough
that complete happened several minutes after they started and clearly after
the timeout. This changes how we record state so that we keep track of the time
we discover a state change, so even if eventually the collective gets marked complete,
we can observe it happened minutes after it was schedule.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119249
Approved by: https://github.com/wconstab
2024-02-14 11:38:10 -08:00
4fe510baf6 [NCCL PG] log NCCL comm at creation and abort (#118335)
Summary: It helps correlate NCCL PG with corresponding NCCL comm in separate logs.

Differential Revision: D53107647

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118335
Approved by: https://github.com/wconstab
2024-02-14 11:38:04 -08:00
16 changed files with 1088 additions and 180 deletions

View File

@ -151,6 +151,10 @@ struct TORCH_CUDA_CPP_API CUDAEvent {
TORCH_CHECK(is_created_ && other.isCreated(),
"Both events must be recorded before calculating elapsed time.");
float time_ms = 0;
// We do not strictly have to set the device index to the same as our event,
// but if we don't and the current device is not initialized, it will
// create a new cuda context, which will consume a lot of memory.
CUDAGuard guard(device_index_);
// raise cudaErrorNotReady if either event is recorded but not yet completed
AT_CUDA_CHECK(cudaEventElapsedTime(&time_ms, event_, other.event_));
return time_ms;

View File

@ -1732,7 +1732,7 @@ if(BUILD_TEST)
foreach(test_src ${Caffe2_CPU_TEST_SRCS})
get_filename_component(test_name ${test_src} NAME_WE)
add_executable(${test_name} "${test_src}")
target_link_libraries(${test_name} torch_library gtest_main)
target_link_libraries(${test_name} torch_library gtest_main stdc++)
target_include_directories(${test_name} PRIVATE $<INSTALL_INTERFACE:include>)
target_include_directories(${test_name} PRIVATE $<BUILD_INTERFACE:${CMAKE_BINARY_DIR}/include>)
target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE})

View File

@ -4,9 +4,10 @@ import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
from torch.distributed._tensor import DTensor
from torch.distributed._tensor.placement_types import Shard
from torch.distributed.checkpoint._state_dict_utils import (
from torch.distributed._state_dict_utils import (
_check_state_dict_similarity,
_copy_state_dict,
_create_cpu_state_dict,
_gather_state_dict,
_offload_state_dict_to_cpu,
)
@ -115,6 +116,58 @@ class TestStateDictUtils(DTensorTestBase):
}
self.assertEqual(state_dict, _gather_state_dict(dist_state_dict))
@skip_if_lt_x_gpu(2)
def test_create_cpu_state_dict(self):
device = torch.device("cuda")
buffer = io.BytesIO()
torch.save(torch.ones(10), buffer)
buffer.seek(0)
state_dict = {
"tensor1": torch.arange(10, device=device),
"tensor2": torch.ones(10, device=device),
"non_tensor_bytes_io": copy.deepcopy(buffer),
"non_tensor_bytes": buffer.read(),
"step": torch.tensor(7, dtype=torch.float),
"lr": 1.5,
"nested": {"list": [1, 2, 3, 4]},
}
def _verify(cpu_state_dict):
# Verify the correctness of _check_state_dict_similarity()
self.assertTrue(_check_state_dict_similarity(state_dict, cpu_state_dict))
tensor1 = cpu_state_dict["tensor1"]
cpu_state_dict["tensor1"] = torch.arange(11)
self.assertFalse(_check_state_dict_similarity(state_dict, cpu_state_dict))
cpu_state_dict["tensor1"] = tensor1
_copy_state_dict(state_dict, cpu_state_dict)
# Verify if _copy_state_dict works
for v in cpu_state_dict.values():
if isinstance(v, torch.Tensor):
self.assertFalse(v.is_cuda)
self.assertEqual(cpu_state_dict["tensor1"], torch.arange(10))
self.assertEqual(cpu_state_dict["tensor2"], torch.ones(10))
buffer.seek(0)
cpu_state_dict["non_tensor_bytes_io"].seek(0)
self.assertEqual(
cpu_state_dict["non_tensor_bytes_io"].read(), buffer.read()
)
buffer.seek(0)
self.assertEqual(cpu_state_dict["non_tensor_bytes"], buffer.read())
self.assertEqual(cpu_state_dict["lr"], 1.5)
self.assertEqual(cpu_state_dict["step"], 7)
self.assertEqual(cpu_state_dict["nested"], {"list": [1, 2, 3, 4]})
cpu_state_dict = _create_cpu_state_dict(state_dict, pin_memory=True)
_verify(cpu_state_dict)
cpu_state_dict = _create_cpu_state_dict(state_dict, share_memory=True)
_verify(cpu_state_dict)
cpu_state_dict = _create_cpu_state_dict(
state_dict, share_memory=True, pin_memory=True
)
_verify(cpu_state_dict)
if __name__ == "__main__":
run_tests()

View File

@ -11,6 +11,7 @@ import tempfile
import threading
import pickle
import time
import json
import warnings
from contextlib import contextmanager
from datetime import datetime, timedelta
@ -1334,6 +1335,19 @@ class ProcessGroupNCCLTest(MultiProcessTestCase):
self.assertEqual(tensor, original_tensor)
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
def test_set_process_group_desc(self):
store = c10d.FileStore(self.file_name, self.world_size)
device = torch.device(f'cuda:{self.rank}')
pg_default = self._create_process_group_nccl(store, self.opts(), device_id=device)
self.assertEqual(pg_default.group_desc, "default_pg")
pg_1 = c10d.new_group([0, 1], group_desc="test_purpose")
self.assertEqual(pg_1.group_desc, "test_purpose")
pg_2 = c10d.new_group([0, 1])
self.assertEqual(pg_2.group_desc, "undefined")
class DistributedDataParallelTest(
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
):
@ -3637,11 +3651,18 @@ class NCCLTraceTest(NCCLTraceTestBase):
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
ver = t['version']
self.assertEqual(ver, "1.0")
self.assertEqual(ver, "1.1")
t = t['entries']
self.assertEqual(len(t), 2)
last = t[-1]
self.assertEqual(last['process_group'], ('0', 'default_pg'))
self.assertEqual(last['state'], 'completed')
s = last['time_discovered_started_ns']
f = last['time_discovered_completed_ns']
self.assertIsNotNone(f)
if timing_enabled:
self.assertIsNotNone(s)
self.assertTrue(s <= f)
self.assertIn('test_c10d_nccl.py', str(last['frames']))
self.assertEqual(last['input_sizes'], ((3, 4),))
self.assertEqual(last['output_sizes'], ((3, 4),))
@ -3718,6 +3739,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
self.assertEqual(len(t), 10)
first = t[0]
last = t[-1]
self.assertEqual(last['profiling_name'], 'nccl:all_reduce')
self.assertEqual(last['state'], 'completed')
self.assertIn('test_c10d_nccl.py', str(last['frames']))
self.assertEqual(last['input_sizes'], ((3, 4),))
@ -3750,6 +3772,7 @@ class NCCLTraceTest(NCCLTraceTestBase):
e.synchronize()
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
t = t['entries']
self.assertEqual(t[-1]['profiling_name'], 'nccl:all_reduce')
if self.rank == 0:
self.assertEqual(t[-1]['seq_id'], 1)
self.assertEqual(t[-1]['state'], 'completed')
@ -3792,12 +3815,14 @@ class NCCLTraceTest(NCCLTraceTestBase):
time.sleep(5)
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
t = t['entries']
self.assertEqual(t[-1]['profiling_name'], 'nccl:all_reduce')
if self.rank == 0:
self.assertEqual(t[-1]['seq_id'], 1)
self.assertEqual(t[-1]['state'], 'completed')
else:
self.assertEqual(t[-1]['seq_id'], 2)
self.assertEqual(t[-1]['state'], self.started_or_scheduled(timing_enabled))
self.assertIsNone(t[-1]['time_discovered_completed_ns'])
# this will eventually cause the missing rank 0
# to continue which will unblock the non-zero ranks
self.parent.send('next')
@ -3851,9 +3876,10 @@ class NCCLTraceTestDumpOnTimeout(NCCLTraceTestDumpOnTimeoutBase):
@skip_but_pass_in_sandcastle_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_timeout_dumps(self, timing_enabled):
# We need to completely disable the coordinated timeout dump to avoid rank 0
# also timeout so that we set the check frequency to be very large (25 min).
os.environ['TORCH_NCCL_COORD_CHECK_MILSEC'] = '1500000'
# dump on heartbeatmonitor thread
os.environ['TORCH_NCCL_COORD_CHECK_MILSEC'] = '1000'
# need rank0 to crash before looking for its output file
os.environ['TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC'] = '1'
if self.rank == self.MAIN_PROCESS_RANK:
# wait for rank0 to crash before looking for its output file
@ -3938,6 +3964,60 @@ class NCCLTraceTestTimeoutDumpOnStuckRanks(NCCLTraceTestDumpOnTimeoutBase):
# getting the global signal to dump the debugging info.
time.sleep(600)
class NcclErrorDumpTest(NCCLTraceTestBase):
def _wait_process(self, rank, timeout):
try:
self.processes[rank].join(timeout)
return self.processes[rank].exitcode
except TimeoutError:
return None
def _check_return_codes(self, elapsed_time):
# the base test infra assumes processes exit with matching return codes,
# but we want rank0 to abort with exception and rank1 to exit with exit 1
self.assertEqual(self.processes[0].exitcode, -6)
self.assertEqual(self.processes[1].exitcode, 1)
@requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_nccl_errors_dump(self):
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = '1000'
os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = '1'
# need rank0 to dump before abort
os.environ['TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC'] = '5'
if self.rank == self.MAIN_PROCESS_RANK:
# wait for both rank0 and 1 to crash before looking for dump
self.assertEqual(self._wait_process(0, timeout=90), -6)
self.assertEqual(self._wait_process(1, timeout=90), 1)
# verify that the trace file exists for rank0
self.assertTrue(os.path.exists(self._trace_name(rank=0)))
return
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(
store,
self.rank,
self.world_size,
timeout=timedelta(seconds=10),
)
process_group.allreduce(torch.rand(10).cuda(self.rank))
if self.rank == 0:
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
# expect an error to be raised
with self.assertRaisesRegex(dist.DistBackendError, ""):
# Block the current stream on the NCCL stream
work.wait()
# Run some GPU operations
a = torch.rand(10).cuda(self.rank)
elif self.rank == 1:
# Clean up structures (ex: files for FileStore before going down)
del process_group
sys.exit(1)
if __name__ == "__main__":
assert (

View File

@ -463,6 +463,7 @@ class ProcessGroup:
backend: Optional[ProcessGroup],
) -> None: ...
def _set_group_name(self, name: str) -> None: ...
def _set_group_desc(self, desc: str) -> None: ...
def name(self) -> str: ...
def _has_hooks(self) -> bool: ...
def _wait_for_pending_works(self) -> None: ...
@ -471,6 +472,10 @@ class ProcessGroup:
def bound_device_id(self) -> Optional[torch.device]: ...
@bound_device_id.setter
def bound_device_id(self, device: Optional[torch.device]) -> None: ...
@property
def group_name(self) -> str: ...
@property
def group_desc(self) -> str: ...
class ProcessGroupRoundRobin(ProcessGroup): ...

View File

@ -369,6 +369,14 @@ class TORCH_API Backend : public torch::CustomClassHolder {
return pg_name_;
}
void setGroupDesc(const std::string& desc) {
pg_desc_ = desc;
}
const std::string& getGroupDesc() const {
return pg_desc_;
}
// See similar functions in ProcessGroup.hpp for context.
c10::optional<at::Device> getBoundDeviceId() const {
return bound_device_id_;
@ -399,6 +407,7 @@ class TORCH_API Backend : public torch::CustomClassHolder {
// remains the same across use of this process group.
DebugLevel dist_debug_level_;
std::string pg_name_;
std::string pg_desc_;
std::function<void(std::shared_ptr<WorkInfo>)> onCompletionHook_;

View File

@ -8,6 +8,7 @@
#include <memory>
#include <mutex>
#include <ATen/ATen.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <nccl.h>
@ -282,6 +283,18 @@ class NCCLComm {
}
#endif
#if defined(IS_NCCL_EXP) && defined(NCCL_COMM_DUMP)
std::unordered_map<std::string, std::string> ncclCommDump() {
std::unordered_map<std::string, std::string> dump;
if (isAborted()) {
LOG(INFO) << "Communicator was aborted before trying to dump its state.";
return dump;
}
C10D_NCCL_CHECK(::ncclCommDump(ncclComm_, dump), c10::nullopt);
return dump;
}
#endif
ncclUniqueId getNcclId() {
return ncclId_;
}
@ -337,6 +350,9 @@ class NCCLComm {
// Set true failure reason if provided by ProcessGroupNCCL (e.g. work
// timeout)
commFailureReason_ = commFailureReason;
LOG(INFO) << "Aborting ncclComm_ " << ncclComm_ << " with reason: "
<< (commFailureReason ? *commFailureReason
: "No abort reason provided.");
#ifndef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_);
#else
@ -436,6 +452,8 @@ class NCCLComm {
#endif
}
friend class ProcessGroupNCCL;
protected:
ncclComm_t ncclComm_;
// Unique nccl_id for this communicator.

View File

@ -165,6 +165,18 @@ void ProcessGroup::setGroupName(const std::string& name) {
}
}
const std::string& ProcessGroup::getGroupDesc() const {
return pg_desc_;
}
void ProcessGroup::setGroupDesc(const std::string& name) {
pg_desc_ = name;
// Also set the group desc for all backends
for (auto& kv : deviceTypeToBackend_) {
kv.second->setGroupDesc(name);
}
}
void ProcessGroup::enableCollectivesTiming() {
for (auto& kv : deviceTypeToBackend_) {
kv.second->enableCollectivesTiming();

View File

@ -694,6 +694,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const std::string& getGroupName() const;
void setGroupName(const std::string& name);
const std::string& getGroupDesc() const;
void setGroupDesc(const std::string& name);
void enableCollectivesTiming();
void release_resources() override;
@ -724,6 +726,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
const int size_;
const c10::intrusive_ptr<Options> options_;
const BackendType backendType_;
std::string pg_desc_;
// Debug level setting. It is parsed once when ProcessGroup is constructed and
// remains the same across use of this process group.

View File

@ -4,13 +4,6 @@
#include <mutex>
#include <sstream>
#if defined(__linux__)
#include <fcntl.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#endif
#ifdef USE_C10D_NCCL
#include <exception>
@ -301,6 +294,9 @@ inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) {
static std::unordered_map<std::shared_ptr<NCCLComm>, int> ncclCommDevIdxMap;
static std::mutex ncclCommDevIdxMapMutex;
static bool allocatorHooksAttached = false;
std::atomic<bool> ProcessGroupNCCL::shouldDump_(false);
void cacheAllocatorRegisterHook(
const c10::cuda::CUDACachingAllocator::TraceEntry& te) {
// Register after SEGMENT_ALLOC
@ -337,9 +333,34 @@ void cacheAllocatorDeregisterHook(
}
}
#if defined(IS_NCCL_EXP) && defined(NCCL_COMM_DUMP)
std::string dump_nccl_trace() {
return NCCLTraceBuffer::get()->dump();
std::unordered_map<
std::string /* ncclUniqueID */,
std::unordered_map<std::string, std::string> /* dump from this comm */>
ncclDumpMap;
// dump_nccl_trace is only called from the default PG (uid_=0), but we want to
// dump from all comms so we need to iterate over ncclCommDevIdxMap, which
// is static
std::vector<std::shared_ptr<NCCLComm>> allNCCLComms;
// within the critical section, we don't want to dump while holding the lock
// as dump might hang
ncclCommDevIdxMapMutex.lock();
for (auto& [ncclComm, _] : ncclCommDevIdxMap) {
allNCCLComms.push_back(ncclComm);
}
ncclCommDevIdxMapMutex.unlock();
for (auto& ncclComm : allNCCLComms) {
std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId());
ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump();
}
return NCCLTraceBuffer::get()->dump(ncclDumpMap);
}
#else
std::string dump_nccl_trace() {
return NCCLTraceBuffer::get()->dump(c10::nullopt);
}
#endif
c10::optional<std::function<std::string()>>& get_cpp_trace_dumper() {
static c10::optional<std::function<std::string()>> dumper(c10::nullopt);
@ -744,13 +765,17 @@ ProcessGroupNCCL::ProcessGroupNCCL(
ValueError,
at::cuda::getNumGPUs() != 0,
"ProcessGroupNCCL is only supported with GPUs, no GPUs found!");
this->setGroupName(options_->group_name);
logPrefix_ = createLogPrefix();
blockingWait_ = getCvarBool(TORCH_NCCL_BLOCKING_WAIT, false);
asyncErrorHandling_ = static_cast<ErrorHandlingMode>(
getCvarInt(TORCH_NCCL_ASYNC_ERROR_HANDLING, 3 /*SkipCleanUp*/));
desyncDebug_ = getCvarBool(TORCH_NCCL_DESYNC_DEBUG, false) ||
(dist_debug_level_ >= DebugLevel::Detail);
dumpOnTimeout_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, false) ||
// TODO, we should either deprecate TORCH_NCCL_DUMP_ON_TIMEOUT
// or change its name to reflect that dump happens on exception including
// both timeout and other errors.
dumpOnException_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, false) ||
(dist_debug_level_ >= DebugLevel::Detail);
heartbeat_ = 1ULL;
monitorThreadEnabled_.store(getCvarBool(TORCH_NCCL_ENABLE_MONITORING, true));
@ -827,7 +852,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
<< "NCCL version: " << getNcclVersion() << ", size: " << size
<< ", global rank: " << globalRank()
<< ", TORCH_NCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_
<< ", TORCH_NCCL_DUMP_ON_TIMEOUT: " << dumpOnTimeout_
<< ", TORCH_NCCL_DUMP_ON_TIMEOUT: " << dumpOnException_
<< ", TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: "
<< waitTimeoutDumpInMilSec_
<< ", TORCH_NCCL_DESYNC_DEBUG: " << desyncDebug_
@ -848,7 +873,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
<< ", TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_
<< ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << ncclTraceBufferSize_
<< ", TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_
<< ", ID=" << this->getID();
<< ", PG Name: " << options_->group_name;
RECORD_PARAM_COMMS(
0, // seq
@ -1086,6 +1111,55 @@ void ProcessGroupNCCL::waitForDumpOrTimeout(
std::this_thread::sleep_until(wakeUpTime);
}
// WHC - pulled this from
// https://github.com/pytorch/pytorch/commit/893dcac068f13542b1e00e3e55bca4530ab412cb
// to help cherry-pick go through. did not cherry-pick entirety of the PR that
// provided this new util function.
void ProcessGroupNCCL::waitForFutureOrTimeout(
std::future<bool>& fut,
const std::chrono::milliseconds& timeOutMilSec,
const std::string& futDescription) {
TORCH_CHECK(fut.valid(), "Expected a valid future");
std::future_status status = fut.wait_for(timeOutMilSec);
if (status == std::future_status::ready) {
// Calling .get() will re-raise any exception from the future, and we don't
// care about the retval
try {
bool result = fut.get();
if (result) {
LOG(INFO) << logPrefix()
<< "future is successfully executed for: " << futDescription;
}
} catch (const std::exception& e) {
C10_THROW_ERROR(
DistBackendError,
c10::str(
logPrefix(),
"Exception thrown when waitng for future ",
futDescription,
": ",
e.what()));
} catch (...) {
C10_THROW_ERROR(
DistBackendError,
c10::str(
logPrefix(),
"Unknown exception thrown when waitng for future ",
futDescription));
}
} else {
C10_THROW_ERROR(
DistBackendError,
c10::str(
logPrefix(),
"Future for ",
futDescription,
" timed out after ",
timeOutMilSec.count(),
" ms"));
}
}
void ProcessGroupNCCL::abortCommsFromMap(
std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>&
ncclCommsMap,
@ -1097,6 +1171,8 @@ void ProcessGroupNCCL::abortCommsFromMap(
auto& ncclComms = it.second;
for (const auto& ncclComm : ncclComms) {
LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ "
<< ncclComm->ncclComm_ << " on CUDA device: " << devName;
ncclComm->ncclCommAbort(abortReason);
}
// Note that we don't remove the aborted communicators from the
@ -1117,8 +1193,9 @@ void ProcessGroupNCCL::abortCommsFromMap(
}
}
LOG(INFO) << logPrefix() << "] Destroyed " << ncclComms.size()
<< "communicators on CUDA device: " << devName
LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroyed "
<< ncclComms.size()
<< " communicators on CUDA device: " << devName
<< " with stream: " << streamId;
}
}
@ -1226,12 +1303,18 @@ void ProcessGroupNCCL::heartbeatMonitor() {
uint64_t heartBeatCounter = 0ULL;
std::string errorMsg;
std::string exitMsg;
bool checkTimeoutSignal = (dumpOnTimeout_ && uid_ == 0);
int monitorPollInterval = checkTimeoutSignal ? coordCheckIntervalMilSec_
: heartbeatTimeoutInSec_ * 1000;
bool checkDumpSignal = (dumpOnException_ && uid_ == 0);
int monitorPollInterval = checkDumpSignal ? coordCheckIntervalMilSec_
: heartbeatTimeoutInSec_ * 1000;
auto lastTimePollStore = std::chrono::steady_clock::now();
auto lastTimeHeartBeatCheck = std::chrono::steady_clock::now();
std::future<bool> asyncDebugDump;
c10::optional<DumpPipe> dumpPipe = c10::nullopt;
if (uid_ == 0) {
// DumpPipe is one per-trainer process, and its convenient to name them
// after 'global' ranks in the system, So we assume processgroup (uid)==0 is
// the global PG and has globally unique rank ids across trainers.
dumpPipe.emplace(rank_);
}
while (true) {
// This won't have any lock since this lock is only used here.
// Please be aware that mutex `monitorMutex_` should not be used
@ -1254,7 +1337,28 @@ void ProcessGroupNCCL::heartbeatMonitor() {
// to see if any PG on any rank observed a timeout and signaled peers to
// dump debugging info, and we avoid hammering the TCPStore from all PGs on
// the same rank.
if (checkTimeoutSignal) {
if (checkDumpSignal) {
// There are two scenarios where monitor thread will dump on timeout:
// 1. The local rank is the first to observe a timeout.shouldDump_ will be
// set to true.
// 2. other ranks detected the timeout and signal the local rank to dump
// In addtion, monitor threads will dump if watchdog threads has no
// heartbeat or dumpPipe is not empty.
if (shouldDump_.load()) {
errorMsg = c10::str(
logPrefix(),
"Received a dump signal from this local rank and will ",
"start to dump the debug info.");
exitMsg = c10::str(
"ProcessGroupNCCL's watchdog detected an exception from the local rank. ",
"This is most likely caused by incorrect usages of collectives, e.g., wrong ",
"sizes used across ranks, the order of collectives is not same for all ranks ",
"or the scheduled collective, for some reason, didn't run. Additionally, ",
"this can be caused by GIL deadlock or other reasons such as network errors or ",
"bugs in the communications library (e.g. NCCL), etc. We tried our best to ",
"dump the debug info into the storage to help you debug the issue.");
break;
}
// We poll store to see if some ranks have flagged a timeout when
// we haven't polled for `heartbeat_timeout` seconds and there haven't
// any work added or removed for `watchdog_timeout` seconds.
@ -1263,13 +1367,28 @@ void ProcessGroupNCCL::heartbeatMonitor() {
computeDeltaMS(lastTimePollStore, currentTime) >=
coordCheckIntervalMilSec_) {
lastTimePollStore = currentTime;
if (globalStore_->check({std::string(TIMEOUT_DUMP)})) {
if (globalStore_->check({std::string(EXCEPTION_DUMP)})) {
int timeOutRank = -1;
shouldDump_.store(true);
try {
auto vec = globalStore_->get(std::string(EXCEPTION_DUMP));
TORCH_CHECK_WITH(
DistBackendError,
vec.size() == sizeof(int),
"Invalid size for the timeout rank ID");
std::memcpy(&timeOutRank, vec.data(), vec.size());
} catch (const std::exception& e) {
LOG(ERROR)
<< "Failed to get timeout rank ID from the global store.";
}
errorMsg = c10::str(
logPrefix(),
"Received a global timeout from another rank and will ",
"start to dump the debug info.");
"Received a global dump signal from rank and will ",
"start to dump the debug info. ");
exitMsg = c10::str(
"ProcessGroupNCCL's watchdog detected a collective timeout and notified current rank. ",
"ProcessGroupNCCL's watchdog detected a dump signal from rank ",
timeOutRank,
" and notified the current rank. ",
"This is most likely caused by incorrect usages of collectives, e.g., wrong ",
"sizes used across ranks, the order of collectives is not same for all ranks ",
"or the scheduled collective, for some reason, didn't run. Additionally, ",
@ -1289,6 +1408,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
if (heartbeat != heartBeatCounter) {
heartBeatCounter = heartbeat;
} else {
shouldDump_.store(true);
// No heartbeat increase detected and timeout.
errorMsg = c10::str(
logPrefix(),
@ -1310,6 +1430,14 @@ void ProcessGroupNCCL::heartbeatMonitor() {
break;
}
}
// process a request to dump the trace. only PG uid 0 will respond to dump
// requests, but this is fine since all PG's feed into the same flight
// recorder and dump. After dump, the training should continue.
if (dumpPipe.has_value() && dumpPipe->shouldDump()) {
// best effort dump, not waiting for the dump here
std::future<bool> fut = std::async(
std::launch::async, [this]() { return this->dumpDebuggingInfo(); });
}
}
LOG(ERROR) << errorMsg;
@ -1318,10 +1446,16 @@ void ProcessGroupNCCL::heartbeatMonitor() {
LOG(INFO) << "Dumping c++ stacktraces: " << cpp_dumper.value()();
}
auto wakeUpTime = getWakeupTime(waitTimeoutDumpInMilSec_);
// Store debug info to storage if no other thread does it. (By default to
// local disk)
asyncDebugDump = launchAsyncDebugDump();
std::future<bool> asyncDebugDump = std::async(
std::launch::async, [this]() { return this->dumpDebuggingInfo(); });
// wait for the dump until timeout
waitForFutureOrTimeout(
asyncDebugDump,
std::chrono::milliseconds(waitTimeoutDumpInMilSec_),
"Flight recorder dump in heartbeatMonitor");
if (get_gil_checker() != nullptr) {
auto fut = launchAsyncGilCheck();
@ -1348,7 +1482,8 @@ void ProcessGroupNCCL::heartbeatMonitor() {
// Case two: desync might be slow or get stuck. Or we get stuck in
// destructors, we will sleep for some time before calling std::abort() to
// kill the whole process.
if ((terminateProcessGroup_.load() || collectiveDebugInfoMode_.load()) &&
if ((terminateProcessGroup_.load() || collectiveDebugInfoMode_.load() ||
shouldDump_.load()) &&
!terminateHeartbeatMonitorThread_.load()) {
// Leave another two mins for desync report generation or process group
// destroy.
@ -1364,7 +1499,6 @@ void ProcessGroupNCCL::heartbeatMonitor() {
// We already log completion inside the thread, so it may not be necessary to
// check the return value here. We mainly use a future so we can exit early
// if done.
waitForDumpOrTimeout(asyncDebugDump, wakeUpTime);
if (!terminateHeartbeatMonitorThread_.load()) {
// Create a error message reported from MonitorThread, so
@ -1445,61 +1579,12 @@ std::string ProcessGroupNCCL::getNCCLWatchdogDebugInfo() {
return retrieveDesyncReport(store_, "NCCL", rank_, size_);
}
#if defined(__linux__)
struct DumpPipe {
DumpPipe(int rank) {
std::string fileStem =
getCvarString({"TORCH_NCCL_DEBUG_INFO_PIPE_FILE"}, "");
if (fileStem.empty() ||
getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0) <= 0) {
return;
}
TORCH_CHECK(!fileStem.empty(), "TORCH_NCCL_DEBUG_INFO_TEMP_FILE is empty");
std::string filename = c10::str(fileStem, rank, ".pipe");
TORCH_CHECK(
unlink(filename.c_str()) != -1 || errno == ENOENT,
"Error removing existing named pipe ",
filename);
TORCH_CHECK(
mkfifo(filename.c_str(), 0666) != -1,
"Error creating named pipe ",
filename);
fd_ = open(filename.c_str(), O_RDONLY | O_NONBLOCK);
LOG(INFO) << "Pipe file " << filename
<< " has been opened, write to it to trigger NCCL Debug Dump.";
TORCH_CHECK(fd_ != -1, "Error opening named pipe ", filename);
}
bool shouldDump() {
if (fd_ == -1) {
return false;
}
char buf[128];
// non-blocking from O_NONBLOCK above.
// Ignore EINTR because we already will poll this
// again later.
ssize_t bytesRead = read(fd_, &buf, 128);
return bytesRead > 0;
}
~DumpPipe() {
if (fd_ != -1) {
close(fd_);
}
}
private:
int fd_ = -1;
};
#else
struct DumpPipe {
DumpPipe(int rank) {}
bool shouldDump() {
return false;
}
};
#endif
std::string ProcessGroupNCCL::createLogPrefix() const {
return c10::str("[PG ", uid_, " Rank ", rank_, "] ");
if (!pg_desc_.empty() && pg_desc_ != "undefined") {
return c10::str("[PG ", pg_name_, " (", pg_desc_, ") Rank ", rank_, "] ");
} else {
return c10::str("[PG ", pg_name_, " Rank ", rank_, "] ");
}
}
const std::string& ProcessGroupNCCL::logPrefix() const {
@ -1514,17 +1599,8 @@ const int& ProcessGroupNCCL::globalRank() const {
void ProcessGroupNCCL::watchdogHandler() {
bool done = false;
lastWorkListUpdateTime_ = std::chrono::steady_clock::now();
c10::optional<std::future<bool>> optAsyncDebugDump;
std::list<ProcessGroupNCCL::WorkNCCL> completedWorkList;
c10::optional<DumpPipe> dumpPipe = c10::nullopt;
if (uid_ == 0) {
// DumpPipe is one per-trainer process, and its convenient to name them
// after 'global' ranks in the system, So we assume processgroup (uid)==0 is
// the global PG and has globally unique rank ids across trainers.
dumpPipe.emplace(rank_);
}
while (!done || !terminateProcessGroup_.load()) {
std::unique_lock<std::mutex> lock(workMetaListMutex_);
// We busy-poll the work vector every kWatchdogThreadSleepMillis
@ -1544,6 +1620,28 @@ void ProcessGroupNCCL::watchdogHandler() {
// If work hits an exception (either an error or timeout)
if (work.exception()) {
// try to dump flight records if exception happens.
// Flight recorder behavior should be independent of desync Debug
if (dumpOnException_) {
try {
auto rank = globalRank();
auto vec = std::vector<uint8_t>(
reinterpret_cast<uint8_t*>(&rank),
reinterpret_cast<uint8_t*>(&rank) + sizeof(rank));
globalStore_->set(std::string(EXCEPTION_DUMP), vec);
// signal the monitor thread to start dumping
shouldDump_.store(true);
// This sleep is used to give time for dumping before throwing
// exception
std::this_thread::sleep_for(
std::chrono::seconds(heartbeatTimeoutInSec_));
} catch (const std::exception& e) {
LOG(ERROR) << logPrefix()
<< "Failed to set dump signal in tcpstore. "
<< "Error: " << e.what();
}
}
if (SHOULD_CLEAN_UP(asyncErrorHandling_)) {
// Abort work and corresponding communicators
work.abort();
@ -1554,40 +1652,22 @@ void ProcessGroupNCCL::watchdogHandler() {
// Report desync state in case of timeout
if (timedOut) {
try {
if (desyncDebug_ || dumpOnTimeout_) {
// Set shutdown mode, so the heartbeat monitor thread will not
// abort process immediately.
if (desyncDebug_) {
try {
collectiveDebugInfoMode_.store(true);
std::vector<uint8_t> vec(1);
globalStore_->set(std::string(TIMEOUT_DUMP), vec);
}
auto wakeUpTime = getWakeupTime(waitTimeoutDumpInMilSec_);
if (dumpOnTimeout_ && !optAsyncDebugDump) {
// Store debug info to storage. (By default to local disk)
optAsyncDebugDump = launchAsyncDebugDump();
}
if (desyncDebug_) {
auto desyncMsg = getNCCLWatchdogDebugInfo();
LOG(ERROR) << logPrefix() << desyncMsg;
} catch (const std::exception& e) {
LOG(ERROR)
<< logPrefix()
<< "Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. "
<< " Please file an issue. Error: " << e.what();
} catch (...) {
LOG(ERROR)
<< logPrefix()
<< "Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error."
<< " Please file an issue.";
}
if (dumpOnTimeout_) {
// Store debug info to storage. (By default to local disk)
waitForDumpOrTimeout(*optAsyncDebugDump, wakeUpTime);
}
} catch (const std::exception& e) {
LOG(ERROR) << logPrefix()
<< "Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. "
<< " Please file an issue. Error: " << e.what();
} catch (...) {
LOG(ERROR)
<< logPrefix()
<< "Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error."
<< " Please file an issue.";
}
}
// Throw exception
@ -1630,12 +1710,6 @@ void ProcessGroupNCCL::watchdogHandler() {
// in case processing is slowed down (but not hung) by cuda api contention
heartbeat_++;
}
// process a request to dump the trace. only PG uid 0 will respond to dump
// requests, but this is fine since all PG's feed into the same flight
// recorder and dump.
if (dumpPipe.has_value() && dumpPipe->shouldDump()) {
launchAsyncDebugDump();
}
done = workMetaList_.empty();
}
}
@ -1882,6 +1956,15 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
// Create the unique NCCL ID and broadcast it
ncclUniqueId ncclID;
// reset log prefix to include group_desc
logPrefix_ = createLogPrefix();
#ifdef NCCL_COMM_DESCRIPTION
// Pass process group name and description to NCCL communicator
std::string commDesc = pg_desc_ + ':' + pg_name_;
options_->config.commDesc = strdup(commDesc.c_str());
#endif
// For batch_isend_irecv, ncclGroupStart() would be called upfront
bool batchP2P = ncclActiveGroupCounter_ > 0;
bool singleP2POp = isP2POp(opType, batchP2P);
@ -1893,7 +1976,16 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
// For point-to-point communication on the same process, don't need broadcast.
if (!isSendRecvSelf) {
// Broadcast so that each process can have a unique NCCL ID
auto timeStarted = std::chrono::steady_clock::now();
broadcastUniqueNCCLID(&ncclID, singleP2POp, devicesKey, p2pRank);
auto timerDeltaMs =
std::chrono::duration_cast<std::chrono::duration<double>>(
std::chrono::steady_clock::now() - timeStarted)
.count() *
1000;
LOG(INFO) << logPrefix()
<< "ProcessGroupNCCL broadcast unique ID through store took "
<< timerDeltaMs << " ms";
}
at::cuda::OptionalCUDAGuard gpuGuard;
@ -2004,6 +2096,13 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
}
#endif
for (const auto i : c10::irange(devices.size())) {
int deviceIndex = devices[i].index();
LOG(INFO) << logPrefix() << "ProcessGroupNCCL created ncclComm_ "
<< ncclComms[i]->ncclComm_ << " on CUDA device: " << deviceIndex;
}
logPrefix_ = createLogPrefix(); // reset log prefix to include group_desc
// At this point NCCL should have been initialized, hence we can accurately
// get the env value even if NCCL sets it by reading from nccl.conf file
LOG(INFO) << logPrefix()
@ -2052,18 +2151,17 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
segmentInfo.total_size);
}
}
// Record the mapping between ncclComm and device index so that later
// register hook can register a newly allocated segment to communicators
// on the same device.
// NOTE: we need remove the communicator from this map when it is
// destroyed, otherwise may register onto an invalid communicator.
ncclCommDevIdxMapMutex.lock();
for (const auto i : c10::irange(devices.size())) {
ncclCommDevIdxMap.emplace(ncclComms[i], devices[i].index());
}
ncclCommDevIdxMapMutex.unlock();
}
// Record the mapping between ncclComm and device index so that later
// register hook can register a newly allocated segment to communicators
// on the same device.
// NOTE: we need remove the communicator from this map when it is
// destroyed, otherwise may register onto an invalid communicator.
ncclCommDevIdxMapMutex.lock();
for (const auto i : c10::irange(devices.size())) {
ncclCommDevIdxMap.emplace(ncclComms[i], devices[i].index());
}
ncclCommDevIdxMapMutex.unlock();
}
it = devNCCLCommMap_.find(devicesKey);
@ -2284,8 +2382,10 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
enableTiming_.load());
r->trace_id_ = NCCLTraceBuffer::get()->record(
uid_,
pg_name_,
seq_,
profilingTitle,
// create a string copy of profilingTitle
profilingTitle ? profilingTitle : "",
inputs,
outputs,
r->ncclStartEvents_.get(),

View File

@ -1,5 +1,12 @@
#pragma once
#if defined(__linux__)
#include <fcntl.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#endif
#ifdef USE_C10D_NCCL
#include <atomic>
@ -95,7 +102,7 @@ static std::vector<std::string> TORCH_NCCL_COORD_CHECK_MILSEC = {
constexpr const char* NCCL_BACKEND_NAME = "nccl";
constexpr const char* TIMEOUT_DUMP = "timeout_dump";
constexpr const char* EXCEPTION_DUMP = "exception_dump";
constexpr auto kProcessGroupNCCLDefaultTimeout =
std::chrono::milliseconds(10 * 60 * 1000);
@ -134,6 +141,59 @@ static std::vector<std::string> TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK =
{"TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK",
"NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"};
#if defined(__linux__)
struct DumpPipe {
DumpPipe(int rank) {
std::string fileStem =
getCvarString({"TORCH_NCCL_DEBUG_INFO_PIPE_FILE"}, "");
if (fileStem.empty() ||
getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0) <= 0) {
return;
}
TORCH_CHECK(!fileStem.empty(), "TORCH_NCCL_DEBUG_INFO_TEMP_FILE is empty");
std::string filename = c10::str(fileStem, rank, ".pipe");
TORCH_CHECK(
unlink(filename.c_str()) != -1 || errno == ENOENT,
"Error removing existing named pipe ",
filename);
TORCH_CHECK(
mkfifo(filename.c_str(), 0666) != -1,
"Error creating named pipe ",
filename);
fd_ = open(filename.c_str(), O_RDONLY | O_NONBLOCK);
LOG(INFO) << "Pipe file " << filename
<< " has been opened, write to it to trigger NCCL Debug Dump.";
TORCH_CHECK(fd_ != -1, "Error opening named pipe ", filename);
}
bool shouldDump() {
if (fd_ == -1) {
return false;
}
char buf[128];
// non-blocking from O_NONBLOCK above.
// Ignore EINTR because we already will poll this
// again later.
ssize_t bytesRead = read(fd_, &buf, 128);
return bytesRead > 0;
}
~DumpPipe() {
if (fd_ != -1) {
close(fd_);
}
}
private:
int fd_ = -1;
};
#else
struct DumpPipe {
DumpPipe(int rank) {}
bool shouldDump() {
return false;
}
};
#endif
// ProcessGroupNCCL implements NCCL bindings for c10d.
//
// All functions of the class are expected to be called in the same order
@ -384,6 +444,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// via `ncclCommSplit`
std::shared_ptr<ProcessGroupNCCL> split_from;
int64_t split_color{0};
std::string group_name;
};
// If you wish to create multiple process groups, each with a potentially
@ -761,6 +822,12 @@ class TORCH_API ProcessGroupNCCL : public Backend {
const std::chrono::time_point<std::chrono::steady_clock>& wakeUpTime,
size_t timeout_sec = 30);
// A helper function to wait for a future to complete or timeout.
void waitForFutureOrTimeout(
std::future<bool>& fut,
const std::chrono::milliseconds& timeOutMilSec,
const std::string& futDescription);
// When watchdog timeout, this function will be called and return debug info
// for users. For now we only get information from retrieveDesyncReport.
// We are working on enabling more useful debug information for watchdog
@ -878,6 +945,15 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Whether there are hooks pending to be fired
std::atomic<bool> hasPendingHooks_;
// This is the signal from watchdog threads to indicate whether the monitor
// thread should dump. Making it static so that it is accessiable from all the
// PGs. With this flag, monitor thread would dump debug info under any one of
// the 3 conditions: 1: this flag is set to true by the watchdog thread when
// it detects a timeout. 2: timeout signal is received from
// other ranks through tcpstore 3: no heartbeat of watchdog Note that only the
// monitor thread from PG0 should dump the debug info and only once
static std::atomic<bool> shouldDump_;
// Mutex to Guard workMetaList_
std::mutex workMetaListMutex_;
@ -962,8 +1038,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// Whether or not to enable timeout root cause analysis.
bool desyncDebug_;
// Whether or not to dump debug info on timeout
bool dumpOnTimeout_;
// Whether or not to dump debug info on exception including both watchdog
// timeout and nccl errors.
bool dumpOnException_;
// Whether or not to create start CUDAEvent and enable timing for start
// and end events. Note that enableTiming_ is always true if desyncDebug_

View File

@ -13,7 +13,6 @@
#include <string>
#include <system_error>
#include <vector>
namespace c10d {
/* Trace Utils Related to TORCH_NCCL_DESYNC_DEBUG */
@ -352,6 +351,18 @@ inline c10::List<c10::IValue> new_list() {
return c10::List<c10::IValue>(c10::AnyType::get());
}
inline std::string ranks_str(const std::vector<uint64_t>& ranks) {
std::string str;
for (const auto& rank : ranks) {
if (str.empty()) {
str = std::to_string(rank);
} else {
str += ", " + std::to_string(rank);
}
}
return c10::str("[", str, "]");
}
struct NCCLTraceBuffer {
static NCCLTraceBuffer* get() {
// intentionally leak on exit
@ -371,8 +382,9 @@ struct NCCLTraceBuffer {
// buffer this entry will be located to
// update state information
size_t pg_id_;
std::string pg_name_;
size_t seq_id_; // as tracked by the process group
const char* profiling_name_;
std::string profiling_name_;
std::shared_ptr<torch::CapturedTraceback> traceback_;
// we borrow pointers to start_ and end_ so we can query the state
@ -385,7 +397,16 @@ struct NCCLTraceBuffer {
c10::time_t time_created_;
c10::optional<float> duration_;
const char* state_ = "scheduled";
// timestamp when our CPU threads discovered that the kernel started.
// will always be _after_ it actually started, and can be very late
// if the watchdog thread got stuck on CUDA APIs.
c10::optional<c10::time_t> time_discovered_started_;
// timestamp when our CPU threads discovered that the kernel completed.
// will always be _after_ it actually complated, and can be the same time
// as the discovery of the start if the watchdog thread is stuck on CUDA
// APIs
c10::optional<c10::time_t> time_discovered_completed_;
// size information for input/output tensors
c10::SmallVector<int, 4> input_dims_;
@ -405,8 +426,9 @@ struct NCCLTraceBuffer {
c10::optional<size_t> record(
size_t pg_id,
const std::string& pg_name,
size_t seq_id,
const char* profiling_name,
std::string profiling_name,
const std::vector<at::Tensor>& inputs,
const std::vector<at::Tensor>& outputs,
EventList* start,
@ -421,8 +443,9 @@ struct NCCLTraceBuffer {
auto te = Entry{
id_,
pg_id,
pg_name,
seq_id,
profiling_name == nullptr ? "" : profiling_name,
std::move(profiling_name),
std::move(traceback),
std::move(start),
std::move(end),
@ -460,8 +483,8 @@ struct NCCLTraceBuffer {
break;
}
}
if (started) {
r.state_ = "started";
if (started && !r.time_discovered_started_) {
r.time_discovered_started_ = c10::getTime();
}
}
if (r.end_ != nullptr) {
@ -472,8 +495,8 @@ struct NCCLTraceBuffer {
break;
}
}
if (completed) {
r.state_ = "completed";
if (completed && !r.time_discovered_completed_) {
r.time_discovered_completed_ = c10::getTime();
}
}
}
@ -516,15 +539,15 @@ struct NCCLTraceBuffer {
std::unique_lock<std::mutex> guard(mutex_);
auto& entry = entries_.at(*id % max_entries_);
if (entry.id_ == *id) {
update_state(entry);
Entry* entry = &entries_.at(*id % max_entries_);
if (entry->id_ == *id) {
update_state(*entry);
if (compute_duration) {
can_compute_duration = strcmp(entry.state_, "completed") == 0 &&
entry.start_ && entry.end_;
startEvents = entry.start_;
endEvents = entry.end_;
can_compute_duration = entry->time_discovered_completed_.has_value() &&
entry->start_ && entry->end_;
startEvents = entry->start_;
endEvents = entry->end_;
}
}
@ -536,33 +559,38 @@ struct NCCLTraceBuffer {
duration = getDurationFromFirstEvent(*startEvents, *endEvents);
guard.lock();
// Refresh the entry ref, see if it has been overwritten
entry = entries_.at(*id % max_entries_);
if (entry.id_ != *id) {
// Refresh the entry pointer, see if the entry has been overwritten
entry = &entries_.at(*id % max_entries_);
if (entry->id_ != *id) {
LOG(INFO)
<< "retire_id abandoned for id " << *id
<< ", event was overwritten while waiting to compute duration.";
return;
}
if (duration.has_value()) {
entry.duration_ = duration.value();
entry->duration_ = duration.value();
}
}
entry.retired_ = true;
entry.start_ = entry.end_ = nullptr;
entry->retired_ = true;
entry->start_ = entry->end_ = nullptr;
}
std::string dump() {
std::string dump(
const c10::optional<std::unordered_map<
std::string,
std::unordered_map<std::string, std::string>>>& ncclDumpMap) {
auto result = dump_entries();
auto entries = new_list();
c10::IValue entries_key = "entries";
c10::IValue nccl_comm_key = "nccl_comm_state";
c10::IValue version_key = "version";
// Update whenever changing contents or formatting of the dump
// (minor when adding fields, major when changing existing fields)
c10::IValue version_val = "1.0";
c10::IValue version_val = "1.1";
c10::IValue pg_id_key = "pg_id";
c10::IValue pg_name_key = "process_group";
c10::IValue seq_id_key = "seq_id";
c10::IValue profiling_name_key = "profiling_name";
c10::IValue input_sizes_key = "input_sizes";
@ -576,6 +604,8 @@ struct NCCLTraceBuffer {
c10::IValue name_key = "name";
c10::IValue filename_key = "filename";
c10::IValue retired_key = "retired";
c10::IValue time_discovered_started_key = "time_discovered_started_ns";
c10::IValue time_discovered_completed_key = "time_discovered_completed_ns";
std::vector<torch::CapturedTraceback*> tracebacks;
for (auto& e : result) {
@ -596,6 +626,7 @@ struct NCCLTraceBuffer {
auto& tb = stracebacks.tracebacks.at(i);
auto dict = new_dict();
dict.insert(pg_id_key, int64_t(e.pg_id_));
dict.insert(pg_name_key, e.pg_name_);
dict.insert(seq_id_key, int64_t(e.seq_id_));
dict.insert(profiling_name_key, e.profiling_name_);
dict.insert(time_created_key, int64_t(e.time_created_));
@ -619,7 +650,24 @@ struct NCCLTraceBuffer {
dict.insert(input_sizes_key, read_sizes(e.input_dims_));
dict.insert(output_sizes_key, read_sizes(e.output_dims_));
dict.insert(state_key, e.state_);
if (e.time_discovered_completed_.has_value()) {
dict.insert(state_key, "completed");
} else if (e.time_discovered_started_.has_value()) {
dict.insert(state_key, "started");
} else {
dict.insert(state_key, "scheduled");
}
dict.insert(
time_discovered_started_key,
e.time_discovered_started_.has_value()
? int64_t(*e.time_discovered_started_)
: c10::IValue());
dict.insert(
time_discovered_completed_key,
e.time_discovered_completed_.has_value()
? int64_t(*e.time_discovered_completed_)
: c10::IValue());
dict.insert(retired_key, e.retired_);
auto frames = new_list();
@ -629,10 +677,24 @@ struct NCCLTraceBuffer {
dict.insert(frames_key, frames);
entries.push_back(dict);
}
// convert ncclDumpMap into a dictionary
auto per_comm_dict = new_dict();
if (ncclDumpMap.has_value()) {
for (const auto& [ncclId, ncclDump] : ncclDumpMap.value()) {
auto inner_dict = new_dict();
for (const auto& [key, value] : ncclDump) {
inner_dict.insert(key, value);
}
per_comm_dict.insert(ncclId, inner_dict);
}
}
auto dict = new_dict();
dict.insert(entries_key, entries);
dict.insert(version_key, version_val);
if (per_comm_dict.size() > 0) {
dict.insert(nccl_comm_key, per_comm_dict);
}
return pickle_str(dict);
}

View File

@ -1863,6 +1863,15 @@ Arguments:
"group_name",
&::c10d::ProcessGroup::getGroupName,
"(Gets this process group name. It's cluster unique)")
.def(
"_set_group_desc",
&::c10d::ProcessGroup::setGroupDesc,
py::call_guard<py::gil_scoped_acquire>(),
"Sets the process group description. This is an internal C10D method, do not use.")
.def_property_readonly(
"group_desc",
&::c10d::ProcessGroup::getGroupDesc,
"Gets this process group description")
.def_property(
"bound_device_id",
&::c10d::ProcessGroup::getBoundDeviceId,
@ -2443,7 +2452,9 @@ Example::
.def_readwrite(
"split_from", &::c10d::ProcessGroupNCCL::Options::split_from)
.def_readwrite(
"split_color", &::c10d::ProcessGroupNCCL::Options::split_color);
"split_color", &::c10d::ProcessGroupNCCL::Options::split_color)
.def_readwrite(
"group_name", &::c10d::ProcessGroupNCCL::Options::group_name);
#endif

View File

@ -0,0 +1,448 @@
import io
import math
from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed._functional_collectives import AsyncCollectiveTensor
if dist.is_available() or TYPE_CHECKING:
from torch.distributed import distributed_c10d
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor, Replicate
def _identity_func(
obj: torch.Tensor,
pg: Optional[dist.ProcessGroup],
device: Optional[torch.device],
companion_obj: Any,
) -> torch.Tensor:
return obj
def _all_gather_sharded_tensor(
sharded_tensor: "ShardedTensor",
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
) -> torch.Tensor:
if pg is None:
pg = distributed_c10d._get_default_group()
world_size = dist.get_world_size(pg)
shards = sharded_tensor.local_shards()
dim_0_size = sharded_tensor.size()[0] # type: ignore[index]
tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr]
chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size
pg_device = (
distributed_c10d._get_pg_default_device(pg) if device is None else device
)
if shards:
local_tensor = shards[0].tensor.flatten()
if local_tensor.device.type != pg_device.type:
local_tensor = local_tensor.to(pg_device)
num_padding = chunk_size - local_tensor.numel()
if num_padding > 0:
local_tensor = F.pad(local_tensor, [0, num_padding])
else:
local_tensor = torch.zeros(
chunk_size, dtype=sharded_tensor.dtype, device=pg_device
)
tensor = torch.empty(
chunk_size * world_size,
dtype=local_tensor.dtype,
device=pg_device,
)
dist.all_gather_into_tensor(tensor, local_tensor, group=pg)
tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
return tensor
class CompanionMismatch(Exception):
...
def _iterate_state_dict(
iter_object: Any,
sharded_tensor_func: Callable,
dtensor_func: Callable,
tensor_func: Callable,
*,
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
cpu_offload: bool = False,
companion_obj: Any = None,
ranks_only: Tuple[int, ...] = tuple(),
type_check: bool = True,
non_blocking: bool = True,
) -> Dict[str, Any]:
"""Iterate through the state dict, applying the given functions to each tensor type.
Args:
iter_object (Any): the target state_dict.
sharded_tensor_func (Callable): the function to apply to ShardedTensor
dtensor_func (Callable): the function to apply to DTensor
tensor_func (Callable): the function to apply to Tensor
pg (Optional[dist.ProcessGroup]): process group passed to tensor functions
device (Optional[torch.device]): device passed to tensor functions
cpu_offload (bool): whether to offload the tensors to CPU memory. This option is ignored
if a companion_obj is supplied.
companion_obj (Any): A companion object to the state dict. If this object
is supplied, we attempt to copy the tensor to the companion object.
ranks_only (Tuple[int, ...]): if this tuple is empty, all ranks will
have the same state_dicts. Otherwise only ranks that in ``ranks_only``
have the same state_dicts. Other ranks will get empty state_dicts.
type_check (bool): check if the instance data type is a supported type
that can be saved by DCP. The current supported data types are
torch.Tensor, DTensor, int, float, str, list, dict, None.
non_blocking (bool): whether to use non-blocking copy when copying to the companion object.
"""
# TODO: should we use pytree?
cpu_device = torch.device("cpu")
if isinstance(iter_object, ShardedTensor):
ret = sharded_tensor_func(iter_object, pg, device, companion_obj)
elif isinstance(iter_object, DTensor):
ret = dtensor_func(iter_object, pg, device, companion_obj)
elif isinstance(iter_object, torch.Tensor):
ret = tensor_func(iter_object, pg, device, companion_obj)
elif (
isinstance(iter_object, (int, float, str, bytes, io.BytesIO))
or iter_object is None
):
ret = iter_object
elif isinstance(iter_object, dict):
if companion_obj is not None and (
not isinstance(companion_obj, dict)
or set(companion_obj.keys()) != set(iter_object.keys())
):
raise CompanionMismatch()
ret = {
key: _iterate_state_dict(
value,
sharded_tensor_func,
dtensor_func,
tensor_func,
pg=pg,
device=device,
cpu_offload=cpu_offload,
companion_obj=companion_obj[key] if companion_obj is not None else None,
ranks_only=ranks_only,
type_check=type_check,
non_blocking=non_blocking,
)
for key, value in iter_object.items()
}
elif isinstance(iter_object, (list, tuple)):
if companion_obj is not None and (
not isinstance(companion_obj, (list, tuple))
or len(companion_obj) != len(iter_object)
):
raise CompanionMismatch()
ret = [
_iterate_state_dict(
v,
sharded_tensor_func,
dtensor_func,
tensor_func,
pg=pg,
device=device,
cpu_offload=cpu_offload,
companion_obj=companion_obj[idx] if companion_obj is not None else None,
ranks_only=ranks_only,
type_check=type_check,
non_blocking=non_blocking,
)
for idx, v in enumerate(iter_object)
]
if isinstance(iter_object, tuple):
ret = tuple(ret)
elif not type_check:
ret = iter_object
else:
raise ValueError(f"Unexpected value type {type(iter_object)}")
if not ranks_only or dist.get_rank(pg) in ranks_only:
if isinstance(ret, torch.Tensor):
if cpu_offload and companion_obj is None:
ret = ret.to(cpu_device)
if companion_obj is not None:
# TODO: support DTensor
companion_obj.copy_(ret, non_blocking=non_blocking)
ret = companion_obj
else:
ret = {} if isinstance(ret, dict) else None
return ret
def _gather_state_dict(
state_dict: Dict[str, Any],
*,
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
cpu_offload: bool = False,
ranks_only: Tuple[int, ...] = tuple(),
type_check: bool = True,
) -> Dict[str, Any]:
"""
Given a state_dict, this API gathers all the ShardedTensors or DTensors in
the state_dict.
Args:
state_dict (Dict[str, Any]): the target sharded state_dict.
pg (Optional[dist.ProcessGroup]): the process group that is used to
gather ShardedTensor. Note that gathering a DTensor will use
the DeviceMesh. So this argument will be ignored when gathering a
DTensor.
device: (Optional[torch.device]): the device that is used to
perform allgather for ShardedTensor. Note that gathering a DTensor
will use the DeviceMesh. So this argument will be ignored when
gathering a DTensor.
cpu_offload (bool): whether to offload the tensors to CPU memory. The
default value is False.
ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
have the same state_dicts. Otherwise only ranks that in ``ranks_only``
have the same state_dicts. Other ranks will get empty state_dicts.
type_check: (bool): check if the instance data type is a supported type
that can be saved by DCP. The current supported data types are
torch.Tensor, DTensor, int, float, str, list, dict, None.
Returns:
The gathered state dictionary.
"""
def sharded_tensor_func(value, pg, device, companion_obj):
# ShardedTensor does not seem to record the original device type.
# So if the tensor is moved to CPU, we won't know the original type.
# As a result, we have to rely on the user to tell us the correct one.
cpu_device = torch.device("cpu")
output_tensor = _all_gather_sharded_tensor(value, pg, device)
local_shard_device = (
value.local_shards()[0].tensor.device
if value.local_shards()
else cpu_device
)
if output_tensor.device != local_shard_device:
value = output_tensor.to(local_shard_device)
else:
value = output_tensor
return value
def dtensor_func(value, pg, device, companion_obj):
if value.device != value.device_mesh.device_type:
value = value.to(value.device_mesh.device_type)
# FSDP all_gather: [Shard(0)] -> [Replicate()]
# HSDP all_gather: [Replicate(), Shard(0)] -> [Replicate(), Replicate()]
# 2D FSDP + TP all_gather:
# - [Shard(0), Shard(n)] -> [Replicate(), Replicate()]
# - [Shard(0), Replicate()] -> [Replicate(), Replicate()]
placements = [Replicate() for _ in value.placements]
value = value.redistribute(
device_mesh=value.device_mesh,
placements=placements,
)
# Call `wait()` to force the tensor to be synchronous with respect
# to the main stream.
# See the discussion in https://github.com/pytorch/pytorch/pull/117799.
value = value.to_local()
if isinstance(value, AsyncCollectiveTensor):
value = value.wait()
return value
return _iterate_state_dict(
state_dict,
sharded_tensor_func,
dtensor_func,
_identity_func,
pg=pg,
device=device,
cpu_offload=cpu_offload,
ranks_only=ranks_only,
type_check=type_check,
)
def _offload_state_dict_to_cpu(
state_dict: Dict[str, Any],
*,
ranks_only: Tuple[int, ...] = tuple(),
type_check: bool = True,
) -> Dict[str, Any]:
"""
Given a state_dict, this API offload all the tensors to CPU memory.
Args:
state_dict (Dict[str, Any]): the target state_dict.
pg (Optional[dist.ProcessGroup]): the process group that is used to
gather ShardedTensor. Note that gathering a DTensor will use
the DeviceMesh. So this argument will be ignored when gathering a
DTensor.
ranks_only: (Tuple[int, ...]): if this tuple is empty, all ranks will
have the same state_dicts. Otherwise only ranks that in ``ranks_only``
have the same state_dicts. Other ranks will get empty state_dicts.
type_check: (bool): check if the instance data type is a supported type
that can be saved by DCP. The current supported data types are
torch.Tensor, DTensor, int, float, str, list, dict, None.
Returns:
The gathered state dictionary.
"""
ret = _iterate_state_dict(
state_dict,
_identity_func,
_identity_func,
_identity_func,
pg=None,
device=None,
cpu_offload=True,
ranks_only=ranks_only,
type_check=type_check,
)
return ret
def _copy_state_dict(
state_dict: Dict[str, Any],
copy_state_dict: Dict[str, Any],
non_blocking: bool = False,
):
"""
Copies all tensors in a given state dict into a different state_dict with the
same structure.
.. warning::
It is expected by this function that state_dict and copy_state_dict share
the same structure and data types.
.. warning::
The current supported data types are
torch.Tensor, DTensor, int, float, str, list, dict, None.
Args:
state_dict (Dict[str, Any]): the target state_dict.
copy_state_dict (Dict[str, Any]):
The state dict we are copying into. This state_dict must have exactly
the same structure as the source `state_dict`.
non_blocking: (bool): Whether copy ops should be performed asynchronously
"""
_iterate_state_dict(
state_dict,
_identity_func,
_identity_func,
_identity_func,
pg=None,
device=None,
cpu_offload=False,
ranks_only=tuple(),
companion_obj=copy_state_dict,
type_check=True,
non_blocking=non_blocking,
)
def _create_cpu_state_dict(
state_dict: Dict[str, Any], pin_memory: bool = False, share_memory: bool = False
) -> Dict[str, Any]:
"""
Given a state_dict, create another state_dict with the same structure and elements.
However, all tensors in the returned state_dict are new tensors on CPU. These
tensors can be placed on pin_memory or share_memory based on the provided arguments.
.. warning::
Setting both `pin_memory` and `share_memory` to True significantly increases the
latency of this method because of the nuances which require us to register memory
as pinned directly as opposed to relying on the pin_memory cache allocator. This
option should only be used for long lived tensors which are required to be shared.
This is not the case as long as at least one of `pin_memory` or `share_memory` is
set to False.
"""
def tensor_func(
obj: torch.Tensor,
pg: Optional[dist.ProcessGroup],
device: Optional[torch.device],
_: Any,
) -> torch.Tensor:
if len(obj.size()) == 0:
return torch.tensor(0, dtype=obj.dtype)
if share_memory:
t = torch.empty(*tuple(obj.size()), dtype=obj.dtype).share_memory_()
if pin_memory:
succ = torch.cuda.cudart().cudaHostRegister(
t.data_ptr(),
t.numel() * t.element_size(),
1, # lines up with 'cudaHostRegisterPortable'
)
assert (
succ == 0
), f"Pinning shared memory failed with error-code: {succ}"
return t
elif pin_memory:
return torch.empty(*tuple(obj.size()), dtype=obj.dtype).pin_memory()
else:
return torch.empty(*tuple(obj.size()), dtype=obj.dtype)
ret = _iterate_state_dict(
state_dict,
_identity_func,
_identity_func,
tensor_func,
pg=None,
device=None,
cpu_offload=False,
ranks_only=tuple(),
type_check=False,
)
return ret
def _check_state_dict_similarity(
state_dict: Dict[str, Any],
compared_state_dict: Dict[str, Any],
) -> bool:
"""
Given two state_dicts, check if the structures are the same. And
if a [key, tensor] pair exist in one state_dict there must be
the a corresponding pait, [key, other_tensor], in the other state_dict,
where tensor and other_tensor have the same size and dtype.
Return the check result.
"""
def tensor_func(
obj: torch.Tensor,
pg: Optional[dist.ProcessGroup],
device: Optional[torch.device],
companion_obj: Any,
) -> torch.Tensor:
if companion_obj.dtype != obj.dtype or companion_obj.size() != obj.size():
raise CompanionMismatch()
return obj
try:
_iterate_state_dict(
state_dict,
_identity_func,
_identity_func,
tensor_func,
pg=None,
device=None,
cpu_offload=False,
ranks_only=tuple(),
companion_obj=compared_state_dict,
type_check=False,
)
except CompanionMismatch:
return False
return True

View File

@ -1171,7 +1171,7 @@ def init_process_group(
)
default_pg, _ = _new_process_group_helper(
-1, -1, [], backend, None, group_name, timeout=timeout
-1, -1, [], backend, None, group_name, timeout=timeout, group_desc="default_pg"
)
_update_default_pg(default_pg)
else:
@ -1197,6 +1197,7 @@ def init_process_group(
pg_options=pg_options,
timeout=timeout,
device_id=device_id,
group_desc="default_pg"
)
_update_default_pg(default_pg)
@ -1257,6 +1258,7 @@ def _new_process_group_helper(
timeout=None,
pg_tag=None,
device_id=None,
group_desc=None,
):
"""
Create a new distributed process group.
@ -1289,6 +1291,8 @@ def _new_process_group_helper(
_, prefix_store = _world.pg_map[existing_group]
return existing_group, prefix_store
group_desc = "undefined" if group_desc is None else group_desc
# The list of group ranks is empty if we're creating the default group.
is_default_group = len(global_ranks_in_group) == 0
@ -1375,6 +1379,7 @@ def _new_process_group_helper(
if split_from:
pg_options.split_from = split_from
pg_options.split_color = _process_group_color(global_ranks_in_group)
pg_options.group_name = group_name
backend_class = ProcessGroupNCCL(
backend_prefix_store, group_rank, group_size, pg_options)
backend_type = ProcessGroup.BackendType.NCCL
@ -1461,9 +1466,11 @@ def _new_process_group_helper(
# update global state
assert group_name is not None
assert group_desc is not None
_world.pg_map[pg] = (backend, prefix_store)
_world.pg_names[pg] = group_name
pg._set_group_name(group_name)
pg._set_group_desc(group_desc)
_world.pg_backend_config[pg] = str(backend_config)
# "" is the default tag for user PGs
@ -3614,7 +3621,7 @@ def _get_backend_from_str(backend: Optional[str] = None) -> Backend:
@_time_logger
def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local_synchronization=False):
def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local_synchronization=False, group_desc=None):
"""
Create a new distributed group.
@ -3655,6 +3662,7 @@ def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local
barrier at the end of the process group creation. This is different
in that non-member ranks don't need to call into API and don't
join the barrier.
group_desc (str, optional): a string to describe the process group.
Returns:
A handle of distributed group that can be given to collective calls or None if the rank is not part of ``ranks``.
@ -3669,7 +3677,15 @@ def new_group(ranks=None, timeout=None, backend=None, pg_options=None, use_local
multiple overlaping process groups. To avoid that, make sure all ranks follow the
same global creation order.
"""
return _new_group_with_tag(ranks, timeout, backend, pg_options, None, use_local_synchronization=use_local_synchronization)
return _new_group_with_tag(
ranks,
timeout,
backend,
pg_options,
None,
use_local_synchronization=use_local_synchronization,
group_desc=group_desc,
)
def _new_group_with_tag(
ranks=None,
@ -3677,7 +3693,8 @@ def _new_group_with_tag(
backend=None,
pg_options=None,
pg_tag=None,
use_local_synchronization=False
use_local_synchronization=False,
group_desc=None
):
"""
Variant of ``new_group`` that exposes tag creation.
@ -3749,7 +3766,8 @@ def _new_group_with_tag(
group_name,
pg_options=pg_options,
timeout=timeout,
pg_tag=pg_tag
pg_tag=pg_tag,
group_desc=group_desc
)
# Create the global rank to group rank mapping
@ -3789,6 +3807,7 @@ def new_subgroups(
timeout=None,
backend=None,
pg_options=None,
group_desc=None,
):
"""
Create subgroups of equal size.
@ -3841,6 +3860,8 @@ def new_subgroups(
the construction of specific process groups. i.e. for the ``nccl``
backend, ``is_high_priority_stream`` can be specified so that
process group can pick up high priority cuda streams.
group_desc (str, optional): A string describing the group. Each subgroup will
inherit its group_desc
Returns:
The subgroup containing the current rank, and all the subgroups used for cleanup.
@ -3886,6 +3907,7 @@ def new_subgroups(
timeout=timeout,
backend=backend,
pg_options=pg_options,
group_desc=group_desc,
)
subgroups.append(subgroup)
@ -3905,6 +3927,7 @@ def new_subgroups_by_enumeration(
timeout=None,
backend=None,
pg_options=None,
group_desc=None,
):
"""
Create subgroups by dividing the global world.
@ -3945,6 +3968,8 @@ def new_subgroups_by_enumeration(
the construction of specific process groups. i.e. for the ``nccl``
backend, ``is_high_priority_stream`` can be specified so that
process group can pick up high priority cuda streams.
group_desc (str, optional): A string describing the group. Each subgroup will
inherit its group_desc.
Returns:
The subgroup containing the current rank, and all the subgroups used for cleanup.
@ -3973,6 +3998,7 @@ def new_subgroups_by_enumeration(
timeout=timeout,
backend=backend,
pg_options=pg_options,
group_desc=group_desc,
)
subgroups.append(subgroup)
my_rank = get_rank()

View File

@ -28,7 +28,7 @@ def tail_logfile(
return
time.sleep(interval_sec)
with open(file) as fp:
with open(file, errors="replace") as fp:
while True:
line = fp.readline()