Compare commits

...

35 Commits

Author SHA1 Message Date
fbd288fb04 update cache shape 2025-07-29 16:18:11 +00:00
eb714c4bc1 update 2025-07-29 13:44:18 +00:00
028e6d6e76 chaing v2 default 2025-07-29 13:44:17 +00:00
b6322afd7d add flash sdpa 2025-07-29 13:44:14 +00:00
04828074eb nits 2025-07-29 13:43:45 +00:00
3dc01c46bd comment latest sync for now 2025-07-29 13:43:45 +00:00
d0be498d03 adding v2 2025-07-29 13:43:41 +00:00
5e9b251ab7 mps support 2025-07-29 13:42:08 +00:00
255bc3c951 seems to work 2025-07-29 13:42:05 +00:00
1fd77c3d54 updates 2025-07-29 13:41:41 +00:00
5189c26e33 updates 2025-07-29 13:41:41 +00:00
2721261e1b more fixes 2025-07-29 13:41:40 +00:00
571af3e235 kernel either works or not used 2025-07-29 13:41:37 +00:00
51e115ac74 figuring out why sdpa without kernel is broken 2025-07-29 13:41:01 +00:00
962f7c1625 remove mps device 2025-07-29 13:40:27 +00:00
fa3e0db02d update flash_paged 2025-07-29 13:40:26 +00:00
3cbe825443 I need to test on cuda, this shouldd work 2025-07-29 13:40:26 +00:00
6fbe4177fa the shape expected by the kernel is not the sliced cache I reckon 2025-07-29 13:40:25 +00:00
9f47da148a make sure we can compare without reshape kernel 2025-07-29 13:40:25 +00:00
da935a6827 more cleanup 2025-07-29 13:40:24 +00:00
f73db2b6d0 bunch of cleanup and let's not sample, easier to teset 2025-07-29 13:40:21 +00:00
c5765160a2 smaller for faster 2025-07-29 13:39:54 +00:00
6eec32cf80 works with this 😉 2025-07-29 13:39:54 +00:00
cbe3e6324b where I am at 2025-07-29 13:39:54 +00:00
c2cf536c3d comment out stuff for now! 2025-07-29 13:39:53 +00:00
4d86a5bb11 losing my mind, test on cluster 2025-07-29 13:39:50 +00:00
ccf8a46c73 change shapes 2025-07-29 13:39:01 +00:00
f205a9c9d6 second test 2025-07-29 13:39:01 +00:00
458b35224f fisrt integration 2025-07-29 13:39:00 +00:00
012d8f474d Update src/transformers/generation/continuous_batching.py 2025-07-29 13:39:00 +00:00
7381bfabb5 rename 2025-07-29 13:38:59 +00:00
449dfd135c oups 2025-07-29 13:38:59 +00:00
1652a2f9f0 mps does not have kernels for that 2025-07-29 13:38:58 +00:00
0bfa6cc60a fix 2025-07-29 13:38:58 +00:00
4c5fbb9b92 add pin memory and block table 2025-07-29 13:38:53 +00:00
4 changed files with 440 additions and 56 deletions

View File

