[Kernel] Have rotary embeddings support tensors (#18046)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson
2025-05-14 18:43:55 -04:00
committed by GitHub
parent 749f792553
commit d93c976a0d
4 changed files with 59 additions and 31 deletions

View File

@ -44,7 +44,8 @@ inline __device__ void apply_rotary_embedding(
// head_size]
const scalar_t* cache_ptr, const int head_size, const int num_heads,
const int num_kv_heads, const int rot_dim, const int token_idx,
const int64_t query_stride, const int64_t key_stride) {
const int64_t query_stride, const int64_t key_stride,
const int64_t head_stride) {
const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim;
@ -52,7 +53,8 @@ inline __device__ void apply_rotary_embedding(
const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int64_t token_head =
token_idx * query_stride + head_idx * head_stride;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
@ -62,7 +64,8 @@ inline __device__ void apply_rotary_embedding(
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int64_t token_head =
token_idx * key_stride + head_idx * head_stride;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
@ -84,7 +87,8 @@ __global__ void rotary_embedding_kernel(
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) {
const int64_t head_stride, const int num_heads, const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
@ -92,7 +96,7 @@ __global__ void rotary_embedding_kernel(
apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride);
token_idx, query_stride, key_stride, head_stride);
}
template <typename scalar_t, bool IS_NEOX>
@ -109,9 +113,9 @@ __global__ void batched_rotary_embedding_kernel(
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2]
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
// or [num_tokens]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) {
const int64_t head_stride, const int num_heads, const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
@ -121,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel(
apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride);
token_idx, query_stride, key_stride, head_stride);
}
} // namespace vllm
@ -179,6 +183,12 @@ void rotary_embedding(
int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
// Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size
int query_ndim = query.dim();
int64_t head_stride =
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
@ -190,14 +200,14 @@ void rotary_embedding(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride,
num_heads, num_kv_heads, head_size);
head_stride, num_heads, num_kv_heads, head_size);
} else {
vllm::rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size);
key_stride, head_stride, num_heads, num_kv_heads, head_size);
}
});
}
@ -263,6 +273,12 @@ void batched_rotary_embedding(
int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
// Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size
int query_ndim = query.dim();
int64_t head_stride =
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
@ -276,7 +292,7 @@ void batched_rotary_embedding(
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size);
key_stride, head_stride, num_heads, num_kv_heads, head_size);
} else {
vllm::batched_rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>(
@ -284,7 +300,7 @@ void batched_rotary_embedding(
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size);
key_stride, head_stride, num_heads, num_kv_heads, head_size);
}
});
}

View File

@ -29,12 +29,20 @@ def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
return (batch_size, seq_len, num_heads * head_size)
# For testing sliced tensors
def _get_padded_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
head_size: int) -> tuple[int, ...]:
return (batch_size, seq_len, num_heads, head_size + 64)
def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
head_size: int) -> tuple[int, ...]:
return (batch_size, seq_len, num_heads, head_size)
TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
TENSORS_SHAPES_FN = [
_get_batch_tensor_shape, _get_flat_tensor_shape, _get_padded_tensor_shape
]
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@ -79,6 +87,10 @@ def test_rotary_embedding(
query = torch.randn(query_shape, dtype=dtype)
key = torch.randn_like(query) if use_key else None
# slice tensor if required, noop otherwise
query = query[..., :head_size]
key = key[..., :head_size] if use_key else None
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope.forward_native(positions, query, key)

View File

@ -38,9 +38,10 @@ def rotary_embedding_opcheck(rot,
@pytest.mark.parametrize("head_size", [32, 108])
@pytest.mark.parametrize("seq_len", [11, 1024])
@pytest.mark.parametrize("use_key", [True, False])
@pytest.mark.parametrize("head_stride_is_contingous", [True, False])
def test_rotary_embedding_opcheck(dist_init, device, max_position,
is_neox_style, rotary_dim, head_size,
seq_len, use_key):
seq_len, use_key, head_stride_is_contingous):
batch_size = 1
base = 10000
num_heads = 7
@ -50,15 +51,27 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
positions = torch.randint(0,
max_position, (batch_size, seq_len),
device=device)
head_stride = head_size + (64 if head_stride_is_contingous else 0)
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
num_heads,
head_stride,
dtype=torch.float32,
device=device)
key = torch.randn_like(query) if use_key else None
query = query[..., :head_size]
key = key[..., :head_size] if use_key else None
rotary_embedding_opcheck(rot, positions, query, key)
offsets = torch.zeros(batch_size * seq_len,
device=device,
dtype=torch.long)
rotary_embedding_opcheck(rot, positions, query, key, offsets)
# if we have a contiguous head stride, test the alternate
# [..., num_heads * head_dim] shape/layout
if head_stride_is_contingous:
rotary_embedding_opcheck(
rot, positions, query.flatten(start_dim=-2),
key.flatten(start_dim=-2) if use_key else None)

View File

@ -254,14 +254,8 @@ def rotary_embedding(
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
query_contiguous = query.contiguous()
key_contiguous = key.contiguous() if key is not None else None
torch.ops._C.rotary_embedding(positions, query_contiguous, key_contiguous,
head_size, cos_sin_cache, is_neox)
query.copy_(query_contiguous)
if key is not None:
key.copy_(key_contiguous)
torch.ops._C.rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox)
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
@ -269,16 +263,9 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
cos_sin_cache: torch.Tensor, is_neox: bool,
rot_dim: int,
cos_sin_cache_offsets: torch.Tensor) -> None:
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
query_contiguous = query.contiguous()
key_contiguous = key.contiguous() if key is not None else None
torch.ops._C.batched_rotary_embedding(positions, query_contiguous,
key_contiguous, head_size,
torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox, rot_dim,
cos_sin_cache_offsets)
query.copy_(query_contiguous)
if key is not None:
key.copy_(key_contiguous)
# layer norm ops