mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[ROCm] Add int4 support (#129710)"
This reverts commit d0ad13fa42fc2e9935bd3bda2937a3491276d274. Reverted https://github.com/pytorch/pytorch/pull/129710 on behalf of https://github.com/jeffdaily due to original ROCm PR did not have ciflow/rocm, missed signal ([comment](https://github.com/pytorch/pytorch/pull/129710#issuecomment-2214558368))
This commit is contained in:
@ -1,11 +1,9 @@
|
||||
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#if !defined(USE_ROCM)
|
||||
#include <mma.h>
|
||||
#endif
|
||||
#endif
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
@ -127,38 +125,9 @@ inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) {
|
||||
return diff == 0 ? 0 : uint32_t(Align) - diff;
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
// TODO: Support RDNA
|
||||
constexpr int32_t kWarpSize = 64;
|
||||
|
||||
template<typename T, uint32_t Rank>
|
||||
using VecT = T __attribute__((ext_vector_type(Rank)));
|
||||
|
||||
static bool isCDNA2orLater(int index) {
|
||||
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
|
||||
std::string device_arch = prop->gcnArchName;
|
||||
static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
|
||||
for (std::string arch : archs) {
|
||||
size_t substring = device_arch.find(arch);
|
||||
if (substring != std::string::npos) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
#else
|
||||
constexpr int32_t kWarpSize = 32;
|
||||
#endif
|
||||
|
||||
#if defined (__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
#define CDNA2_OR_LATER 1
|
||||
#else
|
||||
#define CDNA2_OR_LATER 0
|
||||
#endif
|
||||
|
||||
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
|
||||
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
|
||||
// f16 vector types
|
||||
struct __align__(2) f16x1 {
|
||||
__half vals[1];
|
||||
@ -207,19 +176,11 @@ struct __align__(16) bf16x2x4 {
|
||||
};
|
||||
|
||||
struct __align__(16) bf16x2x4_u32 {
|
||||
#if defined(USE_ROCM)
|
||||
VecT<short, 4> val[2];
|
||||
#else
|
||||
uint32_t vals[4];
|
||||
#endif
|
||||
};
|
||||
|
||||
struct __align__(8) bf16x2x2_u32 {
|
||||
#if defined(USE_ROCM)
|
||||
VecT<short, 4> val;
|
||||
#else
|
||||
uint32_t vals[2];
|
||||
#endif
|
||||
};
|
||||
|
||||
struct __align__(4) bf16x2x1_u32 {
|
||||
@ -241,68 +202,38 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
|
||||
uint32_t const source_i4s = source;
|
||||
|
||||
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||
#if !defined(USE_ROCM)
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
#endif
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
|
||||
|
||||
// We don't have enough mantissa to remove as much shift overhead as FP16, so
|
||||
// we must loop. No shift needed for first item.
|
||||
uint32_t i4s = source_i4s;
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3"
|
||||
: "=v"(h[0])
|
||||
: "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM));
|
||||
#else
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[0])
|
||||
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
|
||||
#endif
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < kElements / 2; ++ii) {
|
||||
i4s >>= 4; // or is it 8?
|
||||
// (i4s & 0x000f000f) | 0x43004300
|
||||
#if defined(USE_ROCM)
|
||||
asm volatile("v_and_or_b32 %0, %1, %2, %3"
|
||||
: "=v"(h[ii])
|
||||
: "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM));
|
||||
#else
|
||||
asm volatile(
|
||||
"lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[ii])
|
||||
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
|
||||
#endif
|
||||
}
|
||||
|
||||
// This is the BF16 {-136, -136} represented as an integer.
|
||||
#if defined(USE_ROCM)
|
||||
#if ROCM_VERSION >= 60200
|
||||
auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308}));
|
||||
auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80}));
|
||||
#else
|
||||
auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16{0xC308});
|
||||
auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16{0x3F80});
|
||||
#endif
|
||||
#else
|
||||
static constexpr uint32_t BF16_BIAS = 0xC308C308;
|
||||
static constexpr uint32_t BF16_ONE = 0x3F803F80;
|
||||
#endif
|
||||
|
||||
// Finally, we construct the output numbers.
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < kElements / 2; ++ii) {
|
||||
// Since this section is for Ampere+, we use bf16 fma to do the bias
|
||||
// subtraction
|
||||
#if defined(USE_ROCM)
|
||||
result.vals[ii] = __hfma2(result.vals[ii], BF16_ONE, BF16_BIAS);
|
||||
#else
|
||||
asm("fma.rn.bf16x2 %0, %1, %2, %3;\n"
|
||||
: "=r"(h[ii])
|
||||
: "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
|
||||
#endif
|
||||
}
|
||||
|
||||
return result;
|
||||
@ -323,11 +254,7 @@ enum class KReductionType {
|
||||
template <KReductionType ReduceType>
|
||||
struct ALayout_RM {
|
||||
static constexpr int32_t kMTileSize = 16;
|
||||
#if defined(USE_ROCM)
|
||||
static constexpr int32_t kNTileSize = 16;
|
||||
#else
|
||||
static constexpr int32_t kNTileSize = 8;
|
||||
#endif
|
||||
static constexpr int32_t kKTileSize = 16;
|
||||
|
||||
template <int KTilesToLoad>
|
||||
@ -340,37 +267,22 @@ struct ALayout_RM {
|
||||
int32_t kTiles,
|
||||
int32_t kTileStart,
|
||||
int32_t laneId,
|
||||
#if defined(USE_ROCM)
|
||||
bf16x2x2_u32 out[KTilesToLoad]
|
||||
#else
|
||||
bf16x2x4_u32 out[KTilesToLoad]
|
||||
#endif
|
||||
) {
|
||||
#if defined(USE_ROCM)
|
||||
const auto mLane = mTile * kMTileSize + (laneId % kMTileSize);
|
||||
const auto kLane = kTileStart * kKTileSize + (laneId / kMTileSize) * 4;
|
||||
#else
|
||||
bf16x2x4_u32 out[KTilesToLoad]) {
|
||||
const auto mLane = mTile * kMTileSize + (laneId / 4);
|
||||
const auto kLane = kTileStart * kKTileSize + (laneId % 4) * 2;
|
||||
#endif
|
||||
|
||||
// access
|
||||
// [mTile * kMTileSize + (laneId / 4)]
|
||||
// [kTileStart * kKTileSize + (laneId % 4) * 2]
|
||||
auto aPtr = reinterpret_cast<const __nv_bfloat16*>(A) + mLane * k + kLane;
|
||||
bool m0InBounds = mLane < m;
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
auto aPtrPlus8Rows = aPtr + 8 * k;
|
||||
|
||||
bool m0InBounds = mLane < m;
|
||||
bool m1InBounds = (mLane + 8) < m;
|
||||
#endif
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < KTilesToLoad; ++i) {
|
||||
#if defined(USE_ROCM)
|
||||
out[i].val = m0InBounds ? *((VecT<short, 4> *)(aPtr + i * kKTileSize)) : VecT<short, 4>{0, 0, 0, 0};
|
||||
#else
|
||||
out[i].vals[0] = m0InBounds
|
||||
? *reinterpret_cast<const uint32_t*>(aPtr + i * kKTileSize)
|
||||
: uint32_t(0);
|
||||
@ -384,7 +296,6 @@ struct ALayout_RM {
|
||||
out[i].vals[3] = m1InBounds ? *reinterpret_cast<const uint32_t*>(
|
||||
aPtrPlus8Rows + i * kKTileSize + 8)
|
||||
: uint32_t(0);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@ -401,10 +312,6 @@ struct ALayout_RM {
|
||||
static_assert(ReduceType == KReductionType::None, "");
|
||||
|
||||
if constexpr (ReduceType == KReductionType::None) {
|
||||
#if defined(USE_ROCM)
|
||||
const int outRow = mTile * kMTileSize + (laneId / kNTileSize) * 4;
|
||||
const int outCol = nTile * kNTileSize + (laneId % kNTileSize);
|
||||
#else
|
||||
// sum.x / sum.y are written at
|
||||
// [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
|
||||
// sum.z / sum.w are written at
|
||||
@ -412,21 +319,10 @@ struct ALayout_RM {
|
||||
// i.e., same columns, different row.
|
||||
const int outRow = mTile * kMTileSize + (laneId / 4);
|
||||
const int outCol = nTile * kNTileSize + (laneId % 4) * 2;
|
||||
#endif
|
||||
|
||||
// Pointer where sum.x / sum.y is written
|
||||
auto cPtr = reinterpret_cast<__nv_bfloat16*>(C) + outRow * n + outCol;
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
if (outRow < m)
|
||||
cPtr[0] = __float2bfloat16(out.x);
|
||||
if ((outRow + 1) < m)
|
||||
cPtr[n] = __float2bfloat16(out.y);
|
||||
if ((outRow + 2) < m)
|
||||
cPtr[2*n] = __float2bfloat16(out.z);
|
||||
if ((outRow + 3) < m)
|
||||
cPtr[3*n] = __float2bfloat16(out.w);
|
||||
#else
|
||||
auto v01 = __float22bfloat162_rn(float2{out.x, out.y});
|
||||
auto v23 = __float22bfloat162_rn(float2{out.z, out.w});
|
||||
|
||||
@ -438,7 +334,6 @@ struct ALayout_RM {
|
||||
if (outRow + 8 < m) {
|
||||
*reinterpret_cast<__nv_bfloat162*>(cPtr + 8 * n) = v23;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -447,19 +342,15 @@ template <int InnerKTiles, int QGroupSize>
|
||||
struct BLayout_TC_int4 {
|
||||
static constexpr int32_t kInnerKTiles = InnerKTiles;
|
||||
static constexpr int32_t kMTileSize = 16;
|
||||
#if defined(USE_ROCM)
|
||||
static constexpr int32_t kNTileSize = 16;
|
||||
#else
|
||||
static constexpr int32_t kNTileSize = 8;
|
||||
#endif
|
||||
static constexpr int32_t kKTileSize = 16;
|
||||
|
||||
template <int KTilesToLoad>
|
||||
static __device__ void load(
|
||||
// type uint32, size [n / 8][k / (InnerKTiles * 16)][32][InnerKTiles / 2]
|
||||
// n-tiles: n / 8 for NV, n /16 for AMD
|
||||
// k / (InnerKTiles * 16): TC size per k-tile is 16 (m16n8k16 for NV, m16n16k16 for AMD)
|
||||
// value per warp lane: 32 for NV, 64 for AMD
|
||||
// n / 8: n-tiles (n8)
|
||||
// k / (InnerKTiles * 16): TC size per k-tile is 16 (m16n8k16)
|
||||
// 32: value per warp lane
|
||||
// (InnerKTiles / 2): B layout has 4 values per lane (16 bits) per k-tile.
|
||||
// 2 k-tiles packed is a uint32 (hence InnerKTiles == 2 is our smallest
|
||||
// value) 4 k-tiles packed is a uint32x2 (64 bits) 8 k-tiles packed is a
|
||||
@ -532,11 +423,7 @@ struct BLayout_TC_int4 {
|
||||
|
||||
__nv_bfloat162 qScaleAndZero[kNumQGroups];
|
||||
{
|
||||
#if defined(USE_ROCM)
|
||||
int32_t laneN = nTile * kNTileSize + (laneId % kNTileSize);
|
||||
#else
|
||||
int32_t laneN = nTile * kNTileSize + (laneId / 4);
|
||||
#endif
|
||||
int32_t groupStart = (kTileStart * kKTileSize) / QGroupSize;
|
||||
|
||||
int32_t n = nTiles * kNTileSize;
|
||||
@ -627,15 +514,9 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
|
||||
int32_t nTiles,
|
||||
int32_t kTiles) {
|
||||
constexpr int32_t kMTileSize = 16;
|
||||
#if defined(USE_ROCM)
|
||||
constexpr int32_t kNTileSize = 16;
|
||||
#else
|
||||
constexpr int32_t kNTileSize = 8;
|
||||
#endif
|
||||
constexpr int32_t kKTileSize = 16;
|
||||
|
||||
#if !defined(USE_ROCM) || CDNA2_OR_LATER
|
||||
|
||||
static_assert(
|
||||
ALayout::kMTileSize == kMTileSize && ALayout::kNTileSize == kNTileSize &&
|
||||
ALayout::kKTileSize == kKTileSize,
|
||||
@ -669,11 +550,7 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
|
||||
int32_t mTile = blockIdx.z;
|
||||
int32_t nTile = blockIdx.y;
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
VecT<float, 4> c{0.0f, 0.0f, 0.0f, 0.0f};
|
||||
#else
|
||||
float4 c{0.0f, 0.0f, 0.0f, 0.0f};
|
||||
#endif
|
||||
|
||||
// First, handle whole multiples of KTilesPerIteration
|
||||
auto kTilesLimit = roundDown(kTiles, KTilesPerIteration);
|
||||
@ -685,11 +562,7 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
|
||||
//
|
||||
// Load data from A
|
||||
//
|
||||
#if defined(USE_ROCM)
|
||||
bf16x2x2_u32 a[KTilesPerIteration];
|
||||
#else
|
||||
bf16x2x4_u32 a[KTilesPerIteration];
|
||||
#endif
|
||||
ALayout::template load<KTilesPerIteration>(
|
||||
A, m, k, mTiles, mTile, kTiles, kTileBase, laneId, a);
|
||||
|
||||
@ -723,29 +596,15 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
|
||||
// We don't simply accumulate into `c` as this creates a too-strong
|
||||
// execution dependency. Instead, we only periodically accumulate into
|
||||
// `c`
|
||||
#if defined(USE_ROCM)
|
||||
VecT<float, 4> cTmp[2];
|
||||
#else
|
||||
float4 cTmp[2];
|
||||
#endif
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < 2; ++k) {
|
||||
#if defined(USE_ROCM)
|
||||
cTmp[k] = VecT<float, 4>{0.0f, 0.0f, 0.0f, 0.0f};
|
||||
#else
|
||||
cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f};
|
||||
#endif
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < 2; ++k) {
|
||||
#if defined(USE_ROCM)
|
||||
cTmp[k] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
a[i * kInnerKTiles + j * 2 + k].val,
|
||||
b[i][(j * 2 + k) / 2].val[((j * 2 + k) % 2)],
|
||||
cTmp[k], 0, 0, 0);
|
||||
#else
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
|
||||
@ -763,22 +622,14 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
|
||||
"f"(cTmp[k].y),
|
||||
"f"(cTmp[k].z),
|
||||
"f"(cTmp[k].w));
|
||||
#endif
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < 2; ++k) {
|
||||
#if defined(USE_ROCM)
|
||||
c[0] += cTmp[k][0];
|
||||
c[1] += cTmp[k][1];
|
||||
c[2] += cTmp[k][2];
|
||||
c[3] += cTmp[k][3];
|
||||
#else
|
||||
c.x += cTmp[k].x;
|
||||
c.y += cTmp[k].y;
|
||||
c.z += cTmp[k].z;
|
||||
c.w += cTmp[k].w;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -795,11 +646,7 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
|
||||
// If we have any remainder k-tiles, some warps will handle them, processing
|
||||
// kInnerKTiles k-tiles at a time
|
||||
if (kTileBaseRemaining < kTiles) {
|
||||
#if defined(USE_ROCM)
|
||||
bf16x2x2_u32 a[kInnerKTiles];
|
||||
#else
|
||||
bf16x2x4_u32 a[kInnerKTiles];
|
||||
#endif
|
||||
ALayout::template load<kInnerKTiles>(
|
||||
A, m, k, mTiles, mTile, kTiles, kTileBaseRemaining, laneId, a);
|
||||
|
||||
@ -821,29 +668,15 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
|
||||
// We don't simply accumulate into `c` as this creates a too-strong
|
||||
// execution dependency. Instead, we only periodically accumulate into
|
||||
// `c`
|
||||
#if defined(USE_ROCM)
|
||||
VecT<float, 4> cTmp[2];
|
||||
#else
|
||||
float4 cTmp[2];
|
||||
#endif
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < 2; ++k) {
|
||||
#if defined(USE_ROCM)
|
||||
cTmp[k] = VecT<float, 4>{0.0f, 0.0f, 0.0f, 0.0f};
|
||||
#else
|
||||
cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f};
|
||||
#endif
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < 2; ++k) {
|
||||
#if defined(USE_ROCM)
|
||||
cTmp[k] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
a[j * 2 + k].val,
|
||||
b[0][(j * 2 + k) / 2].val[((j * 2 + k) % 2)],
|
||||
cTmp[k], 0, 0, 0);
|
||||
#else
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
|
||||
@ -858,22 +691,14 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
|
||||
"f"(cTmp[k].y),
|
||||
"f"(cTmp[k].z),
|
||||
"f"(cTmp[k].w));
|
||||
#endif
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k = 0; k < 2; ++k) {
|
||||
#if defined(USE_ROCM)
|
||||
c[0] += cTmp[k][0];
|
||||
c[1] += cTmp[k][1];
|
||||
c[2] += cTmp[k][2];
|
||||
c[3] += cTmp[k][3];
|
||||
#else
|
||||
c.x += cTmp[k].x;
|
||||
c.y += cTmp[k].y;
|
||||
c.z += cTmp[k].z;
|
||||
c.w += cTmp[k].w;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -886,14 +711,7 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
|
||||
// FIXME: this likely doesn't need to be a true reduction tree, can just be a
|
||||
// serial sum, maybe (unless nvcc/ptxas goes back to its old ways)
|
||||
// smem_sum[warpId][laneId] = TreeReduce4<KTilesPerIteration>::reduce(c);
|
||||
#if defined(USE_ROCM)
|
||||
smem_sum[warpId][laneId].x = c[0];
|
||||
smem_sum[warpId][laneId].y = c[1];
|
||||
smem_sum[warpId][laneId].z = c[2];
|
||||
smem_sum[warpId][laneId].w = c[3];
|
||||
#else
|
||||
smem_sum[warpId][laneId] = c;
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@ -923,9 +741,6 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
|
||||
laneId,
|
||||
sum_f32);
|
||||
}
|
||||
#else
|
||||
printf("__builtin_amdgcn_mfma_f32_16x16x16bf16_1k is only supported on AMD gpu arch greater than or equal to CDNA2\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@ -983,12 +798,7 @@ void launch_tinygemm_kernel(
|
||||
cudaFuncAttributes funcAttr;
|
||||
C10_CUDA_CHECK(cudaFuncGetAttributes(
|
||||
&funcAttr,
|
||||
#if defined(USE_ROCM)
|
||||
(void *)func
|
||||
#else
|
||||
func
|
||||
#endif
|
||||
));
|
||||
func));
|
||||
}
|
||||
|
||||
// FIXME: parallelize better, smem staging etc?
|
||||
@ -1003,11 +813,7 @@ __global__ void matrix_to_m16n8k16_Bint4_layout(
|
||||
// innermost k-tiles that we can use is 2.
|
||||
static_assert(InnerKTiles >= 2 && isPowerOf2(InnerKTiles), "");
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
constexpr int32_t kNTileSize = 16;
|
||||
#else
|
||||
constexpr int32_t kNTileSize = 8;
|
||||
#endif
|
||||
constexpr int32_t kKTileSize = 16;
|
||||
|
||||
// gridDim.x corresponds to the number of k-tiles divided by InnerKTiles
|
||||
@ -1019,30 +825,13 @@ __global__ void matrix_to_m16n8k16_Bint4_layout(
|
||||
#pragma unroll
|
||||
for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) {
|
||||
// n dimension that this lane loads from
|
||||
#if defined(USE_ROCM)
|
||||
auto n0 = nTile * kNTileSize + (t % kNTileSize);
|
||||
#else
|
||||
auto n0 = nTile * kNTileSize + (t / 4);
|
||||
#endif
|
||||
|
||||
bool n0Valid = n0 < in.size(0);
|
||||
|
||||
int32_t ks[8];
|
||||
|
||||
auto kBase0 = (kOuterTile * InnerKTiles + innerKTile) * kKTileSize;
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
ks[0] = kBase0 + (t / kNTileSize) * 4;
|
||||
ks[1] = ks[0] + 1;
|
||||
ks[2] = ks[0] + 2;
|
||||
ks[3] = ks[0] + 3;
|
||||
|
||||
auto kBase1 = kBase0 + kKTileSize;
|
||||
ks[4] = kBase1 + (t / kNTileSize) * 4;
|
||||
ks[5] = ks[4] + 1;
|
||||
ks[6] = ks[4] + 2;
|
||||
ks[7] = ks[4] + 3;
|
||||
#else
|
||||
ks[0] = kBase0 + (t % 4) * 2;
|
||||
ks[1] = ks[0] + 1;
|
||||
ks[2] = ks[0] + 8;
|
||||
@ -1053,7 +842,6 @@ __global__ void matrix_to_m16n8k16_Bint4_layout(
|
||||
ks[5] = ks[4] + 1;
|
||||
ks[6] = ks[4] + 8;
|
||||
ks[7] = ks[4] + 8 + 1;
|
||||
#endif
|
||||
|
||||
auto pIn = &in[n0][0];
|
||||
|
||||
@ -1067,19 +855,7 @@ __global__ void matrix_to_m16n8k16_Bint4_layout(
|
||||
(v[6] << 12) | (v[4] << 8) | (v[2] << 4) | v[0];
|
||||
|
||||
// inner k-tiles pack two at a time
|
||||
#if defined(USE_ROCM)
|
||||
// The output tensor shape is [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2], which is specific to Nvidia
|
||||
// But AMD needs [ceil(n / 16)][ceil(k / (InnerKTiles * 16))][64][InnerKTiles / 2]
|
||||
// So construct the pointer accordingly
|
||||
auto bPtr = out.data() +
|
||||
((nTile * out.size(1) * kWarpSize * (InnerKTiles / 2)) +
|
||||
(kOuterTile * kWarpSize * (InnerKTiles / 2)) +
|
||||
(t * (InnerKTiles / 2)) +
|
||||
(innerKTile / 2));
|
||||
*bPtr = pack;
|
||||
#else
|
||||
out[nTile][kOuterTile][t][innerKTile / 2] = pack;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@ -1096,30 +872,16 @@ at::Tensor _weight_int4pack_mm_cuda(
|
||||
TORCH_CHECK(
|
||||
A.device() == B.device() && A.device() == qScaleAndZeros.device());
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
if (!isCDNA2orLater(A.device().index())) {
|
||||
TORCH_CHECK(false, "_weight_int4pack_mm_cuda is only supported on AMD gpu arch greater than or equal to CDNA2");
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr int32_t kMTileSize = 16;
|
||||
#if defined(USE_ROCM)
|
||||
constexpr int32_t kNTileSize = 16;
|
||||
#else
|
||||
constexpr int32_t kNTileSize = 8;
|
||||
#endif
|
||||
constexpr int32_t kKTileSize = 16;
|
||||
|
||||
// row major layout
|
||||
auto m = A.size(0);
|
||||
auto mTiles = divUp(m, kMTileSize);
|
||||
|
||||
// To convert the nTiles from tensor storage layout to the actual matrix core layout
|
||||
constexpr int32_t kNTileSizeTensor = 8;
|
||||
auto nTileScaleFactor = (kNTileSize / kNTileSizeTensor);
|
||||
|
||||
// tensor core layout
|
||||
auto nTiles = (B.size(0) / nTileScaleFactor);
|
||||
auto nTiles = B.size(0);
|
||||
auto n = nTiles * kNTileSize;
|
||||
|
||||
// row major layout
|
||||
@ -1142,7 +904,7 @@ at::Tensor _weight_int4pack_mm_cuda(
|
||||
TORCH_CHECK(B.is_contiguous());
|
||||
TORCH_CHECK(B.dim() == 4);
|
||||
TORCH_CHECK(B.size(1) == k / (B_innerKTiles * kKTileSize));
|
||||
TORCH_CHECK(B.size(2) == 32);
|
||||
TORCH_CHECK(B.size(2) == kWarpSize);
|
||||
|
||||
// Validate the scale and zero point tensor for dequantization
|
||||
// These are the only versions handled at the moment
|
||||
@ -1162,7 +924,7 @@ at::Tensor _weight_int4pack_mm_cuda(
|
||||
auto C_final = at::empty(
|
||||
{m, n}, at::TensorOptions().dtype(at::kBFloat16).device(A.device()));
|
||||
|
||||
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
#define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE) \
|
||||
do { \
|
||||
@ -1291,27 +1053,10 @@ at::Tensor _convert_weight_to_int4pack_cuda(
|
||||
// which is the maximum vectorized load/store size
|
||||
TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8);
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
if (!isCDNA2orLater(in.device().index())) {
|
||||
TORCH_CHECK(false, "_convert_weight_to_int4pack_cuda is only supported on AMD gpu arch greater than or equal to CDNA2");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
constexpr int32_t kNTileSize = 16;
|
||||
#else
|
||||
constexpr int32_t kNTileSize = 8;
|
||||
#endif
|
||||
constexpr int32_t kKTileSize = 16;
|
||||
|
||||
// GPT-FAST assumes nTileSize of 8 for quantized weight tensor.
|
||||
// See https://github.com/pytorch-labs/gpt-fast/blob/091515ab5b06f91c0d6a3b92f9c27463f738cc9b/quantize.py#L510
|
||||
// Torch dynamo also requires the torch ops has the same output shape for each device.
|
||||
// See https://github.com/pytorch/pytorch/blob/ec284d3a74ec1863685febd53687d491fd99a161/torch/_meta_registrations.py#L3263
|
||||
constexpr int32_t kNTileSizeTensor = 8;
|
||||
|
||||
auto nTiles = divUp(in.size(0), kNTileSize);
|
||||
auto nTilesTensor = divUp(in.size(0), kNTileSizeTensor);
|
||||
|
||||
// k-tiles are packed back to back in the innermost dimension in order to
|
||||
// allow for 4/8/16 byte loads
|
||||
@ -1321,14 +1066,11 @@ at::Tensor _convert_weight_to_int4pack_cuda(
|
||||
|
||||
// each block handles `innerKTiles` k-tiles.
|
||||
// 2 k-tiles are a single int32
|
||||
//
|
||||
// We use the same shape for AMD gpus also to match the GPT-FAST spec.
|
||||
// Will index it correctly when dereferencing the quantized weight tensor pointer.
|
||||
auto out = at::empty(
|
||||
{nTilesTensor, kSuperTiles, 32, innerKTiles / 2},
|
||||
{nTiles, kSuperTiles, 32, innerKTiles / 2},
|
||||
at::TensorOptions().dtype(at::kInt).device(in.device()));
|
||||
|
||||
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
|
||||
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
dim3 grid(kSuperTiles, nTiles);
|
||||
|
||||
|
@ -31,7 +31,7 @@ from torch.testing._internal.common_dtype import (
|
||||
floating_and_complex_types_and, floating_types_and, complex_types,
|
||||
)
|
||||
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \
|
||||
_get_torch_cuda_version, CDNA2OrLater
|
||||
_get_torch_cuda_version
|
||||
from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel
|
||||
from torch.testing._internal.common_mkldnn import bf32_on_and_off
|
||||
from torch.distributions.binomial import Binomial
|
||||
@ -6127,8 +6127,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
self.skipTest("requires SM80 or later")
|
||||
|
||||
if TEST_WITH_ROCM:
|
||||
if not CDNA2OrLater():
|
||||
self.skipTest("_int4_mm is supported only for CDNA2 or later")
|
||||
self.skipTest("_int4_mm not compiled for ROCM")
|
||||
|
||||
q_group = 32
|
||||
inner_k_tiles = 2
|
||||
@ -6176,8 +6175,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
self.skipTest("requires SM80 or later")
|
||||
|
||||
if TEST_WITH_ROCM:
|
||||
if not CDNA2OrLater():
|
||||
self.skipTest("_int4_mm is supported only for CDNA2 or later")
|
||||
self.skipTest("_int4_mm not compiled for ROCM")
|
||||
|
||||
q_group = 32
|
||||
inner_k_tiles = 2
|
||||
|
@ -33,12 +33,6 @@ SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_devic
|
||||
|
||||
IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() in [(7, 2), (8, 7)])
|
||||
|
||||
def CDNA2OrLater():
|
||||
if TEST_WITH_ROCM:
|
||||
gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
|
||||
return any(arch in gcn_arch_name for arch in {"gfx90a", "gfx940", "gfx941", "gfx942"})
|
||||
return False
|
||||
|
||||
def evaluate_gfx_arch_exact(matching_arch):
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
|
@ -537,8 +537,6 @@ CUDA_TYPE_NAME_MAP = collections.OrderedDict(
|
||||
("CUuuid", ("hipUUID", CONV_TYPE, API_RUNTIME)),
|
||||
("cudaGraph_t", ("hipGraph_t", CONV_TYPE, API_RAND)),
|
||||
("cudaGraphExec_t", ("hipGraphExec_t", CONV_TYPE, API_RAND)),
|
||||
("__nv_bfloat16", ("__hip_bfloat16", CONV_TYPE, API_RUNTIME)),
|
||||
("__nv_bfloat162", ("__hip_bfloat162", CONV_TYPE, API_RUNTIME)),
|
||||
]
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user