Introduce ProcessGroupCudaP2P (#122163)

## Context
This stack prototypes automatic micro-pipelining of `all-gather -> matmul` and `matmul -> reduce-scatter` via Inductor. The idea originates from the paper [Overlap Communication with Dependent Computation via
Decomposition in Large Deep Learning Models](https://dl.acm.org/doi/pdf/10.1145/3567955.3567959). The implementation and some key optimizations are heavily influenced by @lw's implementation in xformers.

The stack contains several components:
- `ProcessGroupCudaP2P` - a thin wrapper around `ProcessGroupNCCL`. It in addition maintains a P2P workspace that enables SM-free, one-sided P2P communication which is needed for optimal micro-pipelining.
- `fused_all_gather_matmul` and `fused_matmul_reduce_scatter` dispatcher ops.
- Post-grad fx pass that detects `all-gather -> matmul` and `matmul -> reduce-scatter` and replaces them with the fused dispatcher ops.

To enable the prototype feature:
- Set the distributed backend to `cuda_p2p`.
- Set `torch._inductor.config._micro_pipeline_tp` to `True`.

*NOTE: the prototype sets nothing in stone w.r.t to each component's design. The purpose is to have a performant baseline with reasonable design on which each component can be further improved.*

## Benchmark
Setup:
- 8 x H100 (500W) + 3rd gen NVSwitch.
- Llama3 8B training w/ torchtitan.
- 8-way TP. Reduced the number of layers from 32 to 8 for benchmarking purpose.

Trace (baseline): https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmpjaz8zgx0
<img width="832" alt="image" src="https://github.com/pytorch/pytorch/assets/4156752/4addba77-5abc-4d2e-93ea-f68078587fe1">

Trace (w/ micro pipelining): https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmpn073b4wn
<img width="963" alt="image" src="https://github.com/pytorch/pytorch/assets/4156752/4f44e78d-8196-43ab-a1ea-27390f07e9d2">

## This PR
`ProcessGroupCudaP2P` is a thin wrapper around `ProcessGroupNCCL`. By default, it routes all collectives to the underlying `ProcessGroupNCCL`. In addition, `ProcessGroupCudaP2P` initializes a P2P workspace that allows direct GPU memory access among the members. The workspace can be used in Python to optimize intra-node communication patterns or to create custom intra-node collectives in CUDA.

`ProcessGroupCudaP2P` aims to bridge the gap where certain important patterns can be better optimized via fine-grained P2P memory access than with collectives in the latest version of NCCL. It is meant to complement NCCL rather than replacing it.
Usage:
```
    # Using ProcessGroupCudaP2P
    dist.init_process_group(backend="cuda_p2p", ...)

    # Using ProcessGroupCudaP2P while specifying ProcessGroupCudaP2P.Options
    pg_options = ProcessGroupCudaP2P.Options()
    dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)

    # Using ProcessGroupCudaP2P while specifying ProcessGroupNCCL.Options
    pg_options = ProcessGroupNCCL.Options()
    dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)

    # Using ProcessGroupCudaP2P while specifying both
    # ProcessGroupCudaP2P.Options and ProcessGroupNCCL.Options
    pg_options = ProcessGroupCudaP2P.Options()
    pg_options.nccl_options = ProcessGroupNCCL.Options()
    dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)

    # Down-casting the backend to access p2p buffers for cuda_p2p specific
    # optimizations
    if is_cuda_p2p_group(group):
        backend = get_cuda_p2p_backend(group)
        if required_p2p_buffer_size > backend.get_buffer_size():
            # fallback
        p2p_buffer = backend.get_p2p_buffer(...)
    else:
        # fallback
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122163
Approved by: https://github.com/wanchaol
This commit is contained in:
Yifu Wang
2024-05-24 01:19:27 -07:00
committed by PyTorch MergeBot
parent 01f04230cf
commit 4a09117d16
13 changed files with 795 additions and 116 deletions

View File

@ -18,6 +18,7 @@ time python test/run_test.py --verbose -i distributed/test_c10d_gloo
time python test/run_test.py --verbose -i distributed/test_c10d_nccl
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_gloo
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_nccl
time python test/run_test.py --verbose -i distributed/test_cuda_p2p
time python test/run_test.py --verbose -i distributed/test_store
time python test/run_test.py --verbose -i distributed/test_pg_wrapper
time python test/run_test.py --verbose -i distributed/rpc/cuda/test_tensorpipe_agent

View File

@ -1,17 +1,16 @@
#!/usr/bin/env python3
# This file contains an example for using IntraNodeComm to implement efficient fused
# This file contains an example for using cuda_p2p backend to implement efficient fused
# allgather_matmul (inspired by https://dl.acm.org/doi/pdf/10.1145/3567955.3567959 and
# @lw's efficient GPU implementation in xformers). Its purpose to help guide the
# development of relevant primitives and serve as an example for interested users.
#
# The benchmark can be executed as follows:
# torchrun --nproc-per-node 8 allgather_matmul.py
#
# NOTE: _IntraNodeComm is a prototype API which WILL change over time.
import os
import torch
import torch._C._distributed_c10d as c10d
import torch.distributed as dist
from torch.distributed._cuda_p2p import ProcessGroupCudaP2P
M = 16384
N = 8192
@ -21,55 +20,60 @@ WARMUP_ITERS = 200
BENCH_ITERS = 50
comm = None
internal_stream = None
internal_event = None
def allgather_matmul(A_shard: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
group = dist.group.WORLD
group_size = group.size()
A = torch.ops._c10d_functional.all_gather_into_tensor(A_shard, group_size, "0")
A = torch.ops._c10d_functional.wait_tensor(A)
return A @ B
def allgather_matmul(A_shard, B, out, rank, world_size):
def allgather_matmul_p2p(A_shard: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""
Equivalent to `torch.matmul(dist.all_gather(A_shard), B)`.
"""
buf_0 = torch.empty_like(A_shard)
buf_1 = torch.empty_like(A_shard)
out_shards = [
out[i : i + A_shard.shape[0]]
for i in range(0, world_size * A_shard.shape[0], A_shard.shape[0])
]
group = dist.group.WORLD
group_size = group.size()
rank = group.rank()
backend = group._get_backend(torch.device("cuda"))
out = torch.empty(
(A_shard.shape[0] * group.size(), B.shape[1]),
dtype=A_shard.dtype,
device="cuda",
)
out_shards = out.chunk(group_size)
local_p2p_buf = backend.get_p2p_buffer(rank, A_shard.shape, A_shard.dtype)
# Perform matmul with the local input shard
torch.matmul(A_shard, B, out=out_shards[rank])
# In another stream, copy the local input shard into the intra-node
# buffer. After the barrier, all peers' input shards are accessible
# via their intra-node buffer without requiring synchronization.
with torch.cuda.stream(internal_stream):
comm.put(A_shard)
comm.barrier()
internal_event.record()
internal_event.wait()
with torch.cuda.stream(backend.stream()):
local_p2p_buf.copy_(A_shard)
work = backend.intra_node_barrier()
work.wait()
# Copy input shard from remote buffer and perform matmul.
# Alternate between two streams to offset the wave quantization
# effect of smaller matmuls.
for i in range(1, world_size):
buf_0 = torch.empty_like(A_shard)
buf_1 = torch.empty_like(A_shard)
for i in range(1, group_size):
if i % 2 == 0:
buf = buf_0
stream = torch.cuda.current_stream()
else:
buf = buf_1
stream = internal_stream
remote = (i + rank) % world_size
stream = backend.stream()
remote_rank = (i + rank) % group_size
remote_p2p_buf = backend.get_p2p_buffer(
remote_rank, A_shard.shape, A_shard.dtype
)
with torch.cuda.stream(stream):
comm.get(remote, buf)
torch.matmul(buf, B, out=out_shards[remote])
buf.copy_(remote_p2p_buf)
torch.matmul(buf, B, out=out_shards[remote_rank])
# Perform another barrier to ensure all peers have completed consuming the
# intra-node buffer so it can be reused.
with torch.cuda.stream(internal_stream):
comm.barrier()
internal_event.record()
internal_event.wait()
with torch.cuda.stream(backend.stream()):
work = backend.intra_node_barrier()
work.wait()
return out
def do_bench(fn):
@ -89,8 +93,6 @@ def do_bench(fn):
def main():
os.environ["ENABLE_INTRA_NODE_COMM"] = "1"
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
@ -98,33 +100,32 @@ def main():
assert M % world_size == 0
torch.cuda.set_device(local_rank)
store, _, _ = next(torch.distributed.rendezvous("env://", rank, world_size))
global comm, internal_stream, internal_event
comm = c10d._IntraNodeComm(
store=store,
rank=rank,
world_size=world_size,
buffer_size=M * K * torch.finfo(torch.bfloat16).bits // 8 // world_size,
)
internal_stream = torch.cuda.Stream()
internal_event = torch.cuda.Event()
options = ProcessGroupCudaP2P.Options()
options.buffer_size = M * N * 2 // world_size
dist.init_process_group("cuda_p2p", pg_options=options)
torch.manual_seed(42)
A = torch.randn((M, K), dtype=torch.bfloat16, device="cuda")
B = torch.randn((K, N), dtype=torch.bfloat16, device="cuda")
out = torch.empty((M, N), dtype=torch.bfloat16, device="cuda")
stride = M // world_size
A_shard = A[rank * stride : (rank + 1) * stride]
comm.barrier()
torch.cuda.synchronize()
allgather_matmul_ms = do_bench(
lambda: allgather_matmul(A_shard, B, out, rank, world_size)
assert torch.allclose(
allgather_matmul(A_shard, B),
allgather_matmul_p2p(A_shard, B),
)
comm.barrier()
dist.barrier()
torch.cuda.synchronize()
allgather_matmul_ms = do_bench(lambda: allgather_matmul(A_shard, B))
dist.barrier()
torch.cuda.synchronize()
allgather_matmul_p2p_ms = do_bench(lambda: allgather_matmul_p2p(A_shard, B))
dist.barrier()
torch.cuda.synchronize()
matmul_ms = do_bench(lambda: torch.matmul(A, B))
@ -134,8 +135,15 @@ def main():
f"(M={M // world_size}, N={N}, K={K}, world_size={world_size}): "
f"{allgather_matmul_ms:.4} ms/iter"
)
print(
"allgather_matmul_p2p "
f"(M={M // world_size}, N={N}, K={K}, world_size={world_size}): "
f"{allgather_matmul_p2p_ms:.4} ms/iter"
)
print(f"matmul (M={M}, N={N}, K={K}): {matmul_ms:.4} ms/iter")
dist.destroy_process_group()
if __name__ == "__main__":
main()

View File

@ -675,6 +675,7 @@ libtorch_cuda_distributed_base_sources = [
# These files are only supported on Linux (and others) but not on Windows.
libtorch_cuda_distributed_extra_sources = [
"torch/csrc/distributed/c10d/NCCLUtils.cpp",
"torch/csrc/distributed/c10d/ProcessGroupCudaP2P.cpp",
"torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp",
"torch/csrc/distributed/c10d/ProcessGroupUCC.cpp",
"torch/csrc/distributed/c10d/UCCTracing.cpp",

View File

@ -0,0 +1,142 @@
# Owner(s): ["module: c10d"]
import os
from typing import List
import torch
import torch.distributed as dist
from torch.distributed._cuda_p2p import (
get_cuda_p2p_backend,
get_p2p_buffer_size,
is_cuda_p2p_group,
)
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
run_tests,
skip_but_pass_in_sandcastle_if,
skipIfRocm,
)
def requires_cuda_p2p_access():
cuda_p2p_access_available = (
torch.cuda.is_available()
and torch.cuda.device_count() >= 2
and dist.is_nccl_available()
)
num_devices = torch.cuda.device_count()
for i in range(num_devices - 1):
for j in range(i + 1, num_devices):
if not torch.cuda.can_device_access_peer(i, j):
cuda_p2p_access_available = False
break
if not cuda_p2p_access_available:
break
return skip_but_pass_in_sandcastle_if(
not cuda_p2p_access_available,
"cuda p2p access is not available",
)
@requires_nccl()
@requires_cuda_p2p_access()
class ProcessGroupCudaP2PTest(MultiProcessTestCase):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
@property
def world_size(self) -> int:
return 2
@property
def ranks(self) -> List[int]:
return list(range(self.world_size))
@property
def device(self) -> torch.device:
return torch.device(f"cuda:{self.rank}")
def _init_process_group(self, buffer_size: int) -> None:
os.environ["TEST_INTRA_NODE_COMM"] = "1"
torch.cuda.set_device(self.device)
# Verify cuda p2p specific APIs on ProcessGroupCudaP2P
store = dist.FileStore(self.file_name, self.world_size)
options = dist.ProcessGroupCudaP2P.Options()
options.buffer_size = buffer_size
dist.init_process_group(
backend="cuda_p2p",
world_size=self.world_size,
rank=self.rank,
store=store,
pg_options=options,
)
@skipIfRocm
@skip_if_lt_x_gpu(2)
def test_p2p_apis(self) -> None:
BUFFER_SIZE = 4 * 1024
self._init_process_group(BUFFER_SIZE)
# Verify cuda p2p specific APIs on ProcessGroupCudaP2P
assert is_cuda_p2p_group(dist.group.WORLD)
assert get_p2p_buffer_size(dist.group.WORLD) == BUFFER_SIZE
backend = get_cuda_p2p_backend(dist.group.WORLD)
assert isinstance(backend, dist.ProcessGroupCudaP2P)
assert backend.get_buffer_size() == BUFFER_SIZE
backend.get_p2p_buffer(self.rank, (BUFFER_SIZE // 4,), torch.float)
with self.assertRaises(RuntimeError):
backend.get_p2p_buffer(self.rank, (BUFFER_SIZE // 4 + 1,), torch.float)
with self.assertRaises(RuntimeError):
backend.get_p2p_buffer(self.rank, (BUFFER_SIZE // 4,), torch.float, 1)
# Verify cuda p2p specific APIs on non-cuda p2p process group
non_cuda_p2p_pg = dist.new_group(backend="nccl")
assert not is_cuda_p2p_group(non_cuda_p2p_pg)
assert get_p2p_buffer_size(non_cuda_p2p_pg) == 0
with self.assertRaises(TypeError):
get_cuda_p2p_backend(non_cuda_p2p_pg)
dist.barrier()
torch.cuda.synchronize()
dist.destroy_process_group()
@skipIfRocm
@skip_if_lt_x_gpu(2)
def test_p2p_buffer(self) -> None:
BUFFER_SIZE = 4 * 1024
self._init_process_group(BUFFER_SIZE)
rank = self.rank
world_size = self.world_size
assert is_cuda_p2p_group(dist.group.WORLD)
backend = get_cuda_p2p_backend(dist.group.WORLD)
local_buffer = backend.get_p2p_buffer(
(rank) % world_size, (BUFFER_SIZE // 4,), torch.float
)
remote_buffer = backend.get_p2p_buffer(
(rank + 1) % world_size, (BUFFER_SIZE // 4,), torch.float
)
local_buffer.fill_(rank)
backend.intra_node_barrier()
assert remote_buffer.eq((rank + 1) % world_size).all()
dist.barrier()
torch.cuda.synchronize()
dist.destroy_process_group()
if __name__ == "__main__":
run_tests()

View File

@ -605,3 +605,30 @@ def _register_process_group(
def _resolve_process_group(group_name: str) -> ProcessGroup: ...
def _unregister_all_process_groups() -> None: ...
def _unregister_process_group(group_name: str) -> None: ...
class ProcessGroupCudaP2P(Backend):
class Options:
nccl_options: Optional[ProcessGroupNCCL.Options]
buffer_size: Optional[int]
def __init__(self) -> None: ...
def __init__(
self,
store: Store,
rank: int,
size: int,
options: ProcessGroupCudaP2P.Options,
) -> None: ...
def is_p2p_available(self) -> bool: ...
def get_buffer_size(self) -> int: ...
def stream(self) -> torch.cuda.Stream: ...
def intra_node_barrier(self) -> Work: ...
def get_p2p_buffer(
self,
rank: int,
sizes: torch.Size,
dtype: torch.dtype,
storage_offset: Optional[int] = 0,
) -> torch.Tensor: ...
def _shutdown(self) -> None: ...

View File

@ -0,0 +1,206 @@
#ifdef USE_C10D_NCCL
#include <torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
namespace c10d {
using namespace c10d::intra_node_comm;
ProcessGroupCudaP2P::ProcessGroupCudaP2P(
const c10::intrusive_ptr<Store>& store,
int rank,
int size,
c10::intrusive_ptr<Options> options)
: Backend(rank, size), stream_(c10::cuda::getStreamFromPool()) {
nccl_backend_ = c10::make_intrusive<ProcessGroupNCCL>(
c10::make_intrusive<PrefixStore>("nccl", store),
rank,
size,
options->nccl_options);
nccl_backend_->setSequenceNumberForGroup();
p2p_backend_ = c10::make_intrusive<IntraNodeComm>(
c10::make_intrusive<PrefixStore>("p2p", store),
rank,
size,
options->buffer_size);
if (!p2p_backend_->rendezvous()) {
p2p_backend_ = nullptr;
}
}
bool ProcessGroupCudaP2P::is_p2p_available() {
return p2p_backend_ != nullptr &&
p2p_backend_->getTopology() == Topology::FULLY_CONNECTED;
}
size_t ProcessGroupCudaP2P::get_buffer_size() {
if (p2p_backend_ == nullptr) {
return 0;
}
return p2p_backend_->getBufferSize();
}
c10::Stream ProcessGroupCudaP2P::stream() {
return stream_;
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts) {
return nccl_backend_->broadcast(tensors, opts);
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts) {
return nccl_backend_->allreduce(tensors, opts);
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allreduce_sparse(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts) {
return nccl_backend_->allreduce_sparse(tensors, opts);
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts) {
return nccl_backend_->allreduce_coalesced(tensors, opts);
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts) {
return nccl_backend_->reduce(tensors, opts);
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts) {
return nccl_backend_->allgather(outputTensors, inputTensors, opts);
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::_allgather_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
const AllgatherOptions& opts) {
return nccl_backend_->_allgather_base(outputBuffer, inputBuffer, opts);
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allgather_coalesced(
std::vector<std::vector<at::Tensor>>& outputTensorLists,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts) {
return nccl_backend_->allgather_coalesced(
outputTensorLists, inputTensors, opts);
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allgather_into_tensor_coalesced(
std::vector<at::Tensor>& outputs,
std::vector<at::Tensor>& inputs,
const AllgatherOptions& opts) {
return nccl_backend_->allgather_into_tensor_coalesced(outputs, inputs, opts);
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts) {
return nccl_backend_->gather(outputTensors, inputTensors);
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts) {
return nccl_backend_->scatter(outputTensors, inputTensors);
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts) {
return nccl_backend_->reduce_scatter(outputTensors, inputTensors, opts);
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::_reduce_scatter_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
const ReduceScatterOptions& opts) {
return nccl_backend_->_reduce_scatter_base(outputBuffer, inputBuffer, opts);
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::reduce_scatter_tensor_coalesced(
std::vector<at::Tensor>& outputs,
std::vector<at::Tensor>& inputs,
const ReduceScatterOptions& opts) {
return nccl_backend_->reduce_scatter_tensor_coalesced(outputs, inputs, opts);
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::alltoall_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts) {
return nccl_backend_->alltoall_base(
outputBuffer, inputBuffer, outputSplitSizes, inputSplitSizes);
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts) {
return nccl_backend_->alltoall(outputTensors, inputTensors, opts);
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) {
return nccl_backend_->send(tensors, dstRank, tag);
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) {
return nccl_backend_->recv(tensors, srcRank, tag);
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) {
return nccl_backend_->recvAnysource(tensors, tag);
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::barrier(
const BarrierOptions& opts) {
return nccl_backend_->barrier(opts);
}
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::intra_node_barrier(
c10::optional<std::vector<int64_t>> ranks) {
TORCH_CHECK(p2p_backend_ != nullptr);
p2p_backend_->barrier(ranks);
return c10::make_intrusive<IntraNodeCommWork>();
}
at::Tensor ProcessGroupCudaP2P::get_p2p_buffer(
size_t rank,
const std::vector<int64_t>& sizes,
c10::ScalarType dtype,
int64_t storage_offset) {
TORCH_CHECK(p2p_backend_ != nullptr);
return p2p_backend_->getBuffer(rank, sizes, dtype, storage_offset);
}
void ProcessGroupCudaP2P::shutdown(c10::optional<std::string> reason) {
nccl_backend_->shutdown(reason);
}
} // namespace c10d
#endif // USE_C10D_NCCL

View File

@ -0,0 +1,148 @@
#pragma once
#ifdef USE_C10D_NCCL
#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
constexpr auto kProcessGroupCudaP2PDefaultTimeout =
std::chrono::milliseconds(10 * 60 * 1000);
namespace c10d {
class TORCH_API ProcessGroupCudaP2P : public Backend {
public:
struct Options : Backend::Options {
c10::intrusive_ptr<ProcessGroupNCCL::Options> nccl_options;
c10::optional<size_t> buffer_size;
explicit Options()
: Backend::Options("cuda_p2p", kProcessGroupCudaP2PDefaultTimeout) {}
};
bool is_p2p_available();
size_t get_buffer_size();
c10::Stream stream();
ProcessGroupCudaP2P(
const c10::intrusive_ptr<Store>& store,
int rank,
int size,
c10::intrusive_ptr<Options> options);
c10::intrusive_ptr<Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()) override;
c10::intrusive_ptr<Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;
c10::intrusive_ptr<Work> allreduce_sparse(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;
c10::intrusive_ptr<Work> allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts =
AllreduceCoalescedOptions()) override;
c10::intrusive_ptr<Work> reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts = ReduceOptions()) override;
c10::intrusive_ptr<Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<Work> _allgather_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<Work> allgather_coalesced(
std::vector<std::vector<at::Tensor>>& outputTensorLists,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<Work> allgather_into_tensor_coalesced(
std::vector<at::Tensor>& outputs,
std::vector<at::Tensor>& inputs,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts = GatherOptions()) override;
c10::intrusive_ptr<Work> scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts = ScatterOptions()) override;
c10::intrusive_ptr<Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts) override;
c10::intrusive_ptr<Work> _reduce_scatter_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced(
std::vector<at::Tensor>& outputs,
std::vector<at::Tensor>& inputs,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
c10::intrusive_ptr<Work> alltoall_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts = AllToAllOptions()) override;
c10::intrusive_ptr<Work> alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts = AllToAllOptions()) override;
c10::intrusive_ptr<Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) override;
c10::intrusive_ptr<Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) override;
c10::intrusive_ptr<Work> recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) override;
/* P2P-only */
c10::intrusive_ptr<Work> barrier(
const BarrierOptions& opts = BarrierOptions()) override;
c10::intrusive_ptr<Work> intra_node_barrier(
c10::optional<std::vector<int64_t>> ranks = c10::nullopt);
at::Tensor get_p2p_buffer(
size_t rank,
const std::vector<int64_t>& sizes,
c10::ScalarType dtype,
int64_t storage_offest = 0);
void shutdown(c10::optional<std::string> reason = c10::nullopt);
private:
c10::intrusive_ptr<ProcessGroupNCCL> nccl_backend_;
c10::intrusive_ptr<c10d::intra_node_comm::IntraNodeComm> p2p_backend_;
c10::Stream stream_;
};
} // namespace c10d
#endif // USE_C10D_NCCL

View File

@ -24,6 +24,7 @@
#ifdef USE_C10D_NCCL
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
#endif
@ -2644,14 +2645,7 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
py::arg("rank"),
py::arg("world_size"),
py::arg("buffer_size") = c10::nullopt)
.def("barrier", &IntraNodeComm::barrier, py::arg("ranks") = py::none())
.def("put", &IntraNodeComm::put, py::arg("input"), py::arg("offset") = 0)
.def(
"get",
&IntraNodeComm::get,
py::arg("rank"),
py::arg("tensor"),
py::arg("offset") = 0);
.def("barrier", &IntraNodeComm::barrier, py::arg("ranks") = py::none());
#ifdef NCCL_HAS_COMM_CTA_CGA
py::class_<ncclConfig_t>(
@ -2727,6 +2721,54 @@ Example::
.def_readwrite(
"group_name", &::c10d::ProcessGroupNCCL::Options::group_name);
auto processGroupCudaP2P =
intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupCudaP2P>(
module, "ProcessGroupCudaP2P", backend)
.def(py::init<
const c10::intrusive_ptr<::c10d::Store>&,
int,
int,
c10::intrusive_ptr<::c10d::ProcessGroupCudaP2P::Options>>())
.def(
"is_p2p_available",
&::c10d::ProcessGroupCudaP2P::is_p2p_available)
.def("get_buffer_size", &::c10d::ProcessGroupCudaP2P::get_buffer_size)
.def("stream", &::c10d::ProcessGroupCudaP2P::stream)
.def(
"intra_node_barrier",
&::c10d::ProcessGroupCudaP2P::intra_node_barrier,
py::arg("ranks") = py::none())
.def(
"get_p2p_buffer",
[](c10::intrusive_ptr<::c10d::ProcessGroupCudaP2P> self,
size_t rank,
const std::vector<int64_t>& sizes,
py::object data_type_obj,
int64_t storage_offset) {
auto scalar_type =
reinterpret_cast<THPDtype*>(data_type_obj.ptr())
->scalar_type;
return self->get_p2p_buffer(
rank, sizes, scalar_type, storage_offset);
},
py::arg("rank"),
py::arg("sizes"),
py::arg("dtype"),
py::arg("storage_offset") = 0)
.def(
"_shutdown",
[](const c10::intrusive_ptr<::c10d::ProcessGroupCudaP2P>& self) {
return self->shutdown();
});
intrusive_ptr_class_<::c10d::ProcessGroupCudaP2P::Options>(
processGroupCudaP2P, "Options", processGroupOptions)
.def(py::init<>())
.def_readwrite(
"nccl_options", &::c10d::ProcessGroupCudaP2P::Options::nccl_options)
.def_readwrite(
"buffer_size", &::c10d::ProcessGroupCudaP2P::Options::buffer_size);
#endif
#ifdef USE_C10D_MPI

View File

@ -211,9 +211,8 @@ IntraNodeComm::IntraNodeComm(
: store_(std::move(store)),
rank_(rank),
worldSize_(worldSize),
bufferSize_(bufferSize.has_value() ? *bufferSize : kDefaultBufferSize) {
rendezvous();
}
bufferSize_(bufferSize.has_value() ? *bufferSize : kDefaultBufferSize),
barrierReady_(at::cuda::CUDAEvent()) {}
IntraNodeComm::~IntraNodeComm() {
if (!isInitialized_) {
@ -289,7 +288,7 @@ bool IntraNodeComm::rendezvous() {
return true;
}
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
if (!isIntraNodeCommSupported() || !isEnabled() || worldSize_ < 2 ||
if (!isIntraNodeCommSupported() || worldSize_ < 2 ||
worldSize_ > kMaxDevices) {
return false;
}

View File

@ -504,7 +504,8 @@ at::Tensor IntraNodeComm::oneShotAllReduce(
at::cuda::CUDAStream& stream) {
checkInput(input, rank_);
const size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
const size_t numelPerWarp =
kBytesPerThread / input.element_size() * kWarpSize;
const size_t N_aligned = alignUp(input.numel(), numelPerWarp);
const bool isAligned = (N_aligned == static_cast<size_t>(input.numel()));
TORCH_CHECK(N_aligned <= bufferSize_ / input.element_size());
@ -733,6 +734,7 @@ static __global__ void barrierKernel(
}
void IntraNodeComm::barrier(std::optional<std::vector<int64_t>> ranks) {
barrierReady_.block(at::cuda::getCurrentCUDAStream());
if (!ranks.has_value()) {
ranks = std::vector<int64_t>(worldSize_);
std::iota(ranks->begin(), ranks->end(), 0);
@ -745,44 +747,23 @@ void IntraNodeComm::barrier(std::optional<std::vector<int64_t>> ranks) {
barrierKernel<<<1, kWarpSize, 0, at::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<P2pState**>(p2pStatesDev_), mask, rank_, worldSize_);
C10_CUDA_KERNEL_LAUNCH_CHECK();
barrierReady_.record();
}
void IntraNodeComm::put(const at::Tensor& tensor, int64_t offset) {
TORCH_CHECK(
tensor.is_non_overlapping_and_dense(),
"IntraNodeComm::put(): tensor must be non-overlapping and dense");
size_t sz = tensor.numel() * tensor.element_size();
TORCH_CHECK(
offset + sz <= bufferSize_,
"IntraNodeComm::put(): offset + tensor size exceeded "
"p2p buffer size");
// This results in "Memcpy PtoP" which does not use SMs for copying
AT_CUDA_CHECK(cudaMemcpyAsync(
static_cast<char*>(buffers_[rank_]) + offset,
static_cast<char*>(tensor.data_ptr()),
sz,
cudaMemcpyDeviceToDevice,
at::cuda::getCurrentCUDAStream()));
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
void IntraNodeComm::get(size_t rank, at::Tensor tensor, int64_t offset) {
TORCH_CHECK(
tensor.is_non_overlapping_and_dense(),
"IntraNodeComm::get(): tensor must be non-overlapping and dense");
size_t sz = tensor.numel() * tensor.element_size();
TORCH_CHECK(
offset + sz <= bufferSize_,
"IntraNodeComm::get(): offset + tensor size exceeded "
"p2p buffer size");
// This results in "Memcpy PtoP" which does not use SMs for copying
AT_CUDA_CHECK(cudaMemcpyAsync(
static_cast<char*>(tensor.data_ptr()),
static_cast<char*>(buffers_[rank]) + offset,
sz,
cudaMemcpyDeviceToDevice,
at::cuda::getCurrentCUDAStream()));
C10_CUDA_KERNEL_LAUNCH_CHECK();
at::Tensor IntraNodeComm::getBuffer(
size_t rank,
const std::vector<int64_t>& sizes,
c10::ScalarType dtype,
int64_t storageOffset) {
const auto numel = std::accumulate(sizes.begin(), sizes.end(), 0);
const auto elementSize = c10::elementSize(dtype);
TORCH_CHECK((numel + storageOffset) * elementSize <= bufferSize_);
auto options = at::TensorOptions().dtype(dtype).device(
at::kCUDA, at::cuda::current_device());
return at::for_blob(buffers_[rank], sizes)
.storage_offset(storageOffset)
.options(options)
.make_tensor();
}
} // namespace intra_node_comm

View File

@ -46,6 +46,10 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target {
*/
bool rendezvous();
Topology getTopology() {
return topology_;
}
size_t getBufferSize() {
return bufferSize_;
}
@ -63,17 +67,11 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target {
*/
void barrier(std::optional<std::vector<int64_t>> ranks = c10::nullopt);
/**
* Puts the given tensor into the p2p buffer of the current rank at the
* specified offset.
*/
void put(const at::Tensor& tensor, int64_t offset = 0);
/**
* Fills the given tensor with the data from the specified rank's p2p buffer
* at the specified offset.
*/
void get(size_t rank, at::Tensor tensor, int64_t offset = 0);
at::Tensor getBuffer(
size_t rank,
const std::vector<int64_t>& sizes,
c10::ScalarType dtype,
int64_t storageOffset);
private:
at::Tensor oneShotAllReduce(
@ -92,6 +90,7 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target {
size_t rank_;
size_t worldSize_;
size_t bufferSize_;
at::cuda::CUDAEvent barrierReady_;
/**
* Members initialized after rendezvous

View File

@ -0,0 +1,123 @@
from contextlib import contextmanager
from functools import partial
from typing import Callable, cast, List, Tuple, Union
import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.distributed_c10d as c10d
from torch._C._distributed_c10d import _DistributedBackendOptions, Backend
"""
This file contains the registration logic and Python APIs for
``ProcessGroupCudaP2P`` (experimental).
``ProcessGroupCudaP2P`` is a thin wrapper around ``ProcessGroupNCCL``. By
default, it routes all collectives to the underlying ``ProcessGroupNCCL``. In
addition, ``ProcessGroupCudaP2P`` initializes a P2P workspace that allows
direct GPU memory access among the members. The workspace can be used in Python
to optimize intra-node communication patterns or to create custom intra-node
collectives in CUDA.
``ProcessGroupCudaP2P`` aims to bridge the gap where certain important patterns
can be better optimized via fine-grained P2P memory access than with
collectives in the latest version of NCCL. It is meant to complement NCCL
rather than replacing it.
Usage:
# Using ProcessGroupCudaP2P
dist.init_process_group(backend="cuda_p2p", ...)
# Using ProcessGroupCudaP2P while specifying ProcessGroupCudaP2P.Options
pg_options = ProcessGroupCudaP2P.Options()
dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)
# Using ProcessGroupCudaP2P while specifying ProcessGroupNCCL.Options
pg_options = ProcessGroupNCCL.Options()
dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)
# Using ProcessGroupCudaP2P while specifying both
# ProcessGroupCudaP2P.Options and ProcessGroupNCCL.Options
pg_options = ProcessGroupCudaP2P.Options()
pg_options.nccl_options = ProcessGroupNCCL.Options()
dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)
# Down-casting the backend to access p2p buffers for cuda_p2p specific
# optimizations
if is_cuda_p2p_group(group):
backend = get_cuda_p2p_backend(group)
if required_p2p_buffer_size > backend.get_buffer_size():
# fallback
p2p_buffer = backend.get_p2p_buffer(...)
else:
# fallback
"""
def _create_cuda_p2p_group(
dist_backend_opts: "_DistributedBackendOptions",
options: Union[
"c10d.ProcessGroupCudaP2P.Options", "c10d.ProcessGroupNCCL.Options", None
],
) -> "Backend":
if not c10d.is_nccl_available():
raise RuntimeError("The cuda_p2p backend is not available")
if options is None:
options = c10d.ProcessGroupCudaP2P.Options()
options.nccl_options = c10d.ProcessGroupNCCL.Options()
elif isinstance(options, c10d.ProcessGroupNCCL.Options):
nccl_options = options
options = c10d.ProcessGroupCudaP2P.Options()
options.nccl_options = nccl_options
elif isinstance(options, c10d.ProcessGroupCudaP2P.Options):
if options.nccl_options is None:
options.nccl_options = c10d.ProcessGroupNCCL.Options()
else:
raise TypeError(
"options for cuda_p2p must be ProcessGroupCudaP2P.Options "
f"or ProcessGroupNCCL.Options (got: {type(options)})"
)
return c10d.ProcessGroupCudaP2P(
dist_backend_opts.store,
dist_backend_opts.group_rank,
dist_backend_opts.group_size,
options,
)
def is_cuda_p2p_group(group: c10d.ProcessGroup) -> bool:
if not c10d.is_nccl_available():
return False
try:
backend = group._get_backend(torch.device("cuda"))
except Exception:
return False
return isinstance(backend, c10d.ProcessGroupCudaP2P) and backend.is_p2p_available()
def get_cuda_p2p_backend(group: c10d.ProcessGroup) -> "c10d.ProcessGroupCudaP2P":
if not is_cuda_p2p_group(group):
raise TypeError("group is not a cuda_p2p process group.")
return cast(
c10d.ProcessGroupCudaP2P,
group._get_backend(torch.device("cuda")),
)
def get_p2p_buffer_size(group: c10d.ProcessGroup) -> int:
if not is_cuda_p2p_group(group):
return 0
backend = get_cuda_p2p_backend(group)
return backend.get_buffer_size()
c10d.Backend.register_backend(
"cuda_p2p",
_create_cuda_p2p_group,
extended_api=True,
devices=["cuda"],
)

View File

@ -110,8 +110,10 @@ except ImportError:
try:
from torch._C._distributed_c10d import ProcessGroupNCCL
from torch._C._distributed_c10d import ProcessGroupCudaP2P
ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupNCCL"]
ProcessGroupCudaP2P.__module__ = "torch.distributed.distributed_c10d"
__all__ += ["ProcessGroupNCCL", "ProcessGroupCudaP2P"]
except ImportError:
_NCCL_AVAILABLE = False
@ -1444,7 +1446,7 @@ def _shutdown_backend(pg):
backend = pg._get_backend(torch.device("cuda"))
except RuntimeError:
pass
if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
if is_nccl_available() and isinstance(backend, (ProcessGroupNCCL, ProcessGroupCudaP2P)):
# explictly call shutdown to ensure that NCCL resources are released
backend._shutdown()