mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
35 Commits
v4.56.0
...
add-block-
Author | SHA1 | Date | |
---|---|---|---|
fbd288fb04 | |||
eb714c4bc1 | |||
028e6d6e76 | |||
b6322afd7d | |||
04828074eb | |||
3dc01c46bd | |||
d0be498d03 | |||
5e9b251ab7 | |||
255bc3c951 | |||
1fd77c3d54 | |||
5189c26e33 | |||
2721261e1b | |||
571af3e235 | |||
51e115ac74 | |||
962f7c1625 | |||
fa3e0db02d | |||
3cbe825443 | |||
6fbe4177fa | |||
9f47da148a | |||
da935a6827 | |||
f73db2b6d0 | |||
c5765160a2 | |||
6eec32cf80 | |||
cbe3e6324b | |||
c2cf536c3d | |||
4d86a5bb11 | |||
ccf8a46c73 | |||
f205a9c9d6 | |||
458b35224f | |||
012d8f474d | |||
7381bfabb5 | |||
449dfd135c | |||
1652a2f9f0 | |||
0bfa6cc60a | |||
4c5fbb9b92 |
@ -9,7 +9,7 @@ from transformers.generation import GenerationConfig
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
model_id = "meta-llama/Llama-3.2-3b-Instruct"
|
||||
model_id = "meta-llama/Llama-3.2-1b-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, attn_implementation="sdpa_paged", torch_dtype=torch.bfloat16, device_map="auto"
|
||||
).eval()
|
||||
@ -20,14 +20,22 @@ generation_config = GenerationConfig(
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
use_cache=False,
|
||||
num_blocks=2048,
|
||||
block_size=128,
|
||||
num_blocks=128,
|
||||
block_size=32,
|
||||
do_sample=True,
|
||||
max_batch_tokens=1024, # Maximum number of tokens to process in a single batch
|
||||
max_batch_tokens=64, # Maximum number of tokens to process in a single batch
|
||||
scheduler="prefill_first",
|
||||
)
|
||||
|
||||
train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
|
||||
train_dataset = train_dataset.select(range(5))
|
||||
|
||||
# Create a dataset from a list with one element
|
||||
# from datasets import Dataset
|
||||
|
||||
# single_element_list = [{"question": "give me a very long story"}]
|
||||
# train_dataset = Dataset.from_list(single_element_list)
|
||||
|
||||
|
||||
# --- Example 1: Simple Version using generate_batch ---
|
||||
print("--- Running CB Generation Example ---")
|
||||
@ -48,6 +56,7 @@ batch_outputs = model.generate_batch(
|
||||
)
|
||||
end_time_simple = time.time()
|
||||
|
||||
generated_tokens: int = 0
|
||||
for request in batch_outputs:
|
||||
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=False)
|
||||
try:
|
||||
@ -58,6 +67,7 @@ for request in batch_outputs:
|
||||
if len(output_text) > 0:
|
||||
print("-" * 20)
|
||||
print(f"{request} Input: {input_text}")
|
||||
generated_tokens += len(batch_outputs[request].generated_tokens)
|
||||
print(f"{request} Output: {output_text}")
|
||||
else:
|
||||
print("", end="\r\r\r\r")
|
||||
@ -65,8 +75,9 @@ print("-" * 20)
|
||||
print("--- Finished CB Generation Example ---\n\n")
|
||||
|
||||
|
||||
print(f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds")
|
||||
|
||||
print(
|
||||
f"CB generation took: {end_time_simple - start_time_simple:.2f} seconds for {generated_tokens} generated tokens. So {generated_tokens / (end_time_simple - start_time_simple)} tok/s\n"
|
||||
)
|
||||
|
||||
# train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version
|
||||
|
||||
|
@ -205,7 +205,8 @@ class PagedAttentionCache:
|
||||
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
|
||||
num_key_value_heads //= tp_size
|
||||
|
||||
self.cache_shape = (num_key_value_heads, num_blocks, self.block_size, self.head_dim)
|
||||
# self.cache_shape = (num_key_value_heads, num_blocks, self.block_size, self.head_dim)
|
||||
self.cache_shape = (num_blocks, self.block_size, self.num_key_value_heads, self.head_dim)
|
||||
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
@ -307,15 +308,47 @@ class PagedAttentionCache:
|
||||
layer_idx: int,
|
||||
read_index,
|
||||
write_index,
|
||||
reshaping_function,
|
||||
kernel=True,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Reshape cache for easier indexing
|
||||
total_slots = self.num_blocks * self.block_size
|
||||
k_cache_flat = self.key_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim)
|
||||
v_cache_flat = self.value_cache[layer_idx].view(self.num_key_value_heads, total_slots, self.head_dim)
|
||||
k_cache_flat[:, write_index, :] = key_states[0]
|
||||
v_cache_flat[:, write_index, :] = value_states[0]
|
||||
return k_cache_flat[None, :, read_index, :], v_cache_flat[None, :, read_index, :]
|
||||
batch_size, num_heads, seq_len, head_size = key_states.shape
|
||||
key = key_states.transpose(1, 2).view(batch_size * seq_len, num_heads, head_size)
|
||||
value = value_states.transpose(1, 2).view(batch_size * seq_len, num_heads, head_size)
|
||||
if kernel:
|
||||
# Pre-create scale tensors to avoid CUDA graph capture issues
|
||||
if not hasattr(self, "_k_scale_tensor") or self._k_scale_tensor.device != key.device:
|
||||
self._k_scale_tensor = torch.tensor(1.0, device=key.device, dtype=key.dtype)
|
||||
if not hasattr(self, "_v_scale_tensor") or self._v_scale_tensor.device != value.device:
|
||||
self._v_scale_tensor = torch.tensor(1.0, device=value.device, dtype=value.dtype)
|
||||
|
||||
reshaping_function(
|
||||
key,
|
||||
value,
|
||||
self.key_cache[layer_idx],
|
||||
self.value_cache[layer_idx],
|
||||
write_index.to(torch.int64).flatten(),
|
||||
"auto", # kv_cache_dtype
|
||||
self._k_scale_tensor, # k_scale
|
||||
self._v_scale_tensor, # v_scale
|
||||
)
|
||||
|
||||
if kwargs.get("is_decoding", False):
|
||||
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||
else:
|
||||
k = self.key_cache[layer_idx].view(total_slots, self.num_key_value_heads, self.head_dim)
|
||||
v = self.value_cache[layer_idx].view(total_slots, self.num_key_value_heads, self.head_dim)
|
||||
return k[read_index, :, :], v[read_index, :, :]
|
||||
else:
|
||||
k_cache_flat = self.key_cache[layer_idx].view(total_slots, self.num_key_value_heads, self.head_dim)
|
||||
v_cache_flat = self.value_cache[layer_idx].view(total_slots, self.num_key_value_heads, self.head_dim)
|
||||
k_cache_flat[write_index, :, :] = key
|
||||
v_cache_flat[write_index, :, :] = value
|
||||
if kwargs.get("is_decoding", False):
|
||||
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||
|
||||
return k_cache_flat[read_index, :, :], v_cache_flat[read_index, :, :]
|
||||
|
||||
|
||||
class Scheduler(ABC):
|
||||
@ -785,26 +818,32 @@ class ContinuousBatchProcessor:
|
||||
|
||||
self.tokenizer = Tokenizer.from_pretrained(self.config._name_or_path)
|
||||
self.decode_stream = DecodeStream(skip_special_tokens=True)
|
||||
self.is_decoding = False
|
||||
|
||||
@traced(standalone=True)
|
||||
def setup_static_tensors(self):
|
||||
T = self.max_batch_tokens
|
||||
max_token_budget = self.cache.num_blocks * self.cache.block_size
|
||||
tensor_metadata = {"dtype": torch.int32, "device": self.model_device}
|
||||
tensor_metadata = {"dtype": torch.int32, "pin_memory": True if torch.cuda.is_available() else False}
|
||||
self.tensor_metadata = tensor_metadata
|
||||
self.input_ids = torch.zeros((1, T), **tensor_metadata)
|
||||
self.position_ids = torch.zeros((1, T), **tensor_metadata)
|
||||
self.attention_mask = torch.zeros(
|
||||
(1, 1, T, max_token_budget), dtype=self.model_dtype, device=self.model_device
|
||||
self.input_ids = torch.zeros((1, T), **tensor_metadata).to(self.model_device, non_blocking=True)
|
||||
self.position_ids = torch.zeros((1, T), **tensor_metadata).to(self.model_device, non_blocking=True)
|
||||
self.attention_mask = torch.zeros((1, 1, T, max_token_budget), dtype=self.model_dtype).to(
|
||||
self.model_device, non_blocking=True
|
||||
)
|
||||
self.cumulative_seqlens_q = torch.zeros((T + 1,), **tensor_metadata)
|
||||
self.cumulative_seqlens_k = torch.zeros((T + 1,), **tensor_metadata)
|
||||
self.write_index = torch.zeros((T,), **tensor_metadata)
|
||||
self.read_index = torch.zeros((max_token_budget,), **tensor_metadata)
|
||||
self.logits_indices = torch.full((T,), -1, **tensor_metadata)
|
||||
self.cumulative_seqlens_q = torch.zeros((T + 1,), **tensor_metadata).to(self.model_device, non_blocking=True)
|
||||
self.cumulative_seqlens_k = torch.zeros((T + 1,), **tensor_metadata).to(self.model_device, non_blocking=True)
|
||||
self.write_index = torch.zeros((T,), **tensor_metadata).to(self.model_device, non_blocking=True)
|
||||
self.read_index = torch.zeros((max_token_budget,), **tensor_metadata).to(self.model_device, non_blocking=True)
|
||||
self.logits_indices = torch.full((T,), -1, **tensor_metadata).to(self.model_device, non_blocking=True)
|
||||
self.max_seqlen_q = 0
|
||||
self.max_seqlen_k = 0
|
||||
self.output_ids = torch.full((1, T), -1, **tensor_metadata)
|
||||
self.output_ids = torch.full((1, T), -1, **tensor_metadata).to(self.model_device, non_blocking=True)
|
||||
self.block_tables = torch.full(
|
||||
(T, 200),
|
||||
fill_value=-1,
|
||||
dtype=torch.int32,
|
||||
).to(self.model_device, non_blocking=True)
|
||||
|
||||
@traced
|
||||
@torch.no_grad()
|
||||
@ -821,6 +860,7 @@ class ContinuousBatchProcessor:
|
||||
self.max_seqlen_q = 0
|
||||
self.max_seqlen_k = 0
|
||||
self.output_ids.zero_()
|
||||
self.block_tables.fill_(-1)
|
||||
|
||||
def get_model_kwargs(self) -> PagedAttentionArgs:
|
||||
"""Get model keyword arguments for the current batch."""
|
||||
@ -836,9 +876,10 @@ class ContinuousBatchProcessor:
|
||||
"logits_indices": self.logits_indices,
|
||||
"max_seqlen_q": self.max_seqlen_q,
|
||||
"max_seqlen_k": self.max_seqlen_k,
|
||||
"block_tables": self.cache._block_tables,
|
||||
"block_tables": self.block_tables,
|
||||
"cache": self.cache,
|
||||
"use_cache": False,
|
||||
"is_decoding": self.is_decoding,
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
@ -947,7 +988,12 @@ class ContinuousBatchProcessor:
|
||||
self.max_seqlen_k = max(self.max_seqlen_k, key_length)
|
||||
state.position_offset += query_length
|
||||
|
||||
logger.info(
|
||||
block_list = self.cache.get_block_table(state.request_id)
|
||||
self.block_tables[len(cumulative_seqlens_q) - 2, : len(block_list)] = torch.tensor(
|
||||
block_list, dtype=torch.int32, device=self.model_device
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"Scheduled: {len(self.requests_in_batch)}, Waiting: {len(self.scheduler.waiting_requests)}, Active: {len(self.scheduler.active_requests)}. cum Q: {cumulative_seqlens_q[-1]}. cum KV: {cumulative_seqlens_k[-1]}, free blocks: {self.cache.get_num_free_blocks()}"
|
||||
)
|
||||
self._build_tensors(
|
||||
@ -962,6 +1008,7 @@ class ContinuousBatchProcessor:
|
||||
|
||||
self.metrics.record_kv_cache_memory_metrics(self.cache)
|
||||
|
||||
self.is_decoding = (self.max_seqlen_q == 1)
|
||||
@traced
|
||||
def _build_tensors(
|
||||
self,
|
||||
@ -982,7 +1029,10 @@ class ContinuousBatchProcessor:
|
||||
self.cumulative_seqlens_k[: len(cumulative_seqlens_k)] = to_tensor(cumulative_seqlens_k)
|
||||
self.logits_indices[: len(logits_indices)] = to_tensor(logits_indices)
|
||||
min_value = torch.finfo(self.model_dtype).min
|
||||
if self.config._attn_implementation != "paged_attention": # we set `is_causal` to True in paged call`
|
||||
if (
|
||||
self.config._attn_implementation != "paged_attention"
|
||||
): # we set `is_causal` to True in paged call`
|
||||
# when decoding with sdpa paged, no need for a mask
|
||||
for i in range(len(cumulative_seqlens_q) - 1):
|
||||
if (
|
||||
cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i]
|
||||
@ -1118,9 +1168,8 @@ class ContinuousBatchingManager:
|
||||
self._request_lock = threading.Lock()
|
||||
self.model.generation_config.top_p = None
|
||||
self.do_sample = getattr(generation_config, "do_sample", True)
|
||||
generation_config = model.generation_config if generation_config is None else generation_config
|
||||
self.logit_processor = self.model._get_logits_processor(generation_config)
|
||||
self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", True)
|
||||
self.logit_processor = self.model._get_logits_processor(self.model.generation_config)
|
||||
self.use_cuda_graph = getattr(generation_config, "use_cuda_graph", False)
|
||||
self.profile = getattr(generation_config, "profile", False)
|
||||
self.manual_eviction = manual_eviction
|
||||
self.batch_processor: Optional[ContinuousBatchProcessor] = None
|
||||
@ -1328,7 +1377,7 @@ class ContinuousBatchingManager:
|
||||
if self.profile:
|
||||
tracing_schedule = schedule(skip_first=2, warmup=3, active=200, repeat=100, wait=1)
|
||||
trace_handler = tensorboard_trace_handler(
|
||||
dir_name="/fsx/arthur/transformers", use_gzip=True, worker_name="paged_compile"
|
||||
dir_name="/fsx/mohamed", use_gzip=True, worker_name="paged_compile"
|
||||
)
|
||||
activities = [
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
|
@ -3,12 +3,13 @@ import torch
|
||||
from ..generation.continuous_batching import PagedAttentionCache
|
||||
from ..utils import is_flash_attn_2_available
|
||||
|
||||
from kernels import get_kernel
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import flash_attn_varlen_func # noqa: F401
|
||||
|
||||
|
||||
def paged_attention_forward(
|
||||
def paged_attention_forward_(
|
||||
module: torch.nn.Module,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
@ -65,3 +66,73 @@ def paged_attention_forward(
|
||||
)
|
||||
|
||||
return attn_output, None
|
||||
|
||||
|
||||
paged_attention_kernel = get_kernel("kernels-community/paged-attention")
|
||||
|
||||
|
||||
def paged_attention_forward(
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor = None,
|
||||
cache: PagedAttentionCache = None,
|
||||
cumulative_seqlens_q=None,
|
||||
cumulative_seqlens_k=None,
|
||||
max_seqlen_q=None,
|
||||
max_seqlen_k=None,
|
||||
block_tables=None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, None]:
|
||||
"""Wrapper for paged attention forward that uses flash attention."""
|
||||
reshaping_function = paged_attention_kernel.reshape_and_cache_flash
|
||||
is_decoding = kwargs.get("max_seqlen_q", -1) == 1
|
||||
if not is_decoding:
|
||||
return paged_attention_forward_(
|
||||
module,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_mask,
|
||||
reshaping_function=reshaping_function,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
num_kv_heads = key.shape[1]
|
||||
cache = kwargs.pop("cache", None)
|
||||
key, value = cache.update(key, value, module.layer_idx, reshaping_function=reshaping_function, **kwargs)
|
||||
|
||||
batch_size, num_heads, seq_len, head_size = query.shape
|
||||
query = query.transpose(1, 2).reshape(batch_size * seq_len, num_heads, head_size)
|
||||
|
||||
if not hasattr(module, "_attn_out"):
|
||||
module._attn_output = torch.empty_like(query, device=query.device)
|
||||
|
||||
x = 16 // key.element_size()
|
||||
key = key.view(cache.num_blocks, cache.block_size, num_kv_heads, head_size // x, x).permute(0, 2, 3, 1, 4)
|
||||
value = value.permute(0, 2, 3, 1).contiguous()
|
||||
seq_lens = kwargs.get("cumulative_seqlens_k") # .flatten()
|
||||
block_tables = kwargs.get("block_tables")
|
||||
block_size = kwargs.get("block_size", 32)
|
||||
torch.mps.synchronize()
|
||||
paged_attention_kernel.paged_attention_v1(
|
||||
module._attn_output,
|
||||
query,
|
||||
key, # → [num_blocks, num_kv_heads, head_dim // x, block_size, x], x depends on the dtype
|
||||
value, # # → [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
num_kv_heads=num_kv_heads,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens,
|
||||
block_size=block_size,
|
||||
max_seq_len=kwargs.get("max_seqlen_k"),
|
||||
kv_cache_dtype=kwargs.get("kv_cache_dtype", "auto"),
|
||||
scale=module.scaling,
|
||||
k_scale=None,
|
||||
v_scale=None,
|
||||
alibi_slopes=None,
|
||||
)
|
||||
|
||||
attn_output = module._attn_output.reshape(batch_size, seq_len, num_heads, head_size)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, None
|
||||
|
@ -2,19 +2,136 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from kernels import get_kernel
|
||||
|
||||
paged_attention_kernel = get_kernel("kernels-community/paged-attention")
|
||||
# sdpa_flash_kernel = get_kernel("kernels-community/metal-flash-sdpa")
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
num_kv_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
hidden_states = hidden_states[:, None, :, :].expand(num_kv_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(num_kv_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def sdpa_attention_paged_forward__(
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
dropout: float = 0.0,
|
||||
scaling: Optional[float] = None,
|
||||
reshaping_function=None,
|
||||
is_causal: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, None]:
|
||||
cache = kwargs.pop("cache", None)
|
||||
if cache is not None:
|
||||
key, value = cache.update(key, value, module.layer_idx, reshaping_function=reshaping_function, **kwargs)
|
||||
|
||||
# because of the kernel, the shape of the cache is different
|
||||
# it return [num_tokens, num_kv_heads, head_dim]
|
||||
|
||||
if key.ndim == 3:
|
||||
key = key.permute(1, 0, 2)
|
||||
value = value.permute(1, 0, 2)
|
||||
else:
|
||||
key = key.view(-1, key.shape[-2], key.shape[-1]).permute(1, 0, 2)
|
||||
value = value.view(-1, value.shape[-2], value.shape[-1]).permute(1, 0, 2)
|
||||
|
||||
if hasattr(module, "num_key_value_groups"):
|
||||
key = repeat_kv(key, module.num_key_value_groups)
|
||||
value = repeat_kv(value, module.num_key_value_groups)
|
||||
causal_mask = attention_mask
|
||||
# print(f"causal_mask.shape: {causal_mask.shape}")
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=causal_mask,
|
||||
scale=scaling,
|
||||
dropout_p=dropout,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, None
|
||||
|
||||
# def sdpa_attention_paged_forward_flash__(
|
||||
# module: torch.nn.Module,
|
||||
# query: torch.Tensor,
|
||||
# key: torch.Tensor,
|
||||
# value: torch.Tensor,
|
||||
# attention_mask: Optional[torch.Tensor],
|
||||
# dropout: float = 0.0,
|
||||
# scaling: Optional[float] = None,
|
||||
# reshaping_function=None,
|
||||
# is_causal: Optional[bool] = None,
|
||||
# **kwargs,
|
||||
# ) -> tuple[torch.Tensor, None]:
|
||||
# cache = kwargs.pop("cache", None)
|
||||
# if cache is not None:
|
||||
# key, value = cache.update(key, value, module.layer_idx, reshaping_function=reshaping_function, **kwargs)
|
||||
|
||||
# # because of the kernel, the shape of the cache is different
|
||||
# # it return [num_tokens, num_kv_heads, head_dim]
|
||||
|
||||
# if key.ndim == 3:
|
||||
# key = key.permute(1, 0, 2)
|
||||
# value = value.permute(1, 0, 2)
|
||||
# else:
|
||||
# key = key.view(-1, key.shape[-2], key.shape[-1]).permute(1, 0, 2)
|
||||
# value = value.view(-1, value.shape[-2], value.shape[-1]).permute(1, 0, 2)
|
||||
|
||||
# if hasattr(module, "num_key_value_groups"):
|
||||
# key = repeat_kv(key, module.num_key_value_groups)
|
||||
# value = repeat_kv(value, module.num_key_value_groups)
|
||||
# causal_mask = attention_mask
|
||||
# # print(f"causal_mask.shape: {causal_mask.shape}")
|
||||
# cu_seqlen_q = kwargs.get("cumulative_seqlens_q")
|
||||
# cu_seqlen_k = kwargs.get("cumulative_seqlens_k")
|
||||
# max_seqlen_q = kwargs.get("max_seqlen_q")
|
||||
# max_seqlen_k = kwargs.get("max_seqlen_k")
|
||||
|
||||
# batch_size, num_heads, seq_len, head_size = query.shape
|
||||
# query = query.transpose(1, 2).reshape(batch_size * seq_len, num_heads, head_size).contiguous()
|
||||
# key = key.transpose(0, 1).contiguous()
|
||||
# value = value.transpose(0, 1).contiguous()
|
||||
# # print(f"query.shape: {query.shape}")
|
||||
# # print(f"key.shape: {key.shape}")
|
||||
# # print(f"value.shape: {value.shape}")
|
||||
# # print(f"cu_seqlen_q: {cu_seqlen_q}")
|
||||
# # print(f"cu_seqlen_k: {cu_seqlen_k}")
|
||||
# # print(f"max_seqlen_q: {max_seqlen_q.item()}")
|
||||
# # print(f"max_seqlen_k: {max_seqlen_k}")
|
||||
# if torch.backends.mps.is_available():
|
||||
# torch.mps.synchronize()
|
||||
# else:
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
# attn_output =sdpa_flash_kernel.flash_attn_varlen_func(
|
||||
# query,
|
||||
# key,
|
||||
# value,
|
||||
# cu_seqlen_q,
|
||||
# cu_seqlen_k,
|
||||
# max_seqlen_q,
|
||||
# max_seqlen_k,
|
||||
# causal=True,
|
||||
# )
|
||||
# attn_output = attn_output.view(batch_size, seq_len, num_heads, head_size)
|
||||
# # attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
# return attn_output, None
|
||||
|
||||
def sdpa_attention_paged_forward(
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
@ -26,26 +143,162 @@ def sdpa_attention_paged_forward(
|
||||
is_causal: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, None]:
|
||||
cache = kwargs.pop("cache", None)
|
||||
if cache is not None:
|
||||
key, value = cache.update(key, value, module.layer_idx, **kwargs)
|
||||
if hasattr(module, "num_key_value_groups"):
|
||||
key = repeat_kv(key, module.num_key_value_groups)
|
||||
value = repeat_kv(value, module.num_key_value_groups)
|
||||
reshaping_function = paged_attention_kernel.reshape_and_cache_flash
|
||||
|
||||
causal_mask = attention_mask
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=causal_mask,
|
||||
dropout_p=dropout,
|
||||
scale=scaling,
|
||||
is_causal=False,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
is_decoding = kwargs.get("is_decoding")
|
||||
# is_decoding = False
|
||||
if not is_decoding:
|
||||
return sdpa_attention_paged_forward__(
|
||||
module,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_mask,
|
||||
scaling=scaling,
|
||||
reshaping_function=reshaping_function,
|
||||
is_causal=False,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
num_kv_heads = key.shape[1]
|
||||
cache = kwargs.pop("cache", None)
|
||||
key, value = cache.update(key, value, module.layer_idx, reshaping_function=reshaping_function, **kwargs)
|
||||
batch_size, num_heads, seq_len, head_size = query.shape
|
||||
query = query.transpose(1, 2).reshape(batch_size * seq_len, num_heads, head_size).contiguous()
|
||||
|
||||
# Get max sequence length to determine if we need v2
|
||||
max_seq_len = kwargs.get("max_seqlen_k", 0)
|
||||
partition_size = 512 # Standard partition size for v2
|
||||
use_v2 = max_seq_len > 1024 # Use v2 for longer sequences
|
||||
|
||||
# Introduce another runtime error - accessing a non-existent attribute
|
||||
if not hasattr(module, "_attn_output"):
|
||||
module._attn_output = torch.zeros(batch_size * seq_len, num_heads, head_size, device=query.device)
|
||||
|
||||
return attn_output, None
|
||||
x = 16 // key.element_size()
|
||||
key = key.view(cache.num_blocks, cache.block_size, num_kv_heads, head_size // x, x).permute(0, 2, 3, 1, 4).contiguous()
|
||||
value = value.permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
if hasattr(module, "num_key_value_groups"):
|
||||
num_kv_heads = num_kv_heads * module.num_key_value_groups
|
||||
key = torch.repeat_interleave(key, module.num_key_value_groups, dim=1)
|
||||
value = torch.repeat_interleave(value, module.num_key_value_groups, dim=1)
|
||||
|
||||
seq_lens = kwargs.get("cumulative_seqlens_k")
|
||||
if seq_lens is not None:
|
||||
seq_lens = torch.diff(seq_lens)
|
||||
if (seq_lens < 0).any():
|
||||
seq_lens = torch.clamp(seq_lens, min=0)
|
||||
|
||||
block_tables = kwargs.get("block_tables")
|
||||
if block_tables is None:
|
||||
raise ValueError("block_tables is required for decoding mode")
|
||||
if seq_lens is None:
|
||||
raise ValueError("seq_lens is required for decoding mode")
|
||||
block_size = cache.block_size
|
||||
|
||||
# Pre-create scale tensors to avoid CUDA graph capture issues
|
||||
if not hasattr(module, "_k_scale_tensor"):
|
||||
module._k_scale_tensor = torch.tensor(1.0, device=key.device, dtype=key.dtype)
|
||||
if not hasattr(module, "_v_scale_tensor"):
|
||||
module._v_scale_tensor = torch.tensor(1.0, device=value.device, dtype=value.dtype)
|
||||
|
||||
# Ensure all tensors are on the same device and contiguous
|
||||
if query.device != key.device:
|
||||
query = query.to(key.device)
|
||||
if module._attn_output.device != key.device:
|
||||
module._attn_output = module._attn_output.to(key.device)
|
||||
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
elif torch.backends.mps.is_available():
|
||||
torch.mps.synchronize()
|
||||
else:
|
||||
raise RuntimeError("No CUDA or MPS available")
|
||||
|
||||
if use_v2:
|
||||
# Calculate number of partitions for v2
|
||||
max_num_partitions = (max_seq_len + partition_size - 1) // partition_size
|
||||
|
||||
# Create v2-specific tensors
|
||||
if not hasattr(module, "_exp_sums") or module._exp_sums.shape != (batch_size, num_heads, max_num_partitions):
|
||||
module._exp_sums = torch.empty(
|
||||
(batch_size, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=key.device
|
||||
)
|
||||
|
||||
if not hasattr(module, "_max_logits") or module._max_logits.shape != (batch_size, num_heads, max_num_partitions):
|
||||
module._max_logits = torch.empty(
|
||||
(batch_size, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=key.device
|
||||
)
|
||||
|
||||
if not hasattr(module, "_tmp_out") or module._tmp_out.shape != (batch_size, num_heads, max_num_partitions, head_size):
|
||||
module._tmp_out = torch.empty(
|
||||
(batch_size, num_heads, max_num_partitions, head_size),
|
||||
dtype=query.dtype,
|
||||
device=key.device
|
||||
)
|
||||
|
||||
paged_attention_kernel.paged_attention_v2(
|
||||
module._attn_output,
|
||||
module._exp_sums,
|
||||
module._max_logits,
|
||||
module._tmp_out,
|
||||
query,
|
||||
key, # → [num_blocks, num_kv_heads, head_dim // x, block_size, x], x depends on the dtype
|
||||
value, # # → [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
num_kv_heads=num_kv_heads,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens,
|
||||
block_size=block_size,
|
||||
max_seq_len=max_seq_len,
|
||||
kv_cache_dtype=kwargs.get("kv_cache_dtype", "auto"),
|
||||
scale=scaling,
|
||||
k_scale=module._k_scale_tensor,
|
||||
v_scale=module._v_scale_tensor,
|
||||
alibi_slopes=None,
|
||||
)
|
||||
else:
|
||||
paged_attention_kernel.paged_attention_v1(
|
||||
module._attn_output,
|
||||
query,
|
||||
key, # → [num_blocks, num_kv_heads, head_dim // x, block_size, x], x depends on the dtype
|
||||
value, # # → [num_blocks, num_kv_heads, head_dim, block_size]
|
||||
num_kv_heads=num_kv_heads,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens,
|
||||
block_size=block_size,
|
||||
max_seq_len=max_seq_len,
|
||||
kv_cache_dtype=kwargs.get("kv_cache_dtype", "auto"),
|
||||
scale=scaling,
|
||||
k_scale=module._k_scale_tensor,
|
||||
v_scale=module._v_scale_tensor,
|
||||
alibi_slopes=None,
|
||||
)
|
||||
|
||||
# if torch.cuda.is_available():
|
||||
# torch.cuda.synchronize()
|
||||
# elif torch.backends.mps.is_available():
|
||||
# torch.mps.synchronize()
|
||||
# else:
|
||||
# raise RuntimeError("No CUDA or MPS available")
|
||||
except RuntimeError as e:
|
||||
print(f"Error in paged_attention_{'v2' if use_v2 else 'v1'}: {e}")
|
||||
print(f"Shapes - query: {query.shape}, key: {key.shape}, value: {value.shape}")
|
||||
print(f"Output shape: {module._attn_output.shape}")
|
||||
print(f"block_tables shape: {block_tables.shape if block_tables is not None else None}")
|
||||
print(f"seq_lens shape: {seq_lens.shape if seq_lens is not None else None}")
|
||||
if use_v2:
|
||||
print(f"max_num_partitions: {max_num_partitions}")
|
||||
print(f"exp_sums shape: {module._exp_sums.shape}")
|
||||
print(f"max_logits shape: {module._max_logits.shape}")
|
||||
print(f"tmp_out shape: {module._tmp_out.shape}")
|
||||
raise
|
||||
|
||||
module._attn_output = module._attn_output.to(torch.bfloat16)
|
||||
attn_output = module._attn_output.view(batch_size, seq_len, num_heads, head_size)
|
||||
return attn_output, None
|
||||
|
Reference in New Issue
Block a user