mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Chore] Remove unused PolyNorm
layer (#27110)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@ -1,155 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
def polynorm_naive(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
def norm(x, eps: float):
|
||||
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
||||
|
||||
x = x.float()
|
||||
return (
|
||||
(
|
||||
weight[0] * norm(x**3, eps)
|
||||
+ weight[1] * norm(x**2, eps)
|
||||
+ weight[2] * norm(x, eps)
|
||||
+ bias
|
||||
)
|
||||
.to(weight.dtype)
|
||||
.view(orig_shape)
|
||||
)
|
||||
|
||||
|
||||
def polynorm_vllm(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
out = torch.empty_like(x)
|
||||
vllm_ops.poly_norm(out, x, weight, bias, eps)
|
||||
output = out
|
||||
|
||||
output = output.view(orig_shape)
|
||||
return output
|
||||
|
||||
|
||||
def calculate_diff(batch_size, seq_len, hidden_dim):
|
||||
dtype = torch.bfloat16
|
||||
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
|
||||
weight = torch.ones(3, dtype=dtype, device="cuda")
|
||||
bias = torch.ones(1, dtype=dtype, device="cuda")
|
||||
|
||||
output_naive = polynorm_naive(x, weight, bias)
|
||||
output_vllm = polynorm_vllm(x, weight, bias)
|
||||
|
||||
if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [2**i for i in range(0, 7, 2)]
|
||||
seq_length_range = [2**i for i in range(6, 11, 1)]
|
||||
dim_range = [2048, 4096]
|
||||
configs = list(itertools.product(dim_range, batch_size_range, seq_length_range))
|
||||
|
||||
|
||||
def get_benchmark():
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["dim", "batch_size", "seq_len"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["naive", "vllm"],
|
||||
line_names=["Naive", "vLLM"],
|
||||
styles=[("blue", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name="polynorm-perf",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(dim, batch_size, seq_len, provider):
|
||||
dtype = torch.bfloat16
|
||||
hidden_dim = dim * 4
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
|
||||
weight = torch.ones(3, dtype=dtype, device="cuda")
|
||||
bias = torch.ones(1, dtype=dtype, device="cuda")
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "naive":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: polynorm_naive(x, weight, bias),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else:
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: polynorm_vllm(x, weight, bias),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seq-len",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Sequence length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden-dim",
|
||||
type=int,
|
||||
default=8192,
|
||||
help="Intermediate size of MLP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./configs/polnorm/",
|
||||
help="Path to save polnorm benchmark results",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run correctness test
|
||||
calculate_diff(
|
||||
batch_size=args.batch_size,
|
||||
seq_len=args.seq_len,
|
||||
hidden_dim=args.hidden_dim,
|
||||
)
|
||||
|
||||
benchmark = get_benchmark()
|
||||
# Run performance benchmark
|
||||
benchmark.run(print_data=True, save_path=args.save_path)
|
@ -148,211 +148,6 @@ fused_add_rms_norm_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
/* Function specialization in the case of FP16/BF16 tensors.
|
||||
Additional optimizations we can make in this case are
|
||||
packed and vectorized operations, which help with the
|
||||
memory latency bottleneck.
|
||||
|
||||
_f16VecPN struct extends _f16Vec to add operations specifically required for
|
||||
polynomial normalization (poly norm).
|
||||
The original _f16Vec does not include the sum-of-powers computation or
|
||||
in-place polynomial normalization logic. */
|
||||
template <typename scalar_t, int width>
|
||||
struct alignas(16) _f16VecPN : _f16Vec<scalar_t, width> {
|
||||
using Base = _f16Vec<scalar_t, width>;
|
||||
using Converter = typename Base::Converter;
|
||||
using T1 = typename Base::T1;
|
||||
using T2 = typename Base::T2;
|
||||
using Base::data;
|
||||
|
||||
__device__ auto sum_pows() const {
|
||||
float s2 = 0.0f, s4 = 0.0f, s6 = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
float2 z = Converter::convert(T2{data[i], data[i + 1]});
|
||||
float x2 = z.x * z.x;
|
||||
float x4 = x2 * x2;
|
||||
float x6 = x4 * x2;
|
||||
|
||||
float y2 = z.y * z.y;
|
||||
float y4 = y2 * y2;
|
||||
float y6 = y4 * y2;
|
||||
|
||||
s2 += x2 + y2;
|
||||
s4 += x4 + y4;
|
||||
s6 += x6 + y6;
|
||||
}
|
||||
return std::make_tuple(s2, s4, s6);
|
||||
}
|
||||
|
||||
__device__ void poly_norm_inplace(const float w2_inv_std,
|
||||
const float w1_inv_std2,
|
||||
const float w0_inv_std3, const float bias) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
float2 z = Converter::convert(T2{data[i], data[i + 1]});
|
||||
|
||||
float x2 = z.x * z.x;
|
||||
float x3 = x2 * z.x;
|
||||
z.x = w2_inv_std * z.x + w1_inv_std2 * x2 + w0_inv_std3 * x3 + bias;
|
||||
|
||||
float y2 = z.y * z.y;
|
||||
float y3 = y2 * z.y;
|
||||
z.y = w2_inv_std * z.y + w1_inv_std2 * y2 + w0_inv_std3 * y3 + bias;
|
||||
|
||||
auto out = Converter::convert(z);
|
||||
data[i] = out.x;
|
||||
data[i + 1] = out.y;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [3]
|
||||
const scalar_t* __restrict__ bias, // [1]
|
||||
const float epsilon, const int hidden_size) {
|
||||
// Sanity checks on our vector struct and type-punned pointer arithmetic
|
||||
static_assert(std::is_pod_v<_f16VecPN<scalar_t, width>>);
|
||||
static_assert(sizeof(_f16VecPN<scalar_t, width>) == sizeof(scalar_t) * width);
|
||||
|
||||
/* These and the argument pointers are all declared `restrict` as they are
|
||||
not aliased in practice. Argument pointers should not be dereferenced
|
||||
in this kernel as that would be undefined behavior */
|
||||
auto* __restrict__ input_v =
|
||||
reinterpret_cast<const _f16VecPN<scalar_t, width>*>(input);
|
||||
const int vec_hidden_size = hidden_size / width;
|
||||
float variance = 0.0f;
|
||||
float variance2 = 0.0f;
|
||||
float variance3 = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
_f16VecPN<scalar_t, width> temp = input_v[id];
|
||||
auto [x2, x4, x6] = temp.sum_pows();
|
||||
|
||||
variance += x2;
|
||||
variance2 += x4;
|
||||
variance3 += x6;
|
||||
}
|
||||
|
||||
float3 thread_variances = make_float3(variance, variance2, variance3);
|
||||
|
||||
struct SumOp {
|
||||
__device__ float3 operator()(const float3& a, const float3& b) const {
|
||||
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
||||
}
|
||||
};
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float3, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
float3 block_variances =
|
||||
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
|
||||
|
||||
variance = block_variances.x;
|
||||
variance2 = block_variances.y;
|
||||
variance3 = block_variances.z;
|
||||
|
||||
__shared__ float s_w2_inv_std;
|
||||
__shared__ float s_w1_inv_std2;
|
||||
__shared__ float s_w0_inv_std3;
|
||||
__shared__ float s_bias;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
float w0 = (float)weight[0];
|
||||
float w1 = (float)weight[1];
|
||||
float w2 = (float)weight[2];
|
||||
s_bias = (float)bias[0];
|
||||
|
||||
s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
|
||||
s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
|
||||
s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
auto* __restrict__ out_v = reinterpret_cast<_f16VecPN<scalar_t, width>*>(out);
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
_f16VecPN<scalar_t, width> temp = input_v[id];
|
||||
temp.poly_norm_inplace(s_w2_inv_std, s_w1_inv_std2, s_w0_inv_std3, s_bias);
|
||||
out_v[id] = temp;
|
||||
}
|
||||
}
|
||||
|
||||
/* Generic poly_norm_kernel
|
||||
The width field is not used here but necessary for other specializations.
|
||||
*/
|
||||
template <typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [3]
|
||||
const scalar_t* __restrict__ bias, // [1]
|
||||
const float epsilon, const int hidden_size) {
|
||||
float variance = 0.0f;
|
||||
float variance2 = 0.0f;
|
||||
float variance3 = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
float x2 = x * x;
|
||||
float x4 = x2 * x2;
|
||||
float x6 = x4 * x2;
|
||||
|
||||
variance += x2;
|
||||
variance2 += x4;
|
||||
variance3 += x6;
|
||||
}
|
||||
|
||||
float3 thread_variances = make_float3(variance, variance2, variance3);
|
||||
|
||||
struct SumOp {
|
||||
__device__ float3 operator()(const float3& a, const float3& b) const {
|
||||
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
||||
}
|
||||
};
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float3, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
float3 block_variances =
|
||||
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
|
||||
|
||||
variance = block_variances.x;
|
||||
variance2 = block_variances.y;
|
||||
variance3 = block_variances.z;
|
||||
|
||||
__shared__ float s_w2_inv_std;
|
||||
__shared__ float s_w1_inv_std2;
|
||||
__shared__ float s_w0_inv_std3;
|
||||
__shared__ float s_bias;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
float w0 = (float)weight[0];
|
||||
float w1 = (float)weight[1];
|
||||
float w2 = (float)weight[2];
|
||||
s_bias = (float)bias[0];
|
||||
|
||||
s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
|
||||
s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
|
||||
s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
float x2 = x * x;
|
||||
float x3 = x2 * x;
|
||||
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
(scalar_t)(x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 +
|
||||
s_bias);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
@ -444,50 +239,3 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||
}
|
||||
}
|
||||
|
||||
#define LAUNCH_FUSED_POLY_NORM(width) \
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \
|
||||
vllm::poly_norm_kernel<scalar_t, width><<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
|
||||
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), epsilon, \
|
||||
hidden_size); \
|
||||
});
|
||||
|
||||
void poly_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [3]
|
||||
torch::Tensor& bias, // [1]
|
||||
double epsilon) {
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.data_ptr() != input.data_ptr());
|
||||
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
/* This kernel is memory-latency bound in many scenarios.
|
||||
When num_tokens is large, a smaller block size allows
|
||||
for increased block occupancy on CUs and better latency
|
||||
hiding on global mem ops. */
|
||||
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
|
||||
dim3 block(std::min(hidden_size, max_block_size));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
/*If the tensor types are FP16/BF16, try to use the optimized kernel
|
||||
with packed + vectorized ops.
|
||||
Max optimization is achieved with a width-8 vector of FP16/BF16s
|
||||
since we can load at most 128 bits at once in a global memory op.
|
||||
However, this requires each tensor's data to be aligned to 16
|
||||
bytes.
|
||||
*/
|
||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
|
||||
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
|
||||
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) {
|
||||
LAUNCH_FUSED_POLY_NORM(8);
|
||||
} else {
|
||||
LAUNCH_FUSED_POLY_NORM(0);
|
||||
}
|
||||
}
|
||||
|
@ -92,9 +92,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||
torch::Tensor& weight, double epsilon);
|
||||
|
||||
void poly_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
torch::Tensor& bias, double epsilon);
|
||||
|
||||
void apply_repetition_penalties_(torch::Tensor& logits,
|
||||
const torch::Tensor& prompt_mask,
|
||||
const torch::Tensor& output_mask,
|
||||
|
@ -175,12 +175,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"float epsilon) -> ()");
|
||||
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
||||
|
||||
// Polynomial Normalization.
|
||||
ops.def(
|
||||
"poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, float "
|
||||
"epsilon) -> ()");
|
||||
ops.impl("poly_norm", torch::kCUDA, &poly_norm);
|
||||
|
||||
// Apply repetition penalties to logits in-place
|
||||
ops.def(
|
||||
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
|
||||
|
@ -6,7 +6,7 @@ import torch
|
||||
|
||||
from tests.kernels.quant_utils import FP8_DTYPE
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
@ -70,38 +70,6 @@ def test_rms_norm(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_poly_norm(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
layer = PolyNorm().to(dtype=dtype)
|
||||
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||
layer.bias.data.normal_(mean=1.0, std=0.1)
|
||||
scale = 1 / (2 * hidden_size)
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
x *= scale
|
||||
|
||||
ref_out = layer.forward_native(x)
|
||||
out = layer(x)
|
||||
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.poly_norm,
|
||||
(out, x, layer.weight.data, layer.bias.data, layer.variance_epsilon),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
|
||||
|
@ -339,18 +339,6 @@ def fused_add_rms_norm(
|
||||
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
|
||||
|
||||
|
||||
def poly_norm(
|
||||
out: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
epsilon: float,
|
||||
) -> None:
|
||||
# TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input
|
||||
input_contiguous = input.contiguous()
|
||||
torch.ops._C.poly_norm(out, input_contiguous, weight, bias, epsilon)
|
||||
|
||||
|
||||
def apply_repetition_penalties_torch(
|
||||
logits: torch.Tensor,
|
||||
prompt_mask: torch.Tensor,
|
||||
|
@ -58,22 +58,6 @@ def fused_add_rms_norm(
|
||||
return x, residual
|
||||
|
||||
|
||||
def poly_norm(
|
||||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
out = torch.empty_like(x)
|
||||
ops.poly_norm(
|
||||
out,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
variance_epsilon,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def rocm_aiter_rms_norm_impl(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
@ -385,53 +369,6 @@ class GemmaRMSNorm(CustomOp):
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
|
||||
@CustomOp.register("poly_norm")
|
||||
class PolyNorm(CustomOp):
|
||||
"""Polynomial normalization.
|
||||
|
||||
Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b
|
||||
where w_n is the learned weight and b is the bias.
|
||||
Refer to https://arxiv.org/html/2411.03884v1
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(3) / 3)
|
||||
self.bias = torch.nn.Parameter(torch.zeros(1))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward().
|
||||
|
||||
Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
|
||||
"""
|
||||
|
||||
orig_dtype = x.dtype
|
||||
x_float = x.to(torch.float32)
|
||||
output = (
|
||||
self.weight[0] * self._norm(x_float**3)
|
||||
+ self.weight[1] * self._norm(x_float**2)
|
||||
+ self.weight[2] * self._norm(x_float)
|
||||
+ self.bias
|
||||
)
|
||||
return output.to(orig_dtype)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return poly_norm(x, self.weight, self.bias, self.variance_epsilon)
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
"""
|
||||
Layer Normalization.
|
||||
|
Reference in New Issue
Block a user