[c10d] Faster coalescing (#98793)

### Description
The PR aims at reducing CPU overhead of context manager style coalescing.

By "context manager style coalescing", we mean:
Sync style:
```
with _coalescing_manager():
     for i in range(num_coll):
         dist.all_reduce(tensors[i])
```
Async style:
```
with _coalescing_manager(async_ops=True) as cm:
     for i in range(num_coll):
         dist.all_reduce(tensors[i])
cm.wait()
```
In previous implementation, each collective in the `num_coll` loop actually calls into the C++ backend, accumulating pybind overhead.

In the new implementation, we capture the collectives at Python level, and only fire towards C++ at the exit of the coalescing manager.

### Tests
In current PR, the "fast path" only applies to all-reduce.
- Flattened 512M: 16.38 ms, including CPU time 131.21 us
- Old _coalescing_manager 64 x 8M: 22.19 ms, including CPU time 2865 us
- New _coalescing_manager 64 x 8M: 16.93 ms, including CPU time 635 us

Hence a 4x reduction in CPU overhead (dependent on `num_coll`).

Cc @mrshenli @kumpera @wanchaol @fegin
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98793
Approved by: https://github.com/kumpera
This commit is contained in:
Ke Wen
2023-04-24 21:27:22 +00:00
committed by PyTorch MergeBot
parent 3dcc7b396c
commit 3a09aa5977
10 changed files with 342 additions and 112 deletions

View File

@ -1248,6 +1248,75 @@ class DistributedTest:
# No model averaging, so the parameters are not updated.
self.assertEqual(param.data, tensor)
# Coalescing manager (sync mode)
@skip_if_no_gpu
@skip_but_pass_in_sandcastle_if(
BACKEND != "nccl" or IS_FBCODE or IS_SANDCASTLE,
"Coalescing manager currently tests with NCCL only; internal test flaky"
)
def test_coalescing_manager(self):
self._barrier()
rank = dist.get_rank()
world_size = dist.get_world_size()
rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
device_id = rank_to_GPU[rank][0]
torch.cuda.set_device(device_id)
num_colls = 2
size_per_coll = 8
small_tensors = [
torch.ones(size_per_coll, device=device_id) for _ in range(num_colls)
]
with dist._coalescing_manager():
for i in range(num_colls):
dist.all_reduce(small_tensors[i])
big_tensor = torch.ones(num_colls * size_per_coll, device=device_id)
dist.all_reduce(big_tensor)
for i in range(num_colls):
self.assertEqual(
small_tensors[i],
big_tensor[i * size_per_coll : (i + 1) * size_per_coll]
)
self._barrier()
# Coalescing manager (async mode)
@skip_if_no_gpu
@skip_but_pass_in_sandcastle_if(
BACKEND != "nccl" or IS_FBCODE or IS_SANDCASTLE,
"Coalescing manager currently tests with NCCL only; internal test flaky"
)
def test_coalescing_manager_async(self):
self._barrier()
rank = dist.get_rank()
world_size = dist.get_world_size()
rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
device_id = rank_to_GPU[rank][0]
torch.cuda.set_device(device_id)
num_colls = 2
size_per_coll = 8
small_tensors = [
torch.ones(size_per_coll, device=device_id) for _ in range(num_colls)
]
with dist._coalescing_manager(async_ops=True) as cm:
for i in range(num_colls):
dist.all_reduce(small_tensors[i])
cm.wait()
big_tensor = torch.ones(num_colls * size_per_coll, device=device_id)
dist.all_reduce(big_tensor)
for i in range(num_colls):
self.assertEqual(
small_tensors[i],
big_tensor[i * size_per_coll : (i + 1) * size_per_coll]
)
self._barrier()
# NCCL Batch SEND RECV
@skip_if_no_gpu
@skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")