ut: add example and e2e test for sleepmode in external_launcher (#2152)

### What this PR does / why we need it?
This pr add e2e testcase to make sure sleep mode in external_launcher is
ok.

### Does this PR introduce _any_ user-facing change?
not involved

### How was this patch tested?
not involved


- vLLM version: v0.10.0
- vLLM main:
74333ae2f6

Signed-off-by: huangxialu <huangxialu1@huawei.com>
This commit is contained in:
huangxialu
2025-08-06 11:11:53 +08:00
committed by GitHub
parent 8a59367d0c
commit 875a86cbe9
2 changed files with 132 additions and 8 deletions

View File

@ -28,7 +28,7 @@ Single node:
--proc-per-node=2
MOE models:
python examples/offline_external_launcher.py \
--model="Qwen/Qwen3-0.6B" \
--model="Qwen/Qwen3-30B-A3B" \
--tp-size=2 \
--proc-per-node=2 \
--enable-expert-parallel
@ -36,7 +36,7 @@ Single node:
Multi-node:
Node 0 (assume the node has ip of 10.99.48.128):
python examples/offline_external_launcher.py \
--model="Qwen/Qwen3-0.6B" \
--model="Qwen/Qwen3-30B-A3B" \
--tp-size=2 \
--node-size=2 \
--node-rank=0 \
@ -46,7 +46,7 @@ Multi-node:
--master-port=13345
Node 1:
python examples/offline_external_launcher.py \
--model="Qwen/Qwen3-0.6B" \
--model="Qwen/Qwen3-30B-A3B" \
--tp-size=2 \
--node-size=2 \
--node-rank=1 \
@ -66,7 +66,7 @@ import torch
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import ( # noqa E402
destroy_distributed_environment, destroy_model_parallel, get_tp_group)
from vllm.utils import get_open_port
from vllm.utils import get_open_port, GiB_bytes
os.environ["VLLM_USE_MODELSCOPE"] = "True"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@ -114,7 +114,28 @@ def parse_args():
parser.add_argument("--enable-expert-parallel",
action="store_true",
help="Enable expert parallel, used in MOE models.")
return parser.parse_args()
parser.add_argument("--enable-sleep-mode",
action="store_true",
help="Enable sleep mode for the engine.")
parser.add_argument("--temperature",
type=float,
default=0.8,
help="Float that controls the randomness of the sampling.")
parser.add_argument("--model-weight-gib",
type=float,
default=None,
help="Model weight memory usage in GiB (e.g., 1.0 for 0.5B model).")
args = parser.parse_args()
if args.enable_sleep_mode:
if args.model_weight_gib is None or args.temperature != 0:
parser.error("model-weight-gib must be provided, and temperature must be zero when enable-sleep-mode is set.")
if args.model_weight_gib <= 0:
parser.error("model-weight-gib must be greater than 0 when enable-sleep-mode is set.")
if args.model == parser.get_default("model") and args.model_weight_gib is None:
parser.error("model-weight-gib must be provided for default model when enable-sleep-mode is set.")
return args
def main(
@ -122,12 +143,15 @@ def main(
rank: int,
master_addr: str,
master_port: int,
model_weight_gib: float,
model: str = "Qwen/Qwen3-0.6B",
world_size: int = 4,
tensor_parallel_size: int = 2,
enable_expert_parallel: bool = False,
enforce_eager: bool = False,
trust_remote_code: bool = True,
enable_sleep_mode: bool = False,
temperature: float = 0.8,
):
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = str(master_port)
@ -147,7 +171,7 @@ def main(
"The future of AI is",
] * 10
sampling_params = SamplingParams(
temperature=0.8,
temperature=temperature,
top_p=0.95,
max_tokens=10,
)
@ -159,10 +183,31 @@ def main(
trust_remote_code=trust_remote_code,
distributed_executor_backend="external_launcher",
seed=0,
enable_sleep_mode=enable_sleep_mode,
)
tp_ranks = get_tp_group().ranks
print(f'TP RANKS: {tp_ranks}')
outputs = llm.generate(prompts, sampling_params)
if enable_sleep_mode:
if rank == 0:
free_bytes_before_sleep, total = torch.npu.mem_get_info()
llm.sleep(level=1)
if rank == 0:
free_bytes_after_sleep, total = torch.npu.mem_get_info()
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
print(f"Freed memory: {freed_bytes / 1024 ** 3:.2f} GiB")
# now the freed memory should be larger than the model weights
assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes
llm.wake_up()
outputs_after_wakeup = llm.generate(prompts, sampling_params)
if rank == 0:
# cmp output
assert outputs[0].outputs[0].text == outputs_after_wakeup[0].outputs[0].text
print("Sleep and wake up successfully!!")
for i, output in enumerate(outputs):
if i >= 5:
# print only 5 outputs
@ -214,12 +259,15 @@ if __name__ == "__main__":
rank,
master_addr,
master_port,
args.model_weight_gib,
args.model,
world_size,
tp_size,
args.enable_expert_parallel,
args.enforce_eager,
args.trust_remote_code,
args.enable_sleep_mode,
args.temperature,
))
proc.start()

View File

@ -24,15 +24,14 @@ import os
import subprocess
import sys
from pathlib import Path
from unittest.mock import patch
import pytest
MODELS = ["Qwen/Qwen3-0.6B"]
MOE_MODELS = ["Qwen/Qwen3-30B-A3B"]
@pytest.mark.parametrize("model", MODELS)
@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3"})
def test_external_launcher(model):
script = Path(
__file__
@ -71,3 +70,80 @@ def test_external_launcher(model):
assert "TP RANKS: [1]" in output
assert "Generated text:" in output
assert proc.returncode == 0
@pytest.mark.parametrize("model", MOE_MODELS)
def test_moe_external_launcher(model):
script = Path(
__file__
).parent.parent.parent.parent / "examples" / "offline_external_launcher.py"
env = os.environ.copy()
# TODO: Change to 2 when ci machine has 4 cards
cmd = [
sys.executable,
str(script), "--model", model, "--tp-size", "2", "--node-size", "1",
"--node-rank", "0", "--proc-per-node", "2", "--trust-remote-code",
"--enable-expert-parallel"
]
print(f"Running subprocess: {' '.join(cmd)}")
proc = subprocess.run(
cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
timeout=600,
)
output = proc.stdout.decode()
print(output)
assert "TP RANKS: [0, 1]" in output
assert "Generated text:" in output
assert proc.returncode == 0
def test_external_launcher_and_sleepmode():
script = Path(
__file__
).parent.parent.parent.parent / "examples" / "offline_external_launcher.py"
env = os.environ.copy()
# TODO: Change to 2 when ci machine has 4 cards
cmd = [
sys.executable,
str(script),
"--model",
"Qwen/Qwen3-8B",
"--tp-size",
"1",
"--node-size",
"1",
"--node-rank",
"0",
"--proc-per-node",
"2",
"--trust-remote-code",
"--enable-sleep-mode",
"--temperature",
"0",
"--model-weight-gib",
"16",
]
print(f"Running subprocess: {' '.join(cmd)}")
proc = subprocess.run(
cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
timeout=300,
)
output = proc.stdout.decode()
print(output)
assert "TP RANKS: [0]" in output
assert "TP RANKS: [1]" in output
assert "Generated text:" in output
assert "Sleep and wake up successfully!!" in output
assert proc.returncode == 0