mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
Add e2e test related to weight updates in RL scenarios. (#2954)
### What this PR does / why we need it?
Add e2e test related to weight updates in RL scenarios.
Due to CI issues, the newly added Python test files cannot locate the
correct path. As a temporary solution, use absolute paths to add test
cases.
- vLLM version: v0.10.2
- vLLM main:
52d0cb8458
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: Shangwei-Li <lishangwei2@huawei.com>
This commit is contained in:
326
examples/offline_weight_load.py
Normal file
326
examples/offline_weight_load.py
Normal file
@ -0,0 +1,326 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/examples/offline_inference/data_parallel.py
|
||||
|
||||
# Note: This script is designed to run with e2e test,
|
||||
# please be careful to modify it.
|
||||
"""
|
||||
Usage:
|
||||
Single node:
|
||||
Dense models:
|
||||
python examples/offline_weight_load.py \
|
||||
--model="Qwen/Qwen2.5-0.5B-Instruct" \
|
||||
--tp-size=1 \
|
||||
--proc-per-node=2
|
||||
MOE models:
|
||||
python examples/offline_weight_load.py \
|
||||
--model="Qwen/Qwen3-30B-A3B" \
|
||||
--tp-size=2 \
|
||||
--proc-per-node=2 \
|
||||
--enable-expert-parallel
|
||||
|
||||
Multi-node:
|
||||
Node 0 (assume the node has ip of 10.99.48.128):
|
||||
python examples/offline_weight_load.py \
|
||||
--model="Qwen/Qwen3-30B-A3B" \
|
||||
--tp-size=2 \
|
||||
--node-size=2 \
|
||||
--node-rank=0 \
|
||||
--proc-per-node=2 \
|
||||
--enable-expert-parallel \
|
||||
--master-addr=10.99.48.128 \
|
||||
--master-port=13345
|
||||
Node 1:
|
||||
python examples/offline_weight_load.py \
|
||||
--model="Qwen/Qwen3-30B-A3B" \
|
||||
--tp-size=2 \
|
||||
--node-size=2 \
|
||||
--node-rank=1 \
|
||||
--enable-expert-parallel \
|
||||
--master-addr=10.99.48.128 \
|
||||
--master-port=13345
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import gc
|
||||
import os
|
||||
from multiprocessing import Process
|
||||
from time import sleep
|
||||
|
||||
import torch
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed.parallel_state import ( # noqa E402
|
||||
destroy_distributed_environment, destroy_model_parallel, get_tp_group)
|
||||
from vllm.utils import get_open_port, GiB_bytes
|
||||
from safetensors.torch import load_file
|
||||
|
||||
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
def patch_vllm_moe_model_weight_loader(model):
|
||||
# Define MLP attribute mapping for different model types
|
||||
|
||||
model = getattr(model, "model", None) or getattr(model, "language_model", None)
|
||||
if model is None:
|
||||
raise ValueError("The provided model does not have a valid 'model' or 'language_model' attribute.")
|
||||
|
||||
for layer in model.layers:
|
||||
mlp_attr = "mlp"
|
||||
mlp = getattr(layer, mlp_attr)
|
||||
|
||||
param_dict = dict(mlp.named_parameters())
|
||||
for name, param in param_dict.items():
|
||||
if "w13_weight" in name or "w2_weight" in name:
|
||||
param.weight_loader = mlp.experts.weight_loader
|
||||
|
||||
def load_and_merge_safetensors(directory):
|
||||
merged_dict = {}
|
||||
|
||||
if not os.path.isdir(directory):
|
||||
raise ValueError(f"directory is not exist : {directory}")
|
||||
|
||||
for filename in os.listdir(directory):
|
||||
if filename.endswith('.safetensors'):
|
||||
file_path = os.path.join(directory, filename)
|
||||
print(f"loading file: {file_path}")
|
||||
|
||||
f = load_file(file_path)
|
||||
merged_dict.update(f)
|
||||
|
||||
return merged_dict
|
||||
|
||||
def parse_args():
|
||||
|
||||
parser = argparse.ArgumentParser(description="External launcher Inference")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="Qwen/Qwen3-0.6B",
|
||||
help="Model name or path",
|
||||
)
|
||||
parser.add_argument("--tp-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallel size")
|
||||
parser.add_argument("--node-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Total number of nodes")
|
||||
parser.add_argument("--node-rank",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Rank of the current node")
|
||||
parser.add_argument("--proc-per-node",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of processes per node")
|
||||
parser.add_argument("--master-addr",
|
||||
type=str,
|
||||
default="",
|
||||
help="Master node IP address")
|
||||
parser.add_argument("--master-port",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Master node port")
|
||||
parser.add_argument("--enforce-eager",
|
||||
action="store_true",
|
||||
help="Enforce eager mode execution.")
|
||||
parser.add_argument("--trust-remote-code",
|
||||
action="store_true",
|
||||
help="Trust remote code.")
|
||||
parser.add_argument("--enable-expert-parallel",
|
||||
action="store_true",
|
||||
help="Enable expert parallel, used in MOE models.")
|
||||
parser.add_argument("--enable-sleep-mode",
|
||||
action="store_true",
|
||||
help="Enable sleep mode for the engine.")
|
||||
parser.add_argument("--temperature",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="Float that controls the randomness of the sampling.")
|
||||
parser.add_argument("--model-weight-gib",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Model weight memory usage in GiB (e.g., 1.0 for 0.5B model).")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.enable_sleep_mode:
|
||||
if args.model_weight_gib is None or args.temperature != 0:
|
||||
parser.error("model-weight-gib must be provided, and temperature must be zero when enable-sleep-mode is set.")
|
||||
if args.model_weight_gib <= 0:
|
||||
parser.error("model-weight-gib must be greater than 0 when enable-sleep-mode is set.")
|
||||
if args.model == parser.get_default("model") and args.model_weight_gib is None:
|
||||
parser.error("model-weight-gib must be provided for default model when enable-sleep-mode is set.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main(
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
master_addr: str,
|
||||
master_port: int,
|
||||
model_weight_gib: float,
|
||||
model: str = "Qwen/Qwen3-30B-A3B",
|
||||
world_size: int = 4,
|
||||
tensor_parallel_size: int = 2,
|
||||
enable_expert_parallel: bool = False,
|
||||
enforce_eager: bool = True,
|
||||
trust_remote_code: bool = True,
|
||||
enable_sleep_mode: bool = False,
|
||||
temperature: float = 0.8,
|
||||
):
|
||||
os.environ["MASTER_ADDR"] = master_addr
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["LOCAL_RANK"] = str(local_rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(
|
||||
backend="cpu:gloo,npu:hccl",
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
] * 10
|
||||
sampling_params = SamplingParams(
|
||||
temperature=temperature,
|
||||
top_p=0.95,
|
||||
max_tokens=10,
|
||||
)
|
||||
llm = LLM(
|
||||
model=model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
enable_expert_parallel=enable_expert_parallel,
|
||||
enforce_eager=enforce_eager,
|
||||
trust_remote_code=trust_remote_code,
|
||||
distributed_executor_backend="external_launcher",
|
||||
seed=0,
|
||||
gpu_memory_utilization = 0.95,
|
||||
enable_sleep_mode=enable_sleep_mode,
|
||||
)
|
||||
model_path = model
|
||||
runmodel = llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
|
||||
patch_vllm_moe_model_weight_loader(runmodel)
|
||||
sd = load_and_merge_safetensors(model_path)
|
||||
runmodel.load_weights(sd.items())
|
||||
print('load state dict done')
|
||||
tp_ranks = get_tp_group().ranks
|
||||
print(f'TP RANKS: {tp_ranks}')
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
if enable_sleep_mode:
|
||||
if rank == 0:
|
||||
free_bytes_before_sleep, total = torch.npu.mem_get_info()
|
||||
llm.sleep(level=1)
|
||||
if rank == 0:
|
||||
free_bytes_after_sleep, total = torch.npu.mem_get_info()
|
||||
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
||||
print(f"Freed memory: {freed_bytes / 1024 ** 3:.2f} GiB")
|
||||
# now the freed memory should be larger than the model weights
|
||||
assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes
|
||||
|
||||
llm.wake_up()
|
||||
outputs_after_wakeup = llm.generate(prompts, sampling_params)
|
||||
if rank == 0:
|
||||
# cmp output
|
||||
assert outputs[0].outputs[0].text == outputs_after_wakeup[0].outputs[0].text
|
||||
print("Sleep and wake up successfully!!")
|
||||
|
||||
for i, output in enumerate(outputs):
|
||||
if i >= 5:
|
||||
# print only 5 outputs
|
||||
break
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Global rank: {rank}, Prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}")
|
||||
|
||||
# Give engines time to pause their processing loops before exiting.
|
||||
sleep(5)
|
||||
del llm
|
||||
cleanup_env_and_memory()
|
||||
|
||||
|
||||
def cleanup_env_and_memory():
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
with contextlib.suppress(AssertionError):
|
||||
torch.distributed.destroy_process_group()
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
tp_size = args.tp_size
|
||||
node_size = args.node_size
|
||||
proc_per_node = args.proc_per_node
|
||||
node_rank = args.node_rank
|
||||
|
||||
if node_size == 1:
|
||||
master_addr = "127.0.0.1"
|
||||
master_port = get_open_port()
|
||||
else:
|
||||
master_addr = args.master_addr
|
||||
master_port = args.master_port
|
||||
|
||||
world_size = node_size * proc_per_node
|
||||
|
||||
procs = []
|
||||
for local_rank, rank in enumerate(
|
||||
range(proc_per_node * node_rank, proc_per_node * (node_rank + 1))):
|
||||
proc = Process(target=main,
|
||||
args=(
|
||||
local_rank,
|
||||
rank,
|
||||
master_addr,
|
||||
master_port,
|
||||
args.model_weight_gib,
|
||||
args.model,
|
||||
world_size,
|
||||
tp_size,
|
||||
args.enable_expert_parallel,
|
||||
args.enforce_eager,
|
||||
args.trust_remote_code,
|
||||
args.enable_sleep_mode,
|
||||
args.temperature,
|
||||
))
|
||||
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
exit_code = 0
|
||||
for proc in procs:
|
||||
proc.join(timeout=600)
|
||||
if proc.exitcode is None:
|
||||
print(
|
||||
f"Killing process {proc.pid} that didn't stop within 30 minutes."
|
||||
)
|
||||
proc.kill()
|
||||
exit_code = 1
|
||||
elif proc.exitcode:
|
||||
exit_code = proc.exitcode
|
||||
|
||||
exit(exit_code)
|
188
tests/e2e/multicard/test_weight_loader.py
Normal file
188
tests/e2e/multicard/test_weight_loader.py
Normal file
@ -0,0 +1,188 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
Compare the outputs of vLLM with and without aclgraph.
|
||||
|
||||
Run `pytest tests/multicard/test_external_launcher.py`.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch_npu
|
||||
|
||||
MOE_MODELS = ["Qwen/Qwen3-30B-A3B"]
|
||||
MODELS = ["Qwen/Qwen3-8B"]
|
||||
DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MOE_MODELS)
|
||||
def test_external_launcher_eager(model):
|
||||
script = script = "/usr/local/python3.11.13/bin/python3.11/__w/vllm-ascend/tests/examples/test_weight_loader.py"
|
||||
env = os.environ.copy()
|
||||
# TODO: Change to 2 when ci machine has 4 cards
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(script),
|
||||
"--model",
|
||||
model,
|
||||
"--tp-size",
|
||||
"2",
|
||||
"--proc-per-node",
|
||||
"2",
|
||||
"--trust-remote-code",
|
||||
"--enforce-eager",
|
||||
"--enable-expert-parallel",
|
||||
"--enable-sleep-mode",
|
||||
"--model-weight-gib",
|
||||
"20",
|
||||
]
|
||||
|
||||
print(f"Running subprocess: {' '.join(cmd)}")
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
timeout=600,
|
||||
)
|
||||
output = proc.stdout.decode()
|
||||
|
||||
print(output)
|
||||
|
||||
assert "TP RANKS: [0]" in output
|
||||
assert "TP RANKS: [1]" in output
|
||||
assert "Generated text:" in output
|
||||
assert proc.returncode == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MOE_MODELS)
|
||||
def test_external_launcher_aclgraph(model):
|
||||
script = "/usr/local/python3.11.13/bin/python3.11/__w/vllm-ascend/tests/examples/test_weight_loader.py"
|
||||
env = os.environ.copy()
|
||||
# TODO: Change to 2 when ci machine has 4 cards
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(script),
|
||||
"--model",
|
||||
model,
|
||||
"--tp-size",
|
||||
"2",
|
||||
"--proc-per-node",
|
||||
"2",
|
||||
"--trust-remote-code",
|
||||
"--enable-expert-parallel",
|
||||
"--enable-sleep-mode",
|
||||
"--model-weight-gib",
|
||||
"20",
|
||||
]
|
||||
|
||||
print(f"Running subprocess: {' '.join(cmd)}")
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
timeout=600,
|
||||
)
|
||||
output = proc.stdout.decode()
|
||||
|
||||
print(output)
|
||||
|
||||
assert "TP RANKS: [0]" in output
|
||||
assert "TP RANKS: [1]" in output
|
||||
assert "Generated text:" in output
|
||||
assert proc.returncode == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
def test_external_launcher_dense(model):
|
||||
script = "/usr/local/python3.11.13/bin/python3.11/__w/vllm-ascend/tests/examples/test_weight_loader.py"
|
||||
env = os.environ.copy()
|
||||
# TODO: Change to 2 when ci machine has 4 cards
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(script),
|
||||
"--model",
|
||||
model,
|
||||
"--tp-size",
|
||||
"2",
|
||||
"--proc-per-node",
|
||||
"2",
|
||||
"--trust-remote-code",
|
||||
"--enable-sleep-mode",
|
||||
"--model-weight-gib",
|
||||
"20",
|
||||
]
|
||||
|
||||
print(f"Running subprocess: {' '.join(cmd)}")
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
timeout=600,
|
||||
)
|
||||
output = proc.stdout.decode()
|
||||
|
||||
print(output)
|
||||
|
||||
assert "TP RANKS: [0]" in output
|
||||
assert "TP RANKS: [1]" in output
|
||||
assert "Generated text:" in output
|
||||
assert proc.returncode == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
def test_external_launcher_dense_eager(model):
|
||||
script = "/usr/local/python3.11.13/bin/python3.11/__w/vllm-ascend/tests/examples/test_weight_loader.py"
|
||||
env = os.environ.copy()
|
||||
# TODO: Change to 2 when ci machine has 4 cards
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(script),
|
||||
"--model",
|
||||
model,
|
||||
"--tp-size",
|
||||
"2",
|
||||
"--proc-per-node",
|
||||
"2",
|
||||
"--trust-remote-code",
|
||||
"--enforce-eager",
|
||||
"--enable-sleep-mode",
|
||||
"--model-weight-gib",
|
||||
"20",
|
||||
]
|
||||
|
||||
print(f"Running subprocess: {' '.join(cmd)}")
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
timeout=600,
|
||||
)
|
||||
output = proc.stdout.decode()
|
||||
|
||||
print(output)
|
||||
|
||||
assert "TP RANKS: [0]" in output
|
||||
assert "TP RANKS: [1]" in output
|
||||
assert "Generated text:" in output
|
||||
assert proc.returncode == 0
|
Reference in New Issue
Block a user