mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
114 lines
3.9 KiB
Python
114 lines
3.9 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Copyright 2025 ModelBest Inc. and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
usage: torchrun --standalone --nnodes=1 \
|
|
--nproc_per_node=2 $(which pytest) \
|
|
-s test_sglang_async_spmd.py
|
|
"""
|
|
|
|
import asyncio
|
|
|
|
import torch
|
|
from sglang.srt.entrypoints.engine import Engine
|
|
from sglang.srt.utils import broadcast_pyobj
|
|
from torch.distributed.device_mesh import init_device_mesh
|
|
from utils_sglang import (
|
|
are_lists_similar,
|
|
clean_torchelastic_env,
|
|
generate_hf_output,
|
|
initialize_global_process_group,
|
|
load_tokenizer_and_model,
|
|
prepare_inputs,
|
|
)
|
|
|
|
|
|
def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor):
|
|
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
|
|
token_ids = prompt_token_ids[non_pad_index:].tolist()
|
|
return token_ids
|
|
|
|
|
|
def test_sglang_spmd():
|
|
assert torch.cuda.device_count() >= 2
|
|
initialize_global_process_group(spmd=True)
|
|
clean_torchelastic_env()
|
|
|
|
max_prompt_length = 16
|
|
max_response_length = 16
|
|
|
|
local_model_path = "Qwen/Qwen2.5-0.5B"
|
|
tokenizer, actor_model = load_tokenizer_and_model(local_model_path)
|
|
|
|
preencode_prompts = ["Who won the Champions League in 2019?", "The founder of Apple is", "What's your name?"]
|
|
input_ids, attention_mask, _ = prepare_inputs(tokenizer, preencode_prompts, max_prompt_length)
|
|
|
|
hf_response_tokens = generate_hf_output(actor_model, input_ids, attention_mask, tokenizer, max_response_length)
|
|
|
|
tensor_parallel_size = 2
|
|
inference_device_mesh_cpu = init_device_mesh("cpu", mesh_shape=(1, tensor_parallel_size, 1), mesh_dim_names=["dp", "tp", "pp"])
|
|
tp_rank = inference_device_mesh_cpu["tp"].get_local_rank()
|
|
|
|
if tp_rank == 0:
|
|
llm = Engine(
|
|
model_path=local_model_path,
|
|
dtype="bfloat16",
|
|
mem_fraction_static=0.5,
|
|
enable_memory_saver=True,
|
|
tp_size=inference_device_mesh_cpu["tp"].size(),
|
|
)
|
|
|
|
input_ids = input_ids.cuda()
|
|
idx_list = []
|
|
|
|
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
|
|
for i in range(input_ids.shape[0]):
|
|
idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i]))
|
|
|
|
sampling_params = dict(
|
|
n=1,
|
|
temperature=0,
|
|
top_p=1,
|
|
top_k=-1,
|
|
max_new_tokens=max_response_length,
|
|
presence_penalty=0.0,
|
|
frequency_penalty=0.0,
|
|
repetition_penalty=1.0,
|
|
skip_special_tokens=True,
|
|
spaces_between_special_tokens=True,
|
|
ignore_eos=False,
|
|
)
|
|
|
|
loop = asyncio.get_event_loop()
|
|
outputs = loop.run_until_complete(llm.async_generate(input_ids=idx_list, sampling_params=sampling_params))
|
|
else:
|
|
outputs = None
|
|
|
|
[outputs] = broadcast_pyobj(
|
|
[outputs],
|
|
rank=inference_device_mesh_cpu["tp"].get_local_rank(),
|
|
src=inference_device_mesh_cpu["tp"].mesh[0].item(),
|
|
dist_group=inference_device_mesh_cpu["tp"].get_group(),
|
|
force_cpu_device=False,
|
|
)
|
|
|
|
sglang_response_tokens = [output["text"] for output in outputs]
|
|
|
|
print(f"sglang response: {sglang_response_tokens}")
|
|
assert are_lists_similar(hf_response_tokens, sglang_response_tokens)
|
|
print("SPMD Test Passed!")
|
|
|
|
torch.distributed.barrier()
|
|
torch.distributed.destroy_process_group()
|