mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core] separate distributed_init from worker (#3904)
This commit is contained in:
@ -4,6 +4,7 @@
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
"""Tensor and pipeline parallel groups."""
|
||||
import contextlib
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -14,14 +15,59 @@ _TENSOR_MODEL_PARALLEL_GROUP = None
|
||||
# Pipeline model parallel group that the current rank belongs to.
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
||||
|
||||
# when people blindly call `torch.distributed.all_reduce` etc,
|
||||
# it will use this group. It is initialized with the `backend`
|
||||
# parameter of `init_distributed_environment` below.
|
||||
# Essentially, this is `torch.distributed.group.WORLD`.
|
||||
# We leave a line here to note that this is device-specific.
|
||||
# Note that this variable is not safe to use, because when users
|
||||
# call `init_distributed_environment` first, and then destroy
|
||||
# the process group themselves, this variable will keep a reference to the
|
||||
# destroyed process group, which is not useful.
|
||||
_DEVICE_WORLD_GROUP = None
|
||||
|
||||
# duing `init_distributed_environment`, we will also initialize a
|
||||
# group with `gloo` backend, to allow direct coordination between
|
||||
# processes through the CPU.
|
||||
_CPU_WORLD_GROUP = None
|
||||
|
||||
# In summary, after calling `init_distributed_environment`, we will
|
||||
# always have two groups: one for device-specific (and is the default)
|
||||
# and one for CPU. All processes will be part of both groups.
|
||||
|
||||
# A list of global ranks for each pipeline group to ease calculation of the
|
||||
# source rank when broadcasting from the first or last pipeline stage.
|
||||
_PIPELINE_GLOBAL_RANKS = None
|
||||
|
||||
|
||||
def init_distributed_environment(
|
||||
world_size: int,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
backend: str = "nccl",
|
||||
):
|
||||
if not torch.distributed.is_initialized():
|
||||
assert distributed_init_method is not None, (
|
||||
"distributed_init_method must be provided when initializing "
|
||||
"distributed environment")
|
||||
# this backend is used for WORLD
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
init_method=distributed_init_method,
|
||||
world_size=world_size,
|
||||
rank=rank)
|
||||
global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP
|
||||
_DEVICE_WORLD_GROUP = torch.distributed.group.WORLD
|
||||
ranks = list(range(torch.distributed.get_world_size()))
|
||||
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
|
||||
backend="gloo")
|
||||
|
||||
|
||||
def initialize_model_parallel(
|
||||
tensor_model_parallel_size: int = 1,
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
backend: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize model parallel groups.
|
||||
@ -48,6 +94,8 @@ def initialize_model_parallel(
|
||||
# Get world size and rank. Ensure some consistencies.
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size: int = torch.distributed.get_world_size()
|
||||
# get the backend of _DEVICE_WORLD_GROUP
|
||||
backend = backend or torch.distributed.get_backend()
|
||||
|
||||
if (world_size !=
|
||||
tensor_model_parallel_size * pipeline_model_parallel_size):
|
||||
@ -69,7 +117,7 @@ def initialize_model_parallel(
|
||||
for i in range(num_tensor_model_parallel_groups):
|
||||
ranks = range(i * tensor_model_parallel_size,
|
||||
(i + 1) * tensor_model_parallel_size)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
group = torch.distributed.new_group(ranks, backend=backend)
|
||||
if rank in ranks:
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = group
|
||||
|
||||
@ -80,7 +128,7 @@ def initialize_model_parallel(
|
||||
"pipeline model parallel group is already initialized")
|
||||
for i in range(num_pipeline_model_parallel_groups):
|
||||
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
group = torch.distributed.new_group(ranks, backend=backend)
|
||||
if rank in ranks:
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = group
|
||||
_PIPELINE_GLOBAL_RANKS = ranks
|
||||
@ -89,14 +137,17 @@ def initialize_model_parallel(
|
||||
def ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size: int,
|
||||
pipeline_model_parallel_size: int,
|
||||
backend: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Helper to initialize model parallel groups if they are not initialized,
|
||||
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
|
||||
values if the model parallel groups are initialized.
|
||||
"""
|
||||
# get the backend of _DEVICE_WORLD_GROUP
|
||||
backend = backend or torch.distributed.get_backend()
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(tensor_model_parallel_size,
|
||||
pipeline_model_parallel_size)
|
||||
pipeline_model_parallel_size, backend)
|
||||
return
|
||||
|
||||
assert (
|
||||
@ -117,6 +168,12 @@ def model_parallel_is_initialized():
|
||||
and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
|
||||
|
||||
|
||||
def get_cpu_world_group():
|
||||
"""Get the CPU world group."""
|
||||
assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized")
|
||||
return _CPU_WORLD_GROUP
|
||||
|
||||
|
||||
def get_tensor_model_parallel_group():
|
||||
"""Get the tensor model parallel group the caller rank belongs to."""
|
||||
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
|
||||
|
@ -1,8 +1,8 @@
|
||||
import ray
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
ensure_model_parallel_initialized, init_distributed_environment)
|
||||
from vllm.utils import get_open_port
|
||||
from vllm.worker.worker import init_distributed_environment
|
||||
|
||||
|
||||
def init_test_distributed_environment(
|
||||
@ -12,15 +12,14 @@ def init_test_distributed_environment(
|
||||
distributed_init_port: str,
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
parallel_config = ParallelConfig(pipeline_parallel_size,
|
||||
tensor_parallel_size,
|
||||
worker_use_ray=True)
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
init_distributed_environment(
|
||||
parallel_config,
|
||||
rank,
|
||||
world_size=pipeline_parallel_size * tensor_parallel_size,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
local_rank=local_rank)
|
||||
ensure_model_parallel_initialized(tensor_parallel_size,
|
||||
pipeline_parallel_size)
|
||||
|
||||
|
||||
def multi_process_tensor_parallel(
|
||||
|
@ -13,7 +13,7 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
broadcast_tensor_dict)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
ensure_model_parallel_initialized)
|
||||
ensure_model_parallel_initialized, init_distributed_environment)
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
@ -251,26 +251,12 @@ class CPUWorker:
|
||||
parallel_config = self.parallel_config
|
||||
rank = self.rank
|
||||
distributed_init_method = self.distributed_init_method
|
||||
|
||||
if torch.distributed.is_initialized():
|
||||
torch_world_size = torch.distributed.get_world_size()
|
||||
if torch_world_size != parallel_config.world_size:
|
||||
raise RuntimeError(
|
||||
"torch.distributed is already initialized but the torch "
|
||||
"world size does not match parallel_config.world_size "
|
||||
f"({torch_world_size} vs. {parallel_config.world_size}).")
|
||||
elif not distributed_init_method:
|
||||
raise ValueError(
|
||||
"distributed_init_method must be set if torch.distributed "
|
||||
"is not already initialized")
|
||||
else:
|
||||
backend = "gloo"
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
init_method=distributed_init_method,
|
||||
)
|
||||
init_distributed_environment(
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
backend="gloo",
|
||||
)
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
torch.distributed.all_reduce(torch.zeros(1).cpu())
|
||||
|
@ -15,7 +15,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
|
||||
broadcast_tensor_dict)
|
||||
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
ensure_model_parallel_initialized)
|
||||
ensure_model_parallel_initialized, init_distributed_environment)
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
@ -97,9 +97,9 @@ class Worker:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
# Initialize the distributed environment.
|
||||
init_distributed_environment(self.parallel_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
init_worker_distributed_environment(self.parallel_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
@ -248,31 +248,15 @@ class Worker:
|
||||
self.parallel_config)
|
||||
|
||||
|
||||
def init_distributed_environment(
|
||||
def init_worker_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
if torch.distributed.is_initialized():
|
||||
torch_world_size = torch.distributed.get_world_size()
|
||||
if torch_world_size != parallel_config.world_size:
|
||||
raise RuntimeError(
|
||||
"torch.distributed is already initialized but the torch world "
|
||||
"size does not match parallel_config.world_size "
|
||||
f"({torch_world_size} vs. {parallel_config.world_size}).")
|
||||
elif not distributed_init_method:
|
||||
raise ValueError(
|
||||
"distributed_init_method must be set if torch.distributed "
|
||||
"is not already initialized")
|
||||
else:
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
init_method=distributed_init_method,
|
||||
)
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank)
|
||||
|
||||
if pynccl_utils.is_initialized():
|
||||
pynccl_world_size = pynccl_utils.get_world_size()
|
||||
@ -291,10 +275,6 @@ def init_distributed_environment(
|
||||
init_method=distributed_init_method,
|
||||
)
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||
if pynccl_utils.is_initialized():
|
||||
pynccl_utils.all_reduce(torch.zeros(1).cuda())
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
@ -302,6 +282,11 @@ def init_distributed_environment(
|
||||
if not parallel_config.disable_custom_all_reduce:
|
||||
init_custom_ar()
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||
if pynccl_utils.is_initialized():
|
||||
pynccl_utils.all_reduce(torch.zeros(1).cuda())
|
||||
|
||||
|
||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
# Check if the GPU supports the dtype.
|
||||
|
Reference in New Issue
Block a user