mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-11-04 01:14:35 +08:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
			v0.11.1rc5
			...
			memory-lea
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 2686925630 | 
@ -418,35 +418,39 @@ class MultiHeadAttention(nn.Module):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def wait_for_kv_layer_from_connector(layer_name: str):
 | 
			
		||||
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
 | 
			
		||||
        return
 | 
			
		||||
    print("hi --- wait_for_kv_layer_from_connector")
 | 
			
		||||
    pass
 | 
			
		||||
    # if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
 | 
			
		||||
    #     return
 | 
			
		||||
 | 
			
		||||
    connector = get_kv_transfer_group()
 | 
			
		||||
    # connector = get_kv_transfer_group()
 | 
			
		||||
 | 
			
		||||
    forward_context: ForwardContext = get_forward_context()
 | 
			
		||||
    attn_metadata = forward_context.attn_metadata
 | 
			
		||||
    if attn_metadata is None:
 | 
			
		||||
        return
 | 
			
		||||
    assert isinstance(attn_metadata, dict)
 | 
			
		||||
    connector.wait_for_layer_load(layer_name)
 | 
			
		||||
    # forward_context: ForwardContext = get_forward_context()
 | 
			
		||||
    # attn_metadata = forward_context.attn_metadata
 | 
			
		||||
    # if attn_metadata is None:
 | 
			
		||||
    #     return
 | 
			
		||||
    # assert isinstance(attn_metadata, dict)
 | 
			
		||||
    # connector.wait_for_layer_load(layer_name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def maybe_save_kv_layer_to_connector(
 | 
			
		||||
    layer_name: str,
 | 
			
		||||
    kv_cache_layer: List[torch.Tensor],
 | 
			
		||||
):
 | 
			
		||||
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
 | 
			
		||||
        return
 | 
			
		||||
    print("hi --- maybe_save_kv_layer_to_connector")
 | 
			
		||||
    pass
 | 
			
		||||
    # if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
 | 
			
		||||
    #     return
 | 
			
		||||
 | 
			
		||||
    connector = get_kv_transfer_group()
 | 
			
		||||
    # connector = get_kv_transfer_group()
 | 
			
		||||
 | 
			
		||||
    forward_context: ForwardContext = get_forward_context()
 | 
			
		||||
    attn_metadata = forward_context.attn_metadata
 | 
			
		||||
    if attn_metadata is None:
 | 
			
		||||
        return
 | 
			
		||||
    assert isinstance(attn_metadata, dict)
 | 
			
		||||
    connector.save_kv_layer(layer_name, kv_cache_layer,
 | 
			
		||||
                            attn_metadata[layer_name])
 | 
			
		||||
    # forward_context: ForwardContext = get_forward_context()
 | 
			
		||||
    # attn_metadata = forward_context.attn_metadata
 | 
			
		||||
    # if attn_metadata is None:
 | 
			
		||||
    #     return
 | 
			
		||||
    # assert isinstance(attn_metadata, dict)
 | 
			
		||||
    # connector.save_kv_layer(layer_name, kv_cache_layer,
 | 
			
		||||
    #                         attn_metadata[layer_name])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def unified_attention(
 | 
			
		||||
@ -497,7 +501,7 @@ def unified_attention_with_output(
 | 
			
		||||
    output_scale: Optional[torch.Tensor] = None,
 | 
			
		||||
    output_block_scale: Optional[torch.Tensor] = None,
 | 
			
		||||
) -> None:
 | 
			
		||||
    wait_for_kv_layer_from_connector(layer_name)
 | 
			
		||||
    # wait_for_kv_layer_from_connector(layer_name)
 | 
			
		||||
    forward_context: ForwardContext = get_forward_context()
 | 
			
		||||
    attn_metadata = forward_context.attn_metadata
 | 
			
		||||
    if isinstance(attn_metadata, dict):
 | 
			
		||||
@ -514,7 +518,7 @@ def unified_attention_with_output(
 | 
			
		||||
                      output_scale=output_scale,
 | 
			
		||||
                      output_block_scale=output_block_scale)
 | 
			
		||||
 | 
			
		||||
    maybe_save_kv_layer_to_connector(layer_name, kv_cache)
 | 
			
		||||
    # maybe_save_kv_layer_to_connector(layer_name, kv_cache)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def unified_attention_with_output_fake(
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user