[V1][P/D]P2pNcclConnector supports flashinfer (#23536)
Signed-off-by: Abatom <abzhonghua@gmail.com> Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
@ -30,27 +30,19 @@ logger = init_logger(__name__)
|
||||
class ReqMeta:
|
||||
# Request Id
|
||||
request_id: str
|
||||
# Request tokens
|
||||
token_ids: torch.Tensor
|
||||
# Slot mappings, should have the same length as token_ids
|
||||
slot_mapping: torch.Tensor
|
||||
# Request block ids
|
||||
block_ids: torch.Tensor
|
||||
# Request num tokens
|
||||
num_tokens: int
|
||||
|
||||
@staticmethod
|
||||
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
|
||||
block_size: int) -> "ReqMeta":
|
||||
valid_num_tokens = len(token_ids)
|
||||
token_ids_tensor = torch.tensor(token_ids)
|
||||
block_ids_tensor = torch.tensor(block_ids)
|
||||
num_blocks = block_ids_tensor.shape[0]
|
||||
block_offsets = torch.arange(0, block_size)
|
||||
slot_mapping = block_offsets.reshape((1, block_size)) + \
|
||||
block_ids_tensor.reshape((num_blocks, 1)) * block_size
|
||||
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
|
||||
|
||||
return ReqMeta(
|
||||
request_id=request_id,
|
||||
token_ids=token_ids_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
block_ids=block_ids_tensor,
|
||||
num_tokens=len(token_ids),
|
||||
)
|
||||
|
||||
|
||||
@ -123,63 +115,58 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
return
|
||||
|
||||
def inject_kv_into_layer(
|
||||
dst_kv_cache_layer: torch.Tensor,
|
||||
src_kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
layer: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
block_ids: torch.Tensor,
|
||||
request_id: str,
|
||||
) -> None:
|
||||
"""Inject the KV cache into the layer.
|
||||
"""
|
||||
Inject KV cache data into a given attention layer tensor.
|
||||
|
||||
This function updates `layer` in-place with values from `kv_cache`,
|
||||
handling different backend layouts:
|
||||
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
|
||||
indexed along the first dimension.
|
||||
- FlashAttention: KV tensors are indexed along the second
|
||||
dimension.
|
||||
|
||||
If the number of provided block IDs does not match the number of KV
|
||||
blocks, only the overlapping portion is updated, and a warning is
|
||||
logged.
|
||||
|
||||
Args:
|
||||
dst_kv_cache_layer (torch.Tensor): the destination KV cache
|
||||
layer. In shape [2, num_pages, page_size, xxx] if not
|
||||
using MLA, [num_pages, page_size, xxx] otherwise.
|
||||
src_kv_cache (torch.Tensor): the source KV cache. In shape
|
||||
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
|
||||
otherwise.
|
||||
slot_mapping (torch.Tensor): the slot mapping. In shape
|
||||
[num_tokens].
|
||||
request_id (str): request id for log
|
||||
layer (torch.Tensor): The attention layer KV tensor to update.
|
||||
kv_cache (torch.Tensor): The KV cache tensor to inject.
|
||||
block_ids (torch.Tensor): Indices of the blocks to update.
|
||||
request_id (str): Request identifier used for logging.
|
||||
|
||||
Returns:
|
||||
None. The function modifies `layer` in-place.
|
||||
"""
|
||||
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
num_pages = dst_kv_cache_layer_shape[0]
|
||||
page_size = dst_kv_cache_layer_shape[1]
|
||||
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
|
||||
num_pages * page_size, -1)
|
||||
self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
|
||||
0)
|
||||
num_token = src_kv_cache.shape[0]
|
||||
if len(slot_mapping) == num_token:
|
||||
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
|
||||
if (isinstance(attn_metadata, MLACommonMetadata)
|
||||
or layer.shape[1] == 2): # MLA or FlashInfer
|
||||
num_block = kv_cache.shape[0]
|
||||
self.check_tensors_except_dim(layer, kv_cache, 0)
|
||||
if len(block_ids) == num_block:
|
||||
layer[block_ids, ...] = kv_cache
|
||||
else:
|
||||
dst_kv_cache_layer[slot_mapping[:num_token],
|
||||
...] = src_kv_cache
|
||||
layer[block_ids[:num_block], ...] = kv_cache
|
||||
logger.warning(
|
||||
"🚧src_kv_cache does not match, num_slot:%d, "
|
||||
"num_token:%d, request_id:%s", len(slot_mapping),
|
||||
num_token, request_id)
|
||||
"🚧kv_cache does not match, block_ids:%d, "
|
||||
"num_block:%d, request_id:%s", len(block_ids),
|
||||
num_block, request_id)
|
||||
|
||||
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
|
||||
else:
|
||||
num_pages = dst_kv_cache_layer_shape[1]
|
||||
page_size = dst_kv_cache_layer_shape[2]
|
||||
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
|
||||
2, num_pages * page_size, -1)
|
||||
self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
|
||||
1)
|
||||
num_token = src_kv_cache.shape[1]
|
||||
if len(slot_mapping) == num_token:
|
||||
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
|
||||
elif layer.shape[0] == 2: # FlashAttention
|
||||
num_block = kv_cache.shape[1]
|
||||
self.check_tensors_except_dim(layer, kv_cache, 1)
|
||||
if len(block_ids) == num_block:
|
||||
layer[:, block_ids, ...] = kv_cache
|
||||
else:
|
||||
dst_kv_cache_layer[:, slot_mapping[:num_token],
|
||||
...] = src_kv_cache
|
||||
layer[:, block_ids[:num_block], ...] = kv_cache
|
||||
logger.warning(
|
||||
"🚧src_kv_cache does not match, num_slot:%d, "
|
||||
"num_token:%d, request_id:%s", len(slot_mapping),
|
||||
num_token, request_id)
|
||||
|
||||
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
|
||||
"🚧kv_cache does not match, block_ids:%d, "
|
||||
"num_block:%d, request_id:%s", len(block_ids),
|
||||
num_block, request_id)
|
||||
|
||||
# Get the metadata
|
||||
metadata: KVConnectorMetadata = \
|
||||
@ -201,19 +188,17 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
if kv_cache is None:
|
||||
continue
|
||||
|
||||
kv_cache_layer = kv_cache[ \
|
||||
forward_context.virtual_engine]
|
||||
layer = kv_cache[forward_context.virtual_engine]
|
||||
|
||||
kv_cache = self.p2p_nccl_engine.recv_tensor(
|
||||
request.request_id + "#" + layer_name)
|
||||
|
||||
if kv_cache is None:
|
||||
logger.warning("🚧src_kv_cache is None, %s",
|
||||
request.request_id)
|
||||
logger.warning("🚧kv_cache is None, %s", request.request_id)
|
||||
continue
|
||||
|
||||
inject_kv_into_layer(kv_cache_layer, kv_cache,
|
||||
request.slot_mapping, request.request_id)
|
||||
inject_kv_into_layer(layer, kv_cache, request.block_ids,
|
||||
request.request_id)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""Blocking until the KV for a specific layer is loaded into vLLM's
|
||||
@ -247,20 +232,33 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
|
||||
def extract_kv_from_layer(
|
||||
layer: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
block_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Extract the KV cache from the layer.
|
||||
|
||||
Assume the shape of the layer is (2, num_pages, page_size, xxx)
|
||||
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
|
||||
"""
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
num_pages, page_size = layer.shape[0], layer.shape[1]
|
||||
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
|
||||
...]
|
||||
num_pages, page_size = layer.shape[1], layer.shape[2]
|
||||
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
|
||||
...]
|
||||
Extract KV cache slices from a given attention layer tensor.
|
||||
|
||||
This function handles multiple backend layouts:
|
||||
- MLA (Multi-Linear Attention) or FlashInfer: KV tensors are
|
||||
indexed along the first dimension.
|
||||
- FlashAttention: KV tensors are indexed along the second
|
||||
dimension.
|
||||
|
||||
Args:
|
||||
layer (torch.Tensor): The KV cache from the attention layer.
|
||||
block_ids (torch.Tensor): Indices of blocks to extract.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor containing the extracted KV slices.
|
||||
Returns None if the layout is unsupported.
|
||||
"""
|
||||
if (isinstance(attn_metadata, MLACommonMetadata)
|
||||
or layer.shape[1] == 2): # MLA or FlashInfer
|
||||
return layer[block_ids, ...]
|
||||
|
||||
if layer.shape[0] == 2: # FlashAttention
|
||||
return layer[:, block_ids, ...]
|
||||
|
||||
return None
|
||||
|
||||
connector_metadata = self._get_connector_metadata()
|
||||
assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
|
||||
@ -269,7 +267,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
ip, port = self.parse_request_id(request_id, True)
|
||||
remote_address = ip + ":" + str(port + self._rank)
|
||||
|
||||
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
|
||||
kv_cache = extract_kv_from_layer(kv_layer, request.block_ids)
|
||||
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
|
||||
kv_cache, remote_address)
|
||||
|
||||
|
Reference in New Issue
Block a user