[PERF] PyTorch Symmetric Memory All-Reduce (#20759)
Signed-off-by: ilmarkov <imarkov@redhat.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: ilmarkov <imarkov@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
		| @ -77,7 +77,7 @@ The `multiproc_xpu_executor` forces the use of `spawn`. | ||||
|  | ||||
| There are other miscellaneous places hard-coding the use of `spawn`: | ||||
|  | ||||
| - <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L135> | ||||
| - <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/all_reduce_utils.py#L135> | ||||
| - <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/entrypoints/openai/api_server.py#L184> | ||||
|  | ||||
| Related PRs: | ||||
|  | ||||
							
								
								
									
										108
									
								
								tests/distributed/test_symm_mem_allreduce.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								tests/distributed/test_symm_mem_allreduce.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,108 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
|  | ||||
| import random | ||||
| import typing | ||||
|  | ||||
| import pytest | ||||
| import torch | ||||
| import torch.distributed as dist | ||||
| import torch.multiprocessing as mp | ||||
|  | ||||
| import vllm.envs as envs | ||||
| from vllm.distributed import cleanup_dist_env_and_memory | ||||
| from vllm.distributed.communication_op import tensor_model_parallel_all_reduce | ||||
| from vllm.distributed.device_communicators.cuda_communicator import ( | ||||
|     CudaCommunicator) | ||||
| from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, | ||||
|                                              get_tp_group, | ||||
|                                              init_distributed_environment, | ||||
|                                              initialize_model_parallel) | ||||
| from vllm.platforms import current_platform | ||||
| from vllm.utils import update_environment_variables | ||||
|  | ||||
| torch.manual_seed(42) | ||||
| random.seed(44) | ||||
|  | ||||
| test_size_elements = 4 * 1024 * 1024 | ||||
|  | ||||
|  | ||||
| def symm_mem_allreduce_worker(local_rank: int, world_size: int): | ||||
|     monkeypatch = pytest.MonkeyPatch() | ||||
|     with monkeypatch.context() as m: | ||||
|         m.delenv("CUDA_VISIBLE_DEVICES", raising=False) | ||||
|         dtype = torch.bfloat16 | ||||
|         device = torch.device(f"cuda:{local_rank}") | ||||
|         torch.cuda.set_device(device) | ||||
|         torch.set_default_device(device) | ||||
|         torch.set_default_dtype(dtype) | ||||
|         update_environment_variables({ | ||||
|             'RANK': str(local_rank), | ||||
|             'LOCAL_RANK': str(local_rank), | ||||
|             'WORLD_SIZE': str(world_size), | ||||
|             'MASTER_ADDR': 'localhost', | ||||
|             'MASTER_PORT': '12345', | ||||
|         }) | ||||
|  | ||||
|         init_distributed_environment() | ||||
|         initialize_model_parallel(tensor_model_parallel_size=world_size) | ||||
|  | ||||
|         cuda_communicator = typing.cast(CudaCommunicator, | ||||
|                                         get_tp_group().device_communicator) | ||||
|         symm_mem_comm = cuda_communicator.symm_mem_comm | ||||
|         if symm_mem_comm is None or symm_mem_comm.disabled: | ||||
|             pytest.skip("SymmMemCommunicator is not available or disabled.") | ||||
|  | ||||
|         inp_direct_symm_mem = torch.randint(1, | ||||
|                                             23, (test_size_elements, ), | ||||
|                                             dtype=dtype, | ||||
|                                             device=device) | ||||
|         if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem): | ||||
|             pytest.skip( | ||||
|                 "SymmMemCommunicator isn't used for this world and input size." | ||||
|             ) | ||||
|  | ||||
|         original_inp_direct_symm_mem = inp_direct_symm_mem.clone() | ||||
|         out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem) | ||||
|         assert out_direct_symm_mem is not None | ||||
|  | ||||
|         group = get_tensor_model_parallel_group().device_group | ||||
|         dist.all_reduce(original_inp_direct_symm_mem, group=group) | ||||
|         torch.testing.assert_close(out_direct_symm_mem, | ||||
|                                    original_inp_direct_symm_mem, | ||||
|                                    atol=2.5, | ||||
|                                    rtol=0.1) | ||||
|  | ||||
|         # Test tensor_model_parallel_all_reduce which should use symm_mem | ||||
|         inp_tensor_parallel = torch.randint(-23, | ||||
|                                             1, (test_size_elements, ), | ||||
|                                             dtype=dtype, | ||||
|                                             device=device) | ||||
|         original_inp_tensor_parallel = inp_tensor_parallel.clone() | ||||
|         out_tensor_parallel = tensor_model_parallel_all_reduce( | ||||
|             inp_tensor_parallel) | ||||
|         dist.all_reduce(original_inp_tensor_parallel, group=group) | ||||
|         torch.testing.assert_close(out_tensor_parallel, | ||||
|                                    original_inp_tensor_parallel, | ||||
|                                    atol=2.5, | ||||
|                                    rtol=0.1) | ||||
|  | ||||
|  | ||||
| @pytest.mark.skipif( | ||||
|     not current_platform.is_cuda(), | ||||
|     reason="SymmMemAllreduce is only available for CUDA platforms.") | ||||
| @pytest.mark.parametrize("tp_size", [2]) | ||||
| @pytest.mark.parametrize("pipeline_parallel_size", [1]) | ||||
| @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], | ||||
|                     reason="Only test on CUDA") | ||||
| def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, | ||||
|                             pipeline_parallel_size): | ||||
|     world_size = tp_size * pipeline_parallel_size | ||||
|     if world_size > torch.cuda.device_count(): | ||||
|         pytest.skip("Not enough GPUs to run the test.") | ||||
|  | ||||
|     # Enable SymmMemCommunicator | ||||
|     monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1") | ||||
|  | ||||
|     mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size) | ||||
|     cleanup_dist_env_and_memory() | ||||
| @ -37,7 +37,7 @@ ALLOWED_FILES = set([ | ||||
|     'vllm/distributed/utils.py', | ||||
|     'vllm/distributed/parallel_state.py', | ||||
|     'vllm/engine/multiprocessing/client.py', | ||||
|     'vllm/distributed/device_communicators/custom_all_reduce_utils.py', | ||||
|     'vllm/distributed/device_communicators/all_reduce_utils.py', | ||||
|     'vllm/distributed/device_communicators/shm_broadcast.py', | ||||
|     'vllm/engine/multiprocessing/engine.py', | ||||
|     'benchmarks/kernels/graph_machete_bench.py', | ||||
|  | ||||
| @ -23,6 +23,39 @@ from vllm.utils import (cuda_device_count_stateless, | ||||
| 
 | ||||
| logger = init_logger(__name__) | ||||
| 
 | ||||
| MiB = 1024 * 1024 | ||||
| # Max size for each world size in case symmetric memory is available | ||||
| # For different SM architectures | ||||
| CUSTOM_ALL_REDUCE_MAX_SIZES = { | ||||
|     "9.0": { | ||||
|         2: 64 * MiB,  # 64 MB | ||||
|         4: 32 * MiB,  # 32 MB | ||||
|         6: MiB // 2,  # 512 KB | ||||
|         8: MiB // 4,  # 256 KB | ||||
|     }, | ||||
|     "10.0": { | ||||
|         2: 2 * MiB,  # 2 MB | ||||
|         4: 2 * MiB,  # 2 MB | ||||
|         6: 2 * MiB,  # 2 MB | ||||
|         8: 2 * MiB,  # 2 MB | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| SYMM_MEM_ALL_REDUCE_MAX_SIZES = { | ||||
|     "9.0": { | ||||
|         2: 64 * MiB,  # 64 MB | ||||
|         4: 32 * MiB,  # 32 MB | ||||
|         6: 64 * MiB,  # 64 MB | ||||
|         8: 64 * MiB,  # 64 MB | ||||
|     }, | ||||
|     "10.0": { | ||||
|         2: 8 * MiB,  # 8 MB | ||||
|         4: 32 * MiB,  # 32 MB | ||||
|         6: 128 * MiB,  # 128 MB | ||||
|         8: 128 * MiB,  # 128 MB | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| def producer(batch_src: Sequence[int], | ||||
|              producer_queue, | ||||
| @ -44,6 +44,8 @@ class CudaCommunicator(DeviceCommunicatorBase): | ||||
|             PyNcclCommunicator) | ||||
|         from vllm.distributed.device_communicators.quick_all_reduce import ( | ||||
|             QuickAllReduce) | ||||
|         from vllm.distributed.device_communicators.symm_mem import ( | ||||
|             SymmMemCommunicator) | ||||
|  | ||||
|         self.pynccl_comm: Optional[PyNcclCommunicator] = None | ||||
|         if use_pynccl and self.world_size > 1: | ||||
| @ -54,6 +56,7 @@ class CudaCommunicator(DeviceCommunicatorBase): | ||||
|  | ||||
|         self.ca_comm: Optional[CustomAllreduce] = None | ||||
|         self.qr_comm: Optional[QuickAllReduce] = None | ||||
|         self.symm_mem_comm: Optional[SymmMemCommunicator] = None | ||||
|         if use_custom_allreduce and self.world_size > 1: | ||||
|             # Initialize a custom fast all-reduce implementation. | ||||
|             self.ca_comm = CustomAllreduce( | ||||
| @ -69,6 +72,12 @@ class CudaCommunicator(DeviceCommunicatorBase): | ||||
|                 # currently be an MI300 series. | ||||
|                 self.qr_comm = QuickAllReduce(group=self.cpu_group, | ||||
|                                               device=self.device) | ||||
|         if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): | ||||
|             self.symm_mem_comm = SymmMemCommunicator( | ||||
|                 group=self.cpu_group, | ||||
|                 device=self.device, | ||||
|             ) | ||||
|  | ||||
|         if self.use_all2all: | ||||
|             all2all_backend = envs.VLLM_ALL2ALL_BACKEND | ||||
|             if all2all_backend == "naive": | ||||
| @ -105,6 +114,12 @@ class CudaCommunicator(DeviceCommunicatorBase): | ||||
|             out = ca_comm.custom_all_reduce(input_) | ||||
|             assert out is not None | ||||
|             return out | ||||
|         symm_mem_comm = self.symm_mem_comm | ||||
|         if symm_mem_comm is not None and \ | ||||
|             symm_mem_comm.should_use_symm_mem(input_): | ||||
|             out = symm_mem_comm.all_reduce(input_) | ||||
|             assert out is not None | ||||
|             return out | ||||
|         pynccl_comm = self.pynccl_comm | ||||
|         assert pynccl_comm is not None | ||||
|         out = pynccl_comm.all_reduce(input_) | ||||
|  | ||||
| @ -10,8 +10,8 @@ from torch.distributed import ProcessGroup | ||||
|  | ||||
| import vllm.envs as envs | ||||
| from vllm import _custom_ops as ops | ||||
| from vllm.distributed.device_communicators.custom_all_reduce_utils import ( | ||||
|     gpu_p2p_access_check) | ||||
| from vllm.distributed.device_communicators.all_reduce_utils import ( | ||||
|     CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check) | ||||
| from vllm.distributed.parallel_state import in_the_same_node_as | ||||
| from vllm.logger import init_logger | ||||
| from vllm.platforms import current_platform | ||||
| @ -109,7 +109,13 @@ class CustomAllreduce: | ||||
|         # now `device` is a `torch.device` object | ||||
|         assert isinstance(device, torch.device) | ||||
|         self.device = device | ||||
|  | ||||
|         device_capability = current_platform.get_device_capability( | ||||
|         ).as_version_str() | ||||
|         if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM | ||||
|                 and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES): | ||||
|             max_size = min( | ||||
|                 CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], | ||||
|                 max_size) | ||||
|         cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES | ||||
|         if cuda_visible_devices: | ||||
|             device_ids = list(map(int, cuda_visible_devices.split(","))) | ||||
|  | ||||
							
								
								
									
										111
									
								
								vllm/distributed/device_communicators/symm_mem.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										111
									
								
								vllm/distributed/device_communicators/symm_mem.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,111 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
| from typing import Optional, Union | ||||
|  | ||||
| import torch | ||||
| import torch.distributed as dist | ||||
| from torch.distributed import ProcessGroup | ||||
|  | ||||
| from vllm.distributed.device_communicators.all_reduce_utils import ( | ||||
|     SYMM_MEM_ALL_REDUCE_MAX_SIZES) | ||||
| from vllm.logger import init_logger | ||||
| from vllm.platforms import current_platform | ||||
|  | ||||
| try: | ||||
|     import torch.distributed._symmetric_memory as torch_symm_mem | ||||
|  | ||||
|     symm_mem_available = True | ||||
| except ImportError: | ||||
|     symm_mem_available = False | ||||
|  | ||||
| logger = init_logger(__name__) | ||||
|  | ||||
|  | ||||
| class SymmMemCommunicator: | ||||
|     _WORLD_SIZES_MULTIMEM = { | ||||
|         "9.0": [4, 6, 8], | ||||
|         "10.0": [6, 8], | ||||
|     } | ||||
|  | ||||
|     def __init__(self, group: ProcessGroup, device: Union[int, str, | ||||
|                                                           torch.device]): | ||||
|         self.disabled = True | ||||
|  | ||||
|         if not symm_mem_available: | ||||
|             return | ||||
|  | ||||
|         if not current_platform.is_cuda(): | ||||
|             logger.warning("SymmMemCommunicator: symmetric " | ||||
|                            "memory is not available.") | ||||
|             return | ||||
|         if isinstance(device, int): | ||||
|             device = torch.device(f"cuda:{device}") | ||||
|         elif isinstance(device, str): | ||||
|             device = torch.device(device) | ||||
|         torch.cuda.set_device(device) | ||||
|         self.dtype = torch.bfloat16 | ||||
|         self.device = device | ||||
|         self.group = group | ||||
|         self.world_size = dist.get_world_size(self.group) | ||||
|         self.device_capability = current_platform.get_device_capability( | ||||
|         ).as_version_str() | ||||
|         if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES: | ||||
|             logger.warning( | ||||
|                 "SymmMemCommunicator: Device capability %s not supported, " | ||||
|                 "communicator is not available.", | ||||
|                 self.device_capability, | ||||
|             ) | ||||
|             return | ||||
|         if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[ | ||||
|                 self.device_capability]: | ||||
|             logger.warning( | ||||
|                 "SymmMemCommunicator: World size %d not supported, " | ||||
|                 "communicator is not available.", | ||||
|                 self.world_size, | ||||
|             ) | ||||
|             return | ||||
|         self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ | ||||
|             self.world_size] | ||||
|         self.buffer = torch_symm_mem.empty( | ||||
|             self.max_size // self.dtype.itemsize, | ||||
|             device=self.device, | ||||
|             dtype=self.dtype, | ||||
|         ) | ||||
|         handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name) | ||||
|         if handle.multicast_ptr == 0: | ||||
|             logger.warning("SymmMemCommunicator: symmetric memory " | ||||
|                            "multicast operations are not supported.") | ||||
|             return | ||||
|         self.disabled = False | ||||
|  | ||||
|     def should_use_symm_mem(self, inp: torch.Tensor): | ||||
|         if self.disabled: | ||||
|             return False | ||||
|         if inp.dtype != self.dtype: | ||||
|             return False | ||||
|         inp_size = inp.numel() * inp.element_size() | ||||
|         if inp_size % 4 != 0: | ||||
|             return False | ||||
|         return inp_size < self.max_size | ||||
|  | ||||
|     def all_reduce( | ||||
|             self, | ||||
|             inp: torch.Tensor, | ||||
|             *, | ||||
|             out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]: | ||||
|         if not self.should_use_symm_mem(inp): | ||||
|             return None | ||||
|         if out is None: | ||||
|             out = torch.empty_like(inp) | ||||
|         self.buffer[:inp.numel()].copy_(inp.view(-1)) | ||||
|         if self.world_size in self._WORLD_SIZES_MULTIMEM[ | ||||
|                 self.device_capability]: | ||||
|             torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()], | ||||
|                                                     "sum", | ||||
|                                                     self.group.group_name) | ||||
|         else: | ||||
|             torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()], | ||||
|                                                     "sum", | ||||
|                                                     self.group.group_name) | ||||
|         out.copy_(self.buffer[:inp.numel()].view(out.shape)) | ||||
|         return out | ||||
| @ -161,6 +161,7 @@ if TYPE_CHECKING: | ||||
|     VLLM_HAS_FLASHINFER_CUBIN: bool = False | ||||
|     VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False | ||||
|     VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False | ||||
|     VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False | ||||
|     VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None | ||||
|  | ||||
|  | ||||
| @ -1156,6 +1157,10 @@ environment_variables: dict[str, Callable[[], Any]] = { | ||||
|     "VLLM_ENABLE_RESPONSES_API_STORE": | ||||
|     lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))), | ||||
|  | ||||
|     # Whether to use pytorch symmetric memory for allreduce | ||||
|     "VLLM_ALLREDUCE_USE_SYMM_MEM": | ||||
|     lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))), | ||||
|  | ||||
|     # Allows vllm to find tuned config under customized folder | ||||
|     "VLLM_TUNED_CONFIG_FOLDER": | ||||
|     lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), | ||||
|  | ||||
		Reference in New Issue
	
	Block a user