[Model] New model support for Motif-1-Tiny (#23414)

Signed-off-by: ca1207 <ca1207zzz@gmail.com>
Signed-off-by: TaehyunKim <73943231+ca1207@users.noreply.github.com>
Co-authored-by: WyldeCat <skan1543@gmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
TaehyunKim
2025-09-11 15:29:40 +09:00
committed by GitHub
parent e2b1f863aa
commit 9bd831f501
13 changed files with 871 additions and 4 deletions

View File

@ -0,0 +1,155 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import torch
from vllm import _custom_ops as vllm_ops
from vllm.triton_utils import triton
def polynorm_naive(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
def norm(x, eps: float):
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
x = x.float()
return (
(
weight[0] * norm(x**3, eps)
+ weight[1] * norm(x**2, eps)
+ weight[2] * norm(x, eps)
+ bias
)
.to(weight.dtype)
.view(orig_shape)
)
def polynorm_vllm(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
out = torch.empty_like(x)
vllm_ops.poly_norm(out, x, weight, bias, eps)
output = out
output = output.view(orig_shape)
return output
def calculate_diff(batch_size, seq_len, hidden_dim):
dtype = torch.bfloat16
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
weight = torch.ones(3, dtype=dtype, device="cuda")
bias = torch.ones(1, dtype=dtype, device="cuda")
output_naive = polynorm_naive(x, weight, bias)
output_vllm = polynorm_vllm(x, weight, bias)
if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [2**i for i in range(0, 7, 2)]
seq_length_range = [2**i for i in range(6, 11, 1)]
dim_range = [2048, 4096]
configs = list(itertools.product(dim_range, batch_size_range, seq_length_range))
def get_benchmark():
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["dim", "batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["naive", "vllm"],
line_names=["Naive", "vLLM"],
styles=[("blue", "-"), ("red", "-")],
ylabel="us",
plot_name="polynorm-perf",
args={},
)
)
def benchmark(dim, batch_size, seq_len, provider):
dtype = torch.bfloat16
hidden_dim = dim * 4
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
weight = torch.ones(3, dtype=dtype, device="cuda")
bias = torch.ones(1, dtype=dtype, device="cuda")
quantiles = [0.5, 0.2, 0.8]
if provider == "naive":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: polynorm_naive(x, weight, bias),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: polynorm_vllm(x, weight, bias),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--batch-size",
type=int,
default=4,
help="Batch size",
)
parser.add_argument(
"--seq-len",
type=int,
default=128,
help="Sequence length",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=8192,
help="Intermediate size of MLP",
)
parser.add_argument(
"--save-path",
type=str,
default="./configs/polnorm/",
help="Path to save polnorm benchmark results",
)
args = parser.parse_args()
# Run correctness test
calculate_diff(
batch_size=args.batch_size,
seq_len=args.seq_len,
hidden_dim=args.hidden_dim,
)
benchmark = get_benchmark()
# Run performance benchmark
benchmark.run(print_data=True, save_path=args.save_path)

View File

@ -140,6 +140,211 @@ fused_add_rms_norm_kernel(
}
}
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck.
_f16VecPN struct extends _f16Vec to add operations specifically required for
polynomial normalization (poly norm).
The original _f16Vec does not include the sum-of-powers computation or
in-place polynomial normalization logic. */
template <typename scalar_t, int width>
struct alignas(16) _f16VecPN : _f16Vec<scalar_t, width> {
using Base = _f16Vec<scalar_t, width>;
using Converter = typename Base::Converter;
using T1 = typename Base::T1;
using T2 = typename Base::T2;
using Base::data;
__device__ auto sum_pows() const {
float s2 = 0.0f, s4 = 0.0f, s6 = 0.0f;
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i + 1]});
float x2 = z.x * z.x;
float x4 = x2 * x2;
float x6 = x4 * x2;
float y2 = z.y * z.y;
float y4 = y2 * y2;
float y6 = y4 * y2;
s2 += x2 + y2;
s4 += x4 + y4;
s6 += x6 + y6;
}
return std::make_tuple(s2, s4, s6);
}
__device__ void poly_norm_inplace(const float w2_inv_std,
const float w1_inv_std2,
const float w0_inv_std3, const float bias) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i + 1]});
float x2 = z.x * z.x;
float x3 = x2 * z.x;
z.x = w2_inv_std * z.x + w1_inv_std2 * x2 + w0_inv_std3 * x3 + bias;
float y2 = z.y * z.y;
float y3 = y2 * z.y;
z.y = w2_inv_std * z.y + w1_inv_std2 * y2 + w0_inv_std3 * y3 + bias;
auto out = Converter::convert(z);
data[i] = out.x;
data[i + 1] = out.y;
}
}
};
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [3]
const scalar_t* __restrict__ bias, // [1]
const float epsilon, const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16VecPN<scalar_t, width>>);
static_assert(sizeof(_f16VecPN<scalar_t, width>) == sizeof(scalar_t) * width);
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto* __restrict__ input_v =
reinterpret_cast<const _f16VecPN<scalar_t, width>*>(input);
const int vec_hidden_size = hidden_size / width;
float variance = 0.0f;
float variance2 = 0.0f;
float variance3 = 0.0f;
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16VecPN<scalar_t, width> temp = input_v[id];
auto [x2, x4, x6] = temp.sum_pows();
variance += x2;
variance2 += x4;
variance3 += x6;
}
float3 thread_variances = make_float3(variance, variance2, variance3);
struct SumOp {
__device__ float3 operator()(const float3& a, const float3& b) const {
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
}
};
using BlockReduce = cub::BlockReduce<float3, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
float3 block_variances =
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
variance = block_variances.x;
variance2 = block_variances.y;
variance3 = block_variances.z;
__shared__ float s_w2_inv_std;
__shared__ float s_w1_inv_std2;
__shared__ float s_w0_inv_std3;
__shared__ float s_bias;
if (threadIdx.x == 0) {
float w0 = (float)weight[0];
float w1 = (float)weight[1];
float w2 = (float)weight[2];
s_bias = (float)bias[0];
s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
}
__syncthreads();
auto* __restrict__ out_v = reinterpret_cast<_f16VecPN<scalar_t, width>*>(out);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16VecPN<scalar_t, width> temp = input_v[id];
temp.poly_norm_inplace(s_w2_inv_std, s_w1_inv_std2, s_w0_inv_std3, s_bias);
out_v[id] = temp;
}
}
/* Generic poly_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [3]
const scalar_t* __restrict__ bias, // [1]
const float epsilon, const int hidden_size) {
float variance = 0.0f;
float variance2 = 0.0f;
float variance3 = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
float x2 = x * x;
float x4 = x2 * x2;
float x6 = x4 * x2;
variance += x2;
variance2 += x4;
variance3 += x6;
}
float3 thread_variances = make_float3(variance, variance2, variance3);
struct SumOp {
__device__ float3 operator()(const float3& a, const float3& b) const {
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
}
};
using BlockReduce = cub::BlockReduce<float3, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
float3 block_variances =
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
variance = block_variances.x;
variance2 = block_variances.y;
variance3 = block_variances.z;
__shared__ float s_w2_inv_std;
__shared__ float s_w1_inv_std2;
__shared__ float s_w0_inv_std3;
__shared__ float s_bias;
if (threadIdx.x == 0) {
float w0 = (float)weight[0];
float w1 = (float)weight[1];
float w2 = (float)weight[2];
s_bias = (float)bias[0];
s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
float x2 = x * x;
float x3 = x2 * x;
out[blockIdx.x * hidden_size + idx] =
(scalar_t)(x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 +
s_bias);
}
}
} // namespace vllm
void rms_norm(torch::Tensor& out, // [..., hidden_size]
@ -219,3 +424,49 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
LAUNCH_FUSED_ADD_RMS_NORM(0);
}
}
#define LAUNCH_FUSED_POLY_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \
vllm::poly_norm_kernel<scalar_t, width><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), epsilon, \
hidden_size); \
});
void poly_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [3]
torch::Tensor& bias, // [1]
double epsilon) {
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.data_ptr() != input.data_ptr());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_POLY_NORM(8);
} else {
LAUNCH_FUSED_POLY_NORM(0);
}
}

View File

@ -92,6 +92,9 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);
void poly_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
torch::Tensor& bias, double epsilon);
void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& prompt_mask,
const torch::Tensor& output_mask,
@ -353,4 +356,4 @@ void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles);
void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
int64_t quant_level, bool cast_bf2half = false);
int64_t qr_max_size();
#endif
#endif

View File

@ -168,6 +168,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
// Polynomial Normalization.
ops.def(
"poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, float "
"epsilon) -> ()");
ops.impl("poly_norm", torch::kCUDA, &poly_norm);
// Apply repetition penalties to logits in-place
ops.def(
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "

View File

@ -383,6 +383,7 @@ th {
| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `MotifForCausalLM` | Motif-1-Tiny | `Motif-Technologies/Motif-2.6B`, `Motif-Technologies/Motif-2.6b-v1.1-LC`, etc. | ✅︎ | ✅︎ | |
| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ |
| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ |

View File

@ -6,7 +6,7 @@ import torch
from tests.kernels.quant_utils import FP8_DTYPE
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm
from vllm.platforms import current_platform
DTYPES = [torch.half, torch.bfloat16, torch.float]
@ -70,6 +70,37 @@ def test_rms_norm(
(out, x, layer.weight.data, layer.variance_epsilon))
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_poly_norm(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
layer = PolyNorm().to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1)
layer.bias.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
x *= scale
ref_out = layer.forward_native(x)
out = layer(x)
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
opcheck(
torch.ops._C.poly_norm,
(out, x, layer.weight.data, layer.bias.data, layer.variance_epsilon))
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)

View File

@ -291,6 +291,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
"MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1", # noqa: E501
{"tiny": "TitanML/tiny-mixtral"}), # noqa: E501
"MotifForCausalLM": _HfExamplesInfo("Motif-Technologies/Motif-2.6B",
trust_remote_code=True,
v0_only=True),
"MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False),
"MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"),
"NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"),

View File

@ -63,8 +63,9 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
_initialize_kv_caches_v1), monkeypatch.context() as m):
if model_info.v0_only:
m.setenv("VLLM_USE_V1", "0")
if model_arch == "Phi4FlashForCausalLM":
# Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend
if model_arch in ("Phi4FlashForCausalLM", "MotifForCausalLM"):
# Phi4FlashForCausalLM and MotifForCausalLM
# only supports DIFFERENTIAL_FLASH_ATTN backend
m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN")
if model_arch == "GptOssForCausalLM":
# FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU

View File

@ -280,6 +280,13 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
def poly_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
bias: torch.Tensor, epsilon: float) -> None:
# TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input
input_contiguous = input.contiguous()
torch.ops._C.poly_norm(out, input_contiguous, weight, bias, epsilon)
def apply_repetition_penalties_torch(
logits: torch.Tensor, prompt_mask: torch.Tensor,
output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None:
@ -710,6 +717,7 @@ def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability)
def cutlass_sparse_compress(a: torch.Tensor) \
-> tuple[torch.Tensor, torch.Tensor]:
"""

View File

@ -734,6 +734,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
)
assert prefill_output.shape == output[:
num_prefill_tokens].shape
@ -755,6 +756,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
).squeeze(1)
except Exception as e:
logger.error("Error in PagedAttention.forward_decode: %s",
@ -787,6 +789,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
).squeeze(1)
return output

View File

@ -43,6 +43,20 @@ def fused_add_rms_norm(
return x, residual
def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
from vllm import _custom_ops as ops
out = torch.empty_like(x)
ops.poly_norm(
out,
x,
weight,
bias,
variance_epsilon,
)
return out
def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
import aiter as rocm_aiter
@ -320,3 +334,48 @@ class GemmaRMSNorm(CustomOp):
self.forward_static)
self._is_compiled = True
return self.forward_native(x, residual)
@CustomOp.register("poly_norm")
class PolyNorm(CustomOp):
"""Polynomial normalization.
Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b
where w_n is the learned weight and b is the bias.
Refer to https://arxiv.org/html/2411.03884v1
"""
def __init__(
self,
eps: float = 1e-6,
) -> None:
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(3) / 3)
self.bias = torch.nn.Parameter(torch.zeros(1))
self.variance_epsilon = eps
def _norm(self, x):
return x / torch.sqrt(
x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
def forward_native(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward().
Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
"""
orig_dtype = x.dtype
x_float = x.to(torch.float32)
output = (self.weight[0] * self._norm(x_float**3) +
self.weight[1] * self._norm(x_float**2) +
self.weight[2] * self._norm(x_float) + self.bias)
return output.to(orig_dtype)
def forward_cuda(
self,
x: torch.Tensor,
) -> torch.Tensor:
return poly_norm(x, self.weight, self.bias, self.variance_epsilon)

View File

@ -0,0 +1,345 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://huggingface.co/Motif-Technologies/Motif-2.6B/blob/main/modeling_motif.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Motif-Technologies/Motif-2.6B/blob/main/LICENSE
"""Inference-only Motif model compatible with HuggingFace weights."""
import math
from typing import Any, Optional
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionType
from vllm.attention.selector import _Backend
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.models.llama import LlamaForCausalLM
from .adapters import as_seq_cls_model
from .interfaces import SupportsV0Only
from .utils import extract_layer_index
class MotifMLP(nn.Module):
"""MLP for the language component of the Motif model, which contains a
MergedColumnParallelLinear merging 2 outputs via PolyNorm activation."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str = "poly_norm",
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
reduce_results: bool = True,
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "poly_norm":
raise NotImplementedError(f"Unsupported activation: {hidden_act}. "
"Only poly_norm is supported for now.")
self.act_fn = PolyNorm()
self.intermediate_size = intermediate_size
tp_size = get_tensor_model_parallel_world_size()
if hidden_act == "poly_norm" and tp_size > 1:
raise NotImplementedError(
"Tensor parallelism for poly_norm is not supported yet. "
"Support will be added in the future.")
def forward(self, x):
x, _ = self.gate_up_proj(x)
x = self.act_fn(
x[..., :self.intermediate_size]) * x[..., self.intermediate_size:]
x, _ = self.down_proj(x)
return x
class MotifAttention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
super().__init__()
layer_idx = extract_layer_index(prefix)
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
head_dim = getattr(config, "head_dim", None)
if head_dim is None:
head_dim = self.hidden_size // self.total_num_heads
self.head_dim = head_dim
# Phi models introduced a partial_rotary_factor parameter in the config
self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
1)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
assert self.num_heads % 2 == 0, 'num_heads should be even'
assert self.num_kv_heads % 2 == 0, 'num_heads should be even'
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias_o_proj,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self._init_rotary_emb(config,
rope_scaling=rope_scaling,
quant_config=quant_config)
sliding_window = None
self.lambda_init = self.lambda_init_fn(layer_idx)
self.lambda_q1 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.lambda_k1 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.lambda_q2 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.lambda_k2 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.subln = RMSNorm(2 * self.head_dim, eps=config.attn_rms_norm_eps)
params = {
'differential_flash_attention_config': {
'lambda_init': self.lambda_init,
'lambda_q1': self.lambda_q1,
'lambda_k1': self.lambda_k1,
'lambda_q2': self.lambda_q2,
'lambda_k2': self.lambda_k2,
"subln": self.subln,
}
}
diff_attn_err_msg = (
'Set VLLM_ATTENTION_BACKEND="DIFFERENTIAL_FLASH_ATTN" '
'to enable Differential Flash Attention.')
try:
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=sliding_window,
attn_type=attn_type,
prefix=f"{prefix}.attn",
**params,
)
except TypeError as e:
raise ValueError(diff_attn_err_msg) from e
assert (self.attn.backend == _Backend.DIFFERENTIAL_FLASH_ATTN
), diff_attn_err_msg
def lambda_init_fn(self, depth):
return 0.8 - 0.6 * math.exp(-0.3 * (depth - 1))
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
def _init_rotary_emb(self, config: PretrainedConfig,
rope_scaling: Optional[dict[str, Any]],
quant_config: Optional[QuantizationConfig]) -> None:
is_neox_style = True
is_gguf = quant_config and quant_config.get_name() == "gguf"
if is_gguf and config.model_type == "llama":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor,
)
class MotifDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "use_bias", False)
bias_o_proj = attention_bias
if hasattr(config, 'qkv_bias'):
attention_bias = config.qkv_bias
# By default, Motif uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. parasail-ai/GritLM-7B-vllm)
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
self.self_attn = MotifAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
bias_o_proj=bias_o_proj,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
)
self.mlp = MotifMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "use_bias", False),
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
# Motif model uses differential attention
# Only supported in v0 (no chunked prefill support)
class MotifForCausalLM(LlamaForCausalLM, SupportsV0Only):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = MotifDecoderLayer):
# Prefix caching and chunked prefill is not supported for this model.
assert not vllm_config.cache_config.enable_prefix_caching, \
"Motif currently does not support prefix caching"
assert not vllm_config.scheduler_config.chunked_prefill_enabled, \
"Motif currently does not support chunked prefill"
super().__init__(vllm_config=vllm_config,
prefix=prefix,
layer_type=layer_type)
MotifForSequenceClassification = as_seq_cls_model(MotifForCausalLM)

View File

@ -110,6 +110,7 @@ _TEXT_GENERATION_MODELS = {
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"MotifForCausalLM": ("motif", "MotifForCausalLM"),
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),