[CPU] Refactor CPU unquantized linear (#24150)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang
2025-09-04 14:28:45 +08:00
committed by GitHub
parent cb55ad86fe
commit 57b1ce94f7
9 changed files with 466 additions and 26 deletions

View File

@ -22,6 +22,23 @@ void release_dnnl_matmul_handler(int64_t handler) {
delete ptr;
}
DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) {
this->realloc(allocation_unit * 128);
}
void DNNLScratchPadManager::realloc(size_t new_size) {
new_size = round(new_size);
if (new_size > size_) {
ptr_ = std::aligned_alloc(64, new_size);
size_ = new_size;
}
}
DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() {
static DNNLScratchPadManager manager;
return &manager;
}
template <typename KT, typename VT>
class DNNLPrimitiveCache {
public:
@ -166,6 +183,23 @@ struct hash<W8A8MatMulPrimitiveHandler::MSizeCacheKey> {
hash<int>()(static_cast<int>(val.bias_type));
}
};
template <>
struct hash<MatMulPrimitiveHandler::ClassMatmulCacheKey> {
size_t operator()(
const MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const {
return hash<dnnl_dim_t>()(val.b_n_size) ^ hash<dnnl_dim_t>()(val.b_k_size);
}
};
template <>
struct hash<MatMulPrimitiveHandler::MSizeCacheKey> {
size_t operator()(const MatMulPrimitiveHandler::MSizeCacheKey& val) const {
return hash<dnnl_dim_t>()(val.a_m_size) ^
hash<dnnl_dim_t>()(val.a_m_stride) ^ hash<bool>()(val.use_bias) ^
hash<int>()(static_cast<int>(val.bias_type));
}
};
} // namespace std
bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
@ -181,6 +215,17 @@ bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l,
l.bias_type == r.bias_type;
}
bool operator==(const MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
const MatMulPrimitiveHandler::ClassMatmulCacheKey& r) {
return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size;
}
bool operator==(const MatMulPrimitiveHandler::MSizeCacheKey& l,
const MatMulPrimitiveHandler::MSizeCacheKey& r) {
return l.a_m_size == r.a_m_size && l.a_m_stride == r.a_m_stride &&
l.use_bias == r.use_bias && l.bias_type == r.bias_type;
}
static std::shared_ptr<W8A8MatMulPrimitiveHandler::MSizeCache>
get_w8a8_class_primitive_cache(
const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key,
@ -239,6 +284,11 @@ void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
}
dnnl::matmul matmul = get_matmul_cache(args);
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(5);
scratchpad_storage->set_data_handle(
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
matmul.execute(default_stream(), memory_cache_);
default_stream().wait();
}
@ -257,6 +307,8 @@ dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
return m_size_cache_->get_or_create(key, [&]() {
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager();
manager->realloc(desc.scratchpad_desc().get_size());
return dnnl::matmul(desc);
});
}
@ -300,6 +352,11 @@ void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
default_engine(), nullptr);
set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get());
memory_cache_[DNNL_ARG_SCRATCHPAD] =
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
default_engine(), nullptr);
set_runtime_memory_ptr(5, memory_cache_[DNNL_ARG_SCRATCHPAD].get());
}
dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
@ -319,6 +376,9 @@ dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
dnnl::memory::format_tag::ab);
dnnl::primitive_attr attr;
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
// For PER_TOKEN, scales will be applied in outside epilogue
if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
attr.set_scales_mask(DNNL_ARG_SRC, 0);
@ -344,3 +404,120 @@ dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
attr);
}
}
MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args)
: DNNLMatMulPrimitiveHandler(
static_cast<DNNLMatMulPrimitiveHandler::Args>(args), args.ab_type),
m_size_cache_(nullptr) {
assert(ab_type_ == dnnl::memory::data_type::f32 ||
ab_type_ == dnnl::memory::data_type::bf16 ||
ab_type_ == dnnl::memory::data_type::f16);
prepack_weight(args.b_ptr,
create_primitive_desc(
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
.a_m_stride = DNNL_RUNTIME_DIM_VAL,
.use_bias = false,
.bias_type = dnnl::memory::data_type::undef},
true)
.weights_desc());
init_runtime_memory_cache(args);
}
static std::shared_ptr<MatMulPrimitiveHandler::MSizeCache>
get_matul_class_primitive_cache(
const MatMulPrimitiveHandler::ClassMatmulCacheKey& key,
int64_t cache_size) {
static MatMulPrimitiveHandler::ClassMatmulCache cache(128);
assert(cache_size > 0);
return cache.get_or_create(key, [&]() {
return std::make_shared<MatMulPrimitiveHandler::MSizeCache>(cache_size);
});
}
void MatMulPrimitiveHandler::execute(ExecArgs& args) {
auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0);
auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1);
a_storage->set_data_handle((void*)args.a_ptr);
a_mem_desc->dims[0] = args.a_m_size;
a_mem_desc->format_desc.blocking.strides[0] = args.a_m_stride;
c_storage->set_data_handle((void*)args.c_ptr);
c_mem_desc->dims[0] = args.a_m_size;
if (args.use_bias) {
auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(2);
bias_storage->set_data_handle((void*)args.bias_ptr);
}
dnnl::matmul matmul = get_matmul_cache(args);
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3);
scratchpad_storage->set_data_handle(
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
matmul.execute(default_stream(), memory_cache_);
default_stream().wait();
}
dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache(
const MSizeCacheKey& key) {
if (m_size_cache_.get() == nullptr) {
ClassMatmulCacheKey key = {.b_n_size = b_n_size_, .b_k_size = b_k_size_};
m_size_cache_ = get_matul_class_primitive_cache(key, primitive_cache_size_);
}
return m_size_cache_->get_or_create(key, [&]() {
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager();
manager->realloc(desc.scratchpad_desc().get_size());
return dnnl::matmul(desc);
});
}
dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc(
const MSizeCacheKey& key, bool first_time) {
dnnl::memory::desc a_md;
dnnl::memory::desc b_md;
if (first_time) {
a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_,
dnnl::memory::format_tag::ab);
b_md = dnnl::memory::desc({b_k_size_, b_n_size_}, b_type_,
dnnl::memory::format_tag::any);
} else {
a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_,
{key.a_m_stride, 1});
b_md = b_target_mem_desc_;
}
dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
dnnl::memory::format_tag::ab);
dnnl::primitive_attr attr;
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
if (key.use_bias) {
dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
c_md, attr);
} else {
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
attr);
}
}
void MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
memory_cache_[DNNL_ARG_SRC] = dnnl::memory(
{{1, b_k_size_}, b_type_, {b_k_size_, 1}}, default_engine(), nullptr);
set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get());
memory_cache_[DNNL_ARG_DST] =
dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab},
default_engine(), nullptr);
set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());
memory_cache_[DNNL_ARG_BIAS] =
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
default_engine(), nullptr);
set_runtime_memory_ptr(2, memory_cache_[DNNL_ARG_BIAS].get());
memory_cache_[DNNL_ARG_SCRATCHPAD] =
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
default_engine(), nullptr);
set_runtime_memory_ptr(3, memory_cache_[DNNL_ARG_SCRATCHPAD].get());
}

