diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 1254f3a2f..76906a93b 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -174,8 +174,7 @@ jobs: run: | pytest -sv tests/e2e/multicard/test_data_parallel.py pytest -sv tests/e2e/multicard/test_expert_parallel.py - # external_launcher test is not stable enough. Fix it later - # pytest -sv tests/e2e/multicard/test_external_launcher.py + pytest -sv tests/e2e/multicard/test_external_launcher.py pytest -sv tests/e2e/multicard/test_fused_moe_allgather_ep.py pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py diff --git a/examples/offline_external_launcher.py b/examples/offline_external_launcher.py index 4566fdcfa..17f844b3f 100644 --- a/examples/offline_external_launcher.py +++ b/examples/offline_external_launcher.py @@ -67,11 +67,38 @@ from vllm import LLM, SamplingParams from vllm.distributed.parallel_state import ( # noqa E402 destroy_distributed_environment, destroy_model_parallel, get_tp_group) from vllm.utils import get_open_port, GiB_bytes +from safetensors.torch import load_file os.environ["VLLM_USE_MODELSCOPE"] = "True" os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" +def patch_vllm_moe_model_weight_loader(model): + model = getattr(model, "model", None) or getattr(model, "language_model", None) + if model is None: + raise ValueError("The provided model does not have a valid 'model' or 'language_model' attribute.") + for layer in model.layers: + mlp_attr = "mlp" + mlp = getattr(layer, mlp_attr) + param_dict = dict(mlp.named_parameters()) + for name, param in param_dict.items(): + if "w13_weight" in name or "w2_weight" in name: + param.weight_loader = mlp.experts.weight_loader + + +def load_and_merge_safetensors(directory): + if not os.path.isdir(directory): + raise ValueError(f"The provided directory does not exist: {directory}") + merged_dict = {} + for filename in os.listdir(directory): + if filename.endswith(".safetensors"): + file_path = os.path.join(directory, filename) + print(f"loading file: {file_path}") + f = load_file(file_path) + merged_dict.update(f) + return merged_dict + + def parse_args(): parser = argparse.ArgumentParser(description="External launcher Inference") @@ -125,6 +152,11 @@ def parse_args(): type=float, default=None, help="Model weight memory usage in GiB (e.g., 1.0 for 0.5B model).") + parser.add_argument("--sleep-mode-level", + type=int, + choices=[1, 2], + default=1, + help="Sleep mode level: 1 or 2. This example of level 2 is only supported for dense model.") args = parser.parse_args() if args.enable_sleep_mode: @@ -152,6 +184,7 @@ def main( trust_remote_code: bool = True, enable_sleep_mode: bool = False, temperature: float = 0.8, + sleep_mode_level: int = 1, ): os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_PORT"] = str(master_port) @@ -193,7 +226,7 @@ def main( if enable_sleep_mode: if rank == 0: free_bytes_before_sleep, total = torch.npu.mem_get_info() - llm.sleep(level=1) + llm.sleep(level=sleep_mode_level) if rank == 0: free_bytes_after_sleep, total = torch.npu.mem_get_info() freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep @@ -201,7 +234,16 @@ def main( # now the freed memory should be larger than the model weights assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes - llm.wake_up() + if sleep_mode_level == 1: + llm.wake_up() + else: + llm.wake_up(tags=["weights"]) + run_model = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model + patch_vllm_moe_model_weight_loader(run_model) + sd = load_and_merge_safetensors(model) + run_model.load_weights(sd.items()) + llm.wake_up(tags=["kv_cache"]) + outputs_after_wakeup = llm.generate(prompts, sampling_params) if rank == 0: # cmp output @@ -268,6 +310,7 @@ if __name__ == "__main__": args.trust_remote_code, args.enable_sleep_mode, args.temperature, + args.sleep_mode_level, )) proc.start() diff --git a/tests/e2e/multicard/test_external_launcher.py b/tests/e2e/multicard/test_external_launcher.py index 24c66bfcb..9bf855e30 100644 --- a/tests/e2e/multicard/test_external_launcher.py +++ b/tests/e2e/multicard/test_external_launcher.py @@ -28,6 +28,7 @@ from unittest.mock import patch import pytest import torch_npu +from modelscope import snapshot_download # type: ignore MODELS = ["Qwen/Qwen3-0.6B"] MOE_MODELS = ["Qwen/Qwen3-30B-A3B"] @@ -35,6 +36,7 @@ DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] @pytest.mark.parametrize("model", MODELS) +@patch.dict(os.environ, {"HCCL_BUFFSIZE": "500"}) def test_external_launcher(model): script = Path( __file__ @@ -152,12 +154,64 @@ def test_external_launcher_and_sleepmode(): assert proc.returncode == 0 +def test_external_launcher_and_sleepmode_level2(): + script = Path( + __file__ + ).parent.parent.parent.parent / "examples" / "offline_external_launcher.py" + env = os.environ.copy() + model_path = snapshot_download("Qwen/Qwen3-8B") + # TODO: Add moe model test + cmd = [ + sys.executable, + str(script), + "--model", + model_path, + "--tp-size", + "1", + "--node-size", + "1", + "--node-rank", + "0", + "--proc-per-node", + "2", + "--trust-remote-code", + "--enable-sleep-mode", + "--temperature", + "0", + "--model-weight-gib", + "16", + "--sleep-mode-level", + "2", + ] + + print(f"Running subprocess: {' '.join(cmd)}") + proc = subprocess.run( + cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + timeout=300, + ) + output = proc.stdout.decode() + + print(output) + + assert "TP RANKS: [0]" in output + assert "TP RANKS: [1]" in output + assert "Generated text:" in output + assert "Sleep and wake up successfully!!" in output + assert proc.returncode == 0 + + @pytest.mark.skipif( DEVICE_NAME != "Ascend910B", reason="This test is only for Ascend910B devices.", ) @pytest.mark.parametrize("model", MODELS) -@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE": "1"}) +@patch.dict(os.environ, { + "VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE": "1", + "HCCL_BUFFSIZE": "500" +}) def test_mm_allreduce(model): script = Path( __file__