From d7b7f8b79f4ad5ffe0152649aa60a10b15eb744c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 8 Jul 2024 16:07:52 +0000 Subject: [PATCH] Revert "[ROCm] Add int4 support (#129710)" This reverts commit d0ad13fa42fc2e9935bd3bda2937a3491276d274. Reverted https://github.com/pytorch/pytorch/pull/129710 on behalf of https://github.com/jeffdaily due to original ROCm PR did not have ciflow/rocm, missed signal ([comment](https://github.com/pytorch/pytorch/pull/129710#issuecomment-2214558368)) --- aten/src/ATen/native/cuda/int4mm.cu | 284 +-------------------- test/test_linalg.py | 8 +- torch/testing/_internal/common_cuda.py | 6 - torch/utils/hipify/cuda_to_hip_mappings.py | 2 - 4 files changed, 16 insertions(+), 284 deletions(-) diff --git a/aten/src/ATen/native/cuda/int4mm.cu b/aten/src/ATen/native/cuda/int4mm.cu index 129b27987997..fcfcd2e5ebbd 100644 --- a/aten/src/ATen/native/cuda/int4mm.cu +++ b/aten/src/ATen/native/cuda/int4mm.cu @@ -1,11 +1,9 @@ -#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) +#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) #include #include #include -#if !defined(USE_ROCM) #include #endif -#endif #include #include #include @@ -127,38 +125,9 @@ inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) { return diff == 0 ? 0 : uint32_t(Align) - diff; } -#if defined(USE_ROCM) -// TODO: Support RDNA -constexpr int32_t kWarpSize = 64; - -template -using VecT = T __attribute__((ext_vector_type(Rank))); - -static bool isCDNA2orLater(int index) { - hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index); - std::string device_arch = prop->gcnArchName; - static const std::vector archs = {"gfx90a", "gfx940", "gfx941", "gfx942"}; - for (std::string arch : archs) { - size_t substring = device_arch.find(arch); - if (substring != std::string::npos) { - return true; - } - } - return false; -} - -#else constexpr int32_t kWarpSize = 32; -#endif - -#if defined (__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) -#define CDNA2_OR_LATER 1 -#else -#define CDNA2_OR_LATER 0 -#endif - -#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) +#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) // f16 vector types struct __align__(2) f16x1 { __half vals[1]; @@ -207,19 +176,11 @@ struct __align__(16) bf16x2x4 { }; struct __align__(16) bf16x2x4_u32 { -#if defined(USE_ROCM) - VecT val[2]; -#else uint32_t vals[4]; -#endif }; struct __align__(8) bf16x2x2_u32 { -#if defined(USE_ROCM) - VecT val; -#else uint32_t vals[2]; -#endif }; struct __align__(4) bf16x2x1_u32 { @@ -241,68 +202,38 @@ inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { uint32_t const source_i4s = source; // First, we extract the i4s and construct an intermediate fp16 number. -#if !defined(USE_ROCM) static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; -#endif static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; // We don't have enough mantissa to remove as much shift overhead as FP16, so // we must loop. No shift needed for first item. uint32_t i4s = source_i4s; - -#if defined(USE_ROCM) - asm volatile("v_and_or_b32 %0, %1, %2, %3" - : "=v"(h[0]) - : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM)); -#else asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); -#endif - #pragma unroll for (int ii = 1; ii < kElements / 2; ++ii) { i4s >>= 4; // or is it 8? // (i4s & 0x000f000f) | 0x43004300 -#if defined(USE_ROCM) - asm volatile("v_and_or_b32 %0, %1, %2, %3" - : "=v"(h[ii]) - : "v"(i4s), "v"(MASK), "v"(I4s_TO_BF16s_MAGIC_NUM)); -#else asm volatile( "lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[ii]) : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); -#endif } // This is the BF16 {-136, -136} represented as an integer. -#if defined(USE_ROCM) -#if ROCM_VERSION >= 60200 - auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0xC308})); - auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16(__hip_bfloat16_raw{0x3F80})); -#else - auto BF16_BIAS = __bfloat162bfloat162(__hip_bfloat16{0xC308}); - auto BF16_ONE = __bfloat162bfloat162(__hip_bfloat16{0x3F80}); -#endif -#else static constexpr uint32_t BF16_BIAS = 0xC308C308; static constexpr uint32_t BF16_ONE = 0x3F803F80; -#endif // Finally, we construct the output numbers. #pragma unroll for (int ii = 0; ii < kElements / 2; ++ii) { // Since this section is for Ampere+, we use bf16 fma to do the bias // subtraction -#if defined(USE_ROCM) - result.vals[ii] = __hfma2(result.vals[ii], BF16_ONE, BF16_BIAS); -#else asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); -#endif } return result; @@ -323,11 +254,7 @@ enum class KReductionType { template struct ALayout_RM { static constexpr int32_t kMTileSize = 16; -#if defined(USE_ROCM) - static constexpr int32_t kNTileSize = 16; -#else static constexpr int32_t kNTileSize = 8; -#endif static constexpr int32_t kKTileSize = 16; template @@ -340,37 +267,22 @@ struct ALayout_RM { int32_t kTiles, int32_t kTileStart, int32_t laneId, -#if defined(USE_ROCM) - bf16x2x2_u32 out[KTilesToLoad] -#else - bf16x2x4_u32 out[KTilesToLoad] -#endif - ) { -#if defined(USE_ROCM) - const auto mLane = mTile * kMTileSize + (laneId % kMTileSize); - const auto kLane = kTileStart * kKTileSize + (laneId / kMTileSize) * 4; -#else + bf16x2x4_u32 out[KTilesToLoad]) { const auto mLane = mTile * kMTileSize + (laneId / 4); const auto kLane = kTileStart * kKTileSize + (laneId % 4) * 2; -#endif // access // [mTile * kMTileSize + (laneId / 4)] // [kTileStart * kKTileSize + (laneId % 4) * 2] auto aPtr = reinterpret_cast(A) + mLane * k + kLane; - bool m0InBounds = mLane < m; -#if !defined(USE_ROCM) auto aPtrPlus8Rows = aPtr + 8 * k; + bool m0InBounds = mLane < m; bool m1InBounds = (mLane + 8) < m; -#endif #pragma unroll for (int i = 0; i < KTilesToLoad; ++i) { -#if defined(USE_ROCM) - out[i].val = m0InBounds ? *((VecT *)(aPtr + i * kKTileSize)) : VecT{0, 0, 0, 0}; -#else out[i].vals[0] = m0InBounds ? *reinterpret_cast(aPtr + i * kKTileSize) : uint32_t(0); @@ -384,7 +296,6 @@ struct ALayout_RM { out[i].vals[3] = m1InBounds ? *reinterpret_cast( aPtrPlus8Rows + i * kKTileSize + 8) : uint32_t(0); -#endif } } @@ -401,10 +312,6 @@ struct ALayout_RM { static_assert(ReduceType == KReductionType::None, ""); if constexpr (ReduceType == KReductionType::None) { -#if defined(USE_ROCM) - const int outRow = mTile * kMTileSize + (laneId / kNTileSize) * 4; - const int outCol = nTile * kNTileSize + (laneId % kNTileSize); -#else // sum.x / sum.y are written at // [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1] // sum.z / sum.w are written at @@ -412,21 +319,10 @@ struct ALayout_RM { // i.e., same columns, different row. const int outRow = mTile * kMTileSize + (laneId / 4); const int outCol = nTile * kNTileSize + (laneId % 4) * 2; -#endif // Pointer where sum.x / sum.y is written auto cPtr = reinterpret_cast<__nv_bfloat16*>(C) + outRow * n + outCol; -#if defined(USE_ROCM) - if (outRow < m) - cPtr[0] = __float2bfloat16(out.x); - if ((outRow + 1) < m) - cPtr[n] = __float2bfloat16(out.y); - if ((outRow + 2) < m) - cPtr[2*n] = __float2bfloat16(out.z); - if ((outRow + 3) < m) - cPtr[3*n] = __float2bfloat16(out.w); -#else auto v01 = __float22bfloat162_rn(float2{out.x, out.y}); auto v23 = __float22bfloat162_rn(float2{out.z, out.w}); @@ -438,7 +334,6 @@ struct ALayout_RM { if (outRow + 8 < m) { *reinterpret_cast<__nv_bfloat162*>(cPtr + 8 * n) = v23; } -#endif } } }; @@ -447,19 +342,15 @@ template struct BLayout_TC_int4 { static constexpr int32_t kInnerKTiles = InnerKTiles; static constexpr int32_t kMTileSize = 16; -#if defined(USE_ROCM) - static constexpr int32_t kNTileSize = 16; -#else static constexpr int32_t kNTileSize = 8; -#endif static constexpr int32_t kKTileSize = 16; template static __device__ void load( // type uint32, size [n / 8][k / (InnerKTiles * 16)][32][InnerKTiles / 2] - // n-tiles: n / 8 for NV, n /16 for AMD - // k / (InnerKTiles * 16): TC size per k-tile is 16 (m16n8k16 for NV, m16n16k16 for AMD) - // value per warp lane: 32 for NV, 64 for AMD + // n / 8: n-tiles (n8) + // k / (InnerKTiles * 16): TC size per k-tile is 16 (m16n8k16) + // 32: value per warp lane // (InnerKTiles / 2): B layout has 4 values per lane (16 bits) per k-tile. // 2 k-tiles packed is a uint32 (hence InnerKTiles == 2 is our smallest // value) 4 k-tiles packed is a uint32x2 (64 bits) 8 k-tiles packed is a @@ -532,11 +423,7 @@ struct BLayout_TC_int4 { __nv_bfloat162 qScaleAndZero[kNumQGroups]; { -#if defined(USE_ROCM) - int32_t laneN = nTile * kNTileSize + (laneId % kNTileSize); -#else int32_t laneN = nTile * kNTileSize + (laneId / 4); -#endif int32_t groupStart = (kTileStart * kKTileSize) / QGroupSize; int32_t n = nTiles * kNTileSize; @@ -627,15 +514,9 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( int32_t nTiles, int32_t kTiles) { constexpr int32_t kMTileSize = 16; -#if defined(USE_ROCM) - constexpr int32_t kNTileSize = 16; -#else constexpr int32_t kNTileSize = 8; -#endif constexpr int32_t kKTileSize = 16; -#if !defined(USE_ROCM) || CDNA2_OR_LATER - static_assert( ALayout::kMTileSize == kMTileSize && ALayout::kNTileSize == kNTileSize && ALayout::kKTileSize == kKTileSize, @@ -669,11 +550,7 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( int32_t mTile = blockIdx.z; int32_t nTile = blockIdx.y; -#if defined(USE_ROCM) - VecT c{0.0f, 0.0f, 0.0f, 0.0f}; -#else float4 c{0.0f, 0.0f, 0.0f, 0.0f}; -#endif // First, handle whole multiples of KTilesPerIteration auto kTilesLimit = roundDown(kTiles, KTilesPerIteration); @@ -685,11 +562,7 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( // // Load data from A // -#if defined(USE_ROCM) - bf16x2x2_u32 a[KTilesPerIteration]; -#else bf16x2x4_u32 a[KTilesPerIteration]; -#endif ALayout::template load( A, m, k, mTiles, mTile, kTiles, kTileBase, laneId, a); @@ -723,29 +596,15 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( // We don't simply accumulate into `c` as this creates a too-strong // execution dependency. Instead, we only periodically accumulate into // `c` -#if defined(USE_ROCM) - VecT cTmp[2]; -#else float4 cTmp[2]; -#endif #pragma unroll for (int k = 0; k < 2; ++k) { -#if defined(USE_ROCM) - cTmp[k] = VecT{0.0f, 0.0f, 0.0f, 0.0f}; -#else cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f}; -#endif } #pragma unroll for (int k = 0; k < 2; ++k) { -#if defined(USE_ROCM) - cTmp[k] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( - a[i * kInnerKTiles + j * 2 + k].val, - b[i][(j * 2 + k) / 2].val[((j * 2 + k) % 2)], - cTmp[k], 0, 0, 0); -#else asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};" @@ -763,22 +622,14 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( "f"(cTmp[k].y), "f"(cTmp[k].z), "f"(cTmp[k].w)); -#endif } #pragma unroll for (int k = 0; k < 2; ++k) { -#if defined(USE_ROCM) - c[0] += cTmp[k][0]; - c[1] += cTmp[k][1]; - c[2] += cTmp[k][2]; - c[3] += cTmp[k][3]; -#else c.x += cTmp[k].x; c.y += cTmp[k].y; c.z += cTmp[k].z; c.w += cTmp[k].w; -#endif } } } @@ -795,11 +646,7 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( // If we have any remainder k-tiles, some warps will handle them, processing // kInnerKTiles k-tiles at a time if (kTileBaseRemaining < kTiles) { -#if defined(USE_ROCM) - bf16x2x2_u32 a[kInnerKTiles]; -#else bf16x2x4_u32 a[kInnerKTiles]; -#endif ALayout::template load( A, m, k, mTiles, mTile, kTiles, kTileBaseRemaining, laneId, a); @@ -821,29 +668,15 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( // We don't simply accumulate into `c` as this creates a too-strong // execution dependency. Instead, we only periodically accumulate into // `c` -#if defined(USE_ROCM) - VecT cTmp[2]; -#else float4 cTmp[2]; -#endif #pragma unroll for (int k = 0; k < 2; ++k) { -#if defined(USE_ROCM) - cTmp[k] = VecT{0.0f, 0.0f, 0.0f, 0.0f}; -#else cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f}; -#endif } #pragma unroll for (int k = 0; k < 2; ++k) { -#if defined(USE_ROCM) - cTmp[k] = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( - a[j * 2 + k].val, - b[0][(j * 2 + k) / 2].val[((j * 2 + k) % 2)], - cTmp[k], 0, 0, 0); -#else asm volatile( "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};" @@ -858,22 +691,14 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( "f"(cTmp[k].y), "f"(cTmp[k].z), "f"(cTmp[k].w)); -#endif } #pragma unroll for (int k = 0; k < 2; ++k) { -#if defined(USE_ROCM) - c[0] += cTmp[k][0]; - c[1] += cTmp[k][1]; - c[2] += cTmp[k][2]; - c[3] += cTmp[k][3]; -#else c.x += cTmp[k].x; c.y += cTmp[k].y; c.z += cTmp[k].z; c.w += cTmp[k].w; -#endif } } } @@ -886,14 +711,7 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( // FIXME: this likely doesn't need to be a true reduction tree, can just be a // serial sum, maybe (unless nvcc/ptxas goes back to its old ways) // smem_sum[warpId][laneId] = TreeReduce4::reduce(c); -#if defined(USE_ROCM) - smem_sum[warpId][laneId].x = c[0]; - smem_sum[warpId][laneId].y = c[1]; - smem_sum[warpId][laneId].z = c[2]; - smem_sum[warpId][laneId].w = c[3]; -#else smem_sum[warpId][laneId] = c; -#endif __syncthreads(); @@ -923,9 +741,6 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( laneId, sum_f32); } -#else - printf("__builtin_amdgcn_mfma_f32_16x16x16bf16_1k is only supported on AMD gpu arch greater than or equal to CDNA2\n"); -#endif } @@ -983,12 +798,7 @@ void launch_tinygemm_kernel( cudaFuncAttributes funcAttr; C10_CUDA_CHECK(cudaFuncGetAttributes( &funcAttr, -#if defined(USE_ROCM) - (void *)func -#else - func -#endif - )); + func)); } // FIXME: parallelize better, smem staging etc? @@ -1003,11 +813,7 @@ __global__ void matrix_to_m16n8k16_Bint4_layout( // innermost k-tiles that we can use is 2. static_assert(InnerKTiles >= 2 && isPowerOf2(InnerKTiles), ""); -#if defined(USE_ROCM) - constexpr int32_t kNTileSize = 16; -#else constexpr int32_t kNTileSize = 8; -#endif constexpr int32_t kKTileSize = 16; // gridDim.x corresponds to the number of k-tiles divided by InnerKTiles @@ -1019,30 +825,13 @@ __global__ void matrix_to_m16n8k16_Bint4_layout( #pragma unroll for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) { // n dimension that this lane loads from -#if defined(USE_ROCM) - auto n0 = nTile * kNTileSize + (t % kNTileSize); -#else auto n0 = nTile * kNTileSize + (t / 4); -#endif bool n0Valid = n0 < in.size(0); int32_t ks[8]; auto kBase0 = (kOuterTile * InnerKTiles + innerKTile) * kKTileSize; - -#if defined(USE_ROCM) - ks[0] = kBase0 + (t / kNTileSize) * 4; - ks[1] = ks[0] + 1; - ks[2] = ks[0] + 2; - ks[3] = ks[0] + 3; - - auto kBase1 = kBase0 + kKTileSize; - ks[4] = kBase1 + (t / kNTileSize) * 4; - ks[5] = ks[4] + 1; - ks[6] = ks[4] + 2; - ks[7] = ks[4] + 3; -#else ks[0] = kBase0 + (t % 4) * 2; ks[1] = ks[0] + 1; ks[2] = ks[0] + 8; @@ -1053,7 +842,6 @@ __global__ void matrix_to_m16n8k16_Bint4_layout( ks[5] = ks[4] + 1; ks[6] = ks[4] + 8; ks[7] = ks[4] + 8 + 1; -#endif auto pIn = &in[n0][0]; @@ -1067,19 +855,7 @@ __global__ void matrix_to_m16n8k16_Bint4_layout( (v[6] << 12) | (v[4] << 8) | (v[2] << 4) | v[0]; // inner k-tiles pack two at a time -#if defined(USE_ROCM) - // The output tensor shape is [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2], which is specific to Nvidia - // But AMD needs [ceil(n / 16)][ceil(k / (InnerKTiles * 16))][64][InnerKTiles / 2] - // So construct the pointer accordingly - auto bPtr = out.data() + - ((nTile * out.size(1) * kWarpSize * (InnerKTiles / 2)) + - (kOuterTile * kWarpSize * (InnerKTiles / 2)) + - (t * (InnerKTiles / 2)) + - (innerKTile / 2)); - *bPtr = pack; -#else out[nTile][kOuterTile][t][innerKTile / 2] = pack; -#endif } } @@ -1096,30 +872,16 @@ at::Tensor _weight_int4pack_mm_cuda( TORCH_CHECK( A.device() == B.device() && A.device() == qScaleAndZeros.device()); -#if defined(USE_ROCM) - if (!isCDNA2orLater(A.device().index())) { - TORCH_CHECK(false, "_weight_int4pack_mm_cuda is only supported on AMD gpu arch greater than or equal to CDNA2"); - } -#endif - constexpr int32_t kMTileSize = 16; -#if defined(USE_ROCM) - constexpr int32_t kNTileSize = 16; -#else constexpr int32_t kNTileSize = 8; -#endif constexpr int32_t kKTileSize = 16; // row major layout auto m = A.size(0); auto mTiles = divUp(m, kMTileSize); - // To convert the nTiles from tensor storage layout to the actual matrix core layout - constexpr int32_t kNTileSizeTensor = 8; - auto nTileScaleFactor = (kNTileSize / kNTileSizeTensor); - // tensor core layout - auto nTiles = (B.size(0) / nTileScaleFactor); + auto nTiles = B.size(0); auto n = nTiles * kNTileSize; // row major layout @@ -1142,7 +904,7 @@ at::Tensor _weight_int4pack_mm_cuda( TORCH_CHECK(B.is_contiguous()); TORCH_CHECK(B.dim() == 4); TORCH_CHECK(B.size(1) == k / (B_innerKTiles * kKTileSize)); - TORCH_CHECK(B.size(2) == 32); + TORCH_CHECK(B.size(2) == kWarpSize); // Validate the scale and zero point tensor for dequantization // These are the only versions handled at the moment @@ -1162,7 +924,7 @@ at::Tensor _weight_int4pack_mm_cuda( auto C_final = at::empty( {m, n}, at::TensorOptions().dtype(at::kBFloat16).device(A.device())); -#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) +#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) auto stream = at::cuda::getCurrentCUDAStream(); #define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE) \ do { \ @@ -1291,27 +1053,10 @@ at::Tensor _convert_weight_to_int4pack_cuda( // which is the maximum vectorized load/store size TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8); -#if defined(USE_ROCM) - if (!isCDNA2orLater(in.device().index())) { - TORCH_CHECK(false, "_convert_weight_to_int4pack_cuda is only supported on AMD gpu arch greater than or equal to CDNA2"); - } -#endif - -#if defined(USE_ROCM) - constexpr int32_t kNTileSize = 16; -#else constexpr int32_t kNTileSize = 8; -#endif constexpr int32_t kKTileSize = 16; - // GPT-FAST assumes nTileSize of 8 for quantized weight tensor. - // See https://github.com/pytorch-labs/gpt-fast/blob/091515ab5b06f91c0d6a3b92f9c27463f738cc9b/quantize.py#L510 - // Torch dynamo also requires the torch ops has the same output shape for each device. - // See https://github.com/pytorch/pytorch/blob/ec284d3a74ec1863685febd53687d491fd99a161/torch/_meta_registrations.py#L3263 - constexpr int32_t kNTileSizeTensor = 8; - auto nTiles = divUp(in.size(0), kNTileSize); - auto nTilesTensor = divUp(in.size(0), kNTileSizeTensor); // k-tiles are packed back to back in the innermost dimension in order to // allow for 4/8/16 byte loads @@ -1321,14 +1066,11 @@ at::Tensor _convert_weight_to_int4pack_cuda( // each block handles `innerKTiles` k-tiles. // 2 k-tiles are a single int32 - // - // We use the same shape for AMD gpus also to match the GPT-FAST spec. - // Will index it correctly when dereferencing the quantized weight tensor pointer. auto out = at::empty( - {nTilesTensor, kSuperTiles, 32, innerKTiles / 2}, + {nTiles, kSuperTiles, 32, innerKTiles / 2}, at::TensorOptions().dtype(at::kInt).device(in.device())); -#if (defined(USE_ROCM) && ROCM_VERSION >= 50700) || ((defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))) +#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) auto stream = at::cuda::getCurrentCUDAStream(); dim3 grid(kSuperTiles, nTiles); diff --git a/test/test_linalg.py b/test/test_linalg.py index e0ad1b2ede62..81db475f1e3a 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -31,7 +31,7 @@ from torch.testing._internal.common_dtype import ( floating_and_complex_types_and, floating_types_and, complex_types, ) from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \ - _get_torch_cuda_version, CDNA2OrLater + _get_torch_cuda_version from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel from torch.testing._internal.common_mkldnn import bf32_on_and_off from torch.distributions.binomial import Binomial @@ -6127,8 +6127,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2: self.skipTest("requires SM80 or later") if TEST_WITH_ROCM: - if not CDNA2OrLater(): - self.skipTest("_int4_mm is supported only for CDNA2 or later") + self.skipTest("_int4_mm not compiled for ROCM") q_group = 32 inner_k_tiles = 2 @@ -6176,8 +6175,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2: self.skipTest("requires SM80 or later") if TEST_WITH_ROCM: - if not CDNA2OrLater(): - self.skipTest("_int4_mm is supported only for CDNA2 or later") + self.skipTest("_int4_mm not compiled for ROCM") q_group = 32 inner_k_tiles = 2 diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 01eeac86ae13..7be663e21711 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -33,12 +33,6 @@ SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_devic IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() in [(7, 2), (8, 7)]) -def CDNA2OrLater(): - if TEST_WITH_ROCM: - gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName - return any(arch in gcn_arch_name for arch in {"gfx90a", "gfx940", "gfx941", "gfx942"}) - return False - def evaluate_gfx_arch_exact(matching_arch): if not torch.cuda.is_available(): return False diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index 034418afa46e..976e12e42d33 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -537,8 +537,6 @@ CUDA_TYPE_NAME_MAP = collections.OrderedDict( ("CUuuid", ("hipUUID", CONV_TYPE, API_RUNTIME)), ("cudaGraph_t", ("hipGraph_t", CONV_TYPE, API_RAND)), ("cudaGraphExec_t", ("hipGraphExec_t", CONV_TYPE, API_RAND)), - ("__nv_bfloat16", ("__hip_bfloat16", CONV_TYPE, API_RUNTIME)), - ("__nv_bfloat162", ("__hip_bfloat162", CONV_TYPE, API_RUNTIME)), ] )