View File

@ -59,6 +59,30 @@ constexpr inline dnnl::memory::data_type get_dnnl_type() {
return DNNLType<std::decay_t<T>>::type;
}
class DNNLScratchPadManager {
public:
static constexpr size_t allocation_unit = 4 * 1024 * 1024; // 4KB
static DNNLScratchPadManager* get_dnnl_scratchpad_manager();
DNNLScratchPadManager();
template <typename T>
T* get_data() {
return reinterpret_cast<T*>(ptr_);
}
static size_t round(size_t size) {
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
}
void realloc(size_t new_size);
private:
size_t size_;
void* ptr_;
};
class DNNLMatMulPrimitiveHandler {
public:
virtual ~DNNLMatMulPrimitiveHandler() = default;
@ -166,4 +190,54 @@ class W8A8MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler {
std::shared_ptr<MSizeCache> m_size_cache_;
};
class MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler {
public:
struct Args : public DNNLMatMulPrimitiveHandler::Args {
dnnl::memory::data_type ab_type;
};
struct ClassMatmulCacheKey {
dnnl_dim_t b_n_size;
dnnl_dim_t b_k_size;
friend bool operator==(const ClassMatmulCacheKey& l,
const ClassMatmulCacheKey& r);
};
struct MSizeCacheKey {
dnnl_dim_t a_m_size;
dnnl_dim_t a_m_stride;
bool use_bias;
dnnl::memory::data_type bias_type;
friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r);
};
using MSizeCache = DNNLPrimitiveCache<MSizeCacheKey, dnnl::matmul>;
using ClassMatmulCache =
DNNLPrimitiveCache<ClassMatmulCacheKey, std::shared_ptr<MSizeCache>>;
struct ExecArgs : public MSizeCacheKey {
const void* a_ptr;
const void* bias_ptr;
void* c_ptr;
};
public:
MatMulPrimitiveHandler(const Args& args);
void execute(ExecArgs& args);
private:
dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key,
bool first_time);
void init_runtime_memory_cache(const Args& args);
dnnl::matmul get_matmul_cache(const MSizeCacheKey& key);
private:
std::shared_ptr<MSizeCache> m_size_cache_;
};
#endif

