[KERNEL] Sampler. CUDA kernel for applying repetition penalty (#18437)
This commit is contained in:
@ -242,6 +242,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/activation_kernels.cu"
|
||||
"csrc/layernorm_kernels.cu"
|
||||
"csrc/layernorm_quant_kernels.cu"
|
||||
"csrc/sampler.cu"
|
||||
"csrc/cuda_view.cu"
|
||||
"csrc/quantization/gptq/q_gemm.cu"
|
||||
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
||||
|
@ -92,6 +92,11 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||
torch::Tensor& weight, double epsilon);
|
||||
|
||||
void apply_repetition_penalties_(torch::Tensor& logits,
|
||||
const torch::Tensor& prompt_mask,
|
||||
const torch::Tensor& output_mask,
|
||||
const torch::Tensor& repetition_penalties);
|
||||
|
||||
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||
torch::Tensor& weight, torch::Tensor& scale,
|
||||
double epsilon);
|
||||
|
86
csrc/sampler.cu
Normal file
86
csrc/sampler.cu
Normal file
@ -0,0 +1,86 @@
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/cub.cuh>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void apply_repetition_penalties_kernel(
|
||||
scalar_t* __restrict__ logits, // [num_seqs, vocab_size]
|
||||
const bool* __restrict__ prompt_mask, // [num_seqs, vocab_size]
|
||||
const bool* __restrict__ output_mask, // [num_seqs, vocab_size]
|
||||
const scalar_t* __restrict__ repetition_penalties, // [num_seqs]
|
||||
const int num_seqs, const int vocab_size, const int tile_size) {
|
||||
// Each block handles one sequence and a tile of vocab
|
||||
const int seq_idx = blockIdx.x;
|
||||
if (seq_idx >= num_seqs) return;
|
||||
|
||||
const int tile_start = blockIdx.y * tile_size;
|
||||
const int tile_end = min(tile_start + tile_size, vocab_size);
|
||||
|
||||
// Load repetition penalty for this sequence
|
||||
const scalar_t penalty = repetition_penalties[seq_idx];
|
||||
|
||||
// Each thread processes multiple vocab items within the tile
|
||||
for (int vocab_idx = tile_start + threadIdx.x; vocab_idx < tile_end;
|
||||
vocab_idx += blockDim.x) {
|
||||
const int64_t idx = static_cast<int64_t>(seq_idx) * vocab_size + vocab_idx;
|
||||
const bool is_repeated = prompt_mask[idx] || output_mask[idx];
|
||||
if (is_repeated) {
|
||||
scalar_t logit = logits[idx];
|
||||
if (logit > 0) {
|
||||
logits[idx] = logit / penalty;
|
||||
} else {
|
||||
logits[idx] = logit * penalty;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void apply_repetition_penalties_(
|
||||
torch::Tensor& logits, // [num_seqs, vocab_size], in-place
|
||||
const torch::Tensor& prompt_mask, // [num_seqs, vocab_size]
|
||||
const torch::Tensor& output_mask, // [num_seqs, vocab_size]
|
||||
const torch::Tensor& repetition_penalties) { // [num_seqs]
|
||||
TORCH_CHECK(logits.is_contiguous());
|
||||
TORCH_CHECK(prompt_mask.is_contiguous());
|
||||
TORCH_CHECK(output_mask.is_contiguous());
|
||||
TORCH_CHECK(repetition_penalties.is_contiguous());
|
||||
|
||||
int vocab_size = logits.size(-1);
|
||||
int num_seqs = logits.size(0);
|
||||
|
||||
// Get number of SMs on the current device
|
||||
int sms = 0;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount,
|
||||
logits.get_device());
|
||||
|
||||
// Compute tile_num and tile_size
|
||||
int tile_num =
|
||||
std::min(vocab_size, std::max(1, (sms + num_seqs - 1) / num_seqs));
|
||||
int tile_size = (vocab_size + tile_num - 1) / tile_num;
|
||||
|
||||
// Each block handles one sequence and a tile of vocab
|
||||
dim3 grid(num_seqs, tile_num);
|
||||
dim3 block(std::min(tile_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(logits));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
logits.scalar_type(), "apply_repetition_penalties_kernel", [&] {
|
||||
vllm::apply_repetition_penalties_kernel<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
logits.data_ptr<scalar_t>(), prompt_mask.data_ptr<bool>(),
|
||||
output_mask.data_ptr<bool>(),
|
||||
repetition_penalties.data_ptr<scalar_t>(), num_seqs, vocab_size,
|
||||
tile_size);
|
||||
});
|
||||
}
|
@ -170,6 +170,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"float epsilon) -> ()");
|
||||
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
||||
|
||||
// Apply repetition penalties to logits in-place
|
||||
ops.def(
|
||||
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
|
||||
"Tensor output_mask, Tensor repetition_penalties) -> ()");
|
||||
ops.impl("apply_repetition_penalties_", torch::kCUDA,
|
||||
&apply_repetition_penalties_);
|
||||
|
||||
// Layernorm-quant
|
||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||
ops.def(
|
||||
|
76
tests/kernels/test_apply_repetition_penalties.py
Normal file
76
tests/kernels/test_apply_repetition_penalties.py
Normal file
@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm._custom_ops import (apply_repetition_penalties_cuda,
|
||||
apply_repetition_penalties_torch)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_SEQS = [1, 2, 3, 4, 8, 13, 17, 32, 37, 256, 1023, 1024, 1025]
|
||||
# [stress, stress, stress, Qwen, llama 4]
|
||||
VOCAB_SIZES = [17, 256, 1019, 151936, 202048]
|
||||
REPETITION_PENALTY_VALUES = [1.05]
|
||||
SEEDS = [0]
|
||||
DTYPES = [torch.float32, torch.float16]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_seqs", NUM_SEQS)
|
||||
@pytest.mark.parametrize("vocab_size", VOCAB_SIZES)
|
||||
@pytest.mark.parametrize("repetition_penalty", REPETITION_PENALTY_VALUES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||
reason="This test for checking CUDA kernel")
|
||||
@torch.inference_mode()
|
||||
def test_apply_repetition_penalties(
|
||||
num_seqs: int,
|
||||
vocab_size: int,
|
||||
repetition_penalty: float,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
) -> None:
|
||||
"""
|
||||
Test the apply_repetition_penalties custom op
|
||||
against a reference implementation.
|
||||
"""
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device("cuda:0")
|
||||
|
||||
# Create test data
|
||||
logits = torch.randn(num_seqs, vocab_size, dtype=dtype)
|
||||
|
||||
# Create masks with some random tokens marked as repeated
|
||||
prompt_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool)
|
||||
output_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool)
|
||||
|
||||
# Mark some tokens as repeated in prompt and output
|
||||
prompt_indices = torch.randint(0, vocab_size,
|
||||
(num_seqs, max(1, vocab_size // 200)))
|
||||
output_indices = torch.randint(0, vocab_size,
|
||||
(num_seqs, max(1, vocab_size // 200)))
|
||||
|
||||
for i in range(num_seqs):
|
||||
prompt_mask[i, prompt_indices[i]] = True
|
||||
output_mask[i, output_indices[i]] = True
|
||||
|
||||
# Create repetition penalties tensor
|
||||
repetition_penalties = torch.full((num_seqs, ),
|
||||
repetition_penalty,
|
||||
dtype=dtype)
|
||||
|
||||
# Run all three implementations
|
||||
logits_torch = logits.clone()
|
||||
logits_cuda = logits.clone()
|
||||
|
||||
apply_repetition_penalties_torch(logits_torch, prompt_mask, output_mask,
|
||||
repetition_penalties)
|
||||
apply_repetition_penalties_cuda(logits_cuda, prompt_mask, output_mask,
|
||||
repetition_penalties)
|
||||
|
||||
# Compare all outputs to reference
|
||||
torch.testing.assert_close(logits_torch, logits_cuda, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# Test the operator by applying the opcheck utility
|
||||
opcheck(torch.ops._C.apply_repetition_penalties_,
|
||||
(logits.clone(), prompt_mask, output_mask, repetition_penalties))
|
@ -282,6 +282,45 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
|
||||
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
|
||||
|
||||
|
||||
def apply_repetition_penalties_torch(
|
||||
logits: torch.Tensor, prompt_mask: torch.Tensor,
|
||||
output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None:
|
||||
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
|
||||
1, logits.size(1))
|
||||
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
||||
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
|
||||
1.0)
|
||||
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
||||
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
|
||||
logits *= scaling
|
||||
|
||||
|
||||
def apply_repetition_penalties_cuda(
|
||||
logits: torch.Tensor, prompt_mask: torch.Tensor,
|
||||
output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None:
|
||||
torch.ops._C.apply_repetition_penalties_(logits, prompt_mask, output_mask,
|
||||
repetition_penalties)
|
||||
|
||||
|
||||
def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor,
|
||||
output_mask: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor) -> None:
|
||||
"""Apply repetition penalties to logits in-place.
|
||||
|
||||
Args:
|
||||
logits: The logits tensor of shape [num_seqs, vocab_size].
|
||||
prompt_mask: A boolean tensor indicating which tokens appear in the prompt.
|
||||
output_mask: A boolean tensor indicating which tokens appear in the output.
|
||||
repetition_penalties: The repetition penalties of shape (num_seqs, ).
|
||||
"""
|
||||
if current_platform.is_cuda() and logits.is_contiguous():
|
||||
apply_repetition_penalties_cuda(logits, prompt_mask, output_mask,
|
||||
repetition_penalties)
|
||||
else:
|
||||
apply_repetition_penalties_torch(logits, prompt_mask, output_mask,
|
||||
repetition_penalties)
|
||||
|
||||
|
||||
def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int,
|
||||
input_tokens: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
|
@ -50,16 +50,11 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
||||
vocab_size, num_seqs)
|
||||
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
|
||||
output_tokens_tensor, vocab_size, num_seqs)
|
||||
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
|
||||
1, vocab_size)
|
||||
|
||||
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
||||
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
|
||||
1.0)
|
||||
|
||||
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
||||
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
|
||||
logits *= scaling
|
||||
# Apply repetition penalties as a custom op
|
||||
from vllm._custom_ops import apply_repetition_penalties
|
||||
apply_repetition_penalties(logits, prompt_mask, output_mask,
|
||||
repetition_penalties)
|
||||
|
||||
# We follow the definition in OpenAI API.
|
||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||
|
Reference in New Issue
Block a user