Construct KVTransferConfig properly from Python instead of using JSON blobs without CLI (#17994)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-12 19:25:33 +01:00
committed by GitHub
parent 98ea35601c
commit 72a3f6b898
5 changed files with 37 additions and 31 deletions

View File

@ -49,9 +49,10 @@ def run_prefill(prefill_done, prompts):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"LMCacheConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}'
)
ktc = KVTransferConfig(kv_connector="LMCacheConnector",
kv_role="kv_producer",
kv_rank=0,
kv_parallel_size=2)
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
@ -78,9 +79,10 @@ def run_decode(prefill_done, prompts, timeout=1):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"LMCacheConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}'
)
ktc = KVTransferConfig(kv_connector="LMCacheConnector",
kv_role="kv_consumer",
kv_rank=1,
kv_parallel_size=2)
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# of memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",

View File

@ -49,8 +49,8 @@ def run_store(store_done, prompts):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}')
ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1",
kv_role="kv_both")
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
@ -76,8 +76,8 @@ def run_retrieve(store_done, prompts, timeout=1):
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}')
ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1",
kv_role="kv_both")
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
# of memory. Reduce the value if your GPU has less memory.
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",

View File

@ -16,16 +16,17 @@ except FileNotFoundError:
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
max_num_batched_tokens=64,
max_num_seqs=16,
kv_transfer_config=KVTransferConfig.from_cli(
'{"kv_connector":"SharedStorageConnector","kv_role":"kv_both",'
'"kv_connector_extra_config": {"shared_storage_path": "local_storage"}}'
)) #, max_model_len=2048, max_num_batched_tokens=2048)
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
max_num_batched_tokens=64,
max_num_seqs=16,
kv_transfer_config=KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage"
})) #, max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance)
outputs = llm.generate(prompts, sampling_params)

View File

@ -17,11 +17,12 @@ sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8,
kv_transfer_config=KVTransferConfig.from_cli(
'{"kv_connector":"SharedStorageConnector","kv_role":"kv_both", '
'"kv_connector_extra_config": '
'{"shared_storage_path": "local_storage"}}')
) #, max_model_len=2048, max_num_batched_tokens=2048)
kv_transfer_config=KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={
"shared_storage_path": "local_storage"
})) #, max_model_len=2048, max_num_batched_tokens=2048)
# 1ST generation (prefill instance)
outputs = llm.generate(

View File

@ -32,9 +32,10 @@ def run_prefill(prefill_done):
# This instance is the prefill node (kv_producer, rank 0).
# The number of parallel instances for KV cache transfer is set to 2,
# as required for PyNcclConnector.
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}'
)
ktc = KVTransferConfig(kv_connector="PyNcclConnector",
kv_role="kv_producer",
kv_rank=0,
kv_parallel_size=2)
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
# memory. You may need to adjust the value to fit your GPU.
@ -71,9 +72,10 @@ def run_decode(prefill_done):
# This instance is the decode node (kv_consumer, rank 1).
# The number of parallel instances for KV cache transfer is set to 2,
# as required for PyNcclConnector.
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}'
)
ktc = KVTransferConfig(kv_connector="PyNcclConnector",
kv_role="kv_consumer",
kv_rank=1,
kv_parallel_size=2)
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
# memory. You may need to adjust the value to fit your GPU.