mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Work: block_current_stream API (#156883)
This implements a new `wait_stream` API in Work that matches how `wait` works for ProcessGroupNCCL for CPU based backends such as Gloo. The idea is to support Gloo communication overlap in FSDPv2/HSDP with minimal changes to FSDP. There was a previous attempt to make FSDPv2 use Work.wait but given the extensive stream semantics used it doesn't play nicely. https://github.com/pytorch/pytorch/pull/148780 This uses a "Baton" CUDA kernel which spinlocks on a pinned CPU tensor waiting for it to be set. Test plan: ``` pytest test/distributed/test_c10d_gloo.py -v -k wait_stream pytest test/distributed/test_c10d_nccl.py -v -k wait_stream ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/156883 Approved by: https://github.com/kwen2501, https://github.com/fduwjj
This commit is contained in:
committed by
PyTorch MergeBot
parent
92f41ccc26
commit
1b3d69b59f
@ -1217,6 +1217,21 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
)
|
||||
dist.all_reduce(torch.empty(1, device=torch.device("cuda", device_idx)))
|
||||
|
||||
@requires_nccl()
|
||||
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
|
||||
def test_block_current_stream(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
device = torch.device(f"cuda:{self.rank}")
|
||||
pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
|
||||
|
||||
t = torch.rand(10, device=device)
|
||||
work = pg.allreduce(t)
|
||||
work.block_current_stream()
|
||||
|
||||
torch.cuda.current_stream().synchronize()
|
||||
work.wait()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
class DistributedDataParallelTest(
|
||||
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase
|
||||
|
Reference in New Issue
Block a user