@ -9,7 +9,7 @@ from transformers.generation import GenerationConfig
torch.set_float32_matmul_precision("high")
model_id = "meta-llama/Llama-3.2-3b-Instruct"
model_id = "meta-llama/Llama-3.2-1b-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto"
).eval()
@ -20,14 +20,22 @@ generation_config = GenerationConfig(
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
num_blocks=2048,
block_size=128,
num_blocks=128,
block_size=32,
do_sample=True,
max_batch_tokens=1024, # Maximum number of tokens to process in a single batch
max_batch_tokens=64, # Maximum number of tokens to process in a single batch
scheduler="prefill_first",
)
train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
train_dataset = train_dataset.select(range(5))
# Create a dataset from a list with one element
# from datasets import Dataset
# single_element_list = [{"question": "give me a very long story"}]
# train_dataset = Dataset.from_list(single_element_list)
# --- Example 1: Simple Version using generate_batch ---
print("--- Running CB Generation Example ---")
@ -48,6 +56,7 @@ batch_outputs = model.generate_batch(
)
end_time_simple = time.time()
generated_tokens: int = 0
for request in batch_outputs:
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False)
try:
@ -58,6 +67,7 @@ for request in batch_outputs:
if len(output_text) > 0:
print("-" * 20)
print(f"{request} Input: {input_text}")
generated_tokens += len(batch_outputs[request].generated_tokens)
print(f"{request} Output: {output_text}")
else:
print("", end="\r\r\r\r")
@ -65,8 +75,9 @@ print("-" * 20)
print("--- Finished CB Generation Example ---\n\n")
print(f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds")
print(
f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds for {generated_tokens} generated tokens. So {generated_tokens / (end_time_simple - start_time_simple)} tok/s\n"
)
# train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version

View File

@ -205,7 +205,8 @@ class PagedAttentionCache:
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
num_key_value_heads //= tp_size
self.cache_shape = (num_key_value_heads, num_blocks, self.block_size, self.head_dim)
# self.cache_shape = (num_key_value_heads, num_blocks, self.block_size, self.head_dim)
self.cache_shape = (num_blocks, self.block_size, self.num_key_value_heads, self.head_dim)
self.dtype = dtype
self.device = device
@ -307,15 +308,47 @@ class PagedAttentionCache:
layer_idx: int,
read_index,
write_index,
reshaping_function,
kernel=True,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
# Reshape cache for easier indexing
total_slots = self.num_blocks * self.block_size
k_cache_flat = self.key_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim)
v_cache_flat = self.value_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim)
k_cache_flat[:, write_index, :] = key_states[0]
v_cache_flat[:, write_index, :] = value_states[0]
return k_cache_flat[None, :, read_index, :], v_cache_flat[None, :, read_index, :]
batch_size, num_heads, seq_len, head_size = key_states.shape
key = key_states.transpose(1, 2).view(batch_size * seq_len, num_heads, head_size)
value = value_states.transpose(1, 2).view(batch_size * seq_len, num_heads, head_size)
if kernel:
# Pre-create scale tensors to avoid CUDA graph capture issues
if not hasattr(self, "_k_scale_tensor") or self._k_scale_tensor.device != key.device:
self._k_scale_tensor = torch.tensor(1.0, device=key.device, dtype=key.dtype)
if not hasattr(self, "_v_scale_tensor") or self._v_scale_tensor.device != value.device:
self._v_scale_tensor = torch.tensor(1.0, device=value.device, dtype=value.dtype)
reshaping_function(
key,
value,
self.key_cache[layer_idx],
self.value_cache[layer_idx],
write_index.to(torch.int64).flatten(),
"auto", # kv_cache_dtype
self._k_scale_tensor, # k_scale
self._v_scale_tensor, # v_scale
)
if kwargs.get("is_decoding", False):
return self.key_cache[layer_idx], self.value_cache[layer_idx]
else:
k = self.key_cache[layer_idx].view(total_slots, self.num_key_value_heads, self.head_dim)
v = self.value_cache[layer_idx].view(total_slots, self.num_key_value_heads, self.head_dim)
return k[read_index, :, :], v[read_index, :, :]
else:
k_cache_flat = self.key_cache[layer_idx].view(total_slots, self.num_key_value_heads, self.head_dim)
v_cache_flat = self.value_cache[layer_idx].view(total_slots, self.num_key_value_heads, self.head_dim)
k_cache_flat[write_index, :, :] = key
v_cache_flat[write_index, :, :] = value
if kwargs.get("is_decoding", False):
return self.key_cache[layer_idx], self.value_cache[layer_idx]
return k_cache_flat[read_index, :, :], v_cache_flat[read_index, :, :]
class Scheduler(ABC):
@ -785,26 +818,32 @@ class ContinuousBatchProcessor:
self.tokenizer = Tokenizer.from_pretrained(self.config._name_or_path)
self.decode_stream = DecodeStream(skip_special_tokens=True)
self.is_decoding = False
@traced(standalone=True)
def setup_static_tensors(self):
T = self.max_batch_tokens
max_token_budget = self.cache.num_blocks * self.cache.block_size
tensor_metadata = {"dtype": torch.int32, "device": self.model_device}
tensor_metadata = {"dtype": torch.int32, "pin_memory": True if torch.cuda.is_available() else False}
self.tensor_metadata = tensor_metadata
self.input_ids = torch.zeros((1, T), **tensor_metadata)
self.position_ids = torch.zeros((1, T), **tensor_metadata)
self.attention_mask = torch.zeros(
(1, 1, T, max_token_budget), dtype=self.model_dtype, device=self.model_device
self.input_ids = torch.zeros((1, T), **tensor_metadata).to(self.model_device, non_blocking=True)
self.position_ids = torch.zeros((1, T), **tensor_metadata).to(self.model_device, non_blocking=True)
self.attention_mask = torch.zeros((1, 1, T, max_token_budget), dtype=self.model_dtype).to(
self.model_device, non_blocking=True
)
self.cumulative_seqlens_q = torch.zeros((T + 1,), **tensor_metadata)
self.cumulative_seqlens_k = torch.zeros((T + 1,), **tensor_metadata)
self.write_index = torch.zeros((T,), **tensor_metadata)
self.read_index = torch.zeros((max_token_budget,), **tensor_metadata)
self.logits_indices = torch.full((T,), -1, **tensor_metadata)
self.cumulative_seqlens_q = torch.zeros((T + 1,), **tensor_metadata).to(self.model_device, non_blocking=True)
self.cumulative_seqlens_k = torch.zeros((T + 1,), **tensor_metadata).to(self.model_device, non_blocking=True)
self.write_index = torch.zeros((T,), **tensor_metadata).to(self.model_device, non_blocking=True)
self.read_index = torch.zeros((max_token_budget,), **tensor_metadata).to(self.model_device, non_blocking=True)
self.logits_indices = torch.full((T,), -1, **tensor_metadata).to(self.model_device, non_blocking=True)
self.max_seqlen_q = 0
self.max_seqlen_k = 0
self.output_ids = torch.full((1, T), -1, **tensor_metadata)
self.output_ids = torch.full((1, T), -1, **tensor_metadata).to(self.model_device, non_blocking=True)
self.block_tables = torch.full(
(T, 200),
fill_value=-1,
dtype=torch.int32,
).to(self.model_device, non_blocking=True)
@traced
@torch.no_grad()
@ -821,6 +860,7 @@ class ContinuousBatchProcessor:
self.max_seqlen_q = 0
self.max_seqlen_k = 0
self.output_ids.zero_()
self.block_tables.fill_(-1)
def get_model_kwargs(self) -> PagedAttentionArgs:
"""Get model keyword arguments for the current batch."""
@ -836,9 +876,10 @@ class ContinuousBatchProcessor:
"logits_indices": self.logits_indices,
"max_seqlen_q": self.max_seqlen_q,
"max_seqlen_k": self.max_seqlen_k,
"block_tables": self.cache._block_tables,
"block_tables": self.block_tables,
"cache": self.cache,
"use_cache": False,
"is_decoding": self.is_decoding,
}
def __repr__(self):
@ -947,7 +988,12 @@ class ContinuousBatchProcessor:
self.max_seqlen_k = max(self.max_seqlen_k, key_length)
state.position_offset += query_length
logger.info(
block_list = self.cache.get_block_table(state.request_id)
self.block_tables[len(cumulative_seqlens_q) - 2, : len(block_list)] = torch.tensor(
block_list, dtype=torch.int32, device=self.model_device
)
logger.warning(
f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. cum KV: {cumulative_seqlens_k[-1]}, free blocks: {self.cache.get_num_free_blocks()}"
)
self._build_tensors(
@ -962,6 +1008,7 @@ class ContinuousBatchProcessor:
self.metrics.record_kv_cache_memory_metrics(self.cache)
self.is_decoding = (self.max_seqlen_q == 1)
@traced
def _build_tensors(
self,
@ -982,7 +1029,10 @@ class ContinuousBatchProcessor:
self.cumulative_seqlens_k[: len(cumulative_seqlens_k)] = to_tensor(cumulative_seqlens_k)
self.logits_indices[: len(logits_indices)] = to_tensor(logits_indices)
min_value = torch.finfo(self.model_dtype).min
if self.config._attn_implementation != "paged_attention": # we set `is_causal` to True in paged call`
if (
self.config._attn_implementation != "paged_attention"
): # we set `is_causal` to True in paged call`
# when decoding with sdpa paged, no need for a mask
for i in range(len(cumulative_seqlens_q) - 1):
if (
cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i]
@ -1118,9 +1168,8 @@ class ContinuousBatchingManager:
self._request_lock = threading.Lock()
self.model.generation_config.top_p = None
self.do_sample = getattr(generation_config, "do_sample", True)
generation_config = model.generation_config if generation_config is None else generation_config
self.logit_processor = self.model._get_logits_processor(generation_config)
self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True)
self.logit_processor = self.model._get_logits_processor(self.model.generation_config)
self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", False)
self.profile = getattr(generation_config, "profile", False)
self.manual_eviction = manual_eviction
self.batch_processor: Optional[ContinuousBatchProcessor] = None
@ -1328,7 +1377,7 @@ class ContinuousBatchingManager:
if self.profile:
tracing_schedule = schedule(skip_first=2, warmup=3, active=200, repeat=100, wait=1)
trace_handler = tensorboard_trace_handler(
dir_name="/fsx/arthur/transformers", use_gzip=True, worker_name="paged_compile"
dir_name="/fsx/mohamed", use_gzip=True, worker_name="paged_compile"
)
activities = [
torch.profiler.ProfilerActivity.CPU,

View File

@ -3,12 +3,13 @@ import torch
from ..generation.continuous_batching import PagedAttentionCache
from ..utils import is_flash_attn_2_available
from kernels import get_kernel
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func # noqa: F401
def paged_attention_forward(
def paged_attention_forward_(
module: torch.nn.Module,
q: torch.Tensor,
k: torch.Tensor,
@ -65,3 +66,73 @@ def paged_attention_forward(
)
return attn_output, None
paged_attention_kernel = get_kernel("kernels-community/paged-attention")
def paged_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor = None,
cache: PagedAttentionCache = None,
cumulative_seqlens_q=None,
cumulative_seqlens_k=None,
max_seqlen_q=None,
max_seqlen_k=None,
block_tables=None,
**kwargs,
) -> tuple[torch.Tensor, None]:
"""Wrapper for paged attention forward that uses flash attention."""
reshaping_function = paged_attention_kernel.reshape_and_cache_flash
is_decoding = kwargs.get("max_seqlen_q", -1) == 1
if not is_decoding:
return paged_attention_forward_(
module,
query,
key,
value,
attention_mask,
reshaping_function=reshaping_function,
**kwargs,
)
else:
num_kv_heads = key.shape[1]
cache = kwargs.pop("cache", None)
key, value = cache.update(key, value, module.layer_idx, reshaping_function=reshaping_function, **kwargs)
batch_size, num_heads, seq_len, head_size = query.shape
query = query.transpose(1, 2).reshape(batch_size * seq_len, num_heads, head_size)
if not hasattr(module, "_attn_out"):
module._attn_output = torch.empty_like(query, device=query.device)
x = 16 // key.element_size()
key = key.view(cache.num_blocks, cache.block_size, num_kv_heads, head_size // x, x).permute(0, 2, 3, 1, 4)
value = value.permute(0, 2, 3, 1).contiguous()
seq_lens = kwargs.get("cumulative_seqlens_k") # .flatten()
block_tables = kwargs.get("block_tables")
block_size = kwargs.get("block_size", 32)
torch.mps.synchronize()
paged_attention_kernel.paged_attention_v1(
module._attn_output,
query,
key, # → [num_blocks, num_kv_heads, head_dim // x, block_size, x], x depends on the dtype
value, # # → [num_blocks, num_kv_heads, head_dim, block_size]
num_kv_heads=num_kv_heads,
block_tables=block_tables,
seq_lens=seq_lens,
block_size=block_size,
max_seq_len=kwargs.get("max_seqlen_k"),
kv_cache_dtype=kwargs.get("kv_cache_dtype", "auto"),
scale=module.scaling,
k_scale=None,
v_scale=None,
alibi_slopes=None,
)
attn_output = module._attn_output.reshape(batch_size, seq_len, num_heads, head_size)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None

View File

@ -2,19 +2,136 @@ from typing import Optional
import torch
from kernels import get_kernel
paged_attention_kernel = get_kernel("kernels-community/paged-attention")
# sdpa_flash_kernel = get_kernel("kernels-community/metal-flash-sdpa")
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
num_kv_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
hidden_states = hidden_states[:, None, :, :].expand(num_kv_heads, n_rep, slen, head_dim)
return hidden_states.reshape(num_kv_heads * n_rep, slen, head_dim)
def sdpa_attention_paged_forward__(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
reshaping_function=None,
is_causal: Optional[bool] = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
cache = kwargs.pop("cache", None)
if cache is not None:
key, value = cache.update(key, value, module.layer_idx, reshaping_function=reshaping_function, **kwargs)
# because of the kernel, the shape of the cache is different
# it return [num_tokens, num_kv_heads, head_dim]
if key.ndim == 3:
key = key.permute(1, 0, 2)
value = value.permute(1, 0, 2)
else:
key = key.view(-1, key.shape[-2], key.shape[-1]).permute(1, 0, 2)
value = value.view(-1, value.shape[-2], value.shape[-1]).permute(1, 0, 2)
if hasattr(module, "num_key_value_groups"):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
causal_mask = attention_mask
# print(f"causal_mask.shape: {causal_mask.shape}")
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=causal_mask,
scale=scaling,
dropout_p=dropout,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
# def sdpa_attention_paged_forward_flash__(
# module: torch.nn.Module,
# query: torch.Tensor,
# key: torch.Tensor,
# value: torch.Tensor,
# attention_mask: Optional[torch.Tensor],
# dropout: float = 0.0,
# scaling: Optional[float] = None,
# reshaping_function=None,
# is_causal: Optional[bool] = None,
# **kwargs,
# ) -> tuple[torch.Tensor, None]:
# cache = kwargs.pop("cache", None)
# if cache is not None:
# key, value = cache.update(key, value, module.layer_idx, reshaping_function=reshaping_function, **kwargs)
# # because of the kernel, the shape of the cache is different
# # it return [num_tokens, num_kv_heads, head_dim]
# if key.ndim == 3:
# key = key.permute(1, 0, 2)
# value = value.permute(1, 0, 2)
# else:
# key = key.view(-1, key.shape[-2], key.shape[-1]).permute(1, 0, 2)
# value = value.view(-1, value.shape[-2], value.shape[-1]).permute(1, 0, 2)
# if hasattr(module, "num_key_value_groups"):
# key = repeat_kv(key, module.num_key_value_groups)
# value = repeat_kv(value, module.num_key_value_groups)
# causal_mask = attention_mask
# # print(f"causal_mask.shape: {causal_mask.shape}")
# cu_seqlen_q = kwargs.get("cumulative_seqlens_q")
# cu_seqlen_k = kwargs.get("cumulative_seqlens_k")
# max_seqlen_q = kwargs.get("max_seqlen_q")
# max_seqlen_k = kwargs.get("max_seqlen_k")
# batch_size, num_heads, seq_len, head_size = query.shape
# query = query.transpose(1, 2).reshape(batch_size * seq_len, num_heads, head_size).contiguous()
# key = key.transpose(0, 1).contiguous()
# value = value.transpose(0, 1).contiguous()
# # print(f"query.shape: {query.shape}")
# # print(f"key.shape: {key.shape}")
# # print(f"value.shape: {value.shape}")
# # print(f"cu_seqlen_q: {cu_seqlen_q}")
# # print(f"cu_seqlen_k: {cu_seqlen_k}")
# # print(f"max_seqlen_q: {max_seqlen_q.item()}")
# # print(f"max_seqlen_k: {max_seqlen_k}")
# if torch.backends.mps.is_available():
# torch.mps.synchronize()
# else:
# torch.cuda.synchronize()
# attn_output =sdpa_flash_kernel.flash_attn_varlen_func(
# query,
# key,
# value,
# cu_seqlen_q,
# cu_seqlen_k,
# max_seqlen_q,
# max_seqlen_k,
# causal=True,
# )
# attn_output = attn_output.view(batch_size, seq_len, num_heads, head_size)
# # attn_output = attn_output.transpose(1, 2).contiguous()
# return attn_output, None
def sdpa_attention_paged_forward(
module: torch.nn.Module,
query: torch.Tensor,
@ -26,26 +143,162 @@ def sdpa_attention_paged_forward(
is_causal: Optional[bool] = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
cache = kwargs.pop("cache", None)
if cache is not None:
key, value = cache.update(key, value, module.layer_idx, **kwargs)
if hasattr(module, "num_key_value_groups"):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
reshaping_function = paged_attention_kernel.reshape_and_cache_flash
causal_mask = attention_mask
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=causal_mask,
dropout_p=dropout,
scale=scaling,
is_causal=False,
)
attn_output = attn_output.transpose(1, 2).contiguous()
is_decoding = kwargs.get("is_decoding")
# is_decoding = False
if not is_decoding:
return sdpa_attention_paged_forward__(
module,
query,
key,
value,
attention_mask,
scaling=scaling,
reshaping_function=reshaping_function,
is_causal=False,
**kwargs,
)
else:
num_kv_heads = key.shape[1]
cache = kwargs.pop("cache", None)
key, value = cache.update(key, value, module.layer_idx, reshaping_function=reshaping_function, **kwargs)
batch_size, num_heads, seq_len, head_size = query.shape
query = query.transpose(1, 2).reshape(batch_size * seq_len, num_heads, head_size).contiguous()
# Get max sequence length to determine if we need v2
max_seq_len = kwargs.get("max_seqlen_k", 0)
partition_size = 512 # Standard partition size for v2
use_v2 = max_seq_len > 1024 # Use v2 for longer sequences
# Introduce another runtime error - accessing a non-existent attribute
if not hasattr(module, "_attn_output"):
module._attn_output = torch.zeros(batch_size * seq_len, num_heads, head_size, device=query.device)
return attn_output, None
x = 16 // key.element_size()
key = key.view(cache.num_blocks, cache.block_size, num_kv_heads, head_size // x, x).permute(0, 2, 3, 1, 4).contiguous()
value = value.permute(0, 2, 3, 1).contiguous()
if hasattr(module, "num_key_value_groups"):
num_kv_heads = num_kv_heads * module.num_key_value_groups
key = torch.repeat_interleave(key, module.num_key_value_groups, dim=1)
value = torch.repeat_interleave(value, module.num_key_value_groups, dim=1)
seq_lens = kwargs.get("cumulative_seqlens_k")
if seq_lens is not None:
seq_lens = torch.diff(seq_lens)
if (seq_lens < 0).any():
seq_lens = torch.clamp(seq_lens, min=0)
block_tables = kwargs.get("block_tables")
if block_tables is None:
raise ValueError("block_tables is required for decoding mode")
if seq_lens is None:
raise ValueError("seq_lens is required for decoding mode")
block_size = cache.block_size
# Pre-create scale tensors to avoid CUDA graph capture issues
if not hasattr(module, "_k_scale_tensor"):
module._k_scale_tensor = torch.tensor(1.0, device=key.device, dtype=key.dtype)
if not hasattr(module, "_v_scale_tensor"):
module._v_scale_tensor = torch.tensor(1.0, device=value.device, dtype=value.dtype)
# Ensure all tensors are on the same device and contiguous
if query.device != key.device:
query = query.to(key.device)
if module._attn_output.device != key.device:
module._attn_output = module._attn_output.to(key.device)
try:
if torch.cuda.is_available():
torch.cuda.synchronize()
elif torch.backends.mps.is_available():
torch.mps.synchronize()
else:
raise RuntimeError("No CUDA or MPS available")
if use_v2:
# Calculate number of partitions for v2
max_num_partitions = (max_seq_len + partition_size - 1) // partition_size
# Create v2-specific tensors
if not hasattr(module, "_exp_sums") or module._exp_sums.shape != (batch_size, num_heads, max_num_partitions):
module._exp_sums = torch.empty(
(batch_size, num_heads, max_num_partitions),
dtype=torch.float32,
device=key.device
)
if not hasattr(module, "_max_logits") or module._max_logits.shape != (batch_size, num_heads, max_num_partitions):
module._max_logits = torch.empty(
(batch_size, num_heads, max_num_partitions),
dtype=torch.float32,
device=key.device
)
if not hasattr(module, "_tmp_out") or module._tmp_out.shape != (batch_size, num_heads, max_num_partitions, head_size):
module._tmp_out = torch.empty(
(batch_size, num_heads, max_num_partitions, head_size),
dtype=query.dtype,
device=key.device
)
paged_attention_kernel.paged_attention_v2(
module._attn_output,
module._exp_sums,
module._max_logits,
module._tmp_out,
query,
key, # → [num_blocks, num_kv_heads, head_dim // x, block_size, x], x depends on the dtype
value, # # → [num_blocks, num_kv_heads, head_dim, block_size]
num_kv_heads=num_kv_heads,
block_tables=block_tables,
seq_lens=seq_lens,
block_size=block_size,
max_seq_len=max_seq_len,
kv_cache_dtype=kwargs.get("kv_cache_dtype", "auto"),
scale=scaling,
k_scale=module._k_scale_tensor,
v_scale=module._v_scale_tensor,
alibi_slopes=None,
)
else:
paged_attention_kernel.paged_attention_v1(
module._attn_output,
query,
key, # → [num_blocks, num_kv_heads, head_dim // x, block_size, x], x depends on the dtype
value, # # → [num_blocks, num_kv_heads, head_dim, block_size]
num_kv_heads=num_kv_heads,
block_tables=block_tables,
seq_lens=seq_lens,
block_size=block_size,
max_seq_len=max_seq_len,
kv_cache_dtype=kwargs.get("kv_cache_dtype", "auto"),
scale=scaling,
k_scale=module._k_scale_tensor,
v_scale=module._v_scale_tensor,
alibi_slopes=None,
)
# if torch.cuda.is_available():
# torch.cuda.synchronize()
# elif torch.backends.mps.is_available():
# torch.mps.synchronize()
# else:
# raise RuntimeError("No CUDA or MPS available")
except RuntimeError as e:
print(f"Error in paged_attention_{'v2' if use_v2 else 'v1'}: {e}")
print(f"Shapes - query: {query.shape}, key: {key.shape}, value: {value.shape}")
print(f"Output shape: {module._attn_output.shape}")
print(f"block_tables shape: {block_tables.shape if block_tables is not None else None}")
print(f"seq_lens shape: {seq_lens.shape if seq_lens is not None else None}")
if use_v2:
print(f"max_num_partitions: {max_num_partitions}")
print(f"exp_sums shape: {module._exp_sums.shape}")
print(f"max_logits shape: {module._max_logits.shape}")
print(f"tmp_out shape: {module._tmp_out.shape}")
raise
module._attn_output = module._attn_output.to(torch.bfloat16)
attn_output = module._attn_output.view(batch_size, seq_len, num_heads, head_size)
return attn_output, None