mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Neuron] Remove bypass on EAGLEConfig and add a test (#18514)
Signed-off-by: Elaine Zhao <elaineyz@amazon.com>
This commit is contained in:
@ -53,4 +53,11 @@ docker run --rm -it --device=/dev/neuron0 --network bridge \
|
||||
-e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \
|
||||
--name "${container_name}" \
|
||||
${image_name} \
|
||||
/bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py && python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys && python3 -m pytest /workspace/vllm/tests/neuron/2_core/ -v --capture=tee-sys"
|
||||
/bin/bash -c "
|
||||
python3 /workspace/vllm/examples/offline_inference/neuron.py;
|
||||
python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys;
|
||||
for f in /workspace/vllm/tests/neuron/2_core/*.py; do
|
||||
echo 'Running test file: '$f;
|
||||
python3 -m pytest \$f -v --capture=tee-sys;
|
||||
done
|
||||
"
|
82
tests/neuron/2_core/test_eagle.py
Normal file
82
tests/neuron/2_core/test_eagle.py
Normal file
@ -0,0 +1,82 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors import safe_open
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def patch_eagle_draft_with_lm_head(target_model_id: str,
|
||||
draft_model_id: str) -> str:
|
||||
# In NxDI, draft model checkpoint must include lm_head weights from target
|
||||
# model. For more details see https://awsdocs-neuron.readthedocs-hosted.com
|
||||
# /en/latest/libraries/nxd-inference/developer_guides/feature-guide.html
|
||||
# #eagle-checkpoint-compatibility
|
||||
final_draft_dir = "/tmp/patched_eagle_draft"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
target_dir = snapshot_download(repo_id=target_model_id,
|
||||
local_dir=os.path.join(
|
||||
tmp_dir, "target"))
|
||||
draft_dir = snapshot_download(repo_id=draft_model_id,
|
||||
local_dir=os.path.join(tmp_dir, "draft"))
|
||||
|
||||
lm_head_key = "lm_head.weight"
|
||||
index_path = os.path.join(target_dir, "model.safetensors.index.json")
|
||||
with open(index_path) as f:
|
||||
index = json.load(f)
|
||||
shard_name = index["weight_map"][lm_head_key]
|
||||
target_safetensor_path = os.path.join(target_dir, shard_name)
|
||||
|
||||
with safe_open(target_safetensor_path, framework="pt") as f:
|
||||
target_lm_head = f.get_tensor(lm_head_key)
|
||||
|
||||
draft_path = os.path.join(draft_dir, "pytorch_model.bin")
|
||||
draft_state_dict = torch.load(draft_path, map_location="cpu")
|
||||
draft_state_dict[lm_head_key] = target_lm_head.to(torch.float16)
|
||||
torch.save(draft_state_dict, draft_path)
|
||||
|
||||
shutil.copytree(draft_dir, final_draft_dir, dirs_exist_ok=True)
|
||||
|
||||
return final_draft_dir
|
||||
|
||||
|
||||
def test_eagle():
|
||||
patched_draft_path = patch_eagle_draft_with_lm_head(
|
||||
target_model_id="meta-llama/Llama-2-7b-hf",
|
||||
draft_model_id="yuhuili/EAGLE-llama2-chat-7B")
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-2-7b-hf",
|
||||
speculative_config={
|
||||
"model": patched_draft_path,
|
||||
"num_speculative_tokens": 5,
|
||||
"max_model_len": 128
|
||||
},
|
||||
max_num_seqs=1,
|
||||
max_model_len=128,
|
||||
tensor_parallel_size=2,
|
||||
override_neuron_config={
|
||||
"enable_eagle_speculation": True,
|
||||
"enable_fused_speculation": True,
|
||||
"fused_qkv": True
|
||||
},
|
||||
)
|
||||
prompts = [
|
||||
"The president of the United States is",
|
||||
]
|
||||
outputs = llm.generate(prompts, SamplingParams(top_k=1))
|
||||
expected_output = " the head of state and head of government of " \
|
||||
"the United States. The president direct"
|
||||
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}")
|
||||
assert (expected_output == generated_text)
|
||||
|
||||
print("Neuron Eagle speculation test passed.")
|
@ -12,8 +12,7 @@ def test_mistral():
|
||||
override_neuron_config={
|
||||
"sequence_parallel_enabled": False,
|
||||
"skip_warmup": True
|
||||
},
|
||||
device="neuron")
|
||||
})
|
||||
|
||||
# Send more prompts than the compiled batch size (4) and request
|
||||
# varying generation lengths to test accuracy related to Neuron
|
||||
@ -59,4 +58,7 @@ def test_mistral():
|
||||
|
||||
for expected_output, output in zip(expected_outputs, outputs):
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}")
|
||||
assert (expected_output == generated_text)
|
||||
|
||||
print("Neuron Mistral test passed.")
|
||||
|
@ -2529,11 +2529,10 @@ class SpeculativeConfig:
|
||||
"Chunked prefill and EAGLE are not compatible "
|
||||
"when using V0.")
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.configs.eagle import (
|
||||
EAGLEConfig)
|
||||
if isinstance(self.draft_model_config.hf_config,
|
||||
EAGLEConfig) or current_platform.is_neuron():
|
||||
EAGLEConfig):
|
||||
pass
|
||||
else:
|
||||
eagle_config = EAGLEConfig(
|
||||
|
Reference in New Issue
Block a user