mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[WOQ] Add XPU kernel for _weight_int8pack_mm (#160938)
Summary: This issue proposes implementing a XPU kernel for aten._weight_int8pack_mm, a weight-only quantized (WOQ) linear operation that is currently only supported on CPU and CUDA. Motivation: Same as https://github.com/pytorch/pytorch/pull/159325. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160938 Approved by: https://github.com/EikanWang, https://github.com/ZhiweiYan-96, https://github.com/liangan1, https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
0815091d86
commit
ab5086a7ae
@ -559,4 +559,60 @@ Tensor _int_mm_xpu(const Tensor& self, const Tensor& mat2) {
|
||||
at::empty({self.size(0), mat2.size(1)}, self.options().dtype(at::kInt));
|
||||
return _int_mm_out_xpu(self, mat2, result);
|
||||
}
|
||||
|
||||
Tensor _weight_int8pack_mm_xpu(
|
||||
const Tensor& A,
|
||||
const Tensor& B,
|
||||
const Tensor& scales) {
|
||||
auto M = A.size(0);
|
||||
auto N = B.size(0);
|
||||
auto K = A.size(1);
|
||||
|
||||
TORCH_CHECK(
|
||||
A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat,
|
||||
" : expect A to be either 32-bit or 16-bit float tensor.");
|
||||
TORCH_CHECK(A.dim() == 2, __func__, " : expect A to be 2D tensor.");
|
||||
TORCH_CHECK(
|
||||
A.stride(1) == 1, " : A must be contiguous on the last dimension.");
|
||||
TORCH_CHECK(B.dtype() == kChar, " : expect B to be int8 tensor.");
|
||||
TORCH_CHECK(B.is_contiguous(), " : expect B to be contiguous.");
|
||||
TORCH_CHECK(B.size(1) == K, " : expect B.size(1) == ", K);
|
||||
|
||||
TORCH_CHECK(
|
||||
scales.dim() == 1 && scales.size(0) == N,
|
||||
" : expect scales to be 1d tensor with size ",
|
||||
N);
|
||||
|
||||
auto C = at::empty({M, N}, A.options());
|
||||
|
||||
// --- Launch kernel ---
|
||||
Tensor bias = at::Tensor();
|
||||
Tensor mat2_zero_points = at::Tensor();
|
||||
Tensor non_const_scales = scales;
|
||||
auto post_op_args = torch::List<std::optional<at::Scalar>>();
|
||||
|
||||
at::native::onednn::quantized_matmul(
|
||||
A.contiguous(),
|
||||
1.0,
|
||||
0,
|
||||
B,
|
||||
non_const_scales,
|
||||
mat2_zero_points,
|
||||
bias,
|
||||
C,
|
||||
1.0,
|
||||
0,
|
||||
C.scalar_type(),
|
||||
/*other*/ std::nullopt,
|
||||
/*other scale*/ 1.0,
|
||||
/*other zp*/ 0,
|
||||
/*binary post op*/ "none",
|
||||
/*binary alpha*/ 1.0,
|
||||
/*post_op_name*/ "none",
|
||||
post_op_args,
|
||||
/*post_op_algorithm*/ "none",
|
||||
/*m2_trans*/ false);
|
||||
|
||||
return C;
|
||||
}
|
||||
} // namespace at::native
|
||||
|
@ -110,8 +110,9 @@ void quantized_matmul(
|
||||
// [Note] Quantized Matrix Multiplication at XPU
|
||||
// The following code integrates oneDNN quantized gemm. The quantization
|
||||
// config we support:
|
||||
// activation: s8&u8; per tensor calibrated; symmetric&asymmetric
|
||||
// weight: s8; per_tensor/per_channel calibrated; symmetric
|
||||
// activation: s8, u8, fp16, bf16, fp32; per tensor calibrated;
|
||||
// symmetric&asymmetric weight: s8; per_tensor/per_channel calibrated;
|
||||
// symmetric
|
||||
auto attr = Attr(static_cast<float>(1.0 / output_scale), output_zero_point);
|
||||
construct_attr_by_post_op(
|
||||
binary_post_op,
|
||||
|
@ -4243,6 +4243,7 @@
|
||||
CPU: _weight_int8pack_mm_cpu
|
||||
CUDA: _weight_int8pack_mm_cuda
|
||||
MPS: _weight_int8pack_mm_mps
|
||||
XPU: _weight_int8pack_mm_xpu
|
||||
|
||||
- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor
|
||||
python_module: sparse
|
||||
|
@ -19,8 +19,12 @@ from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
instantiate_device_type_tests,
|
||||
onlyNativeDeviceTypes,
|
||||
precisionOverride,
|
||||
)
|
||||
from torch.testing._internal.common_quantization import (
|
||||
_dynamically_quantize_per_channel,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
iter_indices,
|
||||
parametrize,
|
||||
@ -1446,6 +1450,50 @@ def forward(self, x_1, w_1):
|
||||
return out_dtype""",
|
||||
)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@parametrize("m", [32, 64])
|
||||
@parametrize("k", [32, 64])
|
||||
@parametrize("n", [48, 64])
|
||||
@parametrize("compile", [True, False])
|
||||
@parametrize("slice", [True, False])
|
||||
def test__int8_mm(self, device, m, k, n, compile, slice):
|
||||
torch.manual_seed(1)
|
||||
if slice:
|
||||
# logits are generated from LLaMA LM head like this -
|
||||
# the activation to LM head is a slice of final hidden state
|
||||
# of shape (batch_size, sequence_length, hidden dim),
|
||||
# but is non-contiguous
|
||||
# Using arbitrary batch-size here, since it'd be converted to 2D
|
||||
batch_size = 4
|
||||
a = torch.rand((batch_size, m, k), dtype=torch.bfloat16, device=device)
|
||||
# Make a non-contiguous
|
||||
a = a[:, -1:, :]
|
||||
a = a.view(-1, a.size(-1))
|
||||
else:
|
||||
a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
|
||||
|
||||
b = torch.rand((n, k), dtype=torch.bfloat16, device=device)
|
||||
|
||||
def convert_weight_to_int8pack(b):
|
||||
b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
|
||||
b, -128, 127, torch.int8
|
||||
)
|
||||
return b_int8pack, b_scales
|
||||
|
||||
def weight_int8pack_mm(a, b_int8pack, b_scales):
|
||||
return torch._weight_int8pack_mm(a, b_int8pack, b_scales)
|
||||
|
||||
b_int8pack, b_scales = convert_weight_to_int8pack(b)
|
||||
if compile:
|
||||
mod = torch.compile(weight_int8pack_mm)
|
||||
else:
|
||||
mod = weight_int8pack_mm
|
||||
res = mod(a, b_int8pack, b_scales)
|
||||
ref = torch.mm(a, b.transpose(0, 1))
|
||||
|
||||
mean_err = ((res - ref).abs() / ref).mean()
|
||||
self.assertTrue(mean_err < 0.05)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu", allow_xpu=True)
|
||||
|
||||
|
@ -19,6 +19,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attent
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__weight_int4pack_mm_with_scales_and_zeros(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScale, AtenTensorHandle qZeros, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu__weight_int8pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scales, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_abs(AtenTensorHandle self, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_add_Scalar(AtenTensorHandle self, double other, double alpha, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_addbmm(AtenTensorHandle self, AtenTensorHandle batch1, AtenTensorHandle batch2, double beta, double alpha, AtenTensorHandle* ret0);
|
||||
|
Reference in New Issue
Block a user