mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 05:33:51 +08:00
[Patch]patch of v1 executor when enable eplb. (#3511)
### What this PR does / why we need it? when using dynamic eplb, patch v1 executor to avoid create child process failed. ### How was this patch tested? deepseek in v3. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: offline0806 <3337230449@qq.com> Co-authored-by: offline0806 <3337230449@qq.com>
This commit is contained in:
@ -16,7 +16,7 @@ Expert balancing for MoE models in LLM serving is essential for optimal performa
|
||||
|
||||
### Dynamic EPLB
|
||||
|
||||
We need to add environment variable `export PYTHONOPTIMIZE=1` to get context of vllm process. Enable dynamic balancing with auto-tuned parameters. Adjust num_iterations_eplb_update and num_wait_worker_iterations based on workload patterns.
|
||||
We need to add environment variable `export DYNAMIC_EPLB=true` to enable vllm eplb. Enable dynamic balancing with auto-tuned parameters. Adjust num_iterations_eplb_update and num_wait_worker_iterations based on workload patterns.
|
||||
|
||||
```shell
|
||||
vllm serve Qwen/Qwen3-235B-A22 \
|
||||
@ -32,7 +32,7 @@ vllm serve Qwen/Qwen3-235B-A22 \
|
||||
### Static EPLB
|
||||
#### Initial Setup (Record Expert Map)
|
||||
|
||||
Generate the initial expert distribution map using expert_map_record_path. This creates a baseline configuration for future deployments.
|
||||
We need to add environment variable `export EXPERT_MAP_RECORD=true` to record expert map.Generate the initial expert distribution map using expert_map_record_path. This creates a baseline configuration for future deployments.
|
||||
|
||||
```shell
|
||||
vllm serve Qwen/Qwen3-235B-A22 \
|
||||
|
@ -14,7 +14,27 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
|
||||
from vllm.logger import logger
|
||||
|
||||
import vllm_ascend.patch.platform.patch_common.patch_config # noqa
|
||||
import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa
|
||||
import vllm_ascend.patch.platform.patch_common.patch_mamba_config # noqa
|
||||
|
||||
|
||||
def patch_v1_executor():
|
||||
try:
|
||||
dynamic_eplb = os.getenv("DYNAMIC_EPLB", False) or os.getenv(
|
||||
"EXPERT_MAP_RECORD", False)
|
||||
if dynamic_eplb:
|
||||
import vllm_ascend.patch.platform.patch_common.patch_multiproc_executor # noqa
|
||||
else:
|
||||
logger.warning("Do not patch v1 executor.")
|
||||
except RuntimeError as e:
|
||||
logger.warning(
|
||||
f"Fail to patch v1 executor, please add environment params DYNAMIC_EPLB or EXPERT_MAP_RECORD : {e}"
|
||||
)
|
||||
|
||||
|
||||
patch_v1_executor()
|
||||
|
@ -0,0 +1,151 @@
|
||||
import threading
|
||||
import weakref
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from multiprocessing.synchronize import Lock as LockType
|
||||
from typing import Optional
|
||||
|
||||
import vllm.v1.executor.multiproc_executor
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
||||
from vllm.utils import (get_distributed_init_method, get_loopback_ip,
|
||||
get_mp_context, get_open_port)
|
||||
from vllm.v1.executor.abstract import FailureCallback
|
||||
from vllm.v1.executor.multiproc_executor import (
|
||||
MultiprocExecutor, UnreadyWorkerProcHandle, WorkerProc,
|
||||
set_multiprocessing_worker_envs)
|
||||
|
||||
|
||||
class AscendMultiprocExecutor(MultiprocExecutor):
|
||||
supports_pp: bool = True
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
# Call self.shutdown at exit to clean up
|
||||
# and ensure workers will be terminated.
|
||||
self._finalizer = weakref.finalize(self, self.shutdown)
|
||||
self.is_failed = False
|
||||
self.shutdown_event = threading.Event()
|
||||
self.failure_callback: Optional[FailureCallback] = None
|
||||
self.io_thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
|
||||
self.world_size = self.parallel_config.world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
pp_parallel_size = self.parallel_config.pipeline_parallel_size
|
||||
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
|
||||
f"world_size ({self.world_size}) must be equal to the "
|
||||
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
|
||||
f"_parallel_size ({pp_parallel_size}). ")
|
||||
|
||||
# Set multiprocessing envs
|
||||
set_multiprocessing_worker_envs()
|
||||
|
||||
# Multiprocessing-based executor does not support multi-node setting.
|
||||
# Since it only works for single node, we can use the loopback address
|
||||
# get_loopback_ip() for communication.
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_loopback_ip(), get_open_port())
|
||||
|
||||
# Initialize worker and set up message queues for SchedulerOutputs
|
||||
# and ModelRunnerOutputs
|
||||
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
|
||||
self.rpc_broadcast_mq = MessageQueue(self.world_size,
|
||||
self.world_size,
|
||||
max_chunk_bytes=max_chunk_bytes)
|
||||
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
||||
|
||||
# Create workers
|
||||
context = get_mp_context()
|
||||
shared_worker_lock = context.Lock()
|
||||
unready_workers: list[UnreadyWorkerProcHandle] = []
|
||||
success = False
|
||||
try:
|
||||
for rank in range(self.world_size):
|
||||
unready_workers.append(
|
||||
AscendWorkerProc.make_worker_process(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
input_shm_handle=scheduler_output_handle,
|
||||
shared_worker_lock=shared_worker_lock,
|
||||
))
|
||||
|
||||
# Workers must be created before wait_for_ready to avoid
|
||||
# deadlock, since worker.init_device() does a device sync.
|
||||
self.workers = WorkerProc.wait_for_ready(unready_workers)
|
||||
|
||||
# Ensure message queues are ready. Will deadlock if re-ordered
|
||||
# Must be kept consistent with the WorkerProc.
|
||||
self.rpc_broadcast_mq.wait_until_ready()
|
||||
for w in self.workers:
|
||||
w.worker_response_mq.wait_until_ready()
|
||||
|
||||
self.start_worker_monitor()
|
||||
success = True
|
||||
finally:
|
||||
if not success:
|
||||
# Clean up the worker procs if there was a failure.
|
||||
# Close death_writers first to signal workers to exit
|
||||
for uw in unready_workers:
|
||||
if uw.death_writer is not None:
|
||||
uw.death_writer.close()
|
||||
self._ensure_worker_termination(
|
||||
[uw.proc for uw in unready_workers])
|
||||
|
||||
# For pipeline parallel, we use a thread pool for asynchronous
|
||||
# execute_model.
|
||||
if self.max_concurrent_batches > 1:
|
||||
# Note: must use only 1 IO thread to keep dequeue sequence
|
||||
# from the response queue
|
||||
# _async_aggregate_workers_output also assumes a single IO thread
|
||||
self.io_thread_pool = ThreadPoolExecutor(
|
||||
max_workers=1, thread_name_prefix="mp_exec_io")
|
||||
|
||||
self.output_rank = self._get_output_rank()
|
||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||
|
||||
|
||||
class AscendWorkerProc(WorkerProc):
|
||||
|
||||
@staticmethod
|
||||
def make_worker_process(
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
input_shm_handle, # Receive SchedulerOutput
|
||||
shared_worker_lock: LockType,
|
||||
) -> UnreadyWorkerProcHandle:
|
||||
context = get_mp_context()
|
||||
# (reader, writer)
|
||||
reader, writer = context.Pipe(duplex=False)
|
||||
|
||||
# Create death pipe to detect parent process exit
|
||||
death_reader, death_writer = context.Pipe(duplex=False)
|
||||
|
||||
process_kwargs = {
|
||||
"vllm_config": vllm_config,
|
||||
"local_rank": local_rank,
|
||||
"rank": rank,
|
||||
"distributed_init_method": distributed_init_method,
|
||||
"input_shm_handle": input_shm_handle,
|
||||
"ready_pipe": (reader, writer),
|
||||
"death_pipe": death_reader,
|
||||
"shared_worker_lock": shared_worker_lock,
|
||||
}
|
||||
# Run EngineCore busy loop in background process.
|
||||
proc = context.Process(
|
||||
target=WorkerProc.worker_main,
|
||||
kwargs=process_kwargs,
|
||||
name=f"VllmWorker-{rank}",
|
||||
daemon=False,
|
||||
)
|
||||
|
||||
proc.start()
|
||||
writer.close()
|
||||
# Keep death_writer open in parent - when parent exits,
|
||||
# death_reader in child will get EOFError
|
||||
return UnreadyWorkerProcHandle(proc, rank, reader, death_writer)
|
||||
|
||||
|
||||
vllm.v1.executor.multiproc_executor.MultiprocExecutor = AscendMultiprocExecutor
|
@ -1089,7 +1089,8 @@ class TorchairAscendFusedMoE(FusedMoE):
|
||||
local_num_experts = (torch.sum(self.expert_map != -1)
|
||||
if self.expert_map is not None else num_experts)
|
||||
if self.dynamic_eplb:
|
||||
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
|
||||
self.moe_load = torch.zeros(local_num_experts,
|
||||
dtype=torch.int64).npu()
|
||||
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.multistream_overlap_shared_expert = \
|
||||
@ -1311,17 +1312,26 @@ class TorchairAscendFusedMoE(FusedMoE):
|
||||
tuple) and len(e_hidden_states) == 2:
|
||||
e_hidden_states, shared_hidden_states = e_hidden_states
|
||||
|
||||
if self.dynamic_eplb and isinstance(
|
||||
if isinstance(e_hidden_states, tuple) and len(e_hidden_states) == 4:
|
||||
e_hidden_states, shared_hidden_states, group_list_type, expert_tokens = e_hidden_states
|
||||
if self.dynamic_eplb:
|
||||
self.moe_load += expert_tokens if group_list_type else \
|
||||
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
||||
|
||||
if shared_experts is None and isinstance(
|
||||
e_hidden_states, tuple) and len(e_hidden_states) == 3:
|
||||
e_hidden_states, group_list_type, expert_tokens = e_hidden_states
|
||||
self.moe_load += expert_tokens if group_list_type else \
|
||||
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
||||
if self.dynamic_eplb:
|
||||
self.moe_load += expert_tokens if group_list_type else \
|
||||
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
||||
|
||||
if (fused_moe_state not in [
|
||||
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||
FusedMoEState.NaiveMulticast
|
||||
] and not replace_allreduce and not self.enable_shared_expert_dp):
|
||||
if tp_size > 1:
|
||||
if isinstance(e_hidden_states, tuple):
|
||||
e_hidden_states = e_hidden_states[0]
|
||||
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
||||
self.tp_group)
|
||||
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||
|
@ -365,17 +365,18 @@ def torchair_fused_experts_with_mc2(
|
||||
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
|
||||
**kwargs_mc2)
|
||||
|
||||
if dynamic_eplb:
|
||||
return (hidden_states, 1, expert_token_nums)
|
||||
|
||||
if shared_experts is None:
|
||||
if dynamic_eplb:
|
||||
return (hidden_states, 1, expert_token_nums)
|
||||
return hidden_states
|
||||
else:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(shared_act, down_out_list)
|
||||
shared_output, _ = shared_experts.down_proj(
|
||||
(shared_act, swiglu_out_scale))
|
||||
return hidden_states, shared_output
|
||||
if dynamic_eplb:
|
||||
return (hidden_states, shared_output, 1, expert_token_nums)
|
||||
return (hidden_states, shared_output)
|
||||
|
||||
|
||||
def torchair_init_routing_quant(hidden_states, top_k, topk_ids,
|
||||
|
Reference in New Issue
Block a user