[Platform][Dist] Make torch distributed process group extendable (#18763)

Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-05-28 18:52:34 +08:00
committed by GitHub
parent ce75efeecb
commit d781930f90
4 changed files with 134 additions and 37 deletions

View File

@ -5,7 +5,6 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import datetime
import os
import pickle
import socket
@ -14,14 +13,14 @@ import time
import uuid
from collections import deque
from collections.abc import Sequence
from datetime import timedelta
from typing import Any, Optional
import torch
from torch.distributed import ProcessGroup, TCPStore
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout,
_unregister_process_group,
is_nccl_available)
_unregister_process_group)
from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs
@ -406,7 +405,7 @@ class StatelessProcessGroup:
port=port,
world_size=world_size,
is_master=launch_server,
timeout=datetime.timedelta(seconds=store_timeout),
timeout=timedelta(seconds=store_timeout),
use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215
master_listen_fd=listen_fd,
)
@ -419,6 +418,43 @@ class StatelessProcessGroup:
data_expiration_seconds=data_expiration_seconds)
def init_gloo_process_group(backend: Backend, prefix_store: PrefixStore,
group_rank: int, group_size: int,
timeout: timedelta) -> ProcessGroup:
"""
Stateless init ProcessGroup with gloo backend compatible with
different torch versions.
"""
if is_torch_equal_or_newer("2.6"):
pg = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
else:
options = ProcessGroup.Options(backend=backend)
pg = ProcessGroup(
prefix_store,
group_rank,
group_size,
options,
)
from torch.distributed.distributed_c10d import ProcessGroupGloo
backend_class = ProcessGroupGloo(prefix_store,
group_rank,
group_size,
timeout=timeout)
backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu")
if is_torch_equal_or_newer("2.6"):
# _set_default_backend is supported in torch >= 2.6
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
def stateless_init_torch_distributed_process_group(
host: str, port: int, rank: int, world_size: int,
backend: str) -> ProcessGroup:
@ -468,40 +504,19 @@ def stateless_init_torch_distributed_process_group(
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store)
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
if backend == "gloo":
from torch.distributed.distributed_c10d import ProcessGroupGloo
backend_class = ProcessGroupGloo(prefix_store,
group_rank,
group_size,
timeout=timeout)
backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu")
elif backend == "nccl":
assert is_nccl_available()
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
else:
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
return init_gloo_process_group(backend=backend,
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout)
from vllm.platforms import current_platform
return current_platform.stateless_init_device_torch_dist_pg(
backend=backend,
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout)
def stateless_destroy_torch_distributed_process_group(

View File

@ -4,10 +4,13 @@ pynvml. However, it should not initialize cuda context.
"""
import os
from datetime import timedelta
from functools import wraps
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
import torch
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
from typing_extensions import ParamSpec
# import custom ops, trigger op registration
@ -316,6 +319,36 @@ class CudaPlatformBase(Platform):
def get_piecewise_backend_cls(cls) -> str:
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
assert is_nccl_available()
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

View File

@ -3,11 +3,13 @@ import enum
import os
import platform
import random
from datetime import timedelta
from platform import uname
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
import numpy as np
import torch
from torch.distributed import PrefixStore, ProcessGroup
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
@ -486,6 +488,20 @@ class Platform:
"""
return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
"""
Init platform-specific torch distributed process group.
"""
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED

View File

@ -1,10 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
import os
from datetime import timedelta
from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING, Optional
import torch
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
import vllm.envs as envs
from vllm.logger import init_logger
@ -387,3 +390,33 @@ class RocmPlatform(Platform):
@classmethod
def get_piecewise_backend_cls(cls) -> str:
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
assert is_nccl_available()
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg