mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support transpose and pack for bit8 (#156065)
To be used by CPU INT8 SDPA in torchao. https://github.com/pytorch/ao/pull/2380 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156065 Approved by: https://github.com/mingfeima, https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
2022588295
commit
d26ca5de05
153
aten/src/ATen/cpu/vec/vec_quant.h
Normal file
153
aten/src/ATen/cpu/vec/vec_quant.h
Normal file
@ -0,0 +1,153 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace at::vec {
|
||||
// See Note [CPU_CAPABILITY namespace]
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
// Transpose a [4, 64] block to [64, 4] (with contiguous output, ld=4)
|
||||
template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 1>>
|
||||
static inline void transpose_pad_4x64_block(
|
||||
const scalar_t* src,
|
||||
scalar_t* dst,
|
||||
int64_t ld_src,
|
||||
int krem = 4,
|
||||
int nrem = 64) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
__m512i r[4];
|
||||
// Load with mask if partial
|
||||
if (nrem < 64) {
|
||||
__mmask64 mask = (1ULL << nrem) - 1;
|
||||
for (int i = 0; i < krem; ++i) {
|
||||
r[i] = _mm512_maskz_loadu_epi8(mask, src + i * ld_src);
|
||||
}
|
||||
for (int i = krem; i < 4; ++i) {
|
||||
r[i] = _mm512_setzero_si512();
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < krem; ++i) {
|
||||
r[i] = _mm512_loadu_si512(
|
||||
reinterpret_cast<const __m512i*>(src + i * ld_src));
|
||||
}
|
||||
for (int i = krem; i < 4; ++i) {
|
||||
r[i] = _mm512_setzero_si512();
|
||||
}
|
||||
}
|
||||
|
||||
// Transpose 4x64 bytes using unpack and shuffle
|
||||
__m512i t0 = _mm512_unpacklo_epi8(r[0], r[1]);
|
||||
__m512i t1 = _mm512_unpackhi_epi8(r[0], r[1]);
|
||||
__m512i t2 = _mm512_unpacklo_epi8(r[2], r[3]);
|
||||
__m512i t3 = _mm512_unpackhi_epi8(r[2], r[3]);
|
||||
|
||||
__m512i u0 = _mm512_unpacklo_epi16(t0, t2);
|
||||
__m512i u1 = _mm512_unpackhi_epi16(t0, t2);
|
||||
__m512i u2 = _mm512_unpacklo_epi16(t1, t3);
|
||||
__m512i u3 = _mm512_unpackhi_epi16(t1, t3);
|
||||
|
||||
__m512i v0 = _mm512_shuffle_i32x4(u0, u1, 0x88);
|
||||
__m512i v1 = _mm512_shuffle_i32x4(u0, u1, 0xdd);
|
||||
__m512i v2 = _mm512_shuffle_i32x4(u2, u3, 0x88);
|
||||
__m512i v3 = _mm512_shuffle_i32x4(u2, u3, 0xdd);
|
||||
|
||||
__m512i r0 = _mm512_shuffle_i32x4(v0, v2, 0x88);
|
||||
__m512i r1 = _mm512_shuffle_i32x4(v1, v3, 0x88);
|
||||
__m512i r2 = _mm512_shuffle_i32x4(v0, v2, 0xdd);
|
||||
__m512i r3 = _mm512_shuffle_i32x4(v1, v3, 0xdd);
|
||||
|
||||
// Store output
|
||||
if (nrem < 16) {
|
||||
__mmask64 mask = (1ULL << (nrem * 4)) - 1;
|
||||
_mm512_mask_storeu_epi8(dst, mask, r0);
|
||||
} else if (nrem == 16) {
|
||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
|
||||
} else if (nrem < 32) {
|
||||
int n_bytes1 = 64;
|
||||
int n_bytes2 = (nrem * 4) - n_bytes1;
|
||||
__mmask64 mask = (1ULL << n_bytes2) - 1;
|
||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
|
||||
_mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64), mask, r1);
|
||||
} else if (nrem == 32) {
|
||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
|
||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1);
|
||||
} else if (nrem < 48) {
|
||||
int n_bytes1 = 64 * 2;
|
||||
int n_bytes2 = (nrem * 4) - n_bytes1;
|
||||
__mmask64 mask = (1ULL << n_bytes2) - 1;
|
||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
|
||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1);
|
||||
_mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64 * 2), mask, r2);
|
||||
} else if (nrem == 48) {
|
||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
|
||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1);
|
||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2);
|
||||
} else if (nrem < 64) {
|
||||
int n_bytes1 = 64 * 3;
|
||||
int n_bytes2 = (nrem * 4) - n_bytes1;
|
||||
__mmask64 mask = (1ULL << n_bytes2) - 1;
|
||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
|
||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1);
|
||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2);
|
||||
_mm512_mask_storeu_epi8(reinterpret_cast<__m512i*>(dst + 64 * 3), mask, r3);
|
||||
} else {
|
||||
// normal case, nrem == 64
|
||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst), r0);
|
||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64), r1);
|
||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 2), r2);
|
||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 64 * 3), r3);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"transpose_pad_4x64_block is only supported when AVX-512 is supported")
|
||||
#endif
|
||||
}
|
||||
|
||||
// Reorder [K, N] → [K/4, N, 4] (VNNI4-style layout for bit8)
|
||||
template <typename scalar_t, typename = std::enable_if_t<sizeof(scalar_t) == 1>>
|
||||
static inline void pack_vnni4(
|
||||
const scalar_t* src,
|
||||
scalar_t* dst,
|
||||
int64_t ld_src,
|
||||
int64_t K,
|
||||
int64_t N) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
int64_t bk = 0;
|
||||
int64_t _K = K / 4 * 4;
|
||||
int64_t _N = N / 64 * 64;
|
||||
for (; bk < _K; bk += 4) {
|
||||
int64_t bn = 0;
|
||||
for (; bn < _N; bn += 64) {
|
||||
transpose_pad_4x64_block(
|
||||
src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src);
|
||||
}
|
||||
int64_t nrem = N - bn;
|
||||
if (nrem > 0) {
|
||||
transpose_pad_4x64_block(
|
||||
src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, 4, nrem);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle leftover K rows (< 4)
|
||||
if (K % 4 != 0) {
|
||||
int krem = K - bk;
|
||||
int64_t bn = 0;
|
||||
for (; bn < _N; bn += 64) {
|
||||
transpose_pad_4x64_block(
|
||||
src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, krem);
|
||||
}
|
||||
int64_t nrem = N - bn;
|
||||
if (nrem > 0) {
|
||||
transpose_pad_4x64_block(
|
||||
src + bk * ld_src + bn, dst + bk * N + bn * 4, ld_src, krem, nrem);
|
||||
}
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "pack_vnni4 is only supported when AVX-512 is supported")
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace CPU_CAPABILITY
|
||||
} // namespace at::vec
|
@ -165,6 +165,12 @@ inline void transpose<uint16_t>(int64_t M, int64_t N, const uint16_t* src, int64
|
||||
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
|
||||
fbgemm::transpose_simd<uint16_t>(M, N, src, ld_src, dst, ld_dst);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void transpose<uint8_t>(int64_t M, int64_t N, const uint8_t* src, int64_t ld_src, uint8_t* dst, int64_t ld_dst) {
|
||||
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
|
||||
fbgemm::transpose_simd<uint8_t>(M, N, src, ld_src, dst, ld_dst);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename index_t, typename F>
|
||||
|
@ -61,6 +61,8 @@ namespace {
|
||||
template <typename T>
|
||||
class QuantizationTests : public ::testing::Test {};
|
||||
template <typename T>
|
||||
class Quantization8BitTests : public ::testing::Test {};
|
||||
template <typename T>
|
||||
class Quantization8BitWithTailTests : public ::testing::Test {};
|
||||
template <typename T>
|
||||
class FunctionalTests : public ::testing::Test {};
|
||||
@ -79,6 +81,7 @@ namespace {
|
||||
using FloatTestedTypes = ::testing::Types<vfloat, vdouble, vcomplex, vcomplexDbl>;
|
||||
using ALLTestedTypes = ::testing::Types<vfloat, vdouble, vcomplex, vlong, vint, vshort, vqint8, vquint8, vqint>;
|
||||
using QuantTestedTypes = ::testing::Types<vqint8, vquint8, vqint>;
|
||||
using Quantization8BitTestedTypes = ::testing::Types<vqint8, vquint8>;
|
||||
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
|
||||
using Quantization8BitWithTailTestedTypes =
|
||||
::testing::Types<vqint8, vquint8>;
|
||||
@ -116,6 +119,7 @@ namespace {
|
||||
TYPED_TEST_SUITE(BitwiseFloatsAdditional, RealFloatReducedFloatTestedTypes);
|
||||
TYPED_TEST_SUITE(BitwiseFloatsAdditional2, FloatTestedTypes);
|
||||
TYPED_TEST_SUITE(QuantizationTests, QuantTestedTypes);
|
||||
TYPED_TEST_SUITE(Quantization8BitTests, Quantization8BitTestedTypes);
|
||||
TYPED_TEST_SUITE(InfiniteTests, RealFloatTestedTypes);
|
||||
#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
|
||||
TYPED_TEST_SUITE(
|
||||
@ -1496,6 +1500,68 @@ namespace {
|
||||
},
|
||||
test_case);
|
||||
}
|
||||
#ifndef _WIN32
|
||||
TYPED_TEST(Quantization8BitTests, Transpose) {
|
||||
using VT = ValueType<TypeParam>;
|
||||
constexpr auto M = 4;
|
||||
constexpr auto N = 64;
|
||||
constexpr auto L = M * N;
|
||||
constexpr auto ld_src = N;
|
||||
constexpr auto ld_dst = M;
|
||||
CACHE_ALIGN VT x[L];
|
||||
CACHE_ALIGN VT y[L];
|
||||
CACHE_ALIGN VT ref[L];
|
||||
auto seed = TestSeed();
|
||||
ValueGen<VT> generator(VT(-100), VT(100), seed);
|
||||
for (const auto i : c10::irange(L)) {
|
||||
x[i] = generator.get();
|
||||
}
|
||||
at::native::utils::transpose<uint8_t>(
|
||||
M, N,
|
||||
reinterpret_cast<uint8_t*>(x), ld_src,
|
||||
reinterpret_cast<uint8_t*>(y), ld_dst);
|
||||
for (int64_t j = 0; j < N; j++) {
|
||||
for (int64_t i = 0; i < M; i++) {
|
||||
ref[j * ld_dst + i] = c10::load(&(x[i * ld_src + j]));
|
||||
}
|
||||
}
|
||||
for (const auto i : c10::irange(L)) {
|
||||
ASSERT_EQ(y[i], ref[i])
|
||||
<< "Failure Details:\nTest Seed to reproduce: " << seed;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
TYPED_TEST(Quantization8BitTests, PackVNNI4) {
|
||||
using VT = ValueType<TypeParam>;
|
||||
constexpr auto K = 8;
|
||||
constexpr auto N = 128;
|
||||
constexpr auto L = K * N;
|
||||
constexpr auto ld_src = N;
|
||||
CACHE_ALIGN VT x[L];
|
||||
CACHE_ALIGN VT y[L];
|
||||
CACHE_ALIGN VT ref[L];
|
||||
auto seed = TestSeed();
|
||||
ValueGen<VT> generator(VT(-100), VT(100), seed);
|
||||
for (const auto i : c10::irange(L)) {
|
||||
x[i] = generator.get();
|
||||
}
|
||||
at::vec::pack_vnni4(x, y, ld_src, K, N);
|
||||
int64_t _K = K / 4;
|
||||
for (int64_t k = 0; k < _K; k++) {
|
||||
for(int64_t n = 0; n < N; n++) {
|
||||
for(int64_t l = 0; l < 4; l++) {
|
||||
ref[k * N * 4 + n * 4 + l] =
|
||||
c10::load(&(x[k * ld_src * 4 + l * ld_src + n]));
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const auto i : c10::irange(L)) {
|
||||
ASSERT_EQ(y[i], ref[i])
|
||||
<< "Failure Details:\nTest Seed to reproduce: " << seed;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
TYPED_TEST(FunctionalTests, Map) {
|
||||
using vec = TypeParam;
|
||||
using VT = ValueType<TypeParam>;
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/cpu/vec/vec_quant.h>
|
||||
#include <c10/util/bit_cast.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <gtest/gtest.h>
|
||||
@ -21,7 +22,9 @@
|
||||
#else
|
||||
#define CACHE_LINE 32
|
||||
#endif
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <ATen/native/cpu/utils.h>
|
||||
#endif
|
||||
#if defined(__GNUC__)
|
||||
#define CACHE_ALIGN __attribute__((aligned(CACHE_LINE)))
|
||||
#define not_inline __attribute__((noinline))
|
||||
|
Reference in New Issue
Block a user