View File

@ -379,6 +379,7 @@ void onednn_scaled_mm(
exec_args.a_ptr = a.data_ptr<int8_t>();
exec_args.a_m_size = a.size(0);
exec_args.bias_ptr = nullptr;
exec_args.bias_type = get_dnnl_type<void>();
exec_args.use_bias = false;
exec_args.a_scales_ptr = nullptr;
exec_args.a_zero_points_ptr = nullptr;
@ -492,3 +493,56 @@ void dynamic_scaled_int8_quant(
}
});
}
int64_t create_onednn_mm_handler(const torch::Tensor& b,
int64_t primitive_cache_size) {
TORCH_CHECK(b.dim() == 2);
MatMulPrimitiveHandler::Args args;
args.primitive_cache_size = primitive_cache_size;
args.b_k_size = b.size(0);
args.b_k_stride = b.stride(0);
args.b_n_size = b.size(1);
args.b_n_stride = b.stride(1);
args.b_ptr = b.data_ptr();
VLLM_DISPATCH_FLOATING_TYPES(b.scalar_type(), "create_onednn_mm_handler",
[&] {
args.c_type = get_dnnl_type<scalar_t>();
args.ab_type = get_dnnl_type<scalar_t>();
});
return reinterpret_cast<int64_t>(new MatMulPrimitiveHandler(args));
}
void onednn_mm(torch::Tensor& c, // [M, OC], row-major
const torch::Tensor& a, // [M, IC], row-major
const std::optional<torch::Tensor>& bias, int64_t handler) {
CPU_KERNEL_GUARD_IN(onednn_mm)
TORCH_CHECK(a.dim() == 2);
TORCH_CHECK(a.stride(-1) == 1);
TORCH_CHECK(c.is_contiguous());
MatMulPrimitiveHandler* ptr =
reinterpret_cast<MatMulPrimitiveHandler*>(handler);
MatMulPrimitiveHandler::ExecArgs exec_args;
exec_args.a_m_size = a.size(0);
exec_args.a_m_stride = a.stride(0);
VLLM_DISPATCH_FLOATING_TYPES(a.scalar_type(), "onednn_mm", [&] {
if (bias.has_value()) {
exec_args.use_bias = true;
exec_args.bias_type = get_dnnl_type<scalar_t>();
exec_args.bias_ptr = bias->data_ptr<scalar_t>();
} else {
exec_args.use_bias = false;
exec_args.bias_type = get_dnnl_type<void>();
exec_args.bias_ptr = nullptr;
}
exec_args.a_ptr = a.data_ptr<scalar_t>();
exec_args.c_ptr = c.data_ptr<scalar_t>();
ptr->execute(exec_args);
});
}

