diff --git a/docs/source/user_guide/feature_guide/eplb_swift_balancer.md b/docs/source/user_guide/feature_guide/eplb_swift_balancer.md index 010aa6a92..77ca9e3c3 100644 --- a/docs/source/user_guide/feature_guide/eplb_swift_balancer.md +++ b/docs/source/user_guide/feature_guide/eplb_swift_balancer.md @@ -16,7 +16,7 @@ Expert balancing for MoE models in LLM serving is essential for optimal performa ### Dynamic EPLB -We need to add environment variable `export PYTHONOPTIMIZE=1` to get context of vllm process. Enable dynamic balancing with auto-tuned parameters. Adjust num_iterations_eplb_update and num_wait_worker_iterations based on workload patterns. +We need to add environment variable `export DYNAMIC_EPLB=true` to enable vllm eplb. Enable dynamic balancing with auto-tuned parameters. Adjust num_iterations_eplb_update and num_wait_worker_iterations based on workload patterns. ```shell vllm serve Qwen/Qwen3-235B-A22 \ @@ -32,7 +32,7 @@ vllm serve Qwen/Qwen3-235B-A22 \ ### Static EPLB #### Initial Setup (Record Expert Map) -Generate the initial expert distribution map using expert_map_record_path. This creates a baseline configuration for future deployments. +We need to add environment variable `export EXPERT_MAP_RECORD=true` to record expert map.Generate the initial expert distribution map using expert_map_record_path. This creates a baseline configuration for future deployments. ```shell vllm serve Qwen/Qwen3-235B-A22 \ diff --git a/vllm_ascend/patch/platform/patch_common/__init__.py b/vllm_ascend/patch/platform/patch_common/__init__.py index b1e7f4e6a..11d0b10e3 100644 --- a/vllm_ascend/patch/platform/patch_common/__init__.py +++ b/vllm_ascend/patch/platform/patch_common/__init__.py @@ -14,7 +14,27 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os + +from vllm.logger import logger import vllm_ascend.patch.platform.patch_common.patch_config # noqa import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa import vllm_ascend.patch.platform.patch_common.patch_mamba_config # noqa + + +def patch_v1_executor(): + try: + dynamic_eplb = os.getenv("DYNAMIC_EPLB", False) or os.getenv( + "EXPERT_MAP_RECORD", False) + if dynamic_eplb: + import vllm_ascend.patch.platform.patch_common.patch_multiproc_executor # noqa + else: + logger.warning("Do not patch v1 executor.") + except RuntimeError as e: + logger.warning( + f"Fail to patch v1 executor, please add environment params DYNAMIC_EPLB or EXPERT_MAP_RECORD : {e}" + ) + + +patch_v1_executor() diff --git a/vllm_ascend/patch/platform/patch_common/patch_multiproc_executor.py b/vllm_ascend/patch/platform/patch_common/patch_multiproc_executor.py new file mode 100644 index 000000000..82b16fc4e --- /dev/null +++ b/vllm_ascend/patch/platform/patch_common/patch_multiproc_executor.py @@ -0,0 +1,151 @@ +import threading +import weakref +from concurrent.futures import ThreadPoolExecutor +from multiprocessing.synchronize import Lock as LockType +from typing import Optional + +import vllm.v1.executor.multiproc_executor +from vllm import envs +from vllm.config import VllmConfig +from vllm.distributed.device_communicators.shm_broadcast import MessageQueue +from vllm.utils import (get_distributed_init_method, get_loopback_ip, + get_mp_context, get_open_port) +from vllm.v1.executor.abstract import FailureCallback +from vllm.v1.executor.multiproc_executor import ( + MultiprocExecutor, UnreadyWorkerProcHandle, WorkerProc, + set_multiprocessing_worker_envs) + + +class AscendMultiprocExecutor(MultiprocExecutor): + supports_pp: bool = True + + def _init_executor(self) -> None: + # Call self.shutdown at exit to clean up + # and ensure workers will be terminated. + self._finalizer = weakref.finalize(self, self.shutdown) + self.is_failed = False + self.shutdown_event = threading.Event() + self.failure_callback: Optional[FailureCallback] = None + self.io_thread_pool: Optional[ThreadPoolExecutor] = None + + self.world_size = self.parallel_config.world_size + tensor_parallel_size = self.parallel_config.tensor_parallel_size + pp_parallel_size = self.parallel_config.pipeline_parallel_size + assert self.world_size == tensor_parallel_size * pp_parallel_size, ( + f"world_size ({self.world_size}) must be equal to the " + f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" + f"_parallel_size ({pp_parallel_size}). ") + + # Set multiprocessing envs + set_multiprocessing_worker_envs() + + # Multiprocessing-based executor does not support multi-node setting. + # Since it only works for single node, we can use the loopback address + # get_loopback_ip() for communication. + distributed_init_method = get_distributed_init_method( + get_loopback_ip(), get_open_port()) + + # Initialize worker and set up message queues for SchedulerOutputs + # and ModelRunnerOutputs + max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024 + self.rpc_broadcast_mq = MessageQueue(self.world_size, + self.world_size, + max_chunk_bytes=max_chunk_bytes) + scheduler_output_handle = self.rpc_broadcast_mq.export_handle() + + # Create workers + context = get_mp_context() + shared_worker_lock = context.Lock() + unready_workers: list[UnreadyWorkerProcHandle] = [] + success = False + try: + for rank in range(self.world_size): + unready_workers.append( + AscendWorkerProc.make_worker_process( + vllm_config=self.vllm_config, + local_rank=rank, + rank=rank, + distributed_init_method=distributed_init_method, + input_shm_handle=scheduler_output_handle, + shared_worker_lock=shared_worker_lock, + )) + + # Workers must be created before wait_for_ready to avoid + # deadlock, since worker.init_device() does a device sync. + self.workers = WorkerProc.wait_for_ready(unready_workers) + + # Ensure message queues are ready. Will deadlock if re-ordered + # Must be kept consistent with the WorkerProc. + self.rpc_broadcast_mq.wait_until_ready() + for w in self.workers: + w.worker_response_mq.wait_until_ready() + + self.start_worker_monitor() + success = True + finally: + if not success: + # Clean up the worker procs if there was a failure. + # Close death_writers first to signal workers to exit + for uw in unready_workers: + if uw.death_writer is not None: + uw.death_writer.close() + self._ensure_worker_termination( + [uw.proc for uw in unready_workers]) + + # For pipeline parallel, we use a thread pool for asynchronous + # execute_model. + if self.max_concurrent_batches > 1: + # Note: must use only 1 IO thread to keep dequeue sequence + # from the response queue + # _async_aggregate_workers_output also assumes a single IO thread + self.io_thread_pool = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="mp_exec_io") + + self.output_rank = self._get_output_rank() + self.has_connector = self.vllm_config.kv_transfer_config is not None + + +class AscendWorkerProc(WorkerProc): + + @staticmethod + def make_worker_process( + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + input_shm_handle, # Receive SchedulerOutput + shared_worker_lock: LockType, + ) -> UnreadyWorkerProcHandle: + context = get_mp_context() + # (reader, writer) + reader, writer = context.Pipe(duplex=False) + + # Create death pipe to detect parent process exit + death_reader, death_writer = context.Pipe(duplex=False) + + process_kwargs = { + "vllm_config": vllm_config, + "local_rank": local_rank, + "rank": rank, + "distributed_init_method": distributed_init_method, + "input_shm_handle": input_shm_handle, + "ready_pipe": (reader, writer), + "death_pipe": death_reader, + "shared_worker_lock": shared_worker_lock, + } + # Run EngineCore busy loop in background process. + proc = context.Process( + target=WorkerProc.worker_main, + kwargs=process_kwargs, + name=f"VllmWorker-{rank}", + daemon=False, + ) + + proc.start() + writer.close() + # Keep death_writer open in parent - when parent exits, + # death_reader in child will get EOFError + return UnreadyWorkerProcHandle(proc, rank, reader, death_writer) + + +vllm.v1.executor.multiproc_executor.MultiprocExecutor = AscendMultiprocExecutor diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index 1a87f3e89..72f8cb72c 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -1089,7 +1089,8 @@ class TorchairAscendFusedMoE(FusedMoE): local_num_experts = (torch.sum(self.expert_map != -1) if self.expert_map is not None else num_experts) if self.dynamic_eplb: - self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64) + self.moe_load = torch.zeros(local_num_experts, + dtype=torch.int64).npu() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.multistream_overlap_shared_expert = \ @@ -1311,17 +1312,26 @@ class TorchairAscendFusedMoE(FusedMoE): tuple) and len(e_hidden_states) == 2: e_hidden_states, shared_hidden_states = e_hidden_states - if self.dynamic_eplb and isinstance( + if isinstance(e_hidden_states, tuple) and len(e_hidden_states) == 4: + e_hidden_states, shared_hidden_states, group_list_type, expert_tokens = e_hidden_states + if self.dynamic_eplb: + self.moe_load += expert_tokens if group_list_type else \ + torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) + + if shared_experts is None and isinstance( e_hidden_states, tuple) and len(e_hidden_states) == 3: e_hidden_states, group_list_type, expert_tokens = e_hidden_states - self.moe_load += expert_tokens if group_list_type else \ - torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) + if self.dynamic_eplb: + self.moe_load += expert_tokens if group_list_type else \ + torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) if (fused_moe_state not in [ FusedMoEState.AllGather, FusedMoEState.AllGatherEP, FusedMoEState.NaiveMulticast ] and not replace_allreduce and not self.enable_shared_expert_dp): if tp_size > 1: + if isinstance(e_hidden_states, tuple): + e_hidden_states = e_hidden_states[0] dist.all_gather(list(chunk_hidden_states), e_hidden_states, self.tp_group) final_hidden_states = torch.cat(chunk_hidden_states, dim=0) diff --git a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py index 0f4615443..ceba1c4f8 100644 --- a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py @@ -365,17 +365,18 @@ def torchair_fused_experts_with_mc2( ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine( **kwargs_mc2) - if dynamic_eplb: - return (hidden_states, 1, expert_token_nums) - if shared_experts is None: + if dynamic_eplb: + return (hidden_states, 1, expert_token_nums) return hidden_states else: with npu_stream_switch("moe_secondary", 0): npu_wait_tensor(shared_act, down_out_list) shared_output, _ = shared_experts.down_proj( (shared_act, swiglu_out_scale)) - return hidden_states, shared_output + if dynamic_eplb: + return (hidden_states, shared_output, 1, expert_token_nums) + return (hidden_states, shared_output) def torchair_init_routing_quant(hidden_states, top_k, topk_ids,