mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[Test] enable external launcher and add e2e test for sleep mode in level2 (#3344)
### What this PR does / why we need it? 1. Enable tests/e2e/multicard/test_external_launcher.py 2. Add e2e test for sleep mode in level2 ### Does this PR introduce _any_ user-facing change? not involved ### How was this patch tested? CI passed with existing test. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: huangxialu <huangxialu1@huawei.com> Co-authored-by: Shangwei-Li <lishangwei2@huawei.com>
This commit is contained in:
3
.github/workflows/_e2e_test.yaml
vendored
3
.github/workflows/_e2e_test.yaml
vendored
@ -174,8 +174,7 @@ jobs:
|
||||
run: |
|
||||
pytest -sv tests/e2e/multicard/test_data_parallel.py
|
||||
pytest -sv tests/e2e/multicard/test_expert_parallel.py
|
||||
# external_launcher test is not stable enough. Fix it later
|
||||
# pytest -sv tests/e2e/multicard/test_external_launcher.py
|
||||
pytest -sv tests/e2e/multicard/test_external_launcher.py
|
||||
pytest -sv tests/e2e/multicard/test_fused_moe_allgather_ep.py
|
||||
pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py
|
||||
|
||||
|
@ -67,11 +67,38 @@ from vllm import LLM, SamplingParams
|
||||
from vllm.distributed.parallel_state import ( # noqa E402
|
||||
destroy_distributed_environment, destroy_model_parallel, get_tp_group)
|
||||
from vllm.utils import get_open_port, GiB_bytes
|
||||
from safetensors.torch import load_file
|
||||
|
||||
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
|
||||
def patch_vllm_moe_model_weight_loader(model):
|
||||
model = getattr(model, "model", None) or getattr(model, "language_model", None)
|
||||
if model is None:
|
||||
raise ValueError("The provided model does not have a valid 'model' or 'language_model' attribute.")
|
||||
for layer in model.layers:
|
||||
mlp_attr = "mlp"
|
||||
mlp = getattr(layer, mlp_attr)
|
||||
param_dict = dict(mlp.named_parameters())
|
||||
for name, param in param_dict.items():
|
||||
if "w13_weight" in name or "w2_weight" in name:
|
||||
param.weight_loader = mlp.experts.weight_loader
|
||||
|
||||
|
||||
def load_and_merge_safetensors(directory):
|
||||
if not os.path.isdir(directory):
|
||||
raise ValueError(f"The provided directory does not exist: {directory}")
|
||||
merged_dict = {}
|
||||
for filename in os.listdir(directory):
|
||||
if filename.endswith(".safetensors"):
|
||||
file_path = os.path.join(directory, filename)
|
||||
print(f"loading file: {file_path}")
|
||||
f = load_file(file_path)
|
||||
merged_dict.update(f)
|
||||
return merged_dict
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
||||
parser = argparse.ArgumentParser(description="External launcher Inference")
|
||||
@ -125,6 +152,11 @@ def parse_args():
|
||||
type=float,
|
||||
default=None,
|
||||
help="Model weight memory usage in GiB (e.g., 1.0 for 0.5B model).")
|
||||
parser.add_argument("--sleep-mode-level",
|
||||
type=int,
|
||||
choices=[1, 2],
|
||||
default=1,
|
||||
help="Sleep mode level: 1 or 2. This example of level 2 is only supported for dense model.")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.enable_sleep_mode:
|
||||
@ -152,6 +184,7 @@ def main(
|
||||
trust_remote_code: bool = True,
|
||||
enable_sleep_mode: bool = False,
|
||||
temperature: float = 0.8,
|
||||
sleep_mode_level: int = 1,
|
||||
):
|
||||
os.environ["MASTER_ADDR"] = master_addr
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
@ -193,7 +226,7 @@ def main(
|
||||
if enable_sleep_mode:
|
||||
if rank == 0:
|
||||
free_bytes_before_sleep, total = torch.npu.mem_get_info()
|
||||
llm.sleep(level=1)
|
||||
llm.sleep(level=sleep_mode_level)
|
||||
if rank == 0:
|
||||
free_bytes_after_sleep, total = torch.npu.mem_get_info()
|
||||
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
||||
@ -201,7 +234,16 @@ def main(
|
||||
# now the freed memory should be larger than the model weights
|
||||
assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes
|
||||
|
||||
llm.wake_up()
|
||||
if sleep_mode_level == 1:
|
||||
llm.wake_up()
|
||||
else:
|
||||
llm.wake_up(tags=["weights"])
|
||||
run_model = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
|
||||
patch_vllm_moe_model_weight_loader(run_model)
|
||||
sd = load_and_merge_safetensors(model)
|
||||
run_model.load_weights(sd.items())
|
||||
llm.wake_up(tags=["kv_cache"])
|
||||
|
||||
outputs_after_wakeup = llm.generate(prompts, sampling_params)
|
||||
if rank == 0:
|
||||
# cmp output
|
||||
@ -268,6 +310,7 @@ if __name__ == "__main__":
|
||||
args.trust_remote_code,
|
||||
args.enable_sleep_mode,
|
||||
args.temperature,
|
||||
args.sleep_mode_level,
|
||||
))
|
||||
|
||||
proc.start()
|
||||
|
@ -28,6 +28,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch_npu
|
||||
from modelscope import snapshot_download # type: ignore
|
||||
|
||||
MODELS = ["Qwen/Qwen3-0.6B"]
|
||||
MOE_MODELS = ["Qwen/Qwen3-30B-A3B"]
|
||||
@ -35,6 +36,7 @@ DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@patch.dict(os.environ, {"HCCL_BUFFSIZE": "500"})
|
||||
def test_external_launcher(model):
|
||||
script = Path(
|
||||
__file__
|
||||
@ -152,12 +154,64 @@ def test_external_launcher_and_sleepmode():
|
||||
assert proc.returncode == 0
|
||||
|
||||
|
||||
def test_external_launcher_and_sleepmode_level2():
|
||||
script = Path(
|
||||
__file__
|
||||
).parent.parent.parent.parent / "examples" / "offline_external_launcher.py"
|
||||
env = os.environ.copy()
|
||||
model_path = snapshot_download("Qwen/Qwen3-8B")
|
||||
# TODO: Add moe model test
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(script),
|
||||
"--model",
|
||||
model_path,
|
||||
"--tp-size",
|
||||
"1",
|
||||
"--node-size",
|
||||
"1",
|
||||
"--node-rank",
|
||||
"0",
|
||||
"--proc-per-node",
|
||||
"2",
|
||||
"--trust-remote-code",
|
||||
"--enable-sleep-mode",
|
||||
"--temperature",
|
||||
"0",
|
||||
"--model-weight-gib",
|
||||
"16",
|
||||
"--sleep-mode-level",
|
||||
"2",
|
||||
]
|
||||
|
||||
print(f"Running subprocess: {' '.join(cmd)}")
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
timeout=300,
|
||||
)
|
||||
output = proc.stdout.decode()
|
||||
|
||||
print(output)
|
||||
|
||||
assert "TP RANKS: [0]" in output
|
||||
assert "TP RANKS: [1]" in output
|
||||
assert "Generated text:" in output
|
||||
assert "Sleep and wake up successfully!!" in output
|
||||
assert proc.returncode == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
DEVICE_NAME != "Ascend910B",
|
||||
reason="This test is only for Ascend910B devices.",
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE": "1"})
|
||||
@patch.dict(os.environ, {
|
||||
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE": "1",
|
||||
"HCCL_BUFFSIZE": "500"
|
||||
})
|
||||
def test_mm_allreduce(model):
|
||||
script = Path(
|
||||
__file__
|
||||
|
Reference in New Issue
Block a user