[AOTI] add C shim for _weight_int8pack_mm (#138691)

Fixes the error of running WOQ-INT8 LLaMA:
```
E           In file included from /home/user/inductor/pytorch/torch/include/torch/csrc/inductor/aoti_runtime/arrayref_tensor.h:3,
E                            from /tmp/torchinductor_user/sw/csw5gfmlzp5iooqvfwl2gwn574frwdpmtrx2y6nu2m6x76d3xcux.cpp:4:
E           /tmp/torchinductor_user/sw/csw5gfmlzp5iooqvfwl2gwn574frwdpmtrx2y6nu2m6x76d3xcux.cpp: In function ‘void inductor_entry_impl(AtenTensorOpaque**, AtenTensorOpaque**)’:
E           /tmp/torchinductor_user/sw/csw5gfmlzp5iooqvfwl2gwn574frwdpmtrx2y6nu2m6x76d3xcux.cpp:117:33: error: ‘aoti_torch_cpu__weight_int8pack_mm’ was not declared in this scope
E             117 |     AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu__weight_int8pack_mm(convert_arrayref_tensor_to_tensor(arg8_1), _frozen_param0, _frozen_param1, &buf0_handle));
E                 |                                 ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138691
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5, https://github.com/desertfire
This commit is contained in:
Wu, Chunyuan
2024-10-29 16:31:27 +00:00
committed by PyTorch MergeBot
parent 69d401d010
commit 9af1816974
4 changed files with 28 additions and 0 deletions

View File

@ -201,6 +201,7 @@ if RUN_CPU:
BaseTest("test_adding_tensor_offsets"),
BaseTest("test_inductor_layout_optimization_input_mutations"),
BaseTest("test_int_div", "", test_cpu_repro.CPUReproTests()),
BaseTest("test_int8_weight_only_quant"),
BaseTest("test_linear1"),
BaseTest("test_linear2"),
*[

View File

@ -72,6 +72,9 @@ from torch.testing._internal.common_device_type import (
expectedFailureXPU,
)
from torch.testing._internal.common_dtype import all_types, get_all_dtypes
from torch.testing._internal.common_quantization import (
_dynamically_quantize_per_channel,
)
from torch.testing._internal.common_utils import (
DeterministicGuard,
instantiate_parametrized_tests,
@ -2154,6 +2157,28 @@ class CommonTemplate:
packed = torch.cat([data, scales, offsets], dim=-1)
self.common(fn, [packed])
@skipCUDAIf(True, "No _weight_int8pack_mm implementation on CUDA")
def test_int8_weight_only_quant(self):
def convert_weight_to_int8pack(b):
b_int8pack, b_scales, _ = _dynamically_quantize_per_channel(
b, -128, 127, torch.int8
)
return b_int8pack, b_scales
def fn(a, b_int8pack, b_scales, c):
res = torch._weight_int8pack_mm(a, b_int8pack, b_scales)
res = res + c
return res
m = 32
k = 32
n = 48
a = torch.rand((m, k), dtype=torch.bfloat16)
b = torch.rand((n, k), dtype=torch.bfloat16)
c = torch.rand((m, n), dtype=torch.bfloat16)
b_int8pack, b_scales = convert_weight_to_int8pack(b)
self.common(fn, (a, b_int8pack, b_scales, c))
def test_expanded_reduction(self):
def fn(x, y):
z = x * y

View File

@ -35,6 +35,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attent
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__weight_int8pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scales, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_adaptive_max_pool2d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_adaptive_max_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, AtenTensorHandle indices, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_adaptive_max_pool3d(AtenTensorHandle self, const int64_t* output_size, int64_t output_size_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1);

View File

@ -146,5 +146,6 @@ inductor_fallback_ops = {
"aten.view_as_complex.default",
"aten.view_as_real.default",
"aten.view.dtype",
"aten._weight_int8pack_mm.default",
"aten.zeros.names",
}