Revert "[ROCm] Add int4 support (#129710)"

This reverts commit d0ad13fa42fc2e9935bd3bda2937a3491276d274.

Reverted https://github.com/pytorch/pytorch/pull/129710 on behalf of https://github.com/jeffdaily due to original ROCm PR did not have ciflow/rocm, missed signal ([comment](https://github.com/pytorch/pytorch/pull/129710#issuecomment-2214558368))
This commit is contained in:
PyTorch MergeBot
2024-07-08 16:07:52 +00:00
parent c8ab2e8b63
commit d7b7f8b79f
4 changed files with 16 additions and 284 deletions

View File

@ -1,11 +1,9 @@
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#if !defined(USE_ROCM)
#include <mma.h>
#endif
#endif
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAContext.h>
@ -127,38 +125,9 @@ inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) {
return diff == 0 ? 0 : uint32_t(Align) - diff;
}
#if defined(USE_ROCM)
// TODO: Support RDNA
constexpr int32_t kWarpSize = 64;
template<typename T, uint32_t Rank>
using VecT = T __attribute__((ext_vector_type(Rank)));
static bool isCDNA2orLater(int index) {
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
std::string device_arch = prop->gcnArchName;
static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
for (std::string arch : archs) {
size_t substring = device_arch.find(arch);
if (substring != std::string::npos) {
return true;
}
}
return false;
}
#else
constexpr int32_t kWarpSize = 32;
#endif
#if defined (__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define CDNA2_OR_LATER 1
#else
#define CDNA2_OR_LATER 0
#endif
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
// f16 vector types
struct __align__(2) f16x1 {
__half vals[1];
@ -207,19 +176,11 @@ struct __align__(16) bf16x2x4 {
};
struct __align__(16) bf16x2x4_u32 {
#if defined(USE_ROCM)
VecT<short, 4> val[2];
#else
uint32_t vals[4];
#endif
};
struct __align__(8) bf16x2x2_u32 {
#if defined(USE_ROCM)
VecT<short, 4> val;
#else
uint32_t vals[2];
#endif
};
struct __align__(4) bf16x2x1_u32 {
@ -241,68 +202,38 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) {
uint32_t const source_i4s = source;
// First, we extract the i4s and construct an intermediate fp16 number.
#if !defined(USE_ROCM)
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
#endif
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
// We don't have enough mantissa to remove as much shift overhead as FP16, so
// we must loop. No shift needed for first item.
uint32_t i4s = source_i4s;
#if defined(USE_ROCM)
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(h[0])
: "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM));
#else
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
#endif
#pragma unroll
for (int ii = 1; ii < kElements / 2; ++ii) {
i4s >>= 4; // or is it 8?
// (i4s & 0x000f000f) | 0x43004300
#if defined(USE_ROCM)
asm volatile("v_and_or_b32 %0, %1, %2, %3"
: "=v"(h[ii])
: "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM));
#else
asm volatile(
"lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[ii])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
#endif
}
// This is the BF16 {-136, -136} represented as an integer.
#if defined(USE_ROCM)
#if ROCM_VERSION >= 60200
auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308}));
auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80}));
#else
auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16{0xC308});
auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16{0x3F80});
#endif
#else
static constexpr uint32_t BF16_BIAS = 0xC308C308;
static constexpr uint32_t BF16_ONE = 0x3F803F80;
#endif
// Finally, we construct the output numbers.
#pragma unroll
for (int ii = 0; ii < kElements / 2; ++ii) {
// Since this section is for Ampere+, we use bf16 fma to do the bias
// subtraction
#if defined(USE_ROCM)
result.vals[ii] = __hfma2(result.vals[ii], BF16_ONE, BF16_BIAS);
#else
asm("fma.rn.bf16x2 %0, %1, %2, %3;\n"
: "=r"(h[ii])
: "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
#endif
}
return result;
@ -323,11 +254,7 @@ enum class KReductionType {
template <KReductionType ReduceType>
struct ALayout_RM {
static constexpr int32_t kMTileSize = 16;
#if defined(USE_ROCM)
static constexpr int32_t kNTileSize = 16;
#else
static constexpr int32_t kNTileSize = 8;
#endif
static constexpr int32_t kKTileSize = 16;
template <int KTilesToLoad>
@ -340,37 +267,22 @@ struct ALayout_RM {
int32_t kTiles,
int32_t kTileStart,
int32_t laneId,
#if defined(USE_ROCM)
bf16x2x2_u32 out[KTilesToLoad]
#else
bf16x2x4_u32 out[KTilesToLoad]
#endif
) {
#if defined(USE_ROCM)
const auto mLane = mTile * kMTileSize + (laneId % kMTileSize);
const auto kLane = kTileStart * kKTileSize + (laneId / kMTileSize) * 4;
#else
bf16x2x4_u32 out[KTilesToLoad]) {
const auto mLane = mTile * kMTileSize + (laneId / 4);
const auto kLane = kTileStart * kKTileSize + (laneId % 4) * 2;
#endif
// access
// [mTile * kMTileSize + (laneId / 4)]
// [kTileStart * kKTileSize + (laneId % 4) * 2]
auto aPtr = reinterpret_cast<const __nv_bfloat16*>(A) + mLane * k + kLane;
bool m0InBounds = mLane < m;
#if !defined(USE_ROCM)
auto aPtrPlus8Rows = aPtr + 8 * k;
bool m0InBounds = mLane < m;
bool m1InBounds = (mLane + 8) < m;
#endif
#pragma unroll
for (int i = 0; i < KTilesToLoad; ++i) {
#if defined(USE_ROCM)
out[i].val = m0InBounds ? *((VecT<short, 4> *)(aPtr + i * kKTileSize)) : VecT<short, 4>{0, 0, 0, 0};
#else
out[i].vals[0] = m0InBounds
? *reinterpret_cast<const uint32_t*>(aPtr + i * kKTileSize)
: uint32_t(0);
@ -384,7 +296,6 @@ struct ALayout_RM {
out[i].vals[3] = m1InBounds ? *reinterpret_cast<const uint32_t*>(
aPtrPlus8Rows + i * kKTileSize + 8)
: uint32_t(0);
#endif
}
}
@ -401,10 +312,6 @@ struct ALayout_RM {
static_assert(ReduceType == KReductionType::None, "");
if constexpr (ReduceType == KReductionType::None) {
#if defined(USE_ROCM)
const int outRow = mTile * kMTileSize + (laneId / kNTileSize) * 4;
const int outCol = nTile * kNTileSize + (laneId % kNTileSize);
#else
// sum.x / sum.y are written at
// [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1]
// sum.z / sum.w are written at
@ -412,21 +319,10 @@ struct ALayout_RM {
// i.e., same columns, different row.
const int outRow = mTile * kMTileSize + (laneId / 4);
const int outCol = nTile * kNTileSize + (laneId % 4) * 2;
#endif
// Pointer where sum.x / sum.y is written
auto cPtr = reinterpret_cast<__nv_bfloat16*>(C) + outRow * n + outCol;
#if defined(USE_ROCM)
if (outRow < m)
cPtr[0] = __float2bfloat16(out.x);
if ((outRow + 1) < m)
cPtr[n] = __float2bfloat16(out.y);
if ((outRow + 2) < m)
cPtr[2*n] = __float2bfloat16(out.z);
if ((outRow + 3) < m)
cPtr[3*n] = __float2bfloat16(out.w);
#else
auto v01 = __float22bfloat162_rn(float2{out.x, out.y});
auto v23 = __float22bfloat162_rn(float2{out.z, out.w});
@ -438,7 +334,6 @@ struct ALayout_RM {
if (outRow + 8 < m) {
*reinterpret_cast<__nv_bfloat162*>(cPtr + 8 * n) = v23;
}
#endif
}
}
};
@ -447,19 +342,15 @@ template <int InnerKTiles, int QGroupSize>
struct BLayout_TC_int4 {
static constexpr int32_t kInnerKTiles = InnerKTiles;
static constexpr int32_t kMTileSize = 16;
#if defined(USE_ROCM)
static constexpr int32_t kNTileSize = 16;
#else
static constexpr int32_t kNTileSize = 8;
#endif
static constexpr int32_t kKTileSize = 16;
template <int KTilesToLoad>
static __device__ void load(
// type uint32, size [n / 8][k / (InnerKTiles * 16)][32][InnerKTiles / 2]
// n-tiles: n / 8 for NV, n /16 for AMD
// k / (InnerKTiles * 16): TC size per k-tile is 16 (m16n8k16 for NV, m16n16k16 for AMD)
// value per warp lane: 32 for NV, 64 for AMD
// n / 8: n-tiles (n8)
// k / (InnerKTiles * 16): TC size per k-tile is 16 (m16n8k16)
// 32: value per warp lane
// (InnerKTiles / 2): B layout has 4 values per lane (16 bits) per k-tile.
// 2 k-tiles packed is a uint32 (hence InnerKTiles == 2 is our smallest
// value) 4 k-tiles packed is a uint32x2 (64 bits) 8 k-tiles packed is a
@ -532,11 +423,7 @@ struct BLayout_TC_int4 {
__nv_bfloat162 qScaleAndZero[kNumQGroups];
{
#if defined(USE_ROCM)
int32_t laneN = nTile * kNTileSize + (laneId % kNTileSize);
#else
int32_t laneN = nTile * kNTileSize + (laneId / 4);
#endif
int32_t groupStart = (kTileStart * kKTileSize) / QGroupSize;
int32_t n = nTiles * kNTileSize;
@ -627,15 +514,9 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
int32_t nTiles,
int32_t kTiles) {
constexpr int32_t kMTileSize = 16;
#if defined(USE_ROCM)
constexpr int32_t kNTileSize = 16;
#else
constexpr int32_t kNTileSize = 8;
#endif
constexpr int32_t kKTileSize = 16;
#if !defined(USE_ROCM) || CDNA2_OR_LATER
static_assert(
ALayout::kMTileSize == kMTileSize && ALayout::kNTileSize == kNTileSize &&
ALayout::kKTileSize == kKTileSize,
@ -669,11 +550,7 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
int32_t mTile = blockIdx.z;
int32_t nTile = blockIdx.y;
#if defined(USE_ROCM)
VecT<float, 4> c{0.0f, 0.0f, 0.0f, 0.0f};
#else
float4 c{0.0f, 0.0f, 0.0f, 0.0f};
#endif
// First, handle whole multiples of KTilesPerIteration
auto kTilesLimit = roundDown(kTiles, KTilesPerIteration);
@ -685,11 +562,7 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
//
// Load data from A
//
#if defined(USE_ROCM)
bf16x2x2_u32 a[KTilesPerIteration];
#else
bf16x2x4_u32 a[KTilesPerIteration];
#endif
ALayout::template load<KTilesPerIteration>(
A, m, k, mTiles, mTile, kTiles, kTileBase, laneId, a);
@ -723,29 +596,15 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
// We don't simply accumulate into `c` as this creates a too-strong
// execution dependency. Instead, we only periodically accumulate into
// `c`
#if defined(USE_ROCM)
VecT<float, 4> cTmp[2];
#else
float4 cTmp[2];
#endif
#pragma unroll
for (int k = 0; k < 2; ++k) {
#if defined(USE_ROCM)
cTmp[k] = VecT<float, 4>{0.0f, 0.0f, 0.0f, 0.0f};
#else
cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f};
#endif
}
#pragma unroll
for (int k = 0; k < 2; ++k) {
#if defined(USE_ROCM)
cTmp[k] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
a[i * kInnerKTiles + j * 2 + k].val,
b[i][(j * 2 + k) / 2].val[((j * 2 + k) % 2)],
cTmp[k], 0, 0, 0);
#else
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
@ -763,22 +622,14 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
"f"(cTmp[k].y),
"f"(cTmp[k].z),
"f"(cTmp[k].w));
#endif
}
#pragma unroll
for (int k = 0; k < 2; ++k) {
#if defined(USE_ROCM)
c[0] += cTmp[k][0];
c[1] += cTmp[k][1];
c[2] += cTmp[k][2];
c[3] += cTmp[k][3];
#else
c.x += cTmp[k].x;
c.y += cTmp[k].y;
c.z += cTmp[k].z;
c.w += cTmp[k].w;
#endif
}
}
}
@ -795,11 +646,7 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
// If we have any remainder k-tiles, some warps will handle them, processing
// kInnerKTiles k-tiles at a time
if (kTileBaseRemaining < kTiles) {
#if defined(USE_ROCM)
bf16x2x2_u32 a[kInnerKTiles];
#else
bf16x2x4_u32 a[kInnerKTiles];
#endif
ALayout::template load<kInnerKTiles>(
A, m, k, mTiles, mTile, kTiles, kTileBaseRemaining, laneId, a);
@ -821,29 +668,15 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
// We don't simply accumulate into `c` as this creates a too-strong
// execution dependency. Instead, we only periodically accumulate into
// `c`
#if defined(USE_ROCM)
VecT<float, 4> cTmp[2];
#else
float4 cTmp[2];
#endif
#pragma unroll
for (int k = 0; k < 2; ++k) {
#if defined(USE_ROCM)
cTmp[k] = VecT<float, 4>{0.0f, 0.0f, 0.0f, 0.0f};
#else
cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f};
#endif
}
#pragma unroll
for (int k = 0; k < 2; ++k) {
#if defined(USE_ROCM)
cTmp[k] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
a[j * 2 + k].val,
b[0][(j * 2 + k) / 2].val[((j * 2 + k) % 2)],
cTmp[k], 0, 0, 0);
#else
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
@ -858,22 +691,14 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
"f"(cTmp[k].y),
"f"(cTmp[k].z),
"f"(cTmp[k].w));
#endif
}
#pragma unroll
for (int k = 0; k < 2; ++k) {
#if defined(USE_ROCM)
c[0] += cTmp[k][0];
c[1] += cTmp[k][1];
c[2] += cTmp[k][2];
c[3] += cTmp[k][3];
#else
c.x += cTmp[k].x;
c.y += cTmp[k].y;
c.z += cTmp[k].z;
c.w += cTmp[k].w;
#endif
}
}
}
@ -886,14 +711,7 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
// FIXME: this likely doesn't need to be a true reduction tree, can just be a
// serial sum, maybe (unless nvcc/ptxas goes back to its old ways)
// smem_sum[warpId][laneId] = TreeReduce4<KTilesPerIteration>::reduce(c);
#if defined(USE_ROCM)
smem_sum[warpId][laneId].x = c[0];
smem_sum[warpId][laneId].y = c[1];
smem_sum[warpId][laneId].z = c[2];
smem_sum[warpId][laneId].w = c[3];
#else
smem_sum[warpId][laneId] = c;
#endif
__syncthreads();
@ -923,9 +741,6 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel(
laneId,
sum_f32);
}
#else
printf("__builtin_amdgcn_mfma_f32_16x16x16bf16_1k is only supported on AMD gpu arch greater than or equal to CDNA2\n");
#endif
}
@ -983,12 +798,7 @@ void launch_tinygemm_kernel(
cudaFuncAttributes funcAttr;
C10_CUDA_CHECK(cudaFuncGetAttributes(
&funcAttr,
#if defined(USE_ROCM)
(void *)func
#else
func
#endif
));
func));
}
// FIXME: parallelize better, smem staging etc?
@ -1003,11 +813,7 @@ __global__ void matrix_to_m16n8k16_Bint4_layout(
// innermost k-tiles that we can use is 2.
static_assert(InnerKTiles >= 2 && isPowerOf2(InnerKTiles), "");
#if defined(USE_ROCM)
constexpr int32_t kNTileSize = 16;
#else
constexpr int32_t kNTileSize = 8;
#endif
constexpr int32_t kKTileSize = 16;
// gridDim.x corresponds to the number of k-tiles divided by InnerKTiles
@ -1019,30 +825,13 @@ __global__ void matrix_to_m16n8k16_Bint4_layout(
#pragma unroll
for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) {
// n dimension that this lane loads from
#if defined(USE_ROCM)
auto n0 = nTile * kNTileSize + (t % kNTileSize);
#else
auto n0 = nTile * kNTileSize + (t / 4);
#endif
bool n0Valid = n0 < in.size(0);
int32_t ks[8];
auto kBase0 = (kOuterTile * InnerKTiles + innerKTile) * kKTileSize;
#if defined(USE_ROCM)
ks[0] = kBase0 + (t / kNTileSize) * 4;
ks[1] = ks[0] + 1;
ks[2] = ks[0] + 2;
ks[3] = ks[0] + 3;
auto kBase1 = kBase0 + kKTileSize;
ks[4] = kBase1 + (t / kNTileSize) * 4;
ks[5] = ks[4] + 1;
ks[6] = ks[4] + 2;
ks[7] = ks[4] + 3;
#else
ks[0] = kBase0 + (t % 4) * 2;
ks[1] = ks[0] + 1;
ks[2] = ks[0] + 8;
@ -1053,7 +842,6 @@ __global__ void matrix_to_m16n8k16_Bint4_layout(
ks[5] = ks[4] + 1;
ks[6] = ks[4] + 8;
ks[7] = ks[4] + 8 + 1;
#endif
auto pIn = &in[n0][0];
@ -1067,19 +855,7 @@ __global__ void matrix_to_m16n8k16_Bint4_layout(
(v[6] << 12) | (v[4] << 8) | (v[2] << 4) | v[0];
// inner k-tiles pack two at a time
#if defined(USE_ROCM)
// The output tensor shape is [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2], which is specific to Nvidia
// But AMD needs [ceil(n / 16)][ceil(k / (InnerKTiles * 16))][64][InnerKTiles / 2]
// So construct the pointer accordingly
auto bPtr = out.data() +
((nTile * out.size(1) * kWarpSize * (InnerKTiles / 2)) +
(kOuterTile * kWarpSize * (InnerKTiles / 2)) +
(t * (InnerKTiles / 2)) +
(innerKTile / 2));
*bPtr = pack;
#else
out[nTile][kOuterTile][t][innerKTile / 2] = pack;
#endif
}
}
@ -1096,30 +872,16 @@ at::Tensor _weight_int4pack_mm_cuda(
TORCH_CHECK(
A.device() == B.device() && A.device() == qScaleAndZeros.device());
#if defined(USE_ROCM)
if (!isCDNA2orLater(A.device().index())) {
TORCH_CHECK(false, "_weight_int4pack_mm_cuda is only supported on AMD gpu arch greater than or equal to CDNA2");
}
#endif
constexpr int32_t kMTileSize = 16;
#if defined(USE_ROCM)
constexpr int32_t kNTileSize = 16;
#else
constexpr int32_t kNTileSize = 8;
#endif
constexpr int32_t kKTileSize = 16;
// row major layout
auto m = A.size(0);
auto mTiles = divUp(m, kMTileSize);
// To convert the nTiles from tensor storage layout to the actual matrix core layout
constexpr int32_t kNTileSizeTensor = 8;
auto nTileScaleFactor = (kNTileSize / kNTileSizeTensor);
// tensor core layout
auto nTiles = (B.size(0) / nTileScaleFactor);
auto nTiles = B.size(0);
auto n = nTiles * kNTileSize;
// row major layout
@ -1142,7 +904,7 @@ at::Tensor _weight_int4pack_mm_cuda(
TORCH_CHECK(B.is_contiguous());
TORCH_CHECK(B.dim() == 4);
TORCH_CHECK(B.size(1) == k / (B_innerKTiles * kKTileSize));
TORCH_CHECK(B.size(2) == 32);
TORCH_CHECK(B.size(2) == kWarpSize);
// Validate the scale and zero point tensor for dequantization
// These are the only versions handled at the moment
@ -1162,7 +924,7 @@ at::Tensor _weight_int4pack_mm_cuda(
auto C_final = at::empty(
{m, n}, at::TensorOptions().dtype(at::kBFloat16).device(A.device()));
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
auto stream = at::cuda::getCurrentCUDAStream();
#define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE) \
do { \
@ -1291,27 +1053,10 @@ at::Tensor _convert_weight_to_int4pack_cuda(
// which is the maximum vectorized load/store size
TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8);
#if defined(USE_ROCM)
if (!isCDNA2orLater(in.device().index())) {
TORCH_CHECK(false, "_convert_weight_to_int4pack_cuda is only supported on AMD gpu arch greater than or equal to CDNA2");
}
#endif
#if defined(USE_ROCM)
constexpr int32_t kNTileSize = 16;
#else
constexpr int32_t kNTileSize = 8;
#endif
constexpr int32_t kKTileSize = 16;
// GPT-FAST assumes nTileSize of 8 for quantized weight tensor.
// See https://github.com/pytorch-labs/gpt-fast/blob/091515ab5b06f91c0d6a3b92f9c27463f738cc9b/quantize.py#L510
// Torch dynamo also requires the torch ops has the same output shape for each device.
// See https://github.com/pytorch/pytorch/blob/ec284d3a74ec1863685febd53687d491fd99a161/torch/_meta_registrations.py#L3263
constexpr int32_t kNTileSizeTensor = 8;
auto nTiles = divUp(in.size(0), kNTileSize);
auto nTilesTensor = divUp(in.size(0), kNTileSizeTensor);
// k-tiles are packed back to back in the innermost dimension in order to
// allow for 4/8/16 byte loads
@ -1321,14 +1066,11 @@ at::Tensor _convert_weight_to_int4pack_cuda(
// each block handles `innerKTiles` k-tiles.
// 2 k-tiles are a single int32
//
// We use the same shape for AMD gpus also to match the GPT-FAST spec.
// Will index it correctly when dereferencing the quantized weight tensor pointer.
auto out = at::empty(
{nTilesTensor, kSuperTiles, 32, innerKTiles / 2},
{nTiles, kSuperTiles, 32, innerKTiles / 2},
at::TensorOptions().dtype(at::kInt).device(in.device()));
#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)))
#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
auto stream = at::cuda::getCurrentCUDAStream();
dim3 grid(kSuperTiles, nTiles);

View File

@ -31,7 +31,7 @@ from torch.testing._internal.common_dtype import (
floating_and_complex_types_and, floating_types_and, complex_types,
)
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \
_get_torch_cuda_version, CDNA2OrLater
_get_torch_cuda_version
from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel
from torch.testing._internal.common_mkldnn import bf32_on_and_off
from torch.distributions.binomial import Binomial
@ -6127,8 +6127,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
self.skipTest("requires SM80 or later")
if TEST_WITH_ROCM:
if not CDNA2OrLater():
self.skipTest("_int4_mm is supported only for CDNA2 or later")
self.skipTest("_int4_mm not compiled for ROCM")
q_group = 32
inner_k_tiles = 2
@ -6176,8 +6175,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
self.skipTest("requires SM80 or later")
if TEST_WITH_ROCM:
if not CDNA2OrLater():
self.skipTest("_int4_mm is supported only for CDNA2 or later")
self.skipTest("_int4_mm not compiled for ROCM")
q_group = 32
inner_k_tiles = 2

View File

@ -33,12 +33,6 @@ SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_devic
IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() in [(7, 2), (8, 7)])
def CDNA2OrLater():
if TEST_WITH_ROCM:
gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
return any(arch in gcn_arch_name for arch in {"gfx90a", "gfx940", "gfx941", "gfx942"})
return False
def evaluate_gfx_arch_exact(matching_arch):
if not torch.cuda.is_available():
return False

View File

@ -537,8 +537,6 @@ CUDA_TYPE_NAME_MAP = collections.OrderedDict(
("CUuuid", ("hipUUID", CONV_TYPE, API_RUNTIME)),
("cudaGraph_t", ("hipGraph_t", CONV_TYPE, API_RAND)),
("cudaGraphExec_t", ("hipGraphExec_t", CONV_TYPE, API_RAND)),
("__nv_bfloat16", ("__hip_bfloat16", CONV_TYPE, API_RUNTIME)),
("__nv_bfloat162", ("__hip_bfloat162", CONV_TYPE, API_RUNTIME)),
]
)