mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[CPU] Refactor CPU unquantized linear (#24150)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
@ -22,6 +22,23 @@ void release_dnnl_matmul_handler(int64_t handler) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) {
|
||||
this->realloc(allocation_unit * 128);
|
||||
}
|
||||
|
||||
void DNNLScratchPadManager::realloc(size_t new_size) {
|
||||
new_size = round(new_size);
|
||||
if (new_size > size_) {
|
||||
ptr_ = std::aligned_alloc(64, new_size);
|
||||
size_ = new_size;
|
||||
}
|
||||
}
|
||||
|
||||
DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() {
|
||||
static DNNLScratchPadManager manager;
|
||||
return &manager;
|
||||
}
|
||||
|
||||
template <typename KT, typename VT>
|
||||
class DNNLPrimitiveCache {
|
||||
public:
|
||||
@ -166,6 +183,23 @@ struct hash<W8A8MatMulPrimitiveHandler::MSizeCacheKey> {
|
||||
hash<int>()(static_cast<int>(val.bias_type));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hash<MatMulPrimitiveHandler::ClassMatmulCacheKey> {
|
||||
size_t operator()(
|
||||
const MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const {
|
||||
return hash<dnnl_dim_t>()(val.b_n_size) ^ hash<dnnl_dim_t>()(val.b_k_size);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hash<MatMulPrimitiveHandler::MSizeCacheKey> {
|
||||
size_t operator()(const MatMulPrimitiveHandler::MSizeCacheKey& val) const {
|
||||
return hash<dnnl_dim_t>()(val.a_m_size) ^
|
||||
hash<dnnl_dim_t>()(val.a_m_stride) ^ hash<bool>()(val.use_bias) ^
|
||||
hash<int>()(static_cast<int>(val.bias_type));
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
|
||||
bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
|
||||
@ -181,6 +215,17 @@ bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l,
|
||||
l.bias_type == r.bias_type;
|
||||
}
|
||||
|
||||
bool operator==(const MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
|
||||
const MatMulPrimitiveHandler::ClassMatmulCacheKey& r) {
|
||||
return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size;
|
||||
}
|
||||
|
||||
bool operator==(const MatMulPrimitiveHandler::MSizeCacheKey& l,
|
||||
const MatMulPrimitiveHandler::MSizeCacheKey& r) {
|
||||
return l.a_m_size == r.a_m_size && l.a_m_stride == r.a_m_stride &&
|
||||
l.use_bias == r.use_bias && l.bias_type == r.bias_type;
|
||||
}
|
||||
|
||||
static std::shared_ptr<W8A8MatMulPrimitiveHandler::MSizeCache>
|
||||
get_w8a8_class_primitive_cache(
|
||||
const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key,
|
||||
@ -239,6 +284,11 @@ void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
|
||||
}
|
||||
|
||||
dnnl::matmul matmul = get_matmul_cache(args);
|
||||
|
||||
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(5);
|
||||
scratchpad_storage->set_data_handle(
|
||||
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
|
||||
|
||||
matmul.execute(default_stream(), memory_cache_);
|
||||
default_stream().wait();
|
||||
}
|
||||
@ -257,6 +307,8 @@ dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
|
||||
|
||||
return m_size_cache_->get_or_create(key, [&]() {
|
||||
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
|
||||
auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager();
|
||||
manager->realloc(desc.scratchpad_desc().get_size());
|
||||
return dnnl::matmul(desc);
|
||||
});
|
||||
}
|
||||
@ -300,6 +352,11 @@ void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
|
||||
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
||||
default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get());
|
||||
|
||||
memory_cache_[DNNL_ARG_SCRATCHPAD] =
|
||||
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
||||
default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(5, memory_cache_[DNNL_ARG_SCRATCHPAD].get());
|
||||
}
|
||||
|
||||
dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
|
||||
@ -319,6 +376,9 @@ dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
|
||||
dnnl::memory::format_tag::ab);
|
||||
|
||||
dnnl::primitive_attr attr;
|
||||
|
||||
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
// For PER_TOKEN, scales will be applied in outside epilogue
|
||||
if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
attr.set_scales_mask(DNNL_ARG_SRC, 0);
|
||||
@ -344,3 +404,120 @@ dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
|
||||
attr);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args)
|
||||
: DNNLMatMulPrimitiveHandler(
|
||||
static_cast<DNNLMatMulPrimitiveHandler::Args>(args), args.ab_type),
|
||||
m_size_cache_(nullptr) {
|
||||
assert(ab_type_ == dnnl::memory::data_type::f32 ||
|
||||
ab_type_ == dnnl::memory::data_type::bf16 ||
|
||||
ab_type_ == dnnl::memory::data_type::f16);
|
||||
prepack_weight(args.b_ptr,
|
||||
create_primitive_desc(
|
||||
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
|
||||
.a_m_stride = DNNL_RUNTIME_DIM_VAL,
|
||||
.use_bias = false,
|
||||
.bias_type = dnnl::memory::data_type::undef},
|
||||
true)
|
||||
.weights_desc());
|
||||
init_runtime_memory_cache(args);
|
||||
}
|
||||
|
||||
static std::shared_ptr<MatMulPrimitiveHandler::MSizeCache>
|
||||
get_matul_class_primitive_cache(
|
||||
const MatMulPrimitiveHandler::ClassMatmulCacheKey& key,
|
||||
int64_t cache_size) {
|
||||
static MatMulPrimitiveHandler::ClassMatmulCache cache(128);
|
||||
assert(cache_size > 0);
|
||||
return cache.get_or_create(key, [&]() {
|
||||
return std::make_shared<MatMulPrimitiveHandler::MSizeCache>(cache_size);
|
||||
});
|
||||
}
|
||||
|
||||
void MatMulPrimitiveHandler::execute(ExecArgs& args) {
|
||||
auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0);
|
||||
auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1);
|
||||
a_storage->set_data_handle((void*)args.a_ptr);
|
||||
a_mem_desc->dims[0] = args.a_m_size;
|
||||
a_mem_desc->format_desc.blocking.strides[0] = args.a_m_stride;
|
||||
c_storage->set_data_handle((void*)args.c_ptr);
|
||||
c_mem_desc->dims[0] = args.a_m_size;
|
||||
|
||||
if (args.use_bias) {
|
||||
auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(2);
|
||||
bias_storage->set_data_handle((void*)args.bias_ptr);
|
||||
}
|
||||
|
||||
dnnl::matmul matmul = get_matmul_cache(args);
|
||||
|
||||
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3);
|
||||
scratchpad_storage->set_data_handle(
|
||||
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
|
||||
|
||||
matmul.execute(default_stream(), memory_cache_);
|
||||
default_stream().wait();
|
||||
}
|
||||
|
||||
dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache(
|
||||
const MSizeCacheKey& key) {
|
||||
if (m_size_cache_.get() == nullptr) {
|
||||
ClassMatmulCacheKey key = {.b_n_size = b_n_size_, .b_k_size = b_k_size_};
|
||||
m_size_cache_ = get_matul_class_primitive_cache(key, primitive_cache_size_);
|
||||
}
|
||||
return m_size_cache_->get_or_create(key, [&]() {
|
||||
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
|
||||
auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager();
|
||||
manager->realloc(desc.scratchpad_desc().get_size());
|
||||
return dnnl::matmul(desc);
|
||||
});
|
||||
}
|
||||
|
||||
dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc(
|
||||
const MSizeCacheKey& key, bool first_time) {
|
||||
dnnl::memory::desc a_md;
|
||||
dnnl::memory::desc b_md;
|
||||
if (first_time) {
|
||||
a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_,
|
||||
dnnl::memory::format_tag::ab);
|
||||
b_md = dnnl::memory::desc({b_k_size_, b_n_size_}, b_type_,
|
||||
dnnl::memory::format_tag::any);
|
||||
} else {
|
||||
a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_,
|
||||
{key.a_m_stride, 1});
|
||||
b_md = b_target_mem_desc_;
|
||||
}
|
||||
dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
|
||||
dnnl::memory::format_tag::ab);
|
||||
|
||||
dnnl::primitive_attr attr;
|
||||
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
if (key.use_bias) {
|
||||
dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
|
||||
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
|
||||
c_md, attr);
|
||||
} else {
|
||||
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
|
||||
attr);
|
||||
}
|
||||
}
|
||||
|
||||
void MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
|
||||
memory_cache_[DNNL_ARG_SRC] = dnnl::memory(
|
||||
{{1, b_k_size_}, b_type_, {b_k_size_, 1}}, default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get());
|
||||
memory_cache_[DNNL_ARG_DST] =
|
||||
dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab},
|
||||
default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());
|
||||
|
||||
memory_cache_[DNNL_ARG_BIAS] =
|
||||
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
||||
default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(2, memory_cache_[DNNL_ARG_BIAS].get());
|
||||
|
||||
memory_cache_[DNNL_ARG_SCRATCHPAD] =
|
||||
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
||||
default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(3, memory_cache_[DNNL_ARG_SCRATCHPAD].get());
|
||||
}
|
||||
|
@ -59,6 +59,30 @@ constexpr inline dnnl::memory::data_type get_dnnl_type() {
|
||||
return DNNLType<std::decay_t<T>>::type;
|
||||
}
|
||||
|
||||
class DNNLScratchPadManager {
|
||||
public:
|
||||
static constexpr size_t allocation_unit = 4 * 1024 * 1024; // 4KB
|
||||
|
||||
static DNNLScratchPadManager* get_dnnl_scratchpad_manager();
|
||||
|
||||
DNNLScratchPadManager();
|
||||
|
||||
template <typename T>
|
||||
T* get_data() {
|
||||
return reinterpret_cast<T*>(ptr_);
|
||||
}
|
||||
|
||||
static size_t round(size_t size) {
|
||||
return ((size + allocation_unit - 1) / allocation_unit) * allocation_unit;
|
||||
}
|
||||
|
||||
void realloc(size_t new_size);
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
void* ptr_;
|
||||
};
|
||||
|
||||
class DNNLMatMulPrimitiveHandler {
|
||||
public:
|
||||
virtual ~DNNLMatMulPrimitiveHandler() = default;
|
||||
@ -166,4 +190,54 @@ class W8A8MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler {
|
||||
std::shared_ptr<MSizeCache> m_size_cache_;
|
||||
};
|
||||
|
||||
class MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler {
|
||||
public:
|
||||
struct Args : public DNNLMatMulPrimitiveHandler::Args {
|
||||
dnnl::memory::data_type ab_type;
|
||||
};
|
||||
|
||||
struct ClassMatmulCacheKey {
|
||||
dnnl_dim_t b_n_size;
|
||||
dnnl_dim_t b_k_size;
|
||||
|
||||
friend bool operator==(const ClassMatmulCacheKey& l,
|
||||
const ClassMatmulCacheKey& r);
|
||||
};
|
||||
|
||||
struct MSizeCacheKey {
|
||||
dnnl_dim_t a_m_size;
|
||||
dnnl_dim_t a_m_stride;
|
||||
bool use_bias;
|
||||
dnnl::memory::data_type bias_type;
|
||||
|
||||
friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r);
|
||||
};
|
||||
|
||||
using MSizeCache = DNNLPrimitiveCache<MSizeCacheKey, dnnl::matmul>;
|
||||
using ClassMatmulCache =
|
||||
DNNLPrimitiveCache<ClassMatmulCacheKey, std::shared_ptr<MSizeCache>>;
|
||||
|
||||
struct ExecArgs : public MSizeCacheKey {
|
||||
const void* a_ptr;
|
||||
const void* bias_ptr;
|
||||
void* c_ptr;
|
||||
};
|
||||
|
||||
public:
|
||||
MatMulPrimitiveHandler(const Args& args);
|
||||
|
||||
void execute(ExecArgs& args);
|
||||
|
||||
private:
|
||||
dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key,
|
||||
bool first_time);
|
||||
|
||||
void init_runtime_memory_cache(const Args& args);
|
||||
|
||||
dnnl::matmul get_matmul_cache(const MSizeCacheKey& key);
|
||||
|
||||
private:
|
||||
std::shared_ptr<MSizeCache> m_size_cache_;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
@ -379,6 +379,7 @@ void onednn_scaled_mm(
|
||||
exec_args.a_ptr = a.data_ptr<int8_t>();
|
||||
exec_args.a_m_size = a.size(0);
|
||||
exec_args.bias_ptr = nullptr;
|
||||
exec_args.bias_type = get_dnnl_type<void>();
|
||||
exec_args.use_bias = false;
|
||||
exec_args.a_scales_ptr = nullptr;
|
||||
exec_args.a_zero_points_ptr = nullptr;
|
||||
@ -492,3 +493,56 @@ void dynamic_scaled_int8_quant(
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
int64_t create_onednn_mm_handler(const torch::Tensor& b,
|
||||
int64_t primitive_cache_size) {
|
||||
TORCH_CHECK(b.dim() == 2);
|
||||
|
||||
MatMulPrimitiveHandler::Args args;
|
||||
args.primitive_cache_size = primitive_cache_size;
|
||||
|
||||
args.b_k_size = b.size(0);
|
||||
args.b_k_stride = b.stride(0);
|
||||
args.b_n_size = b.size(1);
|
||||
args.b_n_stride = b.stride(1);
|
||||
args.b_ptr = b.data_ptr();
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(b.scalar_type(), "create_onednn_mm_handler",
|
||||
[&] {
|
||||
args.c_type = get_dnnl_type<scalar_t>();
|
||||
args.ab_type = get_dnnl_type<scalar_t>();
|
||||
});
|
||||
|
||||
return reinterpret_cast<int64_t>(new MatMulPrimitiveHandler(args));
|
||||
}
|
||||
|
||||
void onednn_mm(torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a, // [M, IC], row-major
|
||||
const std::optional<torch::Tensor>& bias, int64_t handler) {
|
||||
CPU_KERNEL_GUARD_IN(onednn_mm)
|
||||
TORCH_CHECK(a.dim() == 2);
|
||||
TORCH_CHECK(a.stride(-1) == 1);
|
||||
TORCH_CHECK(c.is_contiguous());
|
||||
MatMulPrimitiveHandler* ptr =
|
||||
reinterpret_cast<MatMulPrimitiveHandler*>(handler);
|
||||
|
||||
MatMulPrimitiveHandler::ExecArgs exec_args;
|
||||
exec_args.a_m_size = a.size(0);
|
||||
exec_args.a_m_stride = a.stride(0);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(a.scalar_type(), "onednn_mm", [&] {
|
||||
if (bias.has_value()) {
|
||||
exec_args.use_bias = true;
|
||||
exec_args.bias_type = get_dnnl_type<scalar_t>();
|
||||
exec_args.bias_ptr = bias->data_ptr<scalar_t>();
|
||||
} else {
|
||||
exec_args.use_bias = false;
|
||||
exec_args.bias_type = get_dnnl_type<void>();
|
||||
exec_args.bias_ptr = nullptr;
|
||||
}
|
||||
exec_args.a_ptr = a.data_ptr<scalar_t>();
|
||||
exec_args.c_ptr = c.data_ptr<scalar_t>();
|
||||
|
||||
ptr->execute(exec_args);
|
||||
});
|
||||
}
|
||||
|
@ -21,6 +21,12 @@ void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
|
||||
const std::optional<torch::Tensor>& bias,
|
||||
int64_t handler);
|
||||
|
||||
int64_t create_onednn_mm_handler(const torch::Tensor& b,
|
||||
int64_t primitive_cache_size);
|
||||
|
||||
void onednn_mm(torch::Tensor& c, const torch::Tensor& a,
|
||||
const std::optional<torch::Tensor>& bias, int64_t handler);
|
||||
|
||||
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
||||
torch::Tensor& kv_cache, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens);
|
||||
@ -153,6 +159,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("release_dnnl_matmul_handler(int handler) -> ()",
|
||||
&release_dnnl_matmul_handler);
|
||||
|
||||
// Create oneDNN GEMM handler
|
||||
ops.def(
|
||||
"create_onednn_mm_handler(Tensor b, int "
|
||||
"primitive_cache_size) -> int",
|
||||
&create_onednn_mm_handler);
|
||||
|
||||
// oneDNN GEMM
|
||||
ops.def(
|
||||
"onednn_mm(Tensor! c, Tensor a, Tensor? bias, "
|
||||
"int handler) -> ()");
|
||||
ops.impl("onednn_mm", torch::kCPU, &onednn_mm);
|
||||
|
||||
// Create oneDNN W8A8 handler
|
||||
ops.def(
|
||||
"create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType "
|
||||
|
@ -111,6 +111,49 @@ def onednn_int8_gemm_test_helper(primitive_cache_size: int,
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
def onednn_gemm_test_helper(primitive_cache_size: int,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
use_bias: bool,
|
||||
use_stride: bool,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
device: str = "cpu"):
|
||||
if use_stride:
|
||||
a = torch.rand((m, 2 * k), dtype=dtype, device=device) * 1.5
|
||||
a = a[:, :k]
|
||||
else:
|
||||
a = torch.rand((m, k), dtype=dtype, device=device) * 1.5
|
||||
|
||||
b = torch.rand((n, k), dtype=dtype, device=device) * 1.5
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((n, ), device=device, dtype=dtype) * 5
|
||||
bias_f32 = bias.float()
|
||||
else:
|
||||
bias = None
|
||||
bias_f32 = None
|
||||
|
||||
handler = ops.create_onednn_mm(
|
||||
b.t(),
|
||||
primitive_cache_size,
|
||||
)
|
||||
|
||||
out = ops.onednn_mm(handler, a, bias)
|
||||
baseline = torch.nn.functional.linear(a.float(), b.float(),
|
||||
bias_f32).to(dtype=a.dtype)
|
||||
|
||||
torch.testing.assert_close(out, baseline)
|
||||
|
||||
if use_bias:
|
||||
# To test runtime bias setting
|
||||
out = ops.onednn_mm(handler, a, None)
|
||||
baseline = torch.nn.functional.linear(a.float(), b.float(),
|
||||
None).to(dtype=a.dtype)
|
||||
|
||||
torch.testing.assert_close(out, baseline)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n,k", NK_FACTORS)
|
||||
@pytest.mark.parametrize("m_list", M_FACTORS)
|
||||
@pytest.mark.parametrize("per_tensor_a_scale", [True, False])
|
||||
@ -142,3 +185,30 @@ def test_onednn_int8_scaled_gemm(
|
||||
use_azp=use_azp,
|
||||
out_dtype=output_type,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n,k", NK_FACTORS)
|
||||
@pytest.mark.parametrize("m_list", M_FACTORS)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("use_stride", [True, False])
|
||||
@pytest.mark.parametrize("dtype", DTYPE)
|
||||
@pytest.mark.parametrize("primitive_cache_size", CACHE_SIZES)
|
||||
def test_onednn_gemm(
|
||||
n: int,
|
||||
k: int,
|
||||
m_list: tuple[int],
|
||||
use_bias: bool,
|
||||
use_stride: bool,
|
||||
dtype: torch.dtype,
|
||||
primitive_cache_size: int,
|
||||
):
|
||||
for m in m_list:
|
||||
onednn_gemm_test_helper(
|
||||
primitive_cache_size=primitive_cache_size,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
use_bias=use_bias,
|
||||
use_stride=use_stride,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -1928,6 +1928,35 @@ class CPUDNNLGEMMHandler:
|
||||
torch.ops._C.release_dnnl_matmul_handler(self.handler)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "create_onednn_mm_handler"):
|
||||
_supports_onednn = True
|
||||
else:
|
||||
_supports_onednn = False
|
||||
|
||||
|
||||
def create_onednn_mm(
|
||||
weight: torch.Tensor, # [K, N]
|
||||
primitive_cache_size: int = 128,
|
||||
) -> CPUDNNLGEMMHandler:
|
||||
handler = CPUDNNLGEMMHandler()
|
||||
handler.k, handler.n = weight.size()
|
||||
handler.handler = torch.ops._C.create_onednn_mm_handler(
|
||||
weight, primitive_cache_size)
|
||||
return handler
|
||||
|
||||
|
||||
def onednn_mm(
|
||||
dnnl_handler: CPUDNNLGEMMHandler,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype)
|
||||
torch.ops._C.onednn_mm(output, x.reshape(-1, dnnl_handler.k), bias,
|
||||
dnnl_handler.handler)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def create_onednn_scaled_mm(
|
||||
weight: torch.Tensor, # [K, N]
|
||||
weight_scales: torch.Tensor,
|
||||
|
@ -9,7 +9,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm import envs
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
@ -200,26 +199,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# special postprocessing for CPU SGL
|
||||
if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL:
|
||||
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
|
||||
N, K = layer.weight.size()
|
||||
dtype = layer.weight.dtype
|
||||
if check_cpu_sgl_kernel(N, K, dtype):
|
||||
packed_weight = torch.ops._C.convert_weight_packed(
|
||||
layer.weight)
|
||||
assert packed_weight.size() == layer.weight.size()
|
||||
layer.weight.copy_(packed_weight)
|
||||
if layer.bias is not None:
|
||||
layer.bias = Parameter(layer.bias.to(torch.float32),
|
||||
requires_grad=False)
|
||||
layer.use_cpu_sgl = True
|
||||
else:
|
||||
logger.warning(
|
||||
"CPU SGL kernels require Intel AMX support,"
|
||||
" bf16/fp16/int8 weight, IC and OC are divisible by "
|
||||
"32 and 16.")
|
||||
layer.use_cpu_sgl = False
|
||||
if current_platform.is_cpu():
|
||||
from vllm.model_executor.layers.utils import (
|
||||
dispatch_cpu_unquantized_gemm)
|
||||
dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
|
@ -142,20 +142,49 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype):
|
||||
def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype) -> bool:
|
||||
return (torch._C._cpu._is_amx_tile_supported()
|
||||
and (dtype in (torch.bfloat16, torch.int8)) and k % 32 == 0
|
||||
and n % 16 == 0)
|
||||
|
||||
|
||||
def dispatch_cpu_unquantized_gemm(
|
||||
layer: torch.nn.Module,
|
||||
remove_weight: bool,
|
||||
) -> None:
|
||||
N, K = layer.weight.size()
|
||||
dtype = layer.weight.dtype
|
||||
if envs.VLLM_CPU_SGL_KERNEL and check_cpu_sgl_kernel(N, K, dtype):
|
||||
packed_weight = torch.ops._C.convert_weight_packed(layer.weight)
|
||||
if getattr(layer, "bias", None) is not None:
|
||||
bias_f32 = layer.bias.to(torch.float32)
|
||||
else:
|
||||
bias_f32 = None
|
||||
layer.cpu_linear = (
|
||||
lambda x, weight, bias: torch.ops._C.weight_packed_linear(
|
||||
x, packed_weight, bias_f32
|
||||
if bias is not None else None, True))
|
||||
if remove_weight:
|
||||
layer.weight = torch.nn.Parameter(torch.empty(0),
|
||||
requires_grad=False)
|
||||
elif ops._supports_onednn:
|
||||
origin_weight = layer.weight
|
||||
if remove_weight:
|
||||
layer.weight = torch.nn.Parameter(torch.empty(0),
|
||||
requires_grad=False)
|
||||
handler = ops.create_onednn_mm(origin_weight.t(), 32)
|
||||
layer.cpu_linear = lambda x, weight, bias: ops.onednn_mm(
|
||||
handler, x, bias)
|
||||
else:
|
||||
layer.cpu_linear = lambda x, weight, bias: torch.nn.functional.linear(
|
||||
x, weight, bias)
|
||||
|
||||
|
||||
def cpu_unquantized_gemm(layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None):
|
||||
if getattr(layer, "use_cpu_sgl", False):
|
||||
return torch.ops._C.weight_packed_linear(x, weight, bias, True)
|
||||
else:
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
return layer.cpu_linear(x, weight, bias)
|
||||
|
||||
|
||||
def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]:
|
||||
|
@ -40,6 +40,12 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if current_platform.is_cpu():
|
||||
from vllm.model_executor.layers.utils import (
|
||||
dispatch_cpu_unquantized_gemm)
|
||||
dispatch_cpu_unquantized_gemm(layer, remove_weight=False)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
|
Reference in New Issue
Block a user