Compare commits

...

10 Commits

Author SHA1 Message Date
cbaf5a7aa6 set seed 2025-07-04 14:24:19 +02:00
247e69deaa revert unrelated 2025-07-04 14:23:48 +02:00
cbc6e94a6a merge 2025-07-04 14:22:16 +02:00
c1412e22f8 draft sharded bias 2025-06-02 17:30:38 +02:00
35ecff348b Merge branches 'tp-cb' and 'tp-cb' of github.com:huggingface/transformers into tp-cb 2025-06-02 17:08:08 +02:00
5f95f4568c fix? 2025-06-02 17:07:01 +02:00
9b84a34922 Apply style fixes 2025-06-02 15:06:10 +00:00
f5d4c6d3c6 Merge branch 'main' into tp-cb 2025-06-02 16:58:23 +02:00
c786a83897 some changes needed 2025-05-30 15:11:16 +00:00
afe78bdb89 lazy cache init 2025-05-30 15:52:11 +02:00

View File

@ -38,6 +38,9 @@ from ..generation.configuration_utils import GenerationConfig
from ..utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
torch.manual_seed(0)
class RequestStatus(Enum):
"""Status of a generation request through its lifecycle."""
@ -193,21 +196,11 @@ class PagedAttentionCache(Cache):
self.num_blocks = num_blocks
self.cache_shape = (self.num_key_value_heads, num_blocks, self.block_size, self.head_dim)
self.dtype = dtype
self._dtype = dtype
self.device = device
self.key_cache: list[torch.Tensor] = []
self.value_cache: list[torch.Tensor] = []
for idx in range(config.num_hidden_layers):
layer_device = layer_device_map[idx] if layer_device_map is not None else device
new_layer_key_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device)
new_layer_value_cache = torch.zeros(self.cache_shape, dtype=self.dtype, device=layer_device)
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
# preventing compiled graph breaks when updating the cache.
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
# Block management data structures
self._free_blocks = deque(range(num_blocks))
@ -285,6 +278,25 @@ class PagedAttentionCache(Cache):
return physical_indices
@torch.compiler.disable
def initialise_cache_layer(self, layer_idx, key_states):
if len(self.key_cache) > layer_idx:
return
self.num_key_value_heads = key_states.shape[1]
device = key_states.device
cache_shape = (
self.num_key_value_heads,
self.num_blocks,
self.block_size,
self.head_dim,
)
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
@traced
def update(
self,
@ -296,6 +308,7 @@ class PagedAttentionCache(Cache):
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
# Reshape cache for easier indexing
self.initialise_cache_layer(layer_idx, key_states)
total_slots = self.num_blocks * self.block_size
k_cache_flat = self.key_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim)
v_cache_flat = self.value_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim)