diff --git a/benchmarks/kernels/benchmark_polynorm.py b/benchmarks/kernels/benchmark_polynorm.py deleted file mode 100644 index 9ac8f5e659..0000000000 --- a/benchmarks/kernels/benchmark_polynorm.py +++ /dev/null @@ -1,155 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import itertools - -import torch - -from vllm import _custom_ops as vllm_ops -from vllm.triton_utils import triton - - -def polynorm_naive( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - eps: float = 1e-6, -): - orig_shape = x.shape - x = x.view(-1, x.shape[-1]) - - def norm(x, eps: float): - return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) - - x = x.float() - return ( - ( - weight[0] * norm(x**3, eps) - + weight[1] * norm(x**2, eps) - + weight[2] * norm(x, eps) - + bias - ) - .to(weight.dtype) - .view(orig_shape) - ) - - -def polynorm_vllm( - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - eps: float = 1e-6, -): - orig_shape = x.shape - x = x.view(-1, x.shape[-1]) - - out = torch.empty_like(x) - vllm_ops.poly_norm(out, x, weight, bias, eps) - output = out - - output = output.view(orig_shape) - return output - - -def calculate_diff(batch_size, seq_len, hidden_dim): - dtype = torch.bfloat16 - x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda") - weight = torch.ones(3, dtype=dtype, device="cuda") - bias = torch.ones(1, dtype=dtype, device="cuda") - - output_naive = polynorm_naive(x, weight, bias) - output_vllm = polynorm_vllm(x, weight, bias) - - if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2): - print("✅ All implementations match") - else: - print("❌ Implementations differ") - - -batch_size_range = [2**i for i in range(0, 7, 2)] -seq_length_range = [2**i for i in range(6, 11, 1)] -dim_range = [2048, 4096] -configs = list(itertools.product(dim_range, batch_size_range, seq_length_range)) - - -def get_benchmark(): - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["dim", "batch_size", "seq_len"], - x_vals=[list(_) for _ in configs], - line_arg="provider", - line_vals=["naive", "vllm"], - line_names=["Naive", "vLLM"], - styles=[("blue", "-"), ("red", "-")], - ylabel="us", - plot_name="polynorm-perf", - args={}, - ) - ) - def benchmark(dim, batch_size, seq_len, provider): - dtype = torch.bfloat16 - hidden_dim = dim * 4 - - x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda") - weight = torch.ones(3, dtype=dtype, device="cuda") - bias = torch.ones(1, dtype=dtype, device="cuda") - - quantiles = [0.5, 0.2, 0.8] - - if provider == "naive": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: polynorm_naive(x, weight, bias), - quantiles=quantiles, - ) - else: - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: polynorm_vllm(x, weight, bias), - quantiles=quantiles, - ) - - return 1000 * ms, 1000 * max_ms, 1000 * min_ms - - return benchmark - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "--batch-size", - type=int, - default=4, - help="Batch size", - ) - parser.add_argument( - "--seq-len", - type=int, - default=128, - help="Sequence length", - ) - parser.add_argument( - "--hidden-dim", - type=int, - default=8192, - help="Intermediate size of MLP", - ) - parser.add_argument( - "--save-path", - type=str, - default="./configs/polnorm/", - help="Path to save polnorm benchmark results", - ) - - args = parser.parse_args() - - # Run correctness test - calculate_diff( - batch_size=args.batch_size, - seq_len=args.seq_len, - hidden_dim=args.hidden_dim, - ) - - benchmark = get_benchmark() - # Run performance benchmark - benchmark.run(print_data=True, save_path=args.save_path) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 3a8f9bc3b5..8cfcf9f412 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -148,211 +148,6 @@ fused_add_rms_norm_kernel( } } -/* Function specialization in the case of FP16/BF16 tensors. - Additional optimizations we can make in this case are - packed and vectorized operations, which help with the - memory latency bottleneck. - - _f16VecPN struct extends _f16Vec to add operations specifically required for - polynomial normalization (poly norm). - The original _f16Vec does not include the sum-of-powers computation or - in-place polynomial normalization logic. */ -template -struct alignas(16) _f16VecPN : _f16Vec { - using Base = _f16Vec; - using Converter = typename Base::Converter; - using T1 = typename Base::T1; - using T2 = typename Base::T2; - using Base::data; - - __device__ auto sum_pows() const { - float s2 = 0.0f, s4 = 0.0f, s6 = 0.0f; - -#pragma unroll - for (int i = 0; i < width; i += 2) { - float2 z = Converter::convert(T2{data[i], data[i + 1]}); - float x2 = z.x * z.x; - float x4 = x2 * x2; - float x6 = x4 * x2; - - float y2 = z.y * z.y; - float y4 = y2 * y2; - float y6 = y4 * y2; - - s2 += x2 + y2; - s4 += x4 + y4; - s6 += x6 + y6; - } - return std::make_tuple(s2, s4, s6); - } - - __device__ void poly_norm_inplace(const float w2_inv_std, - const float w1_inv_std2, - const float w0_inv_std3, const float bias) { -#pragma unroll - for (int i = 0; i < width; i += 2) { - float2 z = Converter::convert(T2{data[i], data[i + 1]}); - - float x2 = z.x * z.x; - float x3 = x2 * z.x; - z.x = w2_inv_std * z.x + w1_inv_std2 * x2 + w0_inv_std3 * x3 + bias; - - float y2 = z.y * z.y; - float y3 = y2 * z.y; - z.y = w2_inv_std * z.y + w1_inv_std2 * y2 + w0_inv_std3 * y3 + bias; - - auto out = Converter::convert(z); - data[i] = out.x; - data[i + 1] = out.y; - } - } -}; - -template -__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> -poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] - const scalar_t* __restrict__ input, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [3] - const scalar_t* __restrict__ bias, // [1] - const float epsilon, const int hidden_size) { - // Sanity checks on our vector struct and type-punned pointer arithmetic - static_assert(std::is_pod_v<_f16VecPN>); - static_assert(sizeof(_f16VecPN) == sizeof(scalar_t) * width); - - /* These and the argument pointers are all declared `restrict` as they are - not aliased in practice. Argument pointers should not be dereferenced - in this kernel as that would be undefined behavior */ - auto* __restrict__ input_v = - reinterpret_cast*>(input); - const int vec_hidden_size = hidden_size / width; - float variance = 0.0f; - float variance2 = 0.0f; - float variance3 = 0.0f; - - for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { - int id = blockIdx.x * vec_hidden_size + idx; - _f16VecPN temp = input_v[id]; - auto [x2, x4, x6] = temp.sum_pows(); - - variance += x2; - variance2 += x4; - variance3 += x6; - } - - float3 thread_variances = make_float3(variance, variance2, variance3); - - struct SumOp { - __device__ float3 operator()(const float3& a, const float3& b) const { - return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); - } - }; - - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStore; - float3 block_variances = - BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x); - - variance = block_variances.x; - variance2 = block_variances.y; - variance3 = block_variances.z; - - __shared__ float s_w2_inv_std; - __shared__ float s_w1_inv_std2; - __shared__ float s_w0_inv_std3; - __shared__ float s_bias; - - if (threadIdx.x == 0) { - float w0 = (float)weight[0]; - float w1 = (float)weight[1]; - float w2 = (float)weight[2]; - s_bias = (float)bias[0]; - - s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon); - s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon); - s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon); - } - __syncthreads(); - - auto* __restrict__ out_v = reinterpret_cast<_f16VecPN*>(out); - - for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { - int id = blockIdx.x * vec_hidden_size + idx; - _f16VecPN temp = input_v[id]; - temp.poly_norm_inplace(s_w2_inv_std, s_w1_inv_std2, s_w0_inv_std3, s_bias); - out_v[id] = temp; - } -} - -/* Generic poly_norm_kernel - The width field is not used here but necessary for other specializations. - */ -template -__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> -poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] - const scalar_t* __restrict__ input, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [3] - const scalar_t* __restrict__ bias, // [1] - const float epsilon, const int hidden_size) { - float variance = 0.0f; - float variance2 = 0.0f; - float variance3 = 0.0f; - - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float)input[blockIdx.x * hidden_size + idx]; - float x2 = x * x; - float x4 = x2 * x2; - float x6 = x4 * x2; - - variance += x2; - variance2 += x4; - variance3 += x6; - } - - float3 thread_variances = make_float3(variance, variance2, variance3); - - struct SumOp { - __device__ float3 operator()(const float3& a, const float3& b) const { - return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); - } - }; - - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStore; - float3 block_variances = - BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x); - - variance = block_variances.x; - variance2 = block_variances.y; - variance3 = block_variances.z; - - __shared__ float s_w2_inv_std; - __shared__ float s_w1_inv_std2; - __shared__ float s_w0_inv_std3; - __shared__ float s_bias; - - if (threadIdx.x == 0) { - float w0 = (float)weight[0]; - float w1 = (float)weight[1]; - float w2 = (float)weight[2]; - s_bias = (float)bias[0]; - - s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon); - s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon); - s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon); - } - __syncthreads(); - - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float)input[blockIdx.x * hidden_size + idx]; - float x2 = x * x; - float x3 = x2 * x; - - out[blockIdx.x * hidden_size + idx] = - (scalar_t)(x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 + - s_bias); - } -} - } // namespace vllm void rms_norm(torch::Tensor& out, // [..., hidden_size] @@ -444,50 +239,3 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] LAUNCH_FUSED_ADD_RMS_NORM(0); } } - -#define LAUNCH_FUSED_POLY_NORM(width) \ - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \ - vllm::poly_norm_kernel<<>>( \ - out.data_ptr(), input.data_ptr(), \ - weight.data_ptr(), bias.data_ptr(), epsilon, \ - hidden_size); \ - }); - -void poly_norm(torch::Tensor& out, // [..., hidden_size] - torch::Tensor& input, // [..., hidden_size] - torch::Tensor& weight, // [3] - torch::Tensor& bias, // [1] - double epsilon) { - TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(input.is_contiguous()); - TORCH_CHECK(out.data_ptr() != input.data_ptr()); - - int hidden_size = input.size(-1); - int num_tokens = input.numel() / hidden_size; - - dim3 grid(num_tokens); - /* This kernel is memory-latency bound in many scenarios. - When num_tokens is large, a smaller block size allows - for increased block occupancy on CUs and better latency - hiding on global mem ops. */ - const int max_block_size = (num_tokens < 256) ? 1024 : 256; - dim3 block(std::min(hidden_size, max_block_size)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - /*If the tensor types are FP16/BF16, try to use the optimized kernel - with packed + vectorized ops. - Max optimization is achieved with a width-8 vector of FP16/BF16s - since we can load at most 128 bits at once in a global memory op. - However, this requires each tensor's data to be aligned to 16 - bytes. - */ - auto inp_ptr = reinterpret_cast(input.data_ptr()); - auto out_ptr = reinterpret_cast(out.data_ptr()); - bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0; - bool batch_invariant_launch = vllm::vllm_is_batch_invariant(); - if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) { - LAUNCH_FUSED_POLY_NORM(8); - } else { - LAUNCH_FUSED_POLY_NORM(0); - } -} diff --git a/csrc/ops.h b/csrc/ops.h index 2a9214e7fb..c135a14042 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -92,9 +92,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); -void poly_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, - torch::Tensor& bias, double epsilon); - void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& prompt_mask, const torch::Tensor& output_mask, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index a4a9f87b28..2bc526097d 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -175,12 +175,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); - // Polynomial Normalization. - ops.def( - "poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, float " - "epsilon) -> ()"); - ops.impl("poly_norm", torch::kCUDA, &poly_norm); - // Apply repetition penalties to logits in-place ops.def( "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " diff --git a/tests/kernels/core/test_layernorm.py b/tests/kernels/core/test_layernorm.py index aaa13c0662..49bd77f679 100644 --- a/tests/kernels/core/test_layernorm.py +++ b/tests/kernels/core/test_layernorm.py @@ -6,7 +6,7 @@ import torch from tests.kernels.quant_utils import FP8_DTYPE from tests.kernels.utils import opcheck -from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.platforms import current_platform DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -70,38 +70,6 @@ def test_rms_norm( ) -@pytest.mark.parametrize("num_tokens", NUM_TOKENS) -@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_poly_norm( - num_tokens: int, - hidden_size: int, - dtype: torch.dtype, - seed: int, - device: str, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - layer = PolyNorm().to(dtype=dtype) - layer.weight.data.normal_(mean=1.0, std=0.1) - layer.bias.data.normal_(mean=1.0, std=0.1) - scale = 1 / (2 * hidden_size) - x = torch.randn(num_tokens, hidden_size, dtype=dtype) - x *= scale - - ref_out = layer.forward_native(x) - out = layer(x) - torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) - - opcheck( - torch.ops._C.poly_norm, - (out, x, layer.weight.data, layer.bias.data, layer.variance_epsilon), - ) - - @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("add_residual", ADD_RESIDUAL) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d38b162150..0618451c19 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -339,18 +339,6 @@ def fused_add_rms_norm( torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) -def poly_norm( - out: torch.Tensor, - input: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - epsilon: float, -) -> None: - # TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input - input_contiguous = input.contiguous() - torch.ops._C.poly_norm(out, input_contiguous, weight, bias, epsilon) - - def apply_repetition_penalties_torch( logits: torch.Tensor, prompt_mask: torch.Tensor, diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 75d907521e..dac5d129c3 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -58,22 +58,6 @@ def fused_add_rms_norm( return x, residual -def poly_norm( - x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float -) -> torch.Tensor: - from vllm import _custom_ops as ops - - out = torch.empty_like(x) - ops.poly_norm( - out, - x, - weight, - bias, - variance_epsilon, - ) - return out - - def rocm_aiter_rms_norm_impl( x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float ) -> torch.Tensor: @@ -385,53 +369,6 @@ class GemmaRMSNorm(CustomOp): return self.forward_native(x, residual) -@CustomOp.register("poly_norm") -class PolyNorm(CustomOp): - """Polynomial normalization. - - Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b - where w_n is the learned weight and b is the bias. - Refer to https://arxiv.org/html/2411.03884v1 - """ - - def __init__( - self, - eps: float = 1e-6, - ) -> None: - super().__init__() - self.weight = torch.nn.Parameter(torch.ones(3) / 3) - self.bias = torch.nn.Parameter(torch.zeros(1)) - self.variance_epsilon = eps - - def _norm(self, x): - return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon) - - def forward_native( - self, - x: torch.Tensor, - ) -> torch.Tensor: - """PyTorch-native implementation equivalent to forward(). - - Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md - """ - - orig_dtype = x.dtype - x_float = x.to(torch.float32) - output = ( - self.weight[0] * self._norm(x_float**3) - + self.weight[1] * self._norm(x_float**2) - + self.weight[2] * self._norm(x_float) - + self.bias - ) - return output.to(orig_dtype) - - def forward_cuda( - self, - x: torch.Tensor, - ) -> torch.Tensor: - return poly_norm(x, self.weight, self.bias, self.variance_epsilon) - - class LayerNorm(nn.Module): """ Layer Normalization.