[PERF] PyTorch Symmetric Memory All-Reduce (#20759)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Ilya Markov
2025-08-22 23:39:08 +02:00
committed by GitHub
parent 0483fabc74
commit 0313cf854d
8 changed files with 283 additions and 5 deletions

View File

@ -77,7 +77,7 @@ The `multiproc_xpu_executor` forces the use of `spawn`.
There are other miscellaneous places hard-coding the use of `spawn`:
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L135>
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/all_reduce_utils.py#L135>
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/entrypoints/openai/api_server.py#L184>
Related PRs:

View File

@ -0,0 +1,108 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import typing
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import vllm.envs as envs
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.device_communicators.cuda_communicator import (
CudaCommunicator)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
get_tp_group,
init_distributed_environment,
initialize_model_parallel)
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables
torch.manual_seed(42)
random.seed(44)
test_size_elements = 4 * 1024 * 1024
def symm_mem_allreduce_worker(local_rank: int, world_size: int):
monkeypatch = pytest.MonkeyPatch()
with monkeypatch.context() as m:
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
dtype = torch.bfloat16
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': '12345',
})
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
cuda_communicator = typing.cast(CudaCommunicator,
get_tp_group().device_communicator)
symm_mem_comm = cuda_communicator.symm_mem_comm
if symm_mem_comm is None or symm_mem_comm.disabled:
pytest.skip("SymmMemCommunicator is not available or disabled.")
inp_direct_symm_mem = torch.randint(1,
23, (test_size_elements, ),
dtype=dtype,
device=device)
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
pytest.skip(
"SymmMemCommunicator isn't used for this world and input size."
)
original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem)
assert out_direct_symm_mem is not None
group = get_tensor_model_parallel_group().device_group
dist.all_reduce(original_inp_direct_symm_mem, group=group)
torch.testing.assert_close(out_direct_symm_mem,
original_inp_direct_symm_mem,
atol=2.5,
rtol=0.1)
# Test tensor_model_parallel_all_reduce which should use symm_mem
inp_tensor_parallel = torch.randint(-23,
1, (test_size_elements, ),
dtype=dtype,
device=device)
original_inp_tensor_parallel = inp_tensor_parallel.clone()
out_tensor_parallel = tensor_model_parallel_all_reduce(
inp_tensor_parallel)
dist.all_reduce(original_inp_tensor_parallel, group=group)
torch.testing.assert_close(out_tensor_parallel,
original_inp_tensor_parallel,
atol=2.5,
rtol=0.1)
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="SymmMemAllreduce is only available for CUDA platforms.")
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pipeline_parallel_size", [1])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
pipeline_parallel_size):
world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")
# Enable SymmMemCommunicator
monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")
mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size)
cleanup_dist_env_and_memory()

View File

@ -37,7 +37,7 @@ ALLOWED_FILES = set([
'vllm/distributed/utils.py',
'vllm/distributed/parallel_state.py',
'vllm/engine/multiprocessing/client.py',
'vllm/distributed/device_communicators/custom_all_reduce_utils.py',
'vllm/distributed/device_communicators/all_reduce_utils.py',
'vllm/distributed/device_communicators/shm_broadcast.py',
'vllm/engine/multiprocessing/engine.py',
'benchmarks/kernels/graph_machete_bench.py',

View File

@ -23,6 +23,39 @@ from vllm.utils import (cuda_device_count_stateless,
logger = init_logger(__name__)
MiB = 1024 * 1024
# Max size for each world size in case symmetric memory is available
# For different SM architectures
CUSTOM_ALL_REDUCE_MAX_SIZES = {
"9.0": {
2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: MiB // 2, # 512 KB
8: MiB // 4, # 256 KB
},
"10.0": {
2: 2 * MiB, # 2 MB
4: 2 * MiB, # 2 MB
6: 2 * MiB, # 2 MB
8: 2 * MiB, # 2 MB
}
}
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
"9.0": {
2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: 64 * MiB, # 64 MB
8: 64 * MiB, # 64 MB
},
"10.0": {
2: 8 * MiB, # 8 MB
4: 32 * MiB, # 32 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
}
}
def producer(batch_src: Sequence[int],
producer_queue,

View File

@ -44,6 +44,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
PyNcclCommunicator)
from vllm.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce)
from vllm.distributed.device_communicators.symm_mem import (
SymmMemCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
@ -54,6 +56,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
@ -69,6 +72,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
# currently be an MI300 series.
self.qr_comm = QuickAllReduce(group=self.cpu_group,
device=self.device)
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
device=self.device,
)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
@ -105,6 +114,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
symm_mem_comm = self.symm_mem_comm
if symm_mem_comm is not None and \
symm_mem_comm.should_use_symm_mem(input_):
out = symm_mem_comm.all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)

