diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index ef6dd1c097..266f2a0667 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -44,7 +44,8 @@ inline __device__ void apply_rotary_embedding( // head_size] const scalar_t* cache_ptr, const int head_size, const int num_heads, const int num_kv_heads, const int rot_dim, const int token_idx, - const int64_t query_stride, const int64_t key_stride) { + const int64_t query_stride, const int64_t key_stride, + const int64_t head_stride) { const int embed_dim = rot_dim / 2; const scalar_t* cos_ptr = cache_ptr; const scalar_t* sin_ptr = cache_ptr + embed_dim; @@ -52,7 +53,8 @@ inline __device__ void apply_rotary_embedding( const int nq = num_heads * embed_dim; for (int i = threadIdx.x; i < nq; i += blockDim.x) { const int head_idx = i / embed_dim; - const int64_t token_head = token_idx * query_stride + head_idx * head_size; + const int64_t token_head = + token_idx * query_stride + head_idx * head_stride; const int rot_offset = i % embed_dim; apply_token_rotary_embedding( query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); @@ -62,7 +64,8 @@ inline __device__ void apply_rotary_embedding( const int nk = num_kv_heads * embed_dim; for (int i = threadIdx.x; i < nk; i += blockDim.x) { const int head_idx = i / embed_dim; - const int64_t token_head = token_idx * key_stride + head_idx * head_size; + const int64_t token_head = + token_idx * key_stride + head_idx * head_stride; const int rot_offset = i % embed_dim; apply_token_rotary_embedding( key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); @@ -84,7 +87,8 @@ __global__ void rotary_embedding_kernel( const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // // 2] const int rot_dim, const int64_t query_stride, const int64_t key_stride, - const int num_heads, const int num_kv_heads, const int head_size) { + const int64_t head_stride, const int num_heads, const int num_kv_heads, + const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; @@ -92,7 +96,7 @@ __global__ void rotary_embedding_kernel( apply_rotary_embedding( query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, - token_idx, query_stride, key_stride); + token_idx, query_stride, key_stride, head_stride); } template @@ -109,9 +113,9 @@ __global__ void batched_rotary_embedding_kernel( const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // // 2] const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] - // or [num_tokens] const int rot_dim, const int64_t query_stride, const int64_t key_stride, - const int num_heads, const int num_kv_heads, const int head_size) { + const int64_t head_stride, const int num_heads, const int num_kv_heads, + const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; @@ -121,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel( apply_rotary_embedding( query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, - token_idx, query_stride, key_stride); + token_idx, query_stride, key_stride, head_stride); } } // namespace vllm @@ -179,6 +183,12 @@ void rotary_embedding( int seq_dim_idx = positions_ndim - 1; int64_t query_stride = query.stride(seq_dim_idx); int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; + // Determine head stride: for [*, heads, head_size] use stride of last dim; + // for flat [*, heads*head_size], heads blocks are contiguous of size + // head_size + int query_ndim = query.dim(); + int64_t head_stride = + (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size; dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); @@ -190,14 +200,14 @@ void rotary_embedding( positions.data_ptr(), query.data_ptr(), key.has_value() ? key->data_ptr() : nullptr, cos_sin_cache.data_ptr(), rot_dim, query_stride, key_stride, - num_heads, num_kv_heads, head_size); + head_stride, num_heads, num_kv_heads, head_size); } else { vllm::rotary_embedding_kernel <<>>( positions.data_ptr(), query.data_ptr(), key.has_value() ? key->data_ptr() : nullptr, cos_sin_cache.data_ptr(), rot_dim, query_stride, - key_stride, num_heads, num_kv_heads, head_size); + key_stride, head_stride, num_heads, num_kv_heads, head_size); } }); } @@ -263,6 +273,12 @@ void batched_rotary_embedding( int seq_dim_idx = positions_ndim - 1; int64_t query_stride = query.stride(seq_dim_idx); int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; + // Determine head stride: for [*, heads, head_size] use stride of last dim; + // for flat [*, heads*head_size], heads blocks are contiguous of size + // head_size + int query_ndim = query.dim(); + int64_t head_stride = + (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size; dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); @@ -276,7 +292,7 @@ void batched_rotary_embedding( key.has_value() ? key->data_ptr() : nullptr, cos_sin_cache.data_ptr(), cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, - key_stride, num_heads, num_kv_heads, head_size); + key_stride, head_stride, num_heads, num_kv_heads, head_size); } else { vllm::batched_rotary_embedding_kernel <<>>( @@ -284,7 +300,7 @@ void batched_rotary_embedding( key.has_value() ? key->data_ptr() : nullptr, cos_sin_cache.data_ptr(), cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, - key_stride, num_heads, num_kv_heads, head_size); + key_stride, head_stride, num_heads, num_kv_heads, head_size); } }); } diff --git a/tests/kernels/core/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py index d81c7487b8..383a3c83b8 100644 --- a/tests/kernels/core/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -29,12 +29,20 @@ def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int, return (batch_size, seq_len, num_heads * head_size) +# For testing sliced tensors +def _get_padded_tensor_shape(batch_size: int, seq_len: int, num_heads: int, + head_size: int) -> tuple[int, ...]: + return (batch_size, seq_len, num_heads, head_size + 64) + + def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int, head_size: int) -> tuple[int, ...]: return (batch_size, seq_len, num_heads, head_size) -TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape] +TENSORS_SHAPES_FN = [ + _get_batch_tensor_shape, _get_flat_tensor_shape, _get_padded_tensor_shape +] @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @@ -79,6 +87,10 @@ def test_rotary_embedding( query = torch.randn(query_shape, dtype=dtype) key = torch.randn_like(query) if use_key else None + # slice tensor if required, noop otherwise + query = query[..., :head_size] + key = key[..., :head_size] if use_key else None + # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. ref_query, ref_key = rope.forward_native(positions, query, key) diff --git a/tests/kernels/core/test_rotary_embedding.py b/tests/kernels/core/test_rotary_embedding.py index 4e54861005..8383f943b9 100644 --- a/tests/kernels/core/test_rotary_embedding.py +++ b/tests/kernels/core/test_rotary_embedding.py @@ -38,9 +38,10 @@ def rotary_embedding_opcheck(rot, @pytest.mark.parametrize("head_size", [32, 108]) @pytest.mark.parametrize("seq_len", [11, 1024]) @pytest.mark.parametrize("use_key", [True, False]) +@pytest.mark.parametrize("head_stride_is_contingous", [True, False]) def test_rotary_embedding_opcheck(dist_init, device, max_position, is_neox_style, rotary_dim, head_size, - seq_len, use_key): + seq_len, use_key, head_stride_is_contingous): batch_size = 1 base = 10000 num_heads = 7 @@ -50,15 +51,27 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position, positions = torch.randint(0, max_position, (batch_size, seq_len), device=device) + head_stride = head_size + (64 if head_stride_is_contingous else 0) + query = torch.randn(batch_size, seq_len, - num_heads * head_size, + num_heads, + head_stride, dtype=torch.float32, device=device) key = torch.randn_like(query) if use_key else None + query = query[..., :head_size] + key = key[..., :head_size] if use_key else None rotary_embedding_opcheck(rot, positions, query, key) offsets = torch.zeros(batch_size * seq_len, device=device, dtype=torch.long) rotary_embedding_opcheck(rot, positions, query, key, offsets) + + # if we have a contiguous head stride, test the alternate + # [..., num_heads * head_dim] shape/layout + if head_stride_is_contingous: + rotary_embedding_opcheck( + rot, positions, query.flatten(start_dim=-2), + key.flatten(start_dim=-2) if use_key else None) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index c81300db56..e74d139ab9 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -254,14 +254,8 @@ def rotary_embedding( cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: - # TODO: Remove this contiguous call when the kernel is updated to support tensor slices - query_contiguous = query.contiguous() - key_contiguous = key.contiguous() if key is not None else None - torch.ops._C.rotary_embedding(positions, query_contiguous, key_contiguous, - head_size, cos_sin_cache, is_neox) - query.copy_(query_contiguous) - if key is not None: - key.copy_(key_contiguous) + torch.ops._C.rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox) def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, @@ -269,16 +263,9 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, cos_sin_cache: torch.Tensor, is_neox: bool, rot_dim: int, cos_sin_cache_offsets: torch.Tensor) -> None: - # TODO: Remove this contiguous call when the kernel is updated to support tensor slices - query_contiguous = query.contiguous() - key_contiguous = key.contiguous() if key is not None else None - torch.ops._C.batched_rotary_embedding(positions, query_contiguous, - key_contiguous, head_size, + torch.ops._C.batched_rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox, rot_dim, cos_sin_cache_offsets) - query.copy_(query_contiguous) - if key is not None: - key.copy_(key_contiguous) # layer norm ops