mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add int8 packed gemm support on CPU device (#118056)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118056 Approved by: https://github.com/mikekgfb
This commit is contained in:
committed by
PyTorch MergeBot
parent
e8e3049f57
commit
b3065f6899
@ -38,6 +38,7 @@
|
||||
#include <ATen/ops/_linalg_slogdet_native.h>
|
||||
#include <ATen/ops/_unsafe_view.h>
|
||||
#include <ATen/ops/_weight_int4pack_mm_native.h>
|
||||
#include <ATen/ops/_weight_int8pack_mm_native.h>
|
||||
#include <ATen/ops/addbmm_native.h>
|
||||
#include <ATen/ops/addmm_native.h>
|
||||
#include <ATen/ops/addr.h>
|
||||
@ -3394,6 +3395,7 @@ Tensor kron(const Tensor& self, const Tensor& other) {
|
||||
// Weight Only Quantization Gemm
|
||||
DEFINE_DISPATCH(weight_to_int4pack_stub);
|
||||
DEFINE_DISPATCH(int4pack_mm_stub);
|
||||
DEFINE_DISPATCH(int8pack_mm_stub);
|
||||
|
||||
Tensor _convert_weight_to_int4pack_cpu(
|
||||
const Tensor& in,
|
||||
@ -3472,5 +3474,37 @@ Tensor _weight_int4pack_mm_cpu(
|
||||
return C;
|
||||
}
|
||||
|
||||
Tensor _weight_int8pack_mm_cpu(
|
||||
const Tensor& A,
|
||||
const Tensor& B,
|
||||
const Tensor& scales) {
|
||||
|
||||
auto M = A.size(0);
|
||||
auto N = B.size(0);
|
||||
auto K = A.size(1);
|
||||
|
||||
TORCH_CHECK(A.dtype() == kBFloat16,
|
||||
__func__, " : expect A to be bfloat16 tensor.");
|
||||
TORCH_CHECK(A.is_contiguous(),
|
||||
__func__, " : expect A to be contiguous.");
|
||||
TORCH_CHECK(A.dim() == 2,
|
||||
__func__, " : expect A to be 2D tensor.");
|
||||
|
||||
TORCH_CHECK(B.dtype() == kChar,
|
||||
__func__, " : expect B to be int8 tensor.");
|
||||
TORCH_CHECK(B.is_contiguous(),
|
||||
__func__, " : expect B to be contiguous.");
|
||||
TORCH_CHECK(B.size(1) == K,
|
||||
__func__, " : expect B.size(1) == ", K);
|
||||
|
||||
TORCH_CHECK(scales.dim() == 1 && scales.size(0) == N,
|
||||
__func__, " : expect scales to be 1d tensor with size ", N);
|
||||
|
||||
auto C = at::empty({M, N}, A.options());
|
||||
int8pack_mm_stub(kCPU, C, A, B, scales);
|
||||
|
||||
return C;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
302
aten/src/ATen/native/cpu/int8mm_kernel.cpp
Normal file
302
aten/src/ATen/native/cpu/int8mm_kernel.cpp
Normal file
@ -0,0 +1,302 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/native/cpu/int_mm_kernel.h>
|
||||
#include <ATen/native/cpu/utils.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/Unroll.h>
|
||||
|
||||
#if (defined(_WIN32) || defined(_WIN64))
|
||||
#define RESTRICT __restrict
|
||||
#else
|
||||
#define RESTRICT __restrict__
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
namespace {
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
|
||||
// A block : {BLOCK_M, BLOCK_K}, lda = K
|
||||
// B block : {BLOCK_K, BLOCK_N}, ldb = K
|
||||
// C block : {BLOCK_M, BLOCK_N}, ldc = N
|
||||
//
|
||||
// scales block: {BLOCK_N}
|
||||
//
|
||||
template <int BLOCK_M, int BLOCK_N>
|
||||
inline void tinygemm_kernel(
|
||||
const BFloat16* RESTRICT A,
|
||||
const int8_t* RESTRICT B,
|
||||
const BFloat16* RESTRICT scales,
|
||||
BFloat16* RESTRICT C,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
int K) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N;
|
||||
|
||||
const int PREFETCH_SIZE_K = 16 * 4;
|
||||
|
||||
__m512 va;
|
||||
__m512 vb[COLS];
|
||||
__m512 vc[ROWS * COLS];
|
||||
__m512 scale[COLS];
|
||||
|
||||
auto load_scale = [&](int i) {
|
||||
float ss = static_cast<float>(scales[i]);
|
||||
scale[i] = _mm512_set1_ps(ss);
|
||||
};
|
||||
c10::ForcedUnroll<COLS>{}(load_scale);
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
vc[i] = _mm512_setzero_ps();
|
||||
};
|
||||
c10::ForcedUnroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
auto compute = [&](auto i, int k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
__m256i a16 = _mm256_load_si256((__m256i*)(A + row * lda + k));
|
||||
if (k + PREFETCH_SIZE_K < K) {
|
||||
_mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
|
||||
}
|
||||
vec::cvtbf16_fp32(a16, va);
|
||||
}
|
||||
|
||||
if constexpr (row == 0) {
|
||||
__m128i b8 = _mm_load_si128((__m128i*)(B + col * ldb + k));
|
||||
if (k + PREFETCH_SIZE_K < K) {
|
||||
_mm_prefetch(B + col * ldb + k + PREFETCH_SIZE_K, _MM_HINT_T0);
|
||||
}
|
||||
__m512i b32 = _mm512_cvtepi8_epi32(b8);
|
||||
vb[col] = _mm512_cvtepi32_ps(b32);
|
||||
vb[col] = _mm512_mul_ps(vb[col], scale[col]);
|
||||
}
|
||||
|
||||
constexpr int idx = row * COLS + col;
|
||||
vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
|
||||
};
|
||||
|
||||
for (int k = 0; k < K; k += 16) {
|
||||
c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
C[row * ldc + col] = static_cast<BFloat16>(_mm512_reduce_add_ps(vc[i]));
|
||||
};
|
||||
c10::ForcedUnroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
|
||||
#elif defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
|
||||
static inline float _mm256_reduce_add_ps(__m256& v) {
|
||||
__m256 v1 = _mm256_permute2f128_ps(v, v, 0x1);
|
||||
v = _mm256_add_ps(v, v1);
|
||||
v1 = _mm256_shuffle_ps(v, v, 0x4E);
|
||||
v = _mm256_add_ps(v, v1);
|
||||
v1 = _mm256_shuffle_ps(v, v, 0xB1);
|
||||
v = _mm256_add_ps(v, v1);
|
||||
return _mm256_cvtss_f32(v);
|
||||
}
|
||||
|
||||
template <int BLOCK_M, int BLOCK_N>
|
||||
inline void tinygemm_kernel(
|
||||
const BFloat16* RESTRICT A,
|
||||
const int8_t* RESTRICT B,
|
||||
const BFloat16* RESTRICT scales,
|
||||
BFloat16* RESTRICT C,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
int K) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N;
|
||||
|
||||
const int PREFETCH_SIZE_K = 16 * 4;
|
||||
|
||||
__m256 va;
|
||||
__m256 vb[COLS];
|
||||
__m256 vc[ROWS * COLS];
|
||||
__m256 scale[COLS];
|
||||
|
||||
auto load_scale = [&](int i) {
|
||||
float ss = static_cast<float>(scales[i]);
|
||||
scale[i] = _mm256_set1_ps(ss);
|
||||
};
|
||||
c10::ForcedUnroll<COLS>{}(load_scale);
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
vc[i] = _mm256_setzero_ps();
|
||||
};
|
||||
c10::ForcedUnroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
auto compute = [&](auto i, int k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
__m128i a16 = _mm_load_si128((__m128i*)(A + row * lda + k));
|
||||
if (k + PREFETCH_SIZE_K < K) {
|
||||
_mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
|
||||
}
|
||||
vec::cvtbf16_fp32(a16, va);
|
||||
}
|
||||
|
||||
if constexpr (row == 0) {
|
||||
__m128i b8 = _mm_loadu_si64((__m128i*)(B + col * ldb + k));
|
||||
if (k + PREFETCH_SIZE_K < K) {
|
||||
_mm_prefetch(B + col * ldb + k + PREFETCH_SIZE_K, _MM_HINT_T0);
|
||||
}
|
||||
__m256i b32 = _mm256_cvtepi8_epi32(b8);
|
||||
vb[col] = _mm256_cvtepi32_ps(b32);
|
||||
vb[col] = _mm256_mul_ps(vb[col], scale[col]);
|
||||
}
|
||||
|
||||
constexpr int idx = row * COLS + col;
|
||||
vc[idx] = _mm256_fmadd_ps(va, vb[col], vc[idx]);
|
||||
};
|
||||
|
||||
for (int k = 0; k < K; k += 8) {
|
||||
c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
C[row * ldc + col] = static_cast<BFloat16>(_mm256_reduce_add_ps(vc[i]));
|
||||
};
|
||||
c10::ForcedUnroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
// non-vectorized version
|
||||
template <int BLOCK_M, int BLOCK_N>
|
||||
inline void tinygemm_kernel(
|
||||
const BFloat16* RESTRICT A,
|
||||
const int8_t* RESTRICT B,
|
||||
const BFloat16* RESTRICT scales,
|
||||
BFloat16* RESTRICT C,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
int K) {
|
||||
|
||||
for (const auto m : c10::irange(BLOCK_M)) {
|
||||
for (const auto n : c10::irange(BLOCK_N)) {
|
||||
float c_val = 0;
|
||||
float scale_val = static_cast<float>(scales[n]);
|
||||
for (const auto k : c10::irange(K)) {
|
||||
float a_val = static_cast<float>(A[m * lda + k]);
|
||||
float b_val = static_cast<float>(B[n * ldb + k]);
|
||||
c_val += a_val * (b_val * scale_val);
|
||||
}
|
||||
C[m * ldc + n] = c_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel<MB_SIZE, NB_SIZE>( \
|
||||
A_ptr, B_ptr, S_ptr, C_ptr, \
|
||||
K, K, N, K);
|
||||
|
||||
#define LAUNCH_TINYGEMM_NB_SIZE(MB_SIZE) \
|
||||
switch (nb_size) { \
|
||||
case 1: \
|
||||
LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 1); \
|
||||
break; \
|
||||
case 2: \
|
||||
LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 2); \
|
||||
break; \
|
||||
case 3: \
|
||||
LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 3); \
|
||||
break; \
|
||||
case 4: \
|
||||
LAUNCH_TINYGEMM_KERNEL(MB_SIZE, 4); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported n block size: ", nb_size); \
|
||||
break; \
|
||||
}
|
||||
|
||||
void int8pack_mm_kernel(
|
||||
const Tensor& C,
|
||||
const Tensor& A,
|
||||
const Tensor& B,
|
||||
const Tensor& scales) {
|
||||
|
||||
const auto* A_data = A.data_ptr<BFloat16>();
|
||||
const auto* B_data = B.data_ptr<int8_t>();
|
||||
auto* C_data = C.data_ptr<BFloat16>();
|
||||
const auto* S_data = scales.data_ptr<BFloat16>();
|
||||
|
||||
int M = A.size(0);
|
||||
int N = B.size(0);
|
||||
int K = A.size(1);
|
||||
|
||||
constexpr int BLOCK_M = 4;
|
||||
constexpr int BLOCK_N = 4;
|
||||
|
||||
const int MB = (M + BLOCK_M - 1) / BLOCK_M;
|
||||
const int NB = (N + BLOCK_N - 1) / BLOCK_N;
|
||||
|
||||
at::parallel_for(0, MB * NB, 0, [&](int begin, int end) {
|
||||
int mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
(void)i;
|
||||
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
const auto* A_ptr = A_data + mb_start * K;
|
||||
const auto* B_ptr = B_data + nb_start * K;
|
||||
const auto* S_ptr = S_data + nb_start;
|
||||
auto* C_ptr = C_data + mb_start * N + nb_start;
|
||||
|
||||
switch (mb_size) {
|
||||
case 1:
|
||||
LAUNCH_TINYGEMM_NB_SIZE(1);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_TINYGEMM_NB_SIZE(2);
|
||||
break;
|
||||
case 3:
|
||||
LAUNCH_TINYGEMM_NB_SIZE(3);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_TINYGEMM_NB_SIZE(4);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported m block size: ", mb_size);
|
||||
}
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
ALSO_REGISTER_AVX512_DISPATCH(int8pack_mm_stub, &int8pack_mm_kernel);
|
||||
|
||||
} // at::native
|
@ -7,8 +7,10 @@ namespace at::native {
|
||||
|
||||
using weight_to_int4pack_fn = void(*)(const Tensor&, const Tensor&, int, int);
|
||||
using int4pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&, int, int);
|
||||
using int8pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&);
|
||||
|
||||
DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub);
|
||||
DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub);
|
||||
DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub);
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -4100,6 +4100,10 @@
|
||||
CPU: _weight_int4pack_mm_cpu
|
||||
CUDA: _weight_int4pack_mm_cuda
|
||||
|
||||
- func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor
|
||||
dispatch:
|
||||
CPU: _weight_int8pack_mm_cpu
|
||||
|
||||
- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor
|
||||
python_module: sparse
|
||||
|
||||
|
@ -1155,6 +1155,7 @@ aten_native_source_codegen_list = [
|
||||
"aten/src/ATen/native/cpu/batch_norm_kernel.cpp",
|
||||
"aten/src/ATen/native/cpu/group_norm_kernel.cpp",
|
||||
"aten/src/ATen/native/cpu/int4mm_kernel.cpp",
|
||||
"aten/src/ATen/native/cpu/int8mm_kernel.cpp",
|
||||
"aten/src/ATen/native/cpu/layer_norm_kernel.cpp",
|
||||
"aten/src/ATen/native/cpu/AmpGradScalerKernels.cpp",
|
||||
"aten/src/ATen/native/cpu/scaled_modified_bessel_k0.cpp",
|
||||
|
@ -606,6 +606,7 @@ aten::_values
|
||||
aten::_values_copy
|
||||
aten::_values_copy.out
|
||||
aten::_weight_int4pack_mm
|
||||
aten::_weight_int8pack_mm
|
||||
aten::_weight_norm_interface_backward
|
||||
aten::_weight_norm_interface_backward.out
|
||||
aten::adaptive_avg_pool2d.out
|
||||
|
@ -5990,6 +5990,92 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
mean_err = ((res - ref).abs() / ref).mean()
|
||||
self.assertTrue(mean_err < 0.05)
|
||||
|
||||
def _dynamically_quantize_per_channel(self, x, quant_min, quant_max, target_dtype):
|
||||
# source: https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py
|
||||
# default setup for affine quantization of activations
|
||||
x_dtype = x.dtype
|
||||
x = x.float()
|
||||
eps = torch.finfo(torch.float32).eps
|
||||
|
||||
# get min and max
|
||||
min_val, max_val = torch.aminmax(x, dim=1)
|
||||
|
||||
# calculate scales and zero_points based on min and max
|
||||
# reference: https://fburl.com/code/srbiybme
|
||||
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
|
||||
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
|
||||
device = min_val_neg.device
|
||||
|
||||
# reference: https://fburl.com/code/4wll53rk
|
||||
max_val_pos = torch.max(-min_val_neg, max_val_pos)
|
||||
scales = max_val_pos / (float(quant_max - quant_min) / 2)
|
||||
# ensure scales is the same dtype as the original tensor
|
||||
scales = torch.clamp(scales, min=eps).to(x.dtype)
|
||||
zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
|
||||
|
||||
# quantize based on qmin/qmax/scales/zp
|
||||
x_div = x / scales.unsqueeze(-1)
|
||||
x_round = torch.round(x_div)
|
||||
x_zp = x_round + zero_points.unsqueeze(-1)
|
||||
quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
|
||||
|
||||
return quant, scales.to(x_dtype), zero_points
|
||||
|
||||
@onlyCPU
|
||||
@parametrize("m", [32, 64])
|
||||
@parametrize("k", [32, 64])
|
||||
@parametrize("n", [48, 64])
|
||||
def test__int8_mm(self, device, m, k, n):
|
||||
torch.manual_seed(1)
|
||||
a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
|
||||
b = torch.rand((n, k), dtype=torch.bfloat16, device=device)
|
||||
|
||||
def convert_weight_to_int8pack(b):
|
||||
b_int8pack, b_scales, _ = self._dynamically_quantize_per_channel(
|
||||
b, -128, 127, torch.int8
|
||||
)
|
||||
return b_int8pack, b_scales
|
||||
|
||||
def weight_int8pack_mm(a, b_int8pack, b_scales):
|
||||
return torch._weight_int8pack_mm(
|
||||
a, b_int8pack, b_scales
|
||||
)
|
||||
|
||||
b_int8pack, b_scales = convert_weight_to_int8pack(b)
|
||||
res = weight_int8pack_mm(a, b_int8pack, b_scales)
|
||||
ref = torch.mm(a, b.transpose(0, 1))
|
||||
|
||||
mean_err = ((res - ref).abs() / ref).mean()
|
||||
self.assertTrue(mean_err < 0.05)
|
||||
|
||||
@onlyCPU
|
||||
@parametrize("m", [32, 64])
|
||||
@parametrize("k", [32, 64])
|
||||
@parametrize("n", [48, 64])
|
||||
def test_compile_int8_mm(self, device, m, k, n):
|
||||
if sys.version_info >= (3, 12):
|
||||
self.skipTest("Dynamo is not supported on Python 3.12+")
|
||||
|
||||
torch.manual_seed(1)
|
||||
a = torch.rand((m, k), dtype=torch.bfloat16, device=device)
|
||||
b = torch.rand((n, k), dtype=torch.bfloat16, device=device)
|
||||
|
||||
b_int8pack, b_scales, _ = self._dynamically_quantize_per_channel(
|
||||
b, -128, 127, torch.int8
|
||||
)
|
||||
|
||||
@torch.compile
|
||||
def int8_mm(a, b_int8pack, b_scales):
|
||||
return torch._weight_int8pack_mm(
|
||||
a, b_int8pack, b_scales
|
||||
)
|
||||
|
||||
res = int8_mm(a, b_int8pack, b_scales)
|
||||
ref = torch.mm(a, b.transpose(0, 1))
|
||||
|
||||
mean_err = ((res - ref).abs() / ref).mean()
|
||||
self.assertTrue(mean_err < 0.05)
|
||||
|
||||
@slowTest
|
||||
@onlyNativeDeviceTypes
|
||||
# bfloat16 doesn't have sufficient precision to pass this test
|
||||
|
@ -1487,6 +1487,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._use_cudnn_rnn_flatten_weight",
|
||||
"torch._values_copy",
|
||||
"torch._weight_int4pack_mm",
|
||||
"torch._weight_int8pack_mm",
|
||||
"torch._weight_norm_interface",
|
||||
"torch._weight_norm",
|
||||
"torch.abs_",
|
||||
|
@ -3455,6 +3455,21 @@ def meta__weight_int4pack_mm(x, w, q_group_size, q_scale_and_zeros):
|
||||
return x.new_empty(x.size(0), w.size(0) * 8, dtype=x.dtype)
|
||||
|
||||
|
||||
@register_meta([aten._weight_int8pack_mm])
|
||||
def meta__weight_int8pack_mm(x, w, q_scales):
|
||||
torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
|
||||
torch._check(
|
||||
x.dtype is torch.bfloat16,
|
||||
lambda: f"expected x to be bf16, got {x.dtype}",
|
||||
)
|
||||
torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
|
||||
torch._check(
|
||||
w.dtype is torch.int8,
|
||||
lambda: f"expected w to be int8, got {w.dtype}",
|
||||
)
|
||||
return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
|
||||
|
||||
|
||||
@register_meta(aten._cdist_forward.default)
|
||||
def meta_cdist_forward(x1, x2, p, compute_mode):
|
||||
torch._check(
|
||||
|
Reference in New Issue
Block a user