[CI/Build][LoRA] Temporarily fix long context failure issue (#9579)

This commit is contained in:
Jee Jee Li
2024-10-22 19:32:51 +08:00
committed by GitHub
parent 6c5af09b39
commit a48e3ec052

View File

@ -28,9 +28,15 @@ sampling_params = SamplingParams(
def _create_lora_request(lora_id, long_context_infos):
context_len = long_context_infos[lora_id]["context_length"]
scaling_factor = context_len_to_scaling_factor[context_len]
return LoRARequest(context_len, lora_id,
long_context_infos[lora_id]["lora"], None,
4096 * scaling_factor)
return LoRARequest(
# There are 2 LoRAs for 16K, we need to add lora_id to indicate
# they are different LoRAs.
context_len + str(lora_id),
lora_id,
long_context_infos[lora_id]["lora"],
None,
4096 * scaling_factor,
)
def evaluate_json_response(model_response, golden_response):
@ -108,14 +114,17 @@ def lora_llm(long_context_infos):
for info in long_context_infos.values()
]
llm = vllm.LLM("meta-llama/Llama-2-13b-chat-hf",
enable_lora=True,
max_num_seqs=16,
max_loras=2,
long_lora_scaling_factors=tuple(scaling_factors),
max_num_batched_tokens=4096 * 8,
tensor_parallel_size=4,
distributed_executor_backend="mp")
llm = vllm.LLM(
"meta-llama/Llama-2-13b-chat-hf",
enable_lora=True,
max_num_seqs=16,
max_loras=2,
long_lora_scaling_factors=tuple(scaling_factors),
max_num_batched_tokens=4096 * 8,
tensor_parallel_size=4,
# FIXME enable async output processor
disable_async_output_proc=True,
distributed_executor_backend="mp")
yield llm
del llm