View File

@ -10,8 +10,8 @@ from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check)
from vllm.distributed.device_communicators.all_reduce_utils import (
CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check)
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.platforms import current_platform
@ -109,7 +109,13 @@ class CustomAllreduce:
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
device_capability = current_platform.get_device_capability(
).as_version_str()
if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
max_size = min(
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
max_size)
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))

View File

@ -0,0 +1,111 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.distributed.device_communicators.all_reduce_utils import (
SYMM_MEM_ALL_REDUCE_MAX_SIZES)
from vllm.logger import init_logger
from vllm.platforms import current_platform
try:
import torch.distributed._symmetric_memory as torch_symm_mem
symm_mem_available = True
except ImportError:
symm_mem_available = False
logger = init_logger(__name__)
class SymmMemCommunicator:
_WORLD_SIZES_MULTIMEM = {
"9.0": [4, 6, 8],
"10.0": [6, 8],
}
def __init__(self, group: ProcessGroup, device: Union[int, str,
torch.device]):
self.disabled = True
if not symm_mem_available:
return
if not current_platform.is_cuda():
logger.warning("SymmMemCommunicator: symmetric "
"memory is not available.")
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
torch.cuda.set_device(device)
self.dtype = torch.bfloat16
self.device = device
self.group = group
self.world_size = dist.get_world_size(self.group)
self.device_capability = current_platform.get_device_capability(
).as_version_str()
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
logger.warning(
"SymmMemCommunicator: Device capability %s not supported, "
"communicator is not available.",
self.device_capability,
)
return
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[
self.device_capability]:
logger.warning(
"SymmMemCommunicator: World size %d not supported, "
"communicator is not available.",
self.world_size,
)
return
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
self.world_size]
self.buffer = torch_symm_mem.empty(
self.max_size // self.dtype.itemsize,
device=self.device,
dtype=self.dtype,
)
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
if handle.multicast_ptr == 0:
logger.warning("SymmMemCommunicator: symmetric memory "
"multicast operations are not supported.")
return
self.disabled = False
def should_use_symm_mem(self, inp: torch.Tensor):
if self.disabled:
return False
if inp.dtype != self.dtype:
return False
inp_size = inp.numel() * inp.element_size()
if inp_size % 4 != 0:
return False
return inp_size < self.max_size
def all_reduce(
self,
inp: torch.Tensor,
*,
out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
if not self.should_use_symm_mem(inp):
return None
if out is None:
out = torch.empty_like(inp)
self.buffer[:inp.numel()].copy_(inp.view(-1))
if self.world_size in self._WORLD_SIZES_MULTIMEM[
self.device_capability]:
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
else:
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
out.copy_(self.buffer[:inp.numel()].view(out.shape))
return out

View File

@ -161,6 +161,7 @@ if TYPE_CHECKING:
VLLM_HAS_FLASHINFER_CUBIN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
@ -1156,6 +1157,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_RESPONSES_API_STORE":
lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))),
# Whether to use pytorch symmetric memory for allreduce
"VLLM_ALLREDUCE_USE_SYMM_MEM":
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))),
# Allows vllm to find tuned config under customized folder
"VLLM_TUNED_CONFIG_FOLDER":
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),