mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Introduce ProcessGroupCudaP2P (#122163)
## Context This stack prototypes automatic micro-pipelining of `all-gather -> matmul` and `matmul -> reduce-scatter` via Inductor. The idea originates from the paper [Overlap Communication with Dependent Computation via Decomposition in Large Deep Learning Models](https://dl.acm.org/doi/pdf/10.1145/3567955.3567959). The implementation and some key optimizations are heavily influenced by @lw's implementation in xformers. The stack contains several components: - `ProcessGroupCudaP2P` - a thin wrapper around `ProcessGroupNCCL`. It in addition maintains a P2P workspace that enables SM-free, one-sided P2P communication which is needed for optimal micro-pipelining. - `fused_all_gather_matmul` and `fused_matmul_reduce_scatter` dispatcher ops. - Post-grad fx pass that detects `all-gather -> matmul` and `matmul -> reduce-scatter` and replaces them with the fused dispatcher ops. To enable the prototype feature: - Set the distributed backend to `cuda_p2p`. - Set `torch._inductor.config._micro_pipeline_tp` to `True`. *NOTE: the prototype sets nothing in stone w.r.t to each component's design. The purpose is to have a performant baseline with reasonable design on which each component can be further improved.* ## Benchmark Setup: - 8 x H100 (500W) + 3rd gen NVSwitch. - Llama3 8B training w/ torchtitan. - 8-way TP. Reduced the number of layers from 32 to 8 for benchmarking purpose. Trace (baseline): https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmpjaz8zgx0 <img width="832" alt="image" src="https://github.com/pytorch/pytorch/assets/4156752/4addba77-5abc-4d2e-93ea-f68078587fe1"> Trace (w/ micro pipelining): https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html#!/?url=https://interncache-all.fbcdn.net/manifold/perfetto_internal_traces/tree/shared_trace/yifu_tmpn073b4wn <img width="963" alt="image" src="https://github.com/pytorch/pytorch/assets/4156752/4f44e78d-8196-43ab-a1ea-27390f07e9d2"> ## This PR `ProcessGroupCudaP2P` is a thin wrapper around `ProcessGroupNCCL`. By default, it routes all collectives to the underlying `ProcessGroupNCCL`. In addition, `ProcessGroupCudaP2P` initializes a P2P workspace that allows direct GPU memory access among the members. The workspace can be used in Python to optimize intra-node communication patterns or to create custom intra-node collectives in CUDA. `ProcessGroupCudaP2P` aims to bridge the gap where certain important patterns can be better optimized via fine-grained P2P memory access than with collectives in the latest version of NCCL. It is meant to complement NCCL rather than replacing it. Usage: ``` # Using ProcessGroupCudaP2P dist.init_process_group(backend="cuda_p2p", ...) # Using ProcessGroupCudaP2P while specifying ProcessGroupCudaP2P.Options pg_options = ProcessGroupCudaP2P.Options() dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...) # Using ProcessGroupCudaP2P while specifying ProcessGroupNCCL.Options pg_options = ProcessGroupNCCL.Options() dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...) # Using ProcessGroupCudaP2P while specifying both # ProcessGroupCudaP2P.Options and ProcessGroupNCCL.Options pg_options = ProcessGroupCudaP2P.Options() pg_options.nccl_options = ProcessGroupNCCL.Options() dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...) # Down-casting the backend to access p2p buffers for cuda_p2p specific # optimizations if is_cuda_p2p_group(group): backend = get_cuda_p2p_backend(group) if required_p2p_buffer_size > backend.get_buffer_size(): # fallback p2p_buffer = backend.get_p2p_buffer(...) else: # fallback ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/122163 Approved by: https://github.com/wanchaol
This commit is contained in:
committed by
PyTorch MergeBot
parent
01f04230cf
commit
4a09117d16
@ -18,6 +18,7 @@ time python test/run_test.py --verbose -i distributed/test_c10d_gloo
|
|||||||
time python test/run_test.py --verbose -i distributed/test_c10d_nccl
|
time python test/run_test.py --verbose -i distributed/test_c10d_nccl
|
||||||
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_gloo
|
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_gloo
|
||||||
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_nccl
|
time python test/run_test.py --verbose -i distributed/test_c10d_spawn_nccl
|
||||||
|
time python test/run_test.py --verbose -i distributed/test_cuda_p2p
|
||||||
time python test/run_test.py --verbose -i distributed/test_store
|
time python test/run_test.py --verbose -i distributed/test_store
|
||||||
time python test/run_test.py --verbose -i distributed/test_pg_wrapper
|
time python test/run_test.py --verbose -i distributed/test_pg_wrapper
|
||||||
time python test/run_test.py --verbose -i distributed/rpc/cuda/test_tensorpipe_agent
|
time python test/run_test.py --verbose -i distributed/rpc/cuda/test_tensorpipe_agent
|
||||||
|
@ -1,17 +1,16 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# This file contains an example for using IntraNodeComm to implement efficient fused
|
# This file contains an example for using cuda_p2p backend to implement efficient fused
|
||||||
# allgather_matmul (inspired by https://dl.acm.org/doi/pdf/10.1145/3567955.3567959 and
|
# allgather_matmul (inspired by https://dl.acm.org/doi/pdf/10.1145/3567955.3567959 and
|
||||||
# @lw's efficient GPU implementation in xformers). Its purpose to help guide the
|
# @lw's efficient GPU implementation in xformers). Its purpose to help guide the
|
||||||
# development of relevant primitives and serve as an example for interested users.
|
# development of relevant primitives and serve as an example for interested users.
|
||||||
#
|
#
|
||||||
# The benchmark can be executed as follows:
|
# The benchmark can be executed as follows:
|
||||||
# torchrun --nproc-per-node 8 allgather_matmul.py
|
# torchrun --nproc-per-node 8 allgather_matmul.py
|
||||||
#
|
|
||||||
# NOTE: _IntraNodeComm is a prototype API which WILL change over time.
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._C._distributed_c10d as c10d
|
import torch.distributed as dist
|
||||||
|
from torch.distributed._cuda_p2p import ProcessGroupCudaP2P
|
||||||
|
|
||||||
M = 16384
|
M = 16384
|
||||||
N = 8192
|
N = 8192
|
||||||
@ -21,55 +20,60 @@ WARMUP_ITERS = 200
|
|||||||
BENCH_ITERS = 50
|
BENCH_ITERS = 50
|
||||||
|
|
||||||
|
|
||||||
comm = None
|
def allgather_matmul(A_shard: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
|
||||||
internal_stream = None
|
group = dist.group.WORLD
|
||||||
internal_event = None
|
group_size = group.size()
|
||||||
|
A = torch.ops._c10d_functional.all_gather_into_tensor(A_shard, group_size, "0")
|
||||||
|
A = torch.ops._c10d_functional.wait_tensor(A)
|
||||||
|
return A @ B
|
||||||
|
|
||||||
|
|
||||||
def allgather_matmul(A_shard, B, out, rank, world_size):
|
def allgather_matmul_p2p(A_shard: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Equivalent to `torch.matmul(dist.all_gather(A_shard), B)`.
|
Equivalent to `torch.matmul(dist.all_gather(A_shard), B)`.
|
||||||
"""
|
"""
|
||||||
buf_0 = torch.empty_like(A_shard)
|
group = dist.group.WORLD
|
||||||
buf_1 = torch.empty_like(A_shard)
|
group_size = group.size()
|
||||||
out_shards = [
|
rank = group.rank()
|
||||||
out[i : i + A_shard.shape[0]]
|
backend = group._get_backend(torch.device("cuda"))
|
||||||
for i in range(0, world_size * A_shard.shape[0], A_shard.shape[0])
|
|
||||||
]
|
out = torch.empty(
|
||||||
|
(A_shard.shape[0] * group.size(), B.shape[1]),
|
||||||
|
dtype=A_shard.dtype,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
out_shards = out.chunk(group_size)
|
||||||
|
local_p2p_buf = backend.get_p2p_buffer(rank, A_shard.shape, A_shard.dtype)
|
||||||
|
|
||||||
# Perform matmul with the local input shard
|
# Perform matmul with the local input shard
|
||||||
torch.matmul(A_shard, B, out=out_shards[rank])
|
torch.matmul(A_shard, B, out=out_shards[rank])
|
||||||
|
|
||||||
# In another stream, copy the local input shard into the intra-node
|
with torch.cuda.stream(backend.stream()):
|
||||||
# buffer. After the barrier, all peers' input shards are accessible
|
local_p2p_buf.copy_(A_shard)
|
||||||
# via their intra-node buffer without requiring synchronization.
|
work = backend.intra_node_barrier()
|
||||||
with torch.cuda.stream(internal_stream):
|
work.wait()
|
||||||
comm.put(A_shard)
|
|
||||||
comm.barrier()
|
|
||||||
internal_event.record()
|
|
||||||
internal_event.wait()
|
|
||||||
|
|
||||||
# Copy input shard from remote buffer and perform matmul.
|
buf_0 = torch.empty_like(A_shard)
|
||||||
# Alternate between two streams to offset the wave quantization
|
buf_1 = torch.empty_like(A_shard)
|
||||||
# effect of smaller matmuls.
|
for i in range(1, group_size):
|
||||||
for i in range(1, world_size):
|
|
||||||
if i % 2 == 0:
|
if i % 2 == 0:
|
||||||
buf = buf_0
|
buf = buf_0
|
||||||
stream = torch.cuda.current_stream()
|
stream = torch.cuda.current_stream()
|
||||||
else:
|
else:
|
||||||
buf = buf_1
|
buf = buf_1
|
||||||
stream = internal_stream
|
stream = backend.stream()
|
||||||
remote = (i + rank) % world_size
|
remote_rank = (i + rank) % group_size
|
||||||
|
remote_p2p_buf = backend.get_p2p_buffer(
|
||||||
|
remote_rank, A_shard.shape, A_shard.dtype
|
||||||
|
)
|
||||||
with torch.cuda.stream(stream):
|
with torch.cuda.stream(stream):
|
||||||
comm.get(remote, buf)
|
buf.copy_(remote_p2p_buf)
|
||||||
torch.matmul(buf, B, out=out_shards[remote])
|
torch.matmul(buf, B, out=out_shards[remote_rank])
|
||||||
|
|
||||||
# Perform another barrier to ensure all peers have completed consuming the
|
with torch.cuda.stream(backend.stream()):
|
||||||
# intra-node buffer so it can be reused.
|
work = backend.intra_node_barrier()
|
||||||
with torch.cuda.stream(internal_stream):
|
work.wait()
|
||||||
comm.barrier()
|
return out
|
||||||
internal_event.record()
|
|
||||||
internal_event.wait()
|
|
||||||
|
|
||||||
|
|
||||||
def do_bench(fn):
|
def do_bench(fn):
|
||||||
@ -89,8 +93,6 @@ def do_bench(fn):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
os.environ["ENABLE_INTRA_NODE_COMM"] = "1"
|
|
||||||
|
|
||||||
rank = int(os.environ["RANK"])
|
rank = int(os.environ["RANK"])
|
||||||
local_rank = int(os.environ["LOCAL_RANK"])
|
local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
world_size = int(os.environ["WORLD_SIZE"])
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
@ -98,33 +100,32 @@ def main():
|
|||||||
assert M % world_size == 0
|
assert M % world_size == 0
|
||||||
|
|
||||||
torch.cuda.set_device(local_rank)
|
torch.cuda.set_device(local_rank)
|
||||||
store, _, _ = next(torch.distributed.rendezvous("env://", rank, world_size))
|
|
||||||
|
|
||||||
global comm, internal_stream, internal_event
|
options = ProcessGroupCudaP2P.Options()
|
||||||
comm = c10d._IntraNodeComm(
|
options.buffer_size = M * N * 2 // world_size
|
||||||
store=store,
|
dist.init_process_group("cuda_p2p", pg_options=options)
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
buffer_size=M * K * torch.finfo(torch.bfloat16).bits // 8 // world_size,
|
|
||||||
)
|
|
||||||
internal_stream = torch.cuda.Stream()
|
|
||||||
internal_event = torch.cuda.Event()
|
|
||||||
|
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
A = torch.randn((M, K), dtype=torch.bfloat16, device="cuda")
|
A = torch.randn((M, K), dtype=torch.bfloat16, device="cuda")
|
||||||
B = torch.randn((K, N), dtype=torch.bfloat16, device="cuda")
|
B = torch.randn((K, N), dtype=torch.bfloat16, device="cuda")
|
||||||
out = torch.empty((M, N), dtype=torch.bfloat16, device="cuda")
|
|
||||||
|
|
||||||
stride = M // world_size
|
stride = M // world_size
|
||||||
A_shard = A[rank * stride : (rank + 1) * stride]
|
A_shard = A[rank * stride : (rank + 1) * stride]
|
||||||
|
|
||||||
comm.barrier()
|
assert torch.allclose(
|
||||||
torch.cuda.synchronize()
|
allgather_matmul(A_shard, B),
|
||||||
allgather_matmul_ms = do_bench(
|
allgather_matmul_p2p(A_shard, B),
|
||||||
lambda: allgather_matmul(A_shard, B, out, rank, world_size)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
comm.barrier()
|
dist.barrier()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
allgather_matmul_ms = do_bench(lambda: allgather_matmul(A_shard, B))
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
allgather_matmul_p2p_ms = do_bench(lambda: allgather_matmul_p2p(A_shard, B))
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
matmul_ms = do_bench(lambda: torch.matmul(A, B))
|
matmul_ms = do_bench(lambda: torch.matmul(A, B))
|
||||||
|
|
||||||
@ -134,8 +135,15 @@ def main():
|
|||||||
f"(M={M // world_size}, N={N}, K={K}, world_size={world_size}): "
|
f"(M={M // world_size}, N={N}, K={K}, world_size={world_size}): "
|
||||||
f"{allgather_matmul_ms:.4} ms/iter"
|
f"{allgather_matmul_ms:.4} ms/iter"
|
||||||
)
|
)
|
||||||
|
print(
|
||||||
|
"allgather_matmul_p2p "
|
||||||
|
f"(M={M // world_size}, N={N}, K={K}, world_size={world_size}): "
|
||||||
|
f"{allgather_matmul_p2p_ms:.4} ms/iter"
|
||||||
|
)
|
||||||
print(f"matmul (M={M}, N={N}, K={K}): {matmul_ms:.4} ms/iter")
|
print(f"matmul (M={M}, N={N}, K={K}): {matmul_ms:.4} ms/iter")
|
||||||
|
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -675,6 +675,7 @@ libtorch_cuda_distributed_base_sources = [
|
|||||||
# These files are only supported on Linux (and others) but not on Windows.
|
# These files are only supported on Linux (and others) but not on Windows.
|
||||||
libtorch_cuda_distributed_extra_sources = [
|
libtorch_cuda_distributed_extra_sources = [
|
||||||
"torch/csrc/distributed/c10d/NCCLUtils.cpp",
|
"torch/csrc/distributed/c10d/NCCLUtils.cpp",
|
||||||
|
"torch/csrc/distributed/c10d/ProcessGroupCudaP2P.cpp",
|
||||||
"torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp",
|
"torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp",
|
||||||
"torch/csrc/distributed/c10d/ProcessGroupUCC.cpp",
|
"torch/csrc/distributed/c10d/ProcessGroupUCC.cpp",
|
||||||
"torch/csrc/distributed/c10d/UCCTracing.cpp",
|
"torch/csrc/distributed/c10d/UCCTracing.cpp",
|
||||||
|
142
test/distributed/test_cuda_p2p.py
Normal file
142
test/distributed/test_cuda_p2p.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
# Owner(s): ["module: c10d"]
|
||||||
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed._cuda_p2p import (
|
||||||
|
get_cuda_p2p_backend,
|
||||||
|
get_p2p_buffer_size,
|
||||||
|
is_cuda_p2p_group,
|
||||||
|
)
|
||||||
|
from torch.testing._internal.common_distributed import (
|
||||||
|
MultiProcessTestCase,
|
||||||
|
requires_nccl,
|
||||||
|
skip_if_lt_x_gpu,
|
||||||
|
)
|
||||||
|
from torch.testing._internal.common_utils import (
|
||||||
|
run_tests,
|
||||||
|
skip_but_pass_in_sandcastle_if,
|
||||||
|
skipIfRocm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def requires_cuda_p2p_access():
|
||||||
|
cuda_p2p_access_available = (
|
||||||
|
torch.cuda.is_available()
|
||||||
|
and torch.cuda.device_count() >= 2
|
||||||
|
and dist.is_nccl_available()
|
||||||
|
)
|
||||||
|
num_devices = torch.cuda.device_count()
|
||||||
|
for i in range(num_devices - 1):
|
||||||
|
for j in range(i + 1, num_devices):
|
||||||
|
if not torch.cuda.can_device_access_peer(i, j):
|
||||||
|
cuda_p2p_access_available = False
|
||||||
|
break
|
||||||
|
if not cuda_p2p_access_available:
|
||||||
|
break
|
||||||
|
|
||||||
|
return skip_but_pass_in_sandcastle_if(
|
||||||
|
not cuda_p2p_access_available,
|
||||||
|
"cuda p2p access is not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@requires_nccl()
|
||||||
|
@requires_cuda_p2p_access()
|
||||||
|
class ProcessGroupCudaP2PTest(MultiProcessTestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
super().setUp()
|
||||||
|
self._spawn_processes()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def world_size(self) -> int:
|
||||||
|
return 2
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ranks(self) -> List[int]:
|
||||||
|
return list(range(self.world_size))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return torch.device(f"cuda:{self.rank}")
|
||||||
|
|
||||||
|
def _init_process_group(self, buffer_size: int) -> None:
|
||||||
|
os.environ["TEST_INTRA_NODE_COMM"] = "1"
|
||||||
|
torch.cuda.set_device(self.device)
|
||||||
|
|
||||||
|
# Verify cuda p2p specific APIs on ProcessGroupCudaP2P
|
||||||
|
store = dist.FileStore(self.file_name, self.world_size)
|
||||||
|
options = dist.ProcessGroupCudaP2P.Options()
|
||||||
|
options.buffer_size = buffer_size
|
||||||
|
dist.init_process_group(
|
||||||
|
backend="cuda_p2p",
|
||||||
|
world_size=self.world_size,
|
||||||
|
rank=self.rank,
|
||||||
|
store=store,
|
||||||
|
pg_options=options,
|
||||||
|
)
|
||||||
|
|
||||||
|
@skipIfRocm
|
||||||
|
@skip_if_lt_x_gpu(2)
|
||||||
|
def test_p2p_apis(self) -> None:
|
||||||
|
BUFFER_SIZE = 4 * 1024
|
||||||
|
|
||||||
|
self._init_process_group(BUFFER_SIZE)
|
||||||
|
|
||||||
|
# Verify cuda p2p specific APIs on ProcessGroupCudaP2P
|
||||||
|
assert is_cuda_p2p_group(dist.group.WORLD)
|
||||||
|
assert get_p2p_buffer_size(dist.group.WORLD) == BUFFER_SIZE
|
||||||
|
|
||||||
|
backend = get_cuda_p2p_backend(dist.group.WORLD)
|
||||||
|
assert isinstance(backend, dist.ProcessGroupCudaP2P)
|
||||||
|
assert backend.get_buffer_size() == BUFFER_SIZE
|
||||||
|
|
||||||
|
backend.get_p2p_buffer(self.rank, (BUFFER_SIZE // 4,), torch.float)
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
backend.get_p2p_buffer(self.rank, (BUFFER_SIZE // 4 + 1,), torch.float)
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
backend.get_p2p_buffer(self.rank, (BUFFER_SIZE // 4,), torch.float, 1)
|
||||||
|
|
||||||
|
# Verify cuda p2p specific APIs on non-cuda p2p process group
|
||||||
|
non_cuda_p2p_pg = dist.new_group(backend="nccl")
|
||||||
|
|
||||||
|
assert not is_cuda_p2p_group(non_cuda_p2p_pg)
|
||||||
|
assert get_p2p_buffer_size(non_cuda_p2p_pg) == 0
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
get_cuda_p2p_backend(non_cuda_p2p_pg)
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
@skipIfRocm
|
||||||
|
@skip_if_lt_x_gpu(2)
|
||||||
|
def test_p2p_buffer(self) -> None:
|
||||||
|
BUFFER_SIZE = 4 * 1024
|
||||||
|
|
||||||
|
self._init_process_group(BUFFER_SIZE)
|
||||||
|
rank = self.rank
|
||||||
|
world_size = self.world_size
|
||||||
|
|
||||||
|
assert is_cuda_p2p_group(dist.group.WORLD)
|
||||||
|
backend = get_cuda_p2p_backend(dist.group.WORLD)
|
||||||
|
local_buffer = backend.get_p2p_buffer(
|
||||||
|
(rank) % world_size, (BUFFER_SIZE // 4,), torch.float
|
||||||
|
)
|
||||||
|
remote_buffer = backend.get_p2p_buffer(
|
||||||
|
(rank + 1) % world_size, (BUFFER_SIZE // 4,), torch.float
|
||||||
|
)
|
||||||
|
|
||||||
|
local_buffer.fill_(rank)
|
||||||
|
backend.intra_node_barrier()
|
||||||
|
assert remote_buffer.eq((rank + 1) % world_size).all()
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_tests()
|
@ -605,3 +605,30 @@ def _register_process_group(
|
|||||||
def _resolve_process_group(group_name: str) -> ProcessGroup: ...
|
def _resolve_process_group(group_name: str) -> ProcessGroup: ...
|
||||||
def _unregister_all_process_groups() -> None: ...
|
def _unregister_all_process_groups() -> None: ...
|
||||||
def _unregister_process_group(group_name: str) -> None: ...
|
def _unregister_process_group(group_name: str) -> None: ...
|
||||||
|
|
||||||
|
class ProcessGroupCudaP2P(Backend):
|
||||||
|
class Options:
|
||||||
|
nccl_options: Optional[ProcessGroupNCCL.Options]
|
||||||
|
buffer_size: Optional[int]
|
||||||
|
|
||||||
|
def __init__(self) -> None: ...
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
store: Store,
|
||||||
|
rank: int,
|
||||||
|
size: int,
|
||||||
|
options: ProcessGroupCudaP2P.Options,
|
||||||
|
) -> None: ...
|
||||||
|
def is_p2p_available(self) -> bool: ...
|
||||||
|
def get_buffer_size(self) -> int: ...
|
||||||
|
def stream(self) -> torch.cuda.Stream: ...
|
||||||
|
def intra_node_barrier(self) -> Work: ...
|
||||||
|
def get_p2p_buffer(
|
||||||
|
self,
|
||||||
|
rank: int,
|
||||||
|
sizes: torch.Size,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
storage_offset: Optional[int] = 0,
|
||||||
|
) -> torch.Tensor: ...
|
||||||
|
def _shutdown(self) -> None: ...
|
||||||
|
206
torch/csrc/distributed/c10d/ProcessGroupCudaP2P.cpp
Normal file
206
torch/csrc/distributed/c10d/ProcessGroupCudaP2P.cpp
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
#ifdef USE_C10D_NCCL
|
||||||
|
#include <torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp>
|
||||||
|
|
||||||
|
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
||||||
|
|
||||||
|
namespace c10d {
|
||||||
|
|
||||||
|
using namespace c10d::intra_node_comm;
|
||||||
|
|
||||||
|
ProcessGroupCudaP2P::ProcessGroupCudaP2P(
|
||||||
|
const c10::intrusive_ptr<Store>& store,
|
||||||
|
int rank,
|
||||||
|
int size,
|
||||||
|
c10::intrusive_ptr<Options> options)
|
||||||
|
: Backend(rank, size), stream_(c10::cuda::getStreamFromPool()) {
|
||||||
|
nccl_backend_ = c10::make_intrusive<ProcessGroupNCCL>(
|
||||||
|
c10::make_intrusive<PrefixStore>("nccl", store),
|
||||||
|
rank,
|
||||||
|
size,
|
||||||
|
options->nccl_options);
|
||||||
|
nccl_backend_->setSequenceNumberForGroup();
|
||||||
|
|
||||||
|
p2p_backend_ = c10::make_intrusive<IntraNodeComm>(
|
||||||
|
c10::make_intrusive<PrefixStore>("p2p", store),
|
||||||
|
rank,
|
||||||
|
size,
|
||||||
|
options->buffer_size);
|
||||||
|
if (!p2p_backend_->rendezvous()) {
|
||||||
|
p2p_backend_ = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ProcessGroupCudaP2P::is_p2p_available() {
|
||||||
|
return p2p_backend_ != nullptr &&
|
||||||
|
p2p_backend_->getTopology() == Topology::FULLY_CONNECTED;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t ProcessGroupCudaP2P::get_buffer_size() {
|
||||||
|
if (p2p_backend_ == nullptr) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return p2p_backend_->getBufferSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::Stream ProcessGroupCudaP2P::stream() {
|
||||||
|
return stream_;
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::broadcast(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
const BroadcastOptions& opts) {
|
||||||
|
return nccl_backend_->broadcast(tensors, opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allreduce(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
const AllreduceOptions& opts) {
|
||||||
|
return nccl_backend_->allreduce(tensors, opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allreduce_sparse(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
const AllreduceOptions& opts) {
|
||||||
|
return nccl_backend_->allreduce_sparse(tensors, opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allreduce_coalesced(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
const AllreduceCoalescedOptions& opts) {
|
||||||
|
return nccl_backend_->allreduce_coalesced(tensors, opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::reduce(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
const ReduceOptions& opts) {
|
||||||
|
return nccl_backend_->reduce(tensors, opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allgather(
|
||||||
|
std::vector<std::vector<at::Tensor>>& outputTensors,
|
||||||
|
std::vector<at::Tensor>& inputTensors,
|
||||||
|
const AllgatherOptions& opts) {
|
||||||
|
return nccl_backend_->allgather(outputTensors, inputTensors, opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::_allgather_base(
|
||||||
|
at::Tensor& outputBuffer,
|
||||||
|
at::Tensor& inputBuffer,
|
||||||
|
const AllgatherOptions& opts) {
|
||||||
|
return nccl_backend_->_allgather_base(outputBuffer, inputBuffer, opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allgather_coalesced(
|
||||||
|
std::vector<std::vector<at::Tensor>>& outputTensorLists,
|
||||||
|
std::vector<at::Tensor>& inputTensors,
|
||||||
|
const AllgatherOptions& opts) {
|
||||||
|
return nccl_backend_->allgather_coalesced(
|
||||||
|
outputTensorLists, inputTensors, opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::allgather_into_tensor_coalesced(
|
||||||
|
std::vector<at::Tensor>& outputs,
|
||||||
|
std::vector<at::Tensor>& inputs,
|
||||||
|
const AllgatherOptions& opts) {
|
||||||
|
return nccl_backend_->allgather_into_tensor_coalesced(outputs, inputs, opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::gather(
|
||||||
|
std::vector<std::vector<at::Tensor>>& outputTensors,
|
||||||
|
std::vector<at::Tensor>& inputTensors,
|
||||||
|
const GatherOptions& opts) {
|
||||||
|
return nccl_backend_->gather(outputTensors, inputTensors);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::scatter(
|
||||||
|
std::vector<at::Tensor>& outputTensors,
|
||||||
|
std::vector<std::vector<at::Tensor>>& inputTensors,
|
||||||
|
const ScatterOptions& opts) {
|
||||||
|
return nccl_backend_->scatter(outputTensors, inputTensors);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::reduce_scatter(
|
||||||
|
std::vector<at::Tensor>& outputTensors,
|
||||||
|
std::vector<std::vector<at::Tensor>>& inputTensors,
|
||||||
|
const ReduceScatterOptions& opts) {
|
||||||
|
return nccl_backend_->reduce_scatter(outputTensors, inputTensors, opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::_reduce_scatter_base(
|
||||||
|
at::Tensor& outputBuffer,
|
||||||
|
at::Tensor& inputBuffer,
|
||||||
|
const ReduceScatterOptions& opts) {
|
||||||
|
return nccl_backend_->_reduce_scatter_base(outputBuffer, inputBuffer, opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::reduce_scatter_tensor_coalesced(
|
||||||
|
std::vector<at::Tensor>& outputs,
|
||||||
|
std::vector<at::Tensor>& inputs,
|
||||||
|
const ReduceScatterOptions& opts) {
|
||||||
|
return nccl_backend_->reduce_scatter_tensor_coalesced(outputs, inputs, opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::alltoall_base(
|
||||||
|
at::Tensor& outputBuffer,
|
||||||
|
at::Tensor& inputBuffer,
|
||||||
|
std::vector<int64_t>& outputSplitSizes,
|
||||||
|
std::vector<int64_t>& inputSplitSizes,
|
||||||
|
const AllToAllOptions& opts) {
|
||||||
|
return nccl_backend_->alltoall_base(
|
||||||
|
outputBuffer, inputBuffer, outputSplitSizes, inputSplitSizes);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::alltoall(
|
||||||
|
std::vector<at::Tensor>& outputTensors,
|
||||||
|
std::vector<at::Tensor>& inputTensors,
|
||||||
|
const AllToAllOptions& opts) {
|
||||||
|
return nccl_backend_->alltoall(outputTensors, inputTensors, opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::send(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
int dstRank,
|
||||||
|
int tag) {
|
||||||
|
return nccl_backend_->send(tensors, dstRank, tag);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::recv(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
int srcRank,
|
||||||
|
int tag) {
|
||||||
|
return nccl_backend_->recv(tensors, srcRank, tag);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::recvAnysource(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
int tag) {
|
||||||
|
return nccl_backend_->recvAnysource(tensors, tag);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::barrier(
|
||||||
|
const BarrierOptions& opts) {
|
||||||
|
return nccl_backend_->barrier(opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> ProcessGroupCudaP2P::intra_node_barrier(
|
||||||
|
c10::optional<std::vector<int64_t>> ranks) {
|
||||||
|
TORCH_CHECK(p2p_backend_ != nullptr);
|
||||||
|
p2p_backend_->barrier(ranks);
|
||||||
|
return c10::make_intrusive<IntraNodeCommWork>();
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor ProcessGroupCudaP2P::get_p2p_buffer(
|
||||||
|
size_t rank,
|
||||||
|
const std::vector<int64_t>& sizes,
|
||||||
|
c10::ScalarType dtype,
|
||||||
|
int64_t storage_offset) {
|
||||||
|
TORCH_CHECK(p2p_backend_ != nullptr);
|
||||||
|
return p2p_backend_->getBuffer(rank, sizes, dtype, storage_offset);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ProcessGroupCudaP2P::shutdown(c10::optional<std::string> reason) {
|
||||||
|
nccl_backend_->shutdown(reason);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace c10d
|
||||||
|
#endif // USE_C10D_NCCL
|
148
torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp
Normal file
148
torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#ifdef USE_C10D_NCCL
|
||||||
|
#include <torch/csrc/distributed/c10d/Backend.hpp>
|
||||||
|
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
|
||||||
|
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
|
||||||
|
|
||||||
|
constexpr auto kProcessGroupCudaP2PDefaultTimeout =
|
||||||
|
std::chrono::milliseconds(10 * 60 * 1000);
|
||||||
|
|
||||||
|
namespace c10d {
|
||||||
|
|
||||||
|
class TORCH_API ProcessGroupCudaP2P : public Backend {
|
||||||
|
public:
|
||||||
|
struct Options : Backend::Options {
|
||||||
|
c10::intrusive_ptr<ProcessGroupNCCL::Options> nccl_options;
|
||||||
|
c10::optional<size_t> buffer_size;
|
||||||
|
|
||||||
|
explicit Options()
|
||||||
|
: Backend::Options("cuda_p2p", kProcessGroupCudaP2PDefaultTimeout) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
bool is_p2p_available();
|
||||||
|
size_t get_buffer_size();
|
||||||
|
|
||||||
|
c10::Stream stream();
|
||||||
|
|
||||||
|
ProcessGroupCudaP2P(
|
||||||
|
const c10::intrusive_ptr<Store>& store,
|
||||||
|
int rank,
|
||||||
|
int size,
|
||||||
|
c10::intrusive_ptr<Options> options);
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> broadcast(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
const BroadcastOptions& opts = BroadcastOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> allreduce(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
const AllreduceOptions& opts = AllreduceOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> allreduce_sparse(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
const AllreduceOptions& opts = AllreduceOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> allreduce_coalesced(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
const AllreduceCoalescedOptions& opts =
|
||||||
|
AllreduceCoalescedOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> reduce(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
const ReduceOptions& opts = ReduceOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> allgather(
|
||||||
|
std::vector<std::vector<at::Tensor>>& outputTensors,
|
||||||
|
std::vector<at::Tensor>& inputTensors,
|
||||||
|
const AllgatherOptions& opts = AllgatherOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> _allgather_base(
|
||||||
|
at::Tensor& outputBuffer,
|
||||||
|
at::Tensor& inputBuffer,
|
||||||
|
const AllgatherOptions& opts = AllgatherOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> allgather_coalesced(
|
||||||
|
std::vector<std::vector<at::Tensor>>& outputTensorLists,
|
||||||
|
std::vector<at::Tensor>& inputTensors,
|
||||||
|
const AllgatherOptions& opts = AllgatherOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> allgather_into_tensor_coalesced(
|
||||||
|
std::vector<at::Tensor>& outputs,
|
||||||
|
std::vector<at::Tensor>& inputs,
|
||||||
|
const AllgatherOptions& opts = AllgatherOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> gather(
|
||||||
|
std::vector<std::vector<at::Tensor>>& outputTensors,
|
||||||
|
std::vector<at::Tensor>& inputTensors,
|
||||||
|
const GatherOptions& opts = GatherOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> scatter(
|
||||||
|
std::vector<at::Tensor>& outputTensors,
|
||||||
|
std::vector<std::vector<at::Tensor>>& inputTensors,
|
||||||
|
const ScatterOptions& opts = ScatterOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> reduce_scatter(
|
||||||
|
std::vector<at::Tensor>& outputTensors,
|
||||||
|
std::vector<std::vector<at::Tensor>>& inputTensors,
|
||||||
|
const ReduceScatterOptions& opts) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> _reduce_scatter_base(
|
||||||
|
at::Tensor& outputBuffer,
|
||||||
|
at::Tensor& inputBuffer,
|
||||||
|
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced(
|
||||||
|
std::vector<at::Tensor>& outputs,
|
||||||
|
std::vector<at::Tensor>& inputs,
|
||||||
|
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> alltoall_base(
|
||||||
|
at::Tensor& outputBuffer,
|
||||||
|
at::Tensor& inputBuffer,
|
||||||
|
std::vector<int64_t>& outputSplitSizes,
|
||||||
|
std::vector<int64_t>& inputSplitSizes,
|
||||||
|
const AllToAllOptions& opts = AllToAllOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> alltoall(
|
||||||
|
std::vector<at::Tensor>& outputTensors,
|
||||||
|
std::vector<at::Tensor>& inputTensors,
|
||||||
|
const AllToAllOptions& opts = AllToAllOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> send(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
int dstRank,
|
||||||
|
int tag) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> recv(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
int srcRank,
|
||||||
|
int tag) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> recvAnysource(
|
||||||
|
std::vector<at::Tensor>& tensors,
|
||||||
|
int tag) override;
|
||||||
|
|
||||||
|
/* P2P-only */
|
||||||
|
c10::intrusive_ptr<Work> barrier(
|
||||||
|
const BarrierOptions& opts = BarrierOptions()) override;
|
||||||
|
|
||||||
|
c10::intrusive_ptr<Work> intra_node_barrier(
|
||||||
|
c10::optional<std::vector<int64_t>> ranks = c10::nullopt);
|
||||||
|
|
||||||
|
at::Tensor get_p2p_buffer(
|
||||||
|
size_t rank,
|
||||||
|
const std::vector<int64_t>& sizes,
|
||||||
|
c10::ScalarType dtype,
|
||||||
|
int64_t storage_offest = 0);
|
||||||
|
|
||||||
|
void shutdown(c10::optional<std::string> reason = c10::nullopt);
|
||||||
|
|
||||||
|
private:
|
||||||
|
c10::intrusive_ptr<ProcessGroupNCCL> nccl_backend_;
|
||||||
|
c10::intrusive_ptr<c10d::intra_node_comm::IntraNodeComm> p2p_backend_;
|
||||||
|
c10::Stream stream_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace c10d
|
||||||
|
#endif // USE_C10D_NCCL
|
@ -24,6 +24,7 @@
|
|||||||
|
|
||||||
#ifdef USE_C10D_NCCL
|
#ifdef USE_C10D_NCCL
|
||||||
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
|
||||||
|
#include <torch/csrc/distributed/c10d/ProcessGroupCudaP2P.hpp>
|
||||||
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
|
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
|
||||||
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
|
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
|
||||||
#endif
|
#endif
|
||||||
@ -2644,14 +2645,7 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
|
|||||||
py::arg("rank"),
|
py::arg("rank"),
|
||||||
py::arg("world_size"),
|
py::arg("world_size"),
|
||||||
py::arg("buffer_size") = c10::nullopt)
|
py::arg("buffer_size") = c10::nullopt)
|
||||||
.def("barrier", &IntraNodeComm::barrier, py::arg("ranks") = py::none())
|
.def("barrier", &IntraNodeComm::barrier, py::arg("ranks") = py::none());
|
||||||
.def("put", &IntraNodeComm::put, py::arg("input"), py::arg("offset") = 0)
|
|
||||||
.def(
|
|
||||||
"get",
|
|
||||||
&IntraNodeComm::get,
|
|
||||||
py::arg("rank"),
|
|
||||||
py::arg("tensor"),
|
|
||||||
py::arg("offset") = 0);
|
|
||||||
|
|
||||||
#ifdef NCCL_HAS_COMM_CTA_CGA
|
#ifdef NCCL_HAS_COMM_CTA_CGA
|
||||||
py::class_<ncclConfig_t>(
|
py::class_<ncclConfig_t>(
|
||||||
@ -2727,6 +2721,54 @@ Example::
|
|||||||
.def_readwrite(
|
.def_readwrite(
|
||||||
"group_name", &::c10d::ProcessGroupNCCL::Options::group_name);
|
"group_name", &::c10d::ProcessGroupNCCL::Options::group_name);
|
||||||
|
|
||||||
|
auto processGroupCudaP2P =
|
||||||
|
intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupCudaP2P>(
|
||||||
|
module, "ProcessGroupCudaP2P", backend)
|
||||||
|
.def(py::init<
|
||||||
|
const c10::intrusive_ptr<::c10d::Store>&,
|
||||||
|
int,
|
||||||
|
int,
|
||||||
|
c10::intrusive_ptr<::c10d::ProcessGroupCudaP2P::Options>>())
|
||||||
|
.def(
|
||||||
|
"is_p2p_available",
|
||||||
|
&::c10d::ProcessGroupCudaP2P::is_p2p_available)
|
||||||
|
.def("get_buffer_size", &::c10d::ProcessGroupCudaP2P::get_buffer_size)
|
||||||
|
.def("stream", &::c10d::ProcessGroupCudaP2P::stream)
|
||||||
|
.def(
|
||||||
|
"intra_node_barrier",
|
||||||
|
&::c10d::ProcessGroupCudaP2P::intra_node_barrier,
|
||||||
|
py::arg("ranks") = py::none())
|
||||||
|
.def(
|
||||||
|
"get_p2p_buffer",
|
||||||
|
[](c10::intrusive_ptr<::c10d::ProcessGroupCudaP2P> self,
|
||||||
|
size_t rank,
|
||||||
|
const std::vector<int64_t>& sizes,
|
||||||
|
py::object data_type_obj,
|
||||||
|
int64_t storage_offset) {
|
||||||
|
auto scalar_type =
|
||||||
|
reinterpret_cast<THPDtype*>(data_type_obj.ptr())
|
||||||
|
->scalar_type;
|
||||||
|
return self->get_p2p_buffer(
|
||||||
|
rank, sizes, scalar_type, storage_offset);
|
||||||
|
},
|
||||||
|
py::arg("rank"),
|
||||||
|
py::arg("sizes"),
|
||||||
|
py::arg("dtype"),
|
||||||
|
py::arg("storage_offset") = 0)
|
||||||
|
.def(
|
||||||
|
"_shutdown",
|
||||||
|
[](const c10::intrusive_ptr<::c10d::ProcessGroupCudaP2P>& self) {
|
||||||
|
return self->shutdown();
|
||||||
|
});
|
||||||
|
|
||||||
|
intrusive_ptr_class_<::c10d::ProcessGroupCudaP2P::Options>(
|
||||||
|
processGroupCudaP2P, "Options", processGroupOptions)
|
||||||
|
.def(py::init<>())
|
||||||
|
.def_readwrite(
|
||||||
|
"nccl_options", &::c10d::ProcessGroupCudaP2P::Options::nccl_options)
|
||||||
|
.def_readwrite(
|
||||||
|
"buffer_size", &::c10d::ProcessGroupCudaP2P::Options::buffer_size);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef USE_C10D_MPI
|
#ifdef USE_C10D_MPI
|
||||||
|
@ -211,9 +211,8 @@ IntraNodeComm::IntraNodeComm(
|
|||||||
: store_(std::move(store)),
|
: store_(std::move(store)),
|
||||||
rank_(rank),
|
rank_(rank),
|
||||||
worldSize_(worldSize),
|
worldSize_(worldSize),
|
||||||
bufferSize_(bufferSize.has_value() ? *bufferSize : kDefaultBufferSize) {
|
bufferSize_(bufferSize.has_value() ? *bufferSize : kDefaultBufferSize),
|
||||||
rendezvous();
|
barrierReady_(at::cuda::CUDAEvent()) {}
|
||||||
}
|
|
||||||
|
|
||||||
IntraNodeComm::~IntraNodeComm() {
|
IntraNodeComm::~IntraNodeComm() {
|
||||||
if (!isInitialized_) {
|
if (!isInitialized_) {
|
||||||
@ -289,7 +288,7 @@ bool IntraNodeComm::rendezvous() {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||||
if (!isIntraNodeCommSupported() || !isEnabled() || worldSize_ < 2 ||
|
if (!isIntraNodeCommSupported() || worldSize_ < 2 ||
|
||||||
worldSize_ > kMaxDevices) {
|
worldSize_ > kMaxDevices) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -504,7 +504,8 @@ at::Tensor IntraNodeComm::oneShotAllReduce(
|
|||||||
at::cuda::CUDAStream& stream) {
|
at::cuda::CUDAStream& stream) {
|
||||||
checkInput(input, rank_);
|
checkInput(input, rank_);
|
||||||
|
|
||||||
const size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize;
|
const size_t numelPerWarp =
|
||||||
|
kBytesPerThread / input.element_size() * kWarpSize;
|
||||||
const size_t N_aligned = alignUp(input.numel(), numelPerWarp);
|
const size_t N_aligned = alignUp(input.numel(), numelPerWarp);
|
||||||
const bool isAligned = (N_aligned == static_cast<size_t>(input.numel()));
|
const bool isAligned = (N_aligned == static_cast<size_t>(input.numel()));
|
||||||
TORCH_CHECK(N_aligned <= bufferSize_ / input.element_size());
|
TORCH_CHECK(N_aligned <= bufferSize_ / input.element_size());
|
||||||
@ -733,6 +734,7 @@ static __global__ void barrierKernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void IntraNodeComm::barrier(std::optional<std::vector<int64_t>> ranks) {
|
void IntraNodeComm::barrier(std::optional<std::vector<int64_t>> ranks) {
|
||||||
|
barrierReady_.block(at::cuda::getCurrentCUDAStream());
|
||||||
if (!ranks.has_value()) {
|
if (!ranks.has_value()) {
|
||||||
ranks = std::vector<int64_t>(worldSize_);
|
ranks = std::vector<int64_t>(worldSize_);
|
||||||
std::iota(ranks->begin(), ranks->end(), 0);
|
std::iota(ranks->begin(), ranks->end(), 0);
|
||||||
@ -745,44 +747,23 @@ void IntraNodeComm::barrier(std::optional<std::vector<int64_t>> ranks) {
|
|||||||
barrierKernel<<<1, kWarpSize, 0, at::cuda::getCurrentCUDAStream()>>>(
|
barrierKernel<<<1, kWarpSize, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||||
reinterpret_cast<P2pState**>(p2pStatesDev_), mask, rank_, worldSize_);
|
reinterpret_cast<P2pState**>(p2pStatesDev_), mask, rank_, worldSize_);
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||||
|
barrierReady_.record();
|
||||||
}
|
}
|
||||||
|
|
||||||
void IntraNodeComm::put(const at::Tensor& tensor, int64_t offset) {
|
at::Tensor IntraNodeComm::getBuffer(
|
||||||
TORCH_CHECK(
|
size_t rank,
|
||||||
tensor.is_non_overlapping_and_dense(),
|
const std::vector<int64_t>& sizes,
|
||||||
"IntraNodeComm::put(): tensor must be non-overlapping and dense");
|
c10::ScalarType dtype,
|
||||||
size_t sz = tensor.numel() * tensor.element_size();
|
int64_t storageOffset) {
|
||||||
TORCH_CHECK(
|
const auto numel = std::accumulate(sizes.begin(), sizes.end(), 0);
|
||||||
offset + sz <= bufferSize_,
|
const auto elementSize = c10::elementSize(dtype);
|
||||||
"IntraNodeComm::put(): offset + tensor size exceeded "
|
TORCH_CHECK((numel + storageOffset) * elementSize <= bufferSize_);
|
||||||
"p2p buffer size");
|
auto options = at::TensorOptions().dtype(dtype).device(
|
||||||
// This results in "Memcpy PtoP" which does not use SMs for copying
|
at::kCUDA, at::cuda::current_device());
|
||||||
AT_CUDA_CHECK(cudaMemcpyAsync(
|
return at::for_blob(buffers_[rank], sizes)
|
||||||
static_cast<char*>(buffers_[rank_]) + offset,
|
.storage_offset(storageOffset)
|
||||||
static_cast<char*>(tensor.data_ptr()),
|
.options(options)
|
||||||
sz,
|
.make_tensor();
|
||||||
cudaMemcpyDeviceToDevice,
|
|
||||||
at::cuda::getCurrentCUDAStream()));
|
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
||||||
}
|
|
||||||
|
|
||||||
void IntraNodeComm::get(size_t rank, at::Tensor tensor, int64_t offset) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
tensor.is_non_overlapping_and_dense(),
|
|
||||||
"IntraNodeComm::get(): tensor must be non-overlapping and dense");
|
|
||||||
size_t sz = tensor.numel() * tensor.element_size();
|
|
||||||
TORCH_CHECK(
|
|
||||||
offset + sz <= bufferSize_,
|
|
||||||
"IntraNodeComm::get(): offset + tensor size exceeded "
|
|
||||||
"p2p buffer size");
|
|
||||||
// This results in "Memcpy PtoP" which does not use SMs for copying
|
|
||||||
AT_CUDA_CHECK(cudaMemcpyAsync(
|
|
||||||
static_cast<char*>(tensor.data_ptr()),
|
|
||||||
static_cast<char*>(buffers_[rank]) + offset,
|
|
||||||
sz,
|
|
||||||
cudaMemcpyDeviceToDevice,
|
|
||||||
at::cuda::getCurrentCUDAStream()));
|
|
||||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace intra_node_comm
|
} // namespace intra_node_comm
|
||||||
|
@ -46,6 +46,10 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target {
|
|||||||
*/
|
*/
|
||||||
bool rendezvous();
|
bool rendezvous();
|
||||||
|
|
||||||
|
Topology getTopology() {
|
||||||
|
return topology_;
|
||||||
|
}
|
||||||
|
|
||||||
size_t getBufferSize() {
|
size_t getBufferSize() {
|
||||||
return bufferSize_;
|
return bufferSize_;
|
||||||
}
|
}
|
||||||
@ -63,17 +67,11 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target {
|
|||||||
*/
|
*/
|
||||||
void barrier(std::optional<std::vector<int64_t>> ranks = c10::nullopt);
|
void barrier(std::optional<std::vector<int64_t>> ranks = c10::nullopt);
|
||||||
|
|
||||||
/**
|
at::Tensor getBuffer(
|
||||||
* Puts the given tensor into the p2p buffer of the current rank at the
|
size_t rank,
|
||||||
* specified offset.
|
const std::vector<int64_t>& sizes,
|
||||||
*/
|
c10::ScalarType dtype,
|
||||||
void put(const at::Tensor& tensor, int64_t offset = 0);
|
int64_t storageOffset);
|
||||||
|
|
||||||
/**
|
|
||||||
* Fills the given tensor with the data from the specified rank's p2p buffer
|
|
||||||
* at the specified offset.
|
|
||||||
*/
|
|
||||||
void get(size_t rank, at::Tensor tensor, int64_t offset = 0);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
at::Tensor oneShotAllReduce(
|
at::Tensor oneShotAllReduce(
|
||||||
@ -92,6 +90,7 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target {
|
|||||||
size_t rank_;
|
size_t rank_;
|
||||||
size_t worldSize_;
|
size_t worldSize_;
|
||||||
size_t bufferSize_;
|
size_t bufferSize_;
|
||||||
|
at::cuda::CUDAEvent barrierReady_;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Members initialized after rendezvous
|
* Members initialized after rendezvous
|
||||||
|
123
torch/distributed/_cuda_p2p/__init__.py
Normal file
123
torch/distributed/_cuda_p2p/__init__.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from typing import Callable, cast, List, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed._functional_collectives as funcol
|
||||||
|
|
||||||
|
import torch.distributed.distributed_c10d as c10d
|
||||||
|
from torch._C._distributed_c10d import _DistributedBackendOptions, Backend
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file contains the registration logic and Python APIs for
|
||||||
|
``ProcessGroupCudaP2P`` (experimental).
|
||||||
|
|
||||||
|
``ProcessGroupCudaP2P`` is a thin wrapper around ``ProcessGroupNCCL``. By
|
||||||
|
default, it routes all collectives to the underlying ``ProcessGroupNCCL``. In
|
||||||
|
addition, ``ProcessGroupCudaP2P`` initializes a P2P workspace that allows
|
||||||
|
direct GPU memory access among the members. The workspace can be used in Python
|
||||||
|
to optimize intra-node communication patterns or to create custom intra-node
|
||||||
|
collectives in CUDA.
|
||||||
|
|
||||||
|
``ProcessGroupCudaP2P`` aims to bridge the gap where certain important patterns
|
||||||
|
can be better optimized via fine-grained P2P memory access than with
|
||||||
|
collectives in the latest version of NCCL. It is meant to complement NCCL
|
||||||
|
rather than replacing it.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
# Using ProcessGroupCudaP2P
|
||||||
|
dist.init_process_group(backend="cuda_p2p", ...)
|
||||||
|
|
||||||
|
# Using ProcessGroupCudaP2P while specifying ProcessGroupCudaP2P.Options
|
||||||
|
pg_options = ProcessGroupCudaP2P.Options()
|
||||||
|
dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)
|
||||||
|
|
||||||
|
# Using ProcessGroupCudaP2P while specifying ProcessGroupNCCL.Options
|
||||||
|
pg_options = ProcessGroupNCCL.Options()
|
||||||
|
dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)
|
||||||
|
|
||||||
|
# Using ProcessGroupCudaP2P while specifying both
|
||||||
|
# ProcessGroupCudaP2P.Options and ProcessGroupNCCL.Options
|
||||||
|
pg_options = ProcessGroupCudaP2P.Options()
|
||||||
|
pg_options.nccl_options = ProcessGroupNCCL.Options()
|
||||||
|
dist.init_process_group(backend="cuda_p2p", pg_options=pg_options, ...)
|
||||||
|
|
||||||
|
# Down-casting the backend to access p2p buffers for cuda_p2p specific
|
||||||
|
# optimizations
|
||||||
|
if is_cuda_p2p_group(group):
|
||||||
|
backend = get_cuda_p2p_backend(group)
|
||||||
|
if required_p2p_buffer_size > backend.get_buffer_size():
|
||||||
|
# fallback
|
||||||
|
p2p_buffer = backend.get_p2p_buffer(...)
|
||||||
|
else:
|
||||||
|
# fallback
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _create_cuda_p2p_group(
|
||||||
|
dist_backend_opts: "_DistributedBackendOptions",
|
||||||
|
options: Union[
|
||||||
|
"c10d.ProcessGroupCudaP2P.Options", "c10d.ProcessGroupNCCL.Options", None
|
||||||
|
],
|
||||||
|
) -> "Backend":
|
||||||
|
if not c10d.is_nccl_available():
|
||||||
|
raise RuntimeError("The cuda_p2p backend is not available")
|
||||||
|
if options is None:
|
||||||
|
options = c10d.ProcessGroupCudaP2P.Options()
|
||||||
|
options.nccl_options = c10d.ProcessGroupNCCL.Options()
|
||||||
|
elif isinstance(options, c10d.ProcessGroupNCCL.Options):
|
||||||
|
nccl_options = options
|
||||||
|
options = c10d.ProcessGroupCudaP2P.Options()
|
||||||
|
options.nccl_options = nccl_options
|
||||||
|
elif isinstance(options, c10d.ProcessGroupCudaP2P.Options):
|
||||||
|
if options.nccl_options is None:
|
||||||
|
options.nccl_options = c10d.ProcessGroupNCCL.Options()
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
"options for cuda_p2p must be ProcessGroupCudaP2P.Options "
|
||||||
|
f"or ProcessGroupNCCL.Options (got: {type(options)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return c10d.ProcessGroupCudaP2P(
|
||||||
|
dist_backend_opts.store,
|
||||||
|
dist_backend_opts.group_rank,
|
||||||
|
dist_backend_opts.group_size,
|
||||||
|
options,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_cuda_p2p_group(group: c10d.ProcessGroup) -> bool:
|
||||||
|
if not c10d.is_nccl_available():
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
backend = group._get_backend(torch.device("cuda"))
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
return isinstance(backend, c10d.ProcessGroupCudaP2P) and backend.is_p2p_available()
|
||||||
|
|
||||||
|
|
||||||
|
def get_cuda_p2p_backend(group: c10d.ProcessGroup) -> "c10d.ProcessGroupCudaP2P":
|
||||||
|
if not is_cuda_p2p_group(group):
|
||||||
|
raise TypeError("group is not a cuda_p2p process group.")
|
||||||
|
return cast(
|
||||||
|
c10d.ProcessGroupCudaP2P,
|
||||||
|
group._get_backend(torch.device("cuda")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_p2p_buffer_size(group: c10d.ProcessGroup) -> int:
|
||||||
|
if not is_cuda_p2p_group(group):
|
||||||
|
return 0
|
||||||
|
backend = get_cuda_p2p_backend(group)
|
||||||
|
return backend.get_buffer_size()
|
||||||
|
|
||||||
|
|
||||||
|
c10d.Backend.register_backend(
|
||||||
|
"cuda_p2p",
|
||||||
|
_create_cuda_p2p_group,
|
||||||
|
extended_api=True,
|
||||||
|
devices=["cuda"],
|
||||||
|
)
|
@ -110,8 +110,10 @@ except ImportError:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from torch._C._distributed_c10d import ProcessGroupNCCL
|
from torch._C._distributed_c10d import ProcessGroupNCCL
|
||||||
|
from torch._C._distributed_c10d import ProcessGroupCudaP2P
|
||||||
ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d"
|
ProcessGroupNCCL.__module__ = "torch.distributed.distributed_c10d"
|
||||||
__all__ += ["ProcessGroupNCCL"]
|
ProcessGroupCudaP2P.__module__ = "torch.distributed.distributed_c10d"
|
||||||
|
__all__ += ["ProcessGroupNCCL", "ProcessGroupCudaP2P"]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_NCCL_AVAILABLE = False
|
_NCCL_AVAILABLE = False
|
||||||
|
|
||||||
@ -1444,7 +1446,7 @@ def _shutdown_backend(pg):
|
|||||||
backend = pg._get_backend(torch.device("cuda"))
|
backend = pg._get_backend(torch.device("cuda"))
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
|
if is_nccl_available() and isinstance(backend, (ProcessGroupNCCL, ProcessGroupCudaP2P)):
|
||||||
# explictly call shutdown to ensure that NCCL resources are released
|
# explictly call shutdown to ensure that NCCL resources are released
|
||||||
backend._shutdown()
|
backend._shutdown()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user