View File

@ -21,6 +21,12 @@ void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
const std::optional<torch::Tensor>& bias,
int64_t handler);
int64_t create_onednn_mm_handler(const torch::Tensor& b,
int64_t primitive_cache_size);
void onednn_mm(torch::Tensor& c, const torch::Tensor& a,
const std::optional<torch::Tensor>& bias, int64_t handler);
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& kv_cache, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens);
@ -153,6 +159,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("release_dnnl_matmul_handler(int handler) -> ()",
&release_dnnl_matmul_handler);
// Create oneDNN GEMM handler
ops.def(
"create_onednn_mm_handler(Tensor b, int "
"primitive_cache_size) -> int",
&create_onednn_mm_handler);
// oneDNN GEMM
ops.def(
"onednn_mm(Tensor! c, Tensor a, Tensor? bias, "
"int handler) -> ()");
ops.impl("onednn_mm", torch::kCPU, &onednn_mm);
// Create oneDNN W8A8 handler
ops.def(
"create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType "

View File

@ -111,6 +111,49 @@ def onednn_int8_gemm_test_helper(primitive_cache_size: int,
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
def onednn_gemm_test_helper(primitive_cache_size: int,
m: int,
n: int,
k: int,
use_bias: bool,
use_stride: bool,
dtype: torch.dtype = torch.bfloat16,
device: str = "cpu"):
if use_stride:
a = torch.rand((m, 2 * k), dtype=dtype, device=device) * 1.5
a = a[:, :k]
else:
a = torch.rand((m, k), dtype=dtype, device=device) * 1.5
b = torch.rand((n, k), dtype=dtype, device=device) * 1.5
if use_bias:
bias = torch.rand((n, ), device=device, dtype=dtype) * 5
bias_f32 = bias.float()
else:
bias = None
bias_f32 = None
handler = ops.create_onednn_mm(
b.t(),
primitive_cache_size,
)
out = ops.onednn_mm(handler, a, bias)
baseline = torch.nn.functional.linear(a.float(), b.float(),
bias_f32).to(dtype=a.dtype)
torch.testing.assert_close(out, baseline)
if use_bias:
# To test runtime bias setting
out = ops.onednn_mm(handler, a, None)
baseline = torch.nn.functional.linear(a.float(), b.float(),
None).to(dtype=a.dtype)
torch.testing.assert_close(out, baseline)
@pytest.mark.parametrize("n,k", NK_FACTORS)
@pytest.mark.parametrize("m_list", M_FACTORS)
@pytest.mark.parametrize("per_tensor_a_scale", [True, False])
@ -142,3 +185,30 @@ def test_onednn_int8_scaled_gemm(
use_azp=use_azp,
out_dtype=output_type,
)
@pytest.mark.parametrize("n,k", NK_FACTORS)
@pytest.mark.parametrize("m_list", M_FACTORS)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("use_stride", [True, False])
@pytest.mark.parametrize("dtype", DTYPE)
@pytest.mark.parametrize("primitive_cache_size", CACHE_SIZES)
def test_onednn_gemm(
n: int,
k: int,
m_list: tuple[int],
use_bias: bool,
use_stride: bool,
dtype: torch.dtype,
primitive_cache_size: int,
):
for m in m_list:
onednn_gemm_test_helper(
primitive_cache_size=primitive_cache_size,
m=m,
n=n,
k=k,
use_bias=use_bias,
use_stride=use_stride,
dtype=dtype,
)

View File

@ -1928,6 +1928,35 @@ class CPUDNNLGEMMHandler:
torch.ops._C.release_dnnl_matmul_handler(self.handler)
if hasattr(torch.ops._C, "create_onednn_mm_handler"):
_supports_onednn = True
else:
_supports_onednn = False
def create_onednn_mm(
weight: torch.Tensor, # [K, N]
primitive_cache_size: int = 128,
) -> CPUDNNLGEMMHandler:
handler = CPUDNNLGEMMHandler()
handler.k, handler.n = weight.size()
handler.handler = torch.ops._C.create_onednn_mm_handler(
weight, primitive_cache_size)
return handler
def onednn_mm(
dnnl_handler: CPUDNNLGEMMHandler,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype)
torch.ops._C.onednn_mm(output, x.reshape(-1, dnnl_handler.k), bias,
dnnl_handler.handler)
return output
def create_onednn_scaled_mm(
weight: torch.Tensor, # [K, N]
weight_scales: torch.Tensor,

View File

@ -9,7 +9,6 @@ import torch
import torch.nn as nn
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import envs
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
@ -200,26 +199,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
set_weight_attrs(weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# special postprocessing for CPU SGL
if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL:
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
N, K = layer.weight.size()
dtype = layer.weight.dtype
if check_cpu_sgl_kernel(N, K, dtype):
packed_weight = torch.ops._C.convert_weight_packed(
layer.weight)
assert packed_weight.size() == layer.weight.size()
layer.weight.copy_(packed_weight)
if layer.bias is not None:
layer.bias = Parameter(layer.bias.to(torch.float32),
requires_grad=False)
layer.use_cpu_sgl = True
else:
logger.warning(
"CPU SGL kernels require Intel AMX support,"
" bf16/fp16/int8 weight, IC and OC are divisible by "
"32 and 16.")
layer.use_cpu_sgl = False
if current_platform.is_cpu():
from vllm.model_executor.layers.utils import (
dispatch_cpu_unquantized_gemm)
dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
def apply(self,
layer: torch.nn.Module,

View File

@ -142,20 +142,49 @@ direct_register_custom_op(
)
def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype):
def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype) -> bool:
return (torch._C._cpu._is_amx_tile_supported()
and (dtype in (torch.bfloat16, torch.int8)) and k % 32 == 0
and n % 16 == 0)
def dispatch_cpu_unquantized_gemm(
layer: torch.nn.Module,
remove_weight: bool,
) -> None:
N, K = layer.weight.size()
dtype = layer.weight.dtype
if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype):
packed_weight = torch.ops._C.convert_weight_packed(layer.weight)
if getattr(layer, "bias", None) is not None:
bias_f32 = layer.bias.to(torch.float32)
else:
bias_f32 = None
layer.cpu_linear = (
lambda x, weight, bias: torch.ops._C.weight_packed_linear(
x, packed_weight, bias_f32
if bias is not None else None, True))
if remove_weight:
layer.weight = torch.nn.Parameter(torch.empty(0),
requires_grad=False)
elif ops._supports_onednn:
origin_weight = layer.weight
if remove_weight:
layer.weight = torch.nn.Parameter(torch.empty(0),
requires_grad=False)
handler = ops.create_onednn_mm(origin_weight.t(), 32)
layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(
handler, x, bias)
else:
layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear(
x, weight, bias)
def cpu_unquantized_gemm(layer: torch.nn.Module,
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None):
if getattr(layer, "use_cpu_sgl", False):
return torch.ops._C.weight_packed_linear(x, weight, bias, True)
else:
return torch.nn.functional.linear(x, weight, bias)
return layer.cpu_linear(x, weight, bias)
def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]:

View File

@ -40,6 +40,12 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if current_platform.is_cpu():
from vllm.model_executor.layers.utils import (
dispatch_cpu_unquantized_gemm)
dispatch_cpu_unquantized_gemm(layer, remove_weight=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,