mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] New model support for Motif-1-Tiny (#23414)
Signed-off-by: ca1207 <ca1207zzz@gmail.com> Signed-off-by: TaehyunKim <73943231+ca1207@users.noreply.github.com> Co-authored-by: WyldeCat <skan1543@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
155
benchmarks/kernels/benchmark_polynorm.py
Normal file
155
benchmarks/kernels/benchmark_polynorm.py
Normal file
@ -0,0 +1,155 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
def polynorm_naive(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
def norm(x, eps: float):
|
||||
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
||||
|
||||
x = x.float()
|
||||
return (
|
||||
(
|
||||
weight[0] * norm(x**3, eps)
|
||||
+ weight[1] * norm(x**2, eps)
|
||||
+ weight[2] * norm(x, eps)
|
||||
+ bias
|
||||
)
|
||||
.to(weight.dtype)
|
||||
.view(orig_shape)
|
||||
)
|
||||
|
||||
|
||||
def polynorm_vllm(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
out = torch.empty_like(x)
|
||||
vllm_ops.poly_norm(out, x, weight, bias, eps)
|
||||
output = out
|
||||
|
||||
output = output.view(orig_shape)
|
||||
return output
|
||||
|
||||
|
||||
def calculate_diff(batch_size, seq_len, hidden_dim):
|
||||
dtype = torch.bfloat16
|
||||
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
|
||||
weight = torch.ones(3, dtype=dtype, device="cuda")
|
||||
bias = torch.ones(1, dtype=dtype, device="cuda")
|
||||
|
||||
output_naive = polynorm_naive(x, weight, bias)
|
||||
output_vllm = polynorm_vllm(x, weight, bias)
|
||||
|
||||
if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [2**i for i in range(0, 7, 2)]
|
||||
seq_length_range = [2**i for i in range(6, 11, 1)]
|
||||
dim_range = [2048, 4096]
|
||||
configs = list(itertools.product(dim_range, batch_size_range, seq_length_range))
|
||||
|
||||
|
||||
def get_benchmark():
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["dim", "batch_size", "seq_len"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["naive", "vllm"],
|
||||
line_names=["Naive", "vLLM"],
|
||||
styles=[("blue", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name="polynorm-perf",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(dim, batch_size, seq_len, provider):
|
||||
dtype = torch.bfloat16
|
||||
hidden_dim = dim * 4
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
|
||||
weight = torch.ones(3, dtype=dtype, device="cuda")
|
||||
bias = torch.ones(1, dtype=dtype, device="cuda")
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "naive":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: polynorm_naive(x, weight, bias),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else:
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: polynorm_vllm(x, weight, bias),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seq-len",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Sequence length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden-dim",
|
||||
type=int,
|
||||
default=8192,
|
||||
help="Intermediate size of MLP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./configs/polnorm/",
|
||||
help="Path to save polnorm benchmark results",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run correctness test
|
||||
calculate_diff(
|
||||
batch_size=args.batch_size,
|
||||
seq_len=args.seq_len,
|
||||
hidden_dim=args.hidden_dim,
|
||||
)
|
||||
|
||||
benchmark = get_benchmark()
|
||||
# Run performance benchmark
|
||||
benchmark.run(print_data=True, save_path=args.save_path)
|
@ -140,6 +140,211 @@ fused_add_rms_norm_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
/* Function specialization in the case of FP16/BF16 tensors.
|
||||
Additional optimizations we can make in this case are
|
||||
packed and vectorized operations, which help with the
|
||||
memory latency bottleneck.
|
||||
|
||||
_f16VecPN struct extends _f16Vec to add operations specifically required for
|
||||
polynomial normalization (poly norm).
|
||||
The original _f16Vec does not include the sum-of-powers computation or
|
||||
in-place polynomial normalization logic. */
|
||||
template <typename scalar_t, int width>
|
||||
struct alignas(16) _f16VecPN : _f16Vec<scalar_t, width> {
|
||||
using Base = _f16Vec<scalar_t, width>;
|
||||
using Converter = typename Base::Converter;
|
||||
using T1 = typename Base::T1;
|
||||
using T2 = typename Base::T2;
|
||||
using Base::data;
|
||||
|
||||
__device__ auto sum_pows() const {
|
||||
float s2 = 0.0f, s4 = 0.0f, s6 = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
float2 z = Converter::convert(T2{data[i], data[i + 1]});
|
||||
float x2 = z.x * z.x;
|
||||
float x4 = x2 * x2;
|
||||
float x6 = x4 * x2;
|
||||
|
||||
float y2 = z.y * z.y;
|
||||
float y4 = y2 * y2;
|
||||
float y6 = y4 * y2;
|
||||
|
||||
s2 += x2 + y2;
|
||||
s4 += x4 + y4;
|
||||
s6 += x6 + y6;
|
||||
}
|
||||
return std::make_tuple(s2, s4, s6);
|
||||
}
|
||||
|
||||
__device__ void poly_norm_inplace(const float w2_inv_std,
|
||||
const float w1_inv_std2,
|
||||
const float w0_inv_std3, const float bias) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
float2 z = Converter::convert(T2{data[i], data[i + 1]});
|
||||
|
||||
float x2 = z.x * z.x;
|
||||
float x3 = x2 * z.x;
|
||||
z.x = w2_inv_std * z.x + w1_inv_std2 * x2 + w0_inv_std3 * x3 + bias;
|
||||
|
||||
float y2 = z.y * z.y;
|
||||
float y3 = y2 * z.y;
|
||||
z.y = w2_inv_std * z.y + w1_inv_std2 * y2 + w0_inv_std3 * y3 + bias;
|
||||
|
||||
auto out = Converter::convert(z);
|
||||
data[i] = out.x;
|
||||
data[i + 1] = out.y;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [3]
|
||||
const scalar_t* __restrict__ bias, // [1]
|
||||
const float epsilon, const int hidden_size) {
|
||||
// Sanity checks on our vector struct and type-punned pointer arithmetic
|
||||
static_assert(std::is_pod_v<_f16VecPN<scalar_t, width>>);
|
||||
static_assert(sizeof(_f16VecPN<scalar_t, width>) == sizeof(scalar_t) * width);
|
||||
|
||||
/* These and the argument pointers are all declared `restrict` as they are
|
||||
not aliased in practice. Argument pointers should not be dereferenced
|
||||
in this kernel as that would be undefined behavior */
|
||||
auto* __restrict__ input_v =
|
||||
reinterpret_cast<const _f16VecPN<scalar_t, width>*>(input);
|
||||
const int vec_hidden_size = hidden_size / width;
|
||||
float variance = 0.0f;
|
||||
float variance2 = 0.0f;
|
||||
float variance3 = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
_f16VecPN<scalar_t, width> temp = input_v[id];
|
||||
auto [x2, x4, x6] = temp.sum_pows();
|
||||
|
||||
variance += x2;
|
||||
variance2 += x4;
|
||||
variance3 += x6;
|
||||
}
|
||||
|
||||
float3 thread_variances = make_float3(variance, variance2, variance3);
|
||||
|
||||
struct SumOp {
|
||||
__device__ float3 operator()(const float3& a, const float3& b) const {
|
||||
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
||||
}
|
||||
};
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float3, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
float3 block_variances =
|
||||
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
|
||||
|
||||
variance = block_variances.x;
|
||||
variance2 = block_variances.y;
|
||||
variance3 = block_variances.z;
|
||||
|
||||
__shared__ float s_w2_inv_std;
|
||||
__shared__ float s_w1_inv_std2;
|
||||
__shared__ float s_w0_inv_std3;
|
||||
__shared__ float s_bias;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
float w0 = (float)weight[0];
|
||||
float w1 = (float)weight[1];
|
||||
float w2 = (float)weight[2];
|
||||
s_bias = (float)bias[0];
|
||||
|
||||
s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
|
||||
s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
|
||||
s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
auto* __restrict__ out_v = reinterpret_cast<_f16VecPN<scalar_t, width>*>(out);
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
_f16VecPN<scalar_t, width> temp = input_v[id];
|
||||
temp.poly_norm_inplace(s_w2_inv_std, s_w1_inv_std2, s_w0_inv_std3, s_bias);
|
||||
out_v[id] = temp;
|
||||
}
|
||||
}
|
||||
|
||||
/* Generic poly_norm_kernel
|
||||
The width field is not used here but necessary for other specializations.
|
||||
*/
|
||||
template <typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [3]
|
||||
const scalar_t* __restrict__ bias, // [1]
|
||||
const float epsilon, const int hidden_size) {
|
||||
float variance = 0.0f;
|
||||
float variance2 = 0.0f;
|
||||
float variance3 = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
float x2 = x * x;
|
||||
float x4 = x2 * x2;
|
||||
float x6 = x4 * x2;
|
||||
|
||||
variance += x2;
|
||||
variance2 += x4;
|
||||
variance3 += x6;
|
||||
}
|
||||
|
||||
float3 thread_variances = make_float3(variance, variance2, variance3);
|
||||
|
||||
struct SumOp {
|
||||
__device__ float3 operator()(const float3& a, const float3& b) const {
|
||||
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
||||
}
|
||||
};
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float3, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
float3 block_variances =
|
||||
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
|
||||
|
||||
variance = block_variances.x;
|
||||
variance2 = block_variances.y;
|
||||
variance3 = block_variances.z;
|
||||
|
||||
__shared__ float s_w2_inv_std;
|
||||
__shared__ float s_w1_inv_std2;
|
||||
__shared__ float s_w0_inv_std3;
|
||||
__shared__ float s_bias;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
float w0 = (float)weight[0];
|
||||
float w1 = (float)weight[1];
|
||||
float w2 = (float)weight[2];
|
||||
s_bias = (float)bias[0];
|
||||
|
||||
s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
|
||||
s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
|
||||
s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
float x2 = x * x;
|
||||
float x3 = x2 * x;
|
||||
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
(scalar_t)(x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 +
|
||||
s_bias);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
@ -219,3 +424,49 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||
}
|
||||
}
|
||||
|
||||
#define LAUNCH_FUSED_POLY_NORM(width) \
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \
|
||||
vllm::poly_norm_kernel<scalar_t, width><<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
|
||||
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), epsilon, \
|
||||
hidden_size); \
|
||||
});
|
||||
|
||||
void poly_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [3]
|
||||
torch::Tensor& bias, // [1]
|
||||
double epsilon) {
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.data_ptr() != input.data_ptr());
|
||||
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
/* This kernel is memory-latency bound in many scenarios.
|
||||
When num_tokens is large, a smaller block size allows
|
||||
for increased block occupancy on CUs and better latency
|
||||
hiding on global mem ops. */
|
||||
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
|
||||
dim3 block(std::min(hidden_size, max_block_size));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
/*If the tensor types are FP16/BF16, try to use the optimized kernel
|
||||
with packed + vectorized ops.
|
||||
Max optimization is achieved with a width-8 vector of FP16/BF16s
|
||||
since we can load at most 128 bits at once in a global memory op.
|
||||
However, this requires each tensor's data to be aligned to 16
|
||||
bytes.
|
||||
*/
|
||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
|
||||
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
||||
LAUNCH_FUSED_POLY_NORM(8);
|
||||
} else {
|
||||
LAUNCH_FUSED_POLY_NORM(0);
|
||||
}
|
||||
}
|
||||
|
@ -92,6 +92,9 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||
torch::Tensor& weight, double epsilon);
|
||||
|
||||
void poly_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
torch::Tensor& bias, double epsilon);
|
||||
|
||||
void apply_repetition_penalties_(torch::Tensor& logits,
|
||||
const torch::Tensor& prompt_mask,
|
||||
const torch::Tensor& output_mask,
|
||||
@ -353,4 +356,4 @@ void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles);
|
||||
void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
int64_t quant_level, bool cast_bf2half = false);
|
||||
int64_t qr_max_size();
|
||||
#endif
|
||||
#endif
|
||||
|
@ -168,6 +168,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"float epsilon) -> ()");
|
||||
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
||||
|
||||
// Polynomial Normalization.
|
||||
ops.def(
|
||||
"poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, float "
|
||||
"epsilon) -> ()");
|
||||
ops.impl("poly_norm", torch::kCUDA, &poly_norm);
|
||||
|
||||
// Apply repetition penalties to logits in-place
|
||||
ops.def(
|
||||
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
|
||||
|
@ -383,6 +383,7 @@ th {
|
||||
| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MotifForCausalLM` | Motif-1-Tiny | `Motif-Technologies/Motif-2.6B`, `Motif-Technologies/Motif-2.6b-v1.1-LC`, etc. | ✅︎ | ✅︎ | |
|
||||
| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
|
@ -6,7 +6,7 @@ import torch
|
||||
|
||||
from tests.kernels.quant_utils import FP8_DTYPE
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
@ -70,6 +70,37 @@ def test_rms_norm(
|
||||
(out, x, layer.weight.data, layer.variance_epsilon))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_poly_norm(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
layer = PolyNorm().to(dtype=dtype)
|
||||
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||
layer.bias.data.normal_(mean=1.0, std=0.1)
|
||||
scale = 1 / (2 * hidden_size)
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
x *= scale
|
||||
|
||||
ref_out = layer.forward_native(x)
|
||||
out = layer(x)
|
||||
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.poly_norm,
|
||||
(out, x, layer.weight.data, layer.bias.data, layer.variance_epsilon))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
|
||||
|
@ -291,6 +291,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"MistralForCausalLM": _HfExamplesInfo("mistralai/Mistral-7B-Instruct-v0.1"),
|
||||
"MixtralForCausalLM": _HfExamplesInfo("mistralai/Mixtral-8x7B-Instruct-v0.1", # noqa: E501
|
||||
{"tiny": "TitanML/tiny-mixtral"}), # noqa: E501
|
||||
"MotifForCausalLM": _HfExamplesInfo("Motif-Technologies/Motif-2.6B",
|
||||
trust_remote_code=True,
|
||||
v0_only=True),
|
||||
"MptForCausalLM": _HfExamplesInfo("mpt", is_available_online=False),
|
||||
"MPTForCausalLM": _HfExamplesInfo("mosaicml/mpt-7b"),
|
||||
"NemotronForCausalLM": _HfExamplesInfo("nvidia/Minitron-8B-Base"),
|
||||
|
@ -63,8 +63,9 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
|
||||
_initialize_kv_caches_v1), monkeypatch.context() as m):
|
||||
if model_info.v0_only:
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
if model_arch == "Phi4FlashForCausalLM":
|
||||
# Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend
|
||||
if model_arch in ("Phi4FlashForCausalLM", "MotifForCausalLM"):
|
||||
# Phi4FlashForCausalLM and MotifForCausalLM
|
||||
# only supports DIFFERENTIAL_FLASH_ATTN backend
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN")
|
||||
if model_arch == "GptOssForCausalLM":
|
||||
# FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU
|
||||
|
@ -280,6 +280,13 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
|
||||
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
|
||||
|
||||
|
||||
def poly_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
||||
bias: torch.Tensor, epsilon: float) -> None:
|
||||
# TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input
|
||||
input_contiguous = input.contiguous()
|
||||
torch.ops._C.poly_norm(out, input_contiguous, weight, bias, epsilon)
|
||||
|
||||
|
||||
def apply_repetition_penalties_torch(
|
||||
logits: torch.Tensor, prompt_mask: torch.Tensor,
|
||||
output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None:
|
||||
@ -710,6 +717,7 @@ def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
|
||||
def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool:
|
||||
return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability)
|
||||
|
||||
|
||||
def cutlass_sparse_compress(a: torch.Tensor) \
|
||||
-> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
|
@ -734,6 +734,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softcap=self.logits_soft_cap,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
)
|
||||
assert prefill_output.shape == output[:
|
||||
num_prefill_tokens].shape
|
||||
@ -755,6 +756,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softcap=self.logits_soft_cap,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
).squeeze(1)
|
||||
except Exception as e:
|
||||
logger.error("Error in PagedAttention.forward_decode: %s",
|
||||
@ -787,6 +789,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softcap=self.logits_soft_cap,
|
||||
fa_version=self.vllm_flash_attn_version,
|
||||
).squeeze(1)
|
||||
return output
|
||||
|
||||
|
@ -43,6 +43,20 @@ def fused_add_rms_norm(
|
||||
return x, residual
|
||||
|
||||
|
||||
def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
|
||||
variance_epsilon: float) -> torch.Tensor:
|
||||
from vllm import _custom_ops as ops
|
||||
out = torch.empty_like(x)
|
||||
ops.poly_norm(
|
||||
out,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
variance_epsilon,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor,
|
||||
variance_epsilon: float) -> torch.Tensor:
|
||||
import aiter as rocm_aiter
|
||||
@ -320,3 +334,48 @@ class GemmaRMSNorm(CustomOp):
|
||||
self.forward_static)
|
||||
self._is_compiled = True
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
|
||||
@CustomOp.register("poly_norm")
|
||||
class PolyNorm(CustomOp):
|
||||
"""Polynomial normalization.
|
||||
|
||||
Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b
|
||||
where w_n is the learned weight and b is the bias.
|
||||
Refer to https://arxiv.org/html/2411.03884v1
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(3) / 3)
|
||||
self.bias = torch.nn.Parameter(torch.zeros(1))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x / torch.sqrt(
|
||||
x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward().
|
||||
|
||||
Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
|
||||
"""
|
||||
|
||||
orig_dtype = x.dtype
|
||||
x_float = x.to(torch.float32)
|
||||
output = (self.weight[0] * self._norm(x_float**3) +
|
||||
self.weight[1] * self._norm(x_float**2) +
|
||||
self.weight[2] * self._norm(x_float) + self.bias)
|
||||
return output.to(orig_dtype)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return poly_norm(x, self.weight, self.bias, self.variance_epsilon)
|
||||
|
345
vllm/model_executor/models/motif.py
Normal file
345
vllm/model_executor/models/motif.py
Normal file
@ -0,0 +1,345 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://huggingface.co/Motif-Technologies/Motif-2.6B/blob/main/modeling_motif.py
|
||||
# Copyright (c) Alibaba Cloud.
|
||||
# LICENSE: https://huggingface.co/Motif-Technologies/Motif-2.6B/blob/main/LICENSE
|
||||
"""Inference-only Motif model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.selector import _Backend
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
|
||||
from .adapters import as_seq_cls_model
|
||||
from .interfaces import SupportsV0Only
|
||||
from .utils import extract_layer_index
|
||||
|
||||
|
||||
class MotifMLP(nn.Module):
|
||||
"""MLP for the language component of the Motif model, which contains a
|
||||
MergedColumnParallelLinear merging 2 outputs via PolyNorm activation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str = "poly_norm",
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
prefix: str = "",
|
||||
reduce_results: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "poly_norm":
|
||||
raise NotImplementedError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only poly_norm is supported for now.")
|
||||
self.act_fn = PolyNorm()
|
||||
self.intermediate_size = intermediate_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if hidden_act == "poly_norm" and tp_size > 1:
|
||||
raise NotImplementedError(
|
||||
"Tensor parallelism for poly_norm is not supported yet. "
|
||||
"Support will be added in the future.")
|
||||
|
||||
def forward(self, x):
|
||||
x, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(
|
||||
x[..., :self.intermediate_size]) * x[..., self.intermediate_size:]
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class MotifAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
bias_o_proj: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
||||
head_dim = getattr(config, "head_dim", None)
|
||||
if head_dim is None:
|
||||
head_dim = self.hidden_size // self.total_num_heads
|
||||
self.head_dim = head_dim
|
||||
# Phi models introduced a partial_rotary_factor parameter in the config
|
||||
self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
|
||||
1)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
assert self.num_heads % 2 == 0, 'num_heads should be even'
|
||||
assert self.num_kv_heads % 2 == 0, 'num_heads should be even'
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=hidden_size,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.total_num_heads,
|
||||
total_num_kv_heads=self.total_num_kv_heads,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=self.total_num_heads * self.head_dim,
|
||||
output_size=hidden_size,
|
||||
bias=bias_o_proj,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self._init_rotary_emb(config,
|
||||
rope_scaling=rope_scaling,
|
||||
quant_config=quant_config)
|
||||
sliding_window = None
|
||||
|
||||
self.lambda_init = self.lambda_init_fn(layer_idx)
|
||||
self.lambda_q1 = nn.Parameter(
|
||||
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
|
||||
std=0.1))
|
||||
self.lambda_k1 = nn.Parameter(
|
||||
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
|
||||
std=0.1))
|
||||
self.lambda_q2 = nn.Parameter(
|
||||
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
|
||||
std=0.1))
|
||||
self.lambda_k2 = nn.Parameter(
|
||||
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
|
||||
std=0.1))
|
||||
self.subln = RMSNorm(2 * self.head_dim, eps=config.attn_rms_norm_eps)
|
||||
|
||||
params = {
|
||||
'differential_flash_attention_config': {
|
||||
'lambda_init': self.lambda_init,
|
||||
'lambda_q1': self.lambda_q1,
|
||||
'lambda_k1': self.lambda_k1,
|
||||
'lambda_q2': self.lambda_q2,
|
||||
'lambda_k2': self.lambda_k2,
|
||||
"subln": self.subln,
|
||||
}
|
||||
}
|
||||
|
||||
diff_attn_err_msg = (
|
||||
'Set VLLM_ATTENTION_BACKEND="DIFFERENTIAL_FLASH_ATTN" '
|
||||
'to enable Differential Flash Attention.')
|
||||
try:
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
attn_type=attn_type,
|
||||
prefix=f"{prefix}.attn",
|
||||
**params,
|
||||
)
|
||||
except TypeError as e:
|
||||
raise ValueError(diff_attn_err_msg) from e
|
||||
assert (self.attn.backend == _Backend.DIFFERENTIAL_FLASH_ATTN
|
||||
), diff_attn_err_msg
|
||||
|
||||
def lambda_init_fn(self, depth):
|
||||
return 0.8 - 0.6 * math.exp(-0.3 * (depth - 1))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
def _init_rotary_emb(self, config: PretrainedConfig,
|
||||
rope_scaling: Optional[dict[str, Any]],
|
||||
quant_config: Optional[QuantizationConfig]) -> None:
|
||||
is_neox_style = True
|
||||
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
||||
if is_gguf and config.model_type == "llama":
|
||||
is_neox_style = False
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=is_neox_style,
|
||||
partial_rotary_factor=self.partial_rotary_factor,
|
||||
)
|
||||
|
||||
|
||||
class MotifDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is not None and getattr(
|
||||
config, "original_max_position_embeddings", None):
|
||||
rope_scaling["original_max_position_embeddings"] = (
|
||||
config.original_max_position_embeddings)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||
config, "use_bias", False)
|
||||
bias_o_proj = attention_bias
|
||||
if hasattr(config, 'qkv_bias'):
|
||||
attention_bias = config.qkv_bias
|
||||
|
||||
# By default, Motif uses causal attention as it is a decoder-only model.
|
||||
# You can override the HF config with `is_causal=False` to enable
|
||||
# bidirectional attention, which is used in some embedding models
|
||||
# (e.g. parasail-ai/GritLM-7B-vllm)
|
||||
if getattr(config, "is_causal", True):
|
||||
attn_type = AttentionType.DECODER
|
||||
else:
|
||||
attn_type = AttentionType.ENCODER_ONLY
|
||||
|
||||
self.self_attn = MotifAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=getattr(config, "num_key_value_heads",
|
||||
config.num_attention_heads),
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=attention_bias,
|
||||
bias_o_proj=bias_o_proj,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_type=attn_type,
|
||||
)
|
||||
self.mlp = MotifMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
bias=getattr(config, "use_bias", False),
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions=positions,
|
||||
hidden_states=hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
# Motif model uses differential attention
|
||||
# Only supported in v0 (no chunked prefill support)
|
||||
class MotifForCausalLM(LlamaForCausalLM, SupportsV0Only):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[nn.Module] = MotifDecoderLayer):
|
||||
|
||||
# Prefix caching and chunked prefill is not supported for this model.
|
||||
assert not vllm_config.cache_config.enable_prefix_caching, \
|
||||
"Motif currently does not support prefix caching"
|
||||
assert not vllm_config.scheduler_config.chunked_prefill_enabled, \
|
||||
"Motif currently does not support chunked prefill"
|
||||
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
layer_type=layer_type)
|
||||
|
||||
|
||||
MotifForSequenceClassification = as_seq_cls_model(MotifForCausalLM)
|
@ -110,6 +110,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
|
||||
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
|
||||
"MotifForCausalLM": ("motif", "MotifForCausalLM"),
|
||||
# transformers's mpt class has lower case
|
||||
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||
|
Reference in New Issue
Block a user