mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Support DeepSeek-style blockwise scaling scaled-mm for fp8 on Hopper+ (#158037)"
This reverts commit 39ac189808c61588f3594dbc2fc1d69bb6194c47. Reverted https://github.com/pytorch/pytorch/pull/158037 on behalf of https://github.com/jithunnair-amd due to Ignored ROCm failures while ROCm was unstable, but HUD clearly shows this PR introduced failures on trunk ([comment](https://github.com/pytorch/pytorch/pull/158037#issuecomment-3087982975))
This commit is contained in:
@ -7,15 +7,8 @@ namespace at {
|
||||
/**
|
||||
Computes ceil(a / b)
|
||||
*/
|
||||
template <
|
||||
typename Res = void,
|
||||
typename T,
|
||||
typename U,
|
||||
typename = std::enable_if_t<
|
||||
std::conjunction_v<std::is_integral<T>, std::is_integral<U>>>>
|
||||
C10_ALWAYS_INLINE C10_HOST_DEVICE
|
||||
std::conditional_t<std::is_same_v<Res, void>, std::common_type_t<T, U>, Res>
|
||||
ceil_div(T a, U b) {
|
||||
template <typename T, typename = std::enable_if_t<std::is_integral_v<T>>>
|
||||
C10_ALWAYS_INLINE C10_HOST_DEVICE T ceil_div(T a, T b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
@ -23,10 +16,8 @@ C10_ALWAYS_INLINE C10_HOST_DEVICE
|
||||
Computes ceil(a / b) * b; i.e., rounds up `a` to the next highest
|
||||
multiple of b
|
||||
*/
|
||||
template <typename Res = void, typename T, typename U>
|
||||
C10_ALWAYS_INLINE C10_HOST_DEVICE
|
||||
std::conditional_t<std::is_same_v<Res, void>, std::common_type_t<T, U>, Res>
|
||||
round_up(T a, U b) {
|
||||
template <typename T>
|
||||
C10_ALWAYS_INLINE C10_HOST_DEVICE T round_up(T a, T b) {
|
||||
return ceil_div(a, b) * b;
|
||||
}
|
||||
|
||||
|
||||
@ -1843,69 +1843,6 @@ template bool gemm_and_bias(
|
||||
int64_t result_ld,
|
||||
GEMMAndBiasActivationEpilogue activation);
|
||||
|
||||
int get_scale_mode(ScalingType scaling_type, ScalarType scale_dtype, bool use_fast_accum) {
|
||||
switch (scaling_type) {
|
||||
case ScalingType::BlockWise1x32:
|
||||
TORCH_CHECK(scale_dtype == kFloat8_e8m0fnu);
|
||||
#if CUDA_VERSION >= 12080
|
||||
return CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
|
||||
#else
|
||||
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales of 1x32 blocks is only supported for CUDA 12.8 and above");
|
||||
#endif // if CUDA_VERSION >= 12080
|
||||
|
||||
case ScalingType::BlockWise1x16:
|
||||
TORCH_CHECK(scale_dtype == kFloat8_e4m3fn);
|
||||
#if CUDA_VERSION >= 12080
|
||||
return CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
|
||||
#else
|
||||
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e4m3fn` scales of 1x16 blocks is only supported for CUDA 12.8 and above");
|
||||
#endif // if CUDA_VERSION >= 12080
|
||||
|
||||
case ScalingType::RowWise:
|
||||
TORCH_CHECK(scale_dtype == kFloat);
|
||||
#if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC))
|
||||
return CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F;
|
||||
#elif defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT)
|
||||
// Return the default, since in old hipblaslt this is activated via
|
||||
// the SCALE_POINTER_VEC_EXT attributed.
|
||||
return 0;
|
||||
#else
|
||||
TORCH_CHECK(false, "scaled_gemm with rowwise scaling is only supported for CUDA 12.9 and above");
|
||||
#endif // if CUDA_VERSION >= 12090
|
||||
|
||||
case ScalingType::BlockWise1x128:
|
||||
TORCH_CHECK(scale_dtype == kFloat);
|
||||
TORCH_CHECK(!use_fast_accum, "scaled_gemm doesn't support fast accum with 1x128 blockwise scaling")
|
||||
#if CUDA_VERSION >= 12090
|
||||
return CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F;
|
||||
#else
|
||||
TORCH_CHECK(false, "scaled_gemm with 1x128 blockwise scaling is only supported for CUDA 12.9 and above");
|
||||
#endif // if CUDA_VERSION >= 12090
|
||||
|
||||
case ScalingType::BlockWise128x128:
|
||||
TORCH_CHECK(scale_dtype == kFloat);
|
||||
TORCH_CHECK(!use_fast_accum, "scaled_gemm doesn't support fast accum with 128x128 blockwise scaling")
|
||||
#if CUDA_VERSION >= 12090
|
||||
return CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
|
||||
#else
|
||||
TORCH_CHECK(false, "scaled_gemm with 128x128 blockwise scaling is only supported for CUDA 12.9 and above");
|
||||
#endif // if CUDA_VERSION >= 12090
|
||||
|
||||
case ScalingType::TensorWise:
|
||||
TORCH_CHECK(scale_dtype == kFloat);
|
||||
#if CUDA_VERSION >= 12080
|
||||
return CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
|
||||
#else
|
||||
// The macro isn't defined, thus we inline its value.
|
||||
return 0;
|
||||
#endif // if CUDA_VERSION >= 12080
|
||||
|
||||
default:
|
||||
TORCH_CHECK(false);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
void scaled_gemm(
|
||||
char transa,
|
||||
char transb,
|
||||
@ -1917,20 +1854,19 @@ void scaled_gemm(
|
||||
int64_t mat1_ld,
|
||||
ScalarType mat1_dtype,
|
||||
ScalarType mat1_scale_dtype,
|
||||
ScalingType mat1_scaling_type,
|
||||
const void* mat2_ptr,
|
||||
const void* mat2_scale_ptr,
|
||||
int64_t mat2_ld,
|
||||
ScalarType mat2_dtype,
|
||||
ScalarType mat2_scale_dtype,
|
||||
ScalingType mat2_scaling_type,
|
||||
const void* bias_ptr,
|
||||
ScalarType bias_dtype,
|
||||
void* result_ptr,
|
||||
const void *result_scale_ptr,
|
||||
int64_t result_ld,
|
||||
ScalarType result_dtype,
|
||||
bool use_fast_accum) {
|
||||
bool use_fast_accum,
|
||||
bool use_rowwise) {
|
||||
// Note: see `cublasCommonArgs` for various non-intuitive manupulations
|
||||
// of input arguments to this function.
|
||||
#if CUDA_VERSION >= 11080 || defined(USE_ROCM)
|
||||
@ -1943,15 +1879,19 @@ void scaled_gemm(
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
|
||||
cublasLtMatmulDescAttributes_t matmulDescA = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER;
|
||||
cublasLtMatmulDescAttributes_t matmulDescB = CUBLASLT_MATMUL_DESC_B_SCALE_POINTER;
|
||||
// hipblaslt supported row-wise before cublas, and did so their own way (via
|
||||
// the SCALE_POINTERSs), but then migrated to match how cublas does it (via
|
||||
// the SCALE_MODEs). Here we check for this early custom mode.
|
||||
#if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
|
||||
if (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise) {
|
||||
#if defined(USE_ROCM)
|
||||
#if defined(HIPBLASLT_OUTER_VEC)
|
||||
// this case is handled later as hipified CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F
|
||||
#elif defined(HIPBLASLT_VEC_EXT)
|
||||
if (use_rowwise) {
|
||||
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
|
||||
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
|
||||
}
|
||||
#endif // if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
|
||||
#else
|
||||
// rowwise isn't supported using older hipblaslt
|
||||
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with older hipblaslt");
|
||||
#endif
|
||||
#endif // defined(USE_ROCM)
|
||||
computeDesc.setAttribute(matmulDescA, mat1_scale_ptr);
|
||||
computeDesc.setAttribute(matmulDescB, mat2_scale_ptr);
|
||||
if (result_scale_ptr != nullptr) {
|
||||
@ -1991,14 +1931,30 @@ void scaled_gemm(
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
|
||||
}
|
||||
|
||||
// The SCALE_MODE attrs only exist in cuBLAS 12.8+ or in recent hipblaslt,
|
||||
// but we must invoke get_scale_mode anyways to trigger the version checks.
|
||||
int a_scale_mode = get_scale_mode(mat1_scaling_type, mat1_scale_dtype, use_fast_accum);
|
||||
int b_scale_mode = get_scale_mode(mat2_scaling_type, mat2_scale_dtype, use_fast_accum);
|
||||
#if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC))
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, a_scale_mode);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, b_scale_mode);
|
||||
#endif
|
||||
if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) {
|
||||
#if CUDA_VERSION >= 12080
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
|
||||
#else
|
||||
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 and above");
|
||||
#endif // if CUDA_VERSION >= 12080
|
||||
} else if (mat1_scale_dtype == kFloat8_e4m3fn && mat2_scale_dtype == kFloat8_e4m3fn) {
|
||||
#if CUDA_VERSION >= 12080
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3);
|
||||
#else
|
||||
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e4m3fn` scales is only supported for CUDA 12.8 and above");
|
||||
#endif // if CUDA_VERSION >= 12080
|
||||
} else if (mat1_scale_dtype == kFloat && mat2_scale_dtype == kFloat && use_rowwise) {
|
||||
#if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC))
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
|
||||
#elif defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT)
|
||||
// no-op here for older hipblaslt ext enums, to avoid TORCH_CHECK below
|
||||
#else
|
||||
TORCH_CHECK(false, "scaled_gemm with `torch.float` outer vector scaling is only supported for CUDA 12.9 and above");
|
||||
#endif // if CUDA_VERSION >= 12090
|
||||
}
|
||||
|
||||
CuBlasLtMatmulPreference preference;
|
||||
auto ltworkspace = CublasLtWorkspace();
|
||||
|
||||
@ -136,15 +136,6 @@ void int8_gemm(
|
||||
int32_t* result_ptr,
|
||||
int64_t result_ld);
|
||||
|
||||
enum class ScalingType : std::uint8_t {
|
||||
TensorWise, // fp32 scales
|
||||
RowWise, // fp32 scales
|
||||
BlockWise1x16, // fp8_e4m3fn scales
|
||||
BlockWise1x32, // fp8_e8m0fnu scales
|
||||
BlockWise1x128, // fp32 scales
|
||||
BlockWise128x128, // fp32 scales
|
||||
};
|
||||
|
||||
void scaled_gemm(
|
||||
char transa,
|
||||
char transb,
|
||||
@ -156,20 +147,19 @@ void scaled_gemm(
|
||||
int64_t mat1_ld,
|
||||
ScalarType mat1_dtype,
|
||||
ScalarType mat1_scale_dtype,
|
||||
ScalingType mat1_scaling_type,
|
||||
const void* mat2_ptr,
|
||||
const void* mat2_scale_ptr,
|
||||
int64_t mat2_ld,
|
||||
ScalarType mat2_dtype,
|
||||
ScalarType mat2_scale_dtype,
|
||||
ScalingType mat2_scaling_type,
|
||||
const void* bias_ptr,
|
||||
ScalarType bias_dtype,
|
||||
void* result_ptr,
|
||||
const void* result_scale_ptr,
|
||||
int64_t result_ld,
|
||||
ScalarType result_dtype,
|
||||
bool use_fast_accum);
|
||||
bool use_fast_accum,
|
||||
bool use_rowwise);
|
||||
|
||||
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype)
|
||||
|
||||
|
||||
@ -29,8 +29,6 @@
|
||||
|
||||
namespace at::cuda::tunable {
|
||||
|
||||
using at::cuda::blas::ScalingType;
|
||||
|
||||
enum class BlasOp {
|
||||
N = 0,
|
||||
T = 1
|
||||
@ -600,8 +598,7 @@ struct ScaledGemmParams : OpParams {
|
||||
//
|
||||
// In TunableOp, we must distinguish in param signature these two cases: with and without a bias vector.
|
||||
return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld_rw_%d_bias_%s",
|
||||
transa, transb, m, n, k, lda, ldb, ldc,
|
||||
a_scaling_type == ScalingType::RowWise && b_scaling_type == ScalingType::RowWise,
|
||||
transa, transb, m, n, k, lda, ldb, ldc, use_rowwise,
|
||||
bias_ptr == nullptr ? "None" : at::toString(bias_dtype));
|
||||
}
|
||||
|
||||
@ -676,13 +673,11 @@ struct ScaledGemmParams : OpParams {
|
||||
int64_t lda{};
|
||||
ScalarType a_dtype{};
|
||||
ScalarType a_scale_dtype{};
|
||||
ScalingType a_scaling_type{};
|
||||
const void* b{};
|
||||
const void* b_scale_ptr{};
|
||||
int64_t ldb{};
|
||||
ScalarType b_dtype{};
|
||||
ScalarType b_scale_dtype{};
|
||||
ScalingType b_scaling_type{};
|
||||
const void* bias_ptr{};
|
||||
ScalarType bias_dtype{};
|
||||
void* c{};
|
||||
@ -691,6 +686,7 @@ struct ScaledGemmParams : OpParams {
|
||||
ScalarType c_dtype{};
|
||||
void* amax_ptr{};
|
||||
bool use_fast_accum{};
|
||||
bool use_rowwise{};
|
||||
private:
|
||||
bool duplicate_inputs_{false};
|
||||
};
|
||||
|
||||
@ -206,43 +206,23 @@ float GetBetaFromParams(const ScaledGemmParams<T>* params) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ScalingType GetAScalingTypeFromParams(const GemmParams<T>* params) {
|
||||
return ScalingType::TensorWise;
|
||||
bool GetUseRowwiseFromParams(const GemmParams<T>* params) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ScalingType GetBScalingTypeFromParams(const GemmParams<T>* params) {
|
||||
return ScalingType::TensorWise;
|
||||
bool GetUseRowwiseFromParams(const GemmAndBiasParams<T>* params) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ScalingType GetAScalingTypeFromParams(const GemmAndBiasParams<T>* params) {
|
||||
return ScalingType::TensorWise;
|
||||
bool GetUseRowwiseFromParams(const GemmStridedBatchedParams<T>* params) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ScalingType GetBScalingTypeFromParams(const GemmAndBiasParams<T>* params) {
|
||||
return ScalingType::TensorWise;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ScalingType GetAScalingTypeFromParams(const GemmStridedBatchedParams<T>* params) {
|
||||
return ScalingType::TensorWise;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ScalingType GetBScalingTypeFromParams(const GemmStridedBatchedParams<T>* params) {
|
||||
return ScalingType::TensorWise;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ScalingType GetAScalingTypeFromParams(const ScaledGemmParams<T>* params) {
|
||||
return params->a_scaling_type;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ScalingType GetBScalingTypeFromParams(const ScaledGemmParams<T>* params) {
|
||||
return params->b_scaling_type;
|
||||
bool GetUseRowwiseFromParams(const ScaledGemmParams<T>* params) {
|
||||
return params->use_rowwise;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -509,24 +489,23 @@ class HipblasltGemmOp : public Callable<ParamsT> {
|
||||
const void* mat2_scale_ptr = GetBScalePointerFromParams<CT>(params);
|
||||
const void* result_scale_ptr = GetDScalePointerFromParams<CT>(params);
|
||||
if (mat1_scale_ptr && mat2_scale_ptr) {
|
||||
hipblasLtMatmulDescAttributes_t a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER;
|
||||
hipblasLtMatmulDescAttributes_t b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER;
|
||||
if (GetAScalingTypeFromParams<CT>(params) == ScalingType::RowWise) {
|
||||
#if defined(HIPBLASLT_OUTER_VEC)
|
||||
#ifdef HIPBLASLT_VEC_EXT
|
||||
if (GetUseRowwiseFromParams<CT>(params)) {
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT, mat1_scale_ptr);
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT, mat2_scale_ptr);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
|
||||
}
|
||||
#ifdef HIPBLASLT_OUTER_VEC
|
||||
if (GetUseRowwiseFromParams<CT>(params)) {
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
|
||||
#elif defined(HIPBLASLT_VEC_EXT)
|
||||
a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
|
||||
#endif
|
||||
}
|
||||
if (GetBScalingTypeFromParams<CT>(params) == ScalingType::RowWise) {
|
||||
#if defined(HIPBLASLT_OUTER_VEC)
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
|
||||
#elif defined(HIPBLASLT_VEC_EXT)
|
||||
b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
|
||||
#endif
|
||||
}
|
||||
matmul.setAttribute(a_scale_ptr_desc, mat1_scale_ptr);
|
||||
matmul.setAttribute(b_scale_ptr_desc, mat2_scale_ptr);
|
||||
#endif
|
||||
}
|
||||
if (result_scale_ptr) {
|
||||
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
|
||||
|
||||
@ -96,20 +96,19 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
|
||||
params->lda,
|
||||
params->a_dtype,
|
||||
params->a_scale_dtype,
|
||||
params->a_scaling_type,
|
||||
params->b,
|
||||
params->b_scale_ptr,
|
||||
params->ldb,
|
||||
params->b_dtype,
|
||||
params->b_scale_dtype,
|
||||
params->b_scaling_type,
|
||||
params->bias_ptr,
|
||||
params->bias_dtype,
|
||||
params->c,
|
||||
params->c_scale_ptr,
|
||||
params->ldc,
|
||||
params->c_dtype,
|
||||
params->use_fast_accum);
|
||||
params->use_fast_accum,
|
||||
params->use_rowwise);
|
||||
return OK;
|
||||
}
|
||||
};
|
||||
|
||||
@ -19,7 +19,6 @@
|
||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
|
||||
#include <ATen/native/cuda/ScaledGroupMM.h>
|
||||
#include <ATen/native/cuda/GroupMM.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -100,7 +99,6 @@ c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, b
|
||||
}
|
||||
}
|
||||
|
||||
using at::cuda::blas::ScalingType;
|
||||
|
||||
/**
|
||||
* @brief Prepares matrices for CUBLAS operation
|
||||
@ -142,9 +140,7 @@ struct cublasCommonArgs {
|
||||
Tensor& c,
|
||||
const std::optional<Tensor>& scale_a = std::nullopt,
|
||||
const std::optional<Tensor>& scale_b = std::nullopt,
|
||||
const std::optional<Tensor>& scale_result = std::nullopt,
|
||||
const std::optional<ScalingType>& scaling_choice_a = std::nullopt,
|
||||
const std::optional<ScalingType>& scaling_choice_b = std::nullopt) {
|
||||
const std::optional<Tensor>& scale_result = std::nullopt) {
|
||||
bool transpose_result = false, transpose_a = false, transpose_b = false;
|
||||
result = prepare_matrix_for_cublas(c, transpose_result);
|
||||
mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_a, transpose_result);
|
||||
@ -156,10 +152,8 @@ struct cublasCommonArgs {
|
||||
// as B.T @ A.T, check transpose_result to determine if we flip the scales
|
||||
scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr();
|
||||
scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type();
|
||||
scaling_mata_type = transpose_result ? scaling_choice_b : scaling_choice_a;
|
||||
scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr();
|
||||
scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type();
|
||||
scaling_matb_type = transpose_result ? scaling_choice_a : scaling_choice_b;
|
||||
}
|
||||
|
||||
if (scale_result) {
|
||||
@ -205,9 +199,7 @@ struct cublasCommonArgs {
|
||||
void* scale_matb_ptr = nullptr;
|
||||
void* scale_result_ptr = nullptr;
|
||||
std::optional<c10::ScalarType> scale_mata_dtype;
|
||||
std::optional<ScalingType> scaling_mata_type;
|
||||
std::optional<c10::ScalarType> scale_matb_dtype;
|
||||
std::optional<ScalingType> scaling_matb_type;
|
||||
std::optional<c10::ScalarType> scale_result_dtype;
|
||||
};
|
||||
} // namespace
|
||||
@ -1083,114 +1075,133 @@ static bool _scaled_mm_is_fnuz() {
|
||||
|
||||
namespace{
|
||||
|
||||
enum class ScalingType : std::uint8_t {
|
||||
TensorWise,
|
||||
RowWise,
|
||||
BlockWise,
|
||||
Error
|
||||
};
|
||||
/*
|
||||
* Scaling Type Determination:
|
||||
* ---------------------------
|
||||
* Conditions and corresponding Scaling Types:
|
||||
*
|
||||
* - If scale tensor is `Float8_e8m0fnu` or `Float8_e4m3fn`:
|
||||
* - If scale tensors are both `Float8_e8m0fnu` or `Float8_e4m3fn`:
|
||||
* - Returns BlockWise (with additional size checks).
|
||||
*
|
||||
* - Else if scale.numel() == 1:
|
||||
* - If scale_a.numel() == 1 && scale_b.numel() == 1:
|
||||
* - Returns TensorWise.
|
||||
*
|
||||
* - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) == 1:
|
||||
* - Else if scale_a.dim() == 2 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n:
|
||||
* - Returns RowWise.
|
||||
*
|
||||
* - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) == inner_dim / 128:
|
||||
* - Returns BlockWise 1x128.
|
||||
*
|
||||
* - Else if scale.dim() == 2 && scale.size(0) == outer_dim / 128 && scale.size(1) == inner_dim / 128:
|
||||
* - Returns BlockWise 128x128.
|
||||
*
|
||||
* - Otherwise:
|
||||
* - Returns Error.
|
||||
*/
|
||||
|
||||
using at::cuda::blas::ScalingType;
|
||||
// Validates the scale tensors to scaled_mm
|
||||
// And returns the type of scaling/which kernel to use
|
||||
ScalingType get_scaling_type(
|
||||
const at::Tensor& scale_a,
|
||||
const at::Tensor& scale_b,
|
||||
int64_t dim_m,
|
||||
int64_t dim_k,
|
||||
int64_t dim_n) {
|
||||
// Check for BlockWise scaling (FP8_E8M0 and FP8_E4M3 types)
|
||||
if ((scale_a.scalar_type() == scale_b.scalar_type()) &&
|
||||
((scale_a.scalar_type() == at::kFloat8_e8m0fnu) || (scale_a.scalar_type() == at::kFloat8_e4m3fn))) {
|
||||
const bool is_nvfp4 = scale_a.scalar_type() == at::kFloat8_e4m3fn;
|
||||
|
||||
bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||
return isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.numel() == 1;
|
||||
}
|
||||
// cuBLAS's mxfp8 gemm: block_size is 1 scale per 32 elements
|
||||
// cuBLAS's nvfp4 gemm: block_size is 1 scale per 16 unpacked elements.
|
||||
const auto BLOCK_SIZE_K = is_nvfp4 ? 16 : 32;
|
||||
|
||||
bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||
return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2
|
||||
&& scale.size(0) == t.size(0) && scale.size(1) == 1
|
||||
&& scale.is_contiguous());
|
||||
}
|
||||
constexpr int64_t BLOCK_SIZE_MN = 128;
|
||||
|
||||
// 1x16 blocks for packed nvfp4 data and fp8_e4m3fn scales
|
||||
bool is_blockwise_1x16_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||
// Multiply t.size(1) by 2 to adjust for fp4x2 packing
|
||||
// TODO: We might want to enforce some structure on the shapes of the scale
|
||||
// tensors
|
||||
return (t.scalar_type() == ScalarType::Float4_e2m1fn_x2 && scale.scalar_type() == at::kFloat8_e4m3fn
|
||||
&& scale.numel() == round_up(t.size(0), 128) * round_up(ceil_div(t.size(1) * 2, 16), 4)
|
||||
&& scale.is_contiguous());
|
||||
}
|
||||
// adjust for fp4x2 packing if necessary
|
||||
const auto dim_k_unpacked = is_nvfp4 ? dim_k * 2 : dim_k;
|
||||
|
||||
// 1x32 blocks for microscaled fp8 data and fp8_e8m0fnu scales
|
||||
bool is_blockwise_1x32_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||
// TODO: We might want to enforce some structure on the shapes of the scale
|
||||
// tensors
|
||||
return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat8_e8m0fnu
|
||||
&& scale.numel() == round_up(t.size(0), 128) * round_up(ceil_div(t.size(1), 32), 4)
|
||||
&& scale.is_contiguous());
|
||||
}
|
||||
auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; };
|
||||
auto num_k_blocks = ceil_div(dim_k_unpacked, BLOCK_SIZE_K);
|
||||
auto padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4;
|
||||
|
||||
bool is_blockwise_1x128_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||
return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2
|
||||
&& scale.size(0) == t.size(0) && scale.size(1) == ceil_div(t.size(1), 128)
|
||||
&& scale.stride(0) == 1 && scale.stride(1) == t.size(0));
|
||||
}
|
||||
// TODO: We might want to enforce some structure on the shapes of the scale
|
||||
// tensors
|
||||
|
||||
bool is_blockwise_128x128_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||
return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2
|
||||
&& scale.size(0) == ceil_div(t.size(0), 128) && scale.size(1) == ceil_div(t.size(1), 128)
|
||||
&& scale.stride(0) == round_up(ceil_div(t.size(1), 128), 4) && scale.stride(1) == 1);
|
||||
}
|
||||
// Check expected sizes for block-wise scaling
|
||||
auto expected_a_size =
|
||||
BLOCK_SIZE_MN * ceil_div(dim_m, BLOCK_SIZE_MN) * padded_num_k_blocks;
|
||||
auto expected_b_size =
|
||||
BLOCK_SIZE_MN * ceil_div(dim_n, BLOCK_SIZE_MN) * padded_num_k_blocks;
|
||||
|
||||
bool is_desired_scaling(const at::Tensor& t, const at::Tensor& scale, ScalingType desired_scaling) {
|
||||
switch (desired_scaling) {
|
||||
case ScalingType::TensorWise:
|
||||
return is_tensorwise_scaling(t, scale);
|
||||
case ScalingType::RowWise:
|
||||
return is_rowwise_scaling(t, scale);
|
||||
case ScalingType::BlockWise1x16:
|
||||
return is_blockwise_1x16_scaling(t, scale);
|
||||
case ScalingType::BlockWise1x32:
|
||||
return is_blockwise_1x32_scaling(t, scale);
|
||||
case ScalingType::BlockWise1x128:
|
||||
return is_blockwise_1x128_scaling(t, scale);
|
||||
case ScalingType::BlockWise128x128:
|
||||
return is_blockwise_128x128_scaling(t, scale);
|
||||
default:
|
||||
TORCH_CHECK(false);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<ScalingType, ScalingType> get_joint_scaling(
|
||||
std::initializer_list<std::pair<ScalingType, ScalingType>> options,
|
||||
const at::Tensor& a, const at::Tensor& b,
|
||||
const at::Tensor& scale_a, const at::Tensor& scale_b) {
|
||||
for (auto [lhs, rhs] : options) {
|
||||
if (is_desired_scaling(a, scale_a, lhs) && is_desired_scaling(b.t(), scale_b.t(), rhs)) {
|
||||
return {lhs, rhs};
|
||||
}
|
||||
TORCH_CHECK(scale_a.numel() == expected_a_size,
|
||||
"For BlockWise scaling: Expected scale_a size to be ",
|
||||
expected_a_size, " but got ", scale_a.numel());
|
||||
TORCH_CHECK(scale_b.numel() == expected_b_size,
|
||||
"For BlockWise scaling: Expected scale_b size to be ",
|
||||
expected_b_size, " but got ", scale_b.numel());
|
||||
|
||||
TORCH_CHECK(
|
||||
scale_a.is_contiguous() && scale_b.is_contiguous(),
|
||||
"For BlockWise scaling: Both scale_a and scale_b must be contiguous");
|
||||
|
||||
return ScalingType::BlockWise;
|
||||
}
|
||||
// Both Per-Tensor and Row-wise scaling expect fp32 tensors
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Invalid scaling configuration.\n"
|
||||
"- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n"
|
||||
"- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (", a.size(0), ", 1) and scale_b should be (1, ", b.size(1), "), and both should be contiguous.\n"
|
||||
"- For BlockWise 1x128 scaling, a and b should be float8, scales should be float, scale_a should be (", a.size(0), ", ", ceil_div(a.size(1), 128), ") and scale_b should be (", ceil_div(b.size(0), 128), ", ", b.size(1), "), and both should be outer-dim-major.\n"
|
||||
"- For BlockWise 128x128 scaling, a and b should be float8, scales should be float, scale_a should be (", ceil_div(a.size(0), 128), ", ", ceil_div(a.size(1), 128), ") and scale_b should be (", ceil_div(b.size(0), 128), ", ", ceil_div(b.size(1), 128), "), and both should be near-inner-dim-major (with 16-byte aligned strides).\n"
|
||||
"- For Blockwise 1x32 scaling, a and b should be float8, scales should be float8_e8m0fnu, scale_a should have ", round_up(a.size(0), 128) * round_up(ceil_div(a.size(1), 32), 4), " elements and scale_b should have ", round_up(b.size(1), 128) * round_up(ceil_div(b.size(0), 32), 4), " elements, and both should be contiguous.\n"
|
||||
"- For Blockwise 1x16 scaling, a and b should be float4 (packed 2x), scales should be float8_e4m3fn, scale_a should have ", round_up(a.size(0), 128) * round_up(ceil_div(a.size(1) * 2, 16), 4), " elements and scale_b should have ", round_up(b.size(1), 128) * round_up(ceil_div(b.size(0) * 2, 16), 4), " elements, and both should be contiguous.\n"
|
||||
"Got a.dtype()=", a.scalar_type(), ", scale_a.dtype()=", scale_a.scalar_type(), ", scale_a.size()=", scale_a.sizes(), ", scale_a.stride()=", scale_a.strides(), ", ",
|
||||
"b.dtype()=", b.scalar_type(), ", scale_b.dtype()=", scale_b.scalar_type(), ", scale_b.size()=", scale_b.sizes(), " and scale_b.stride()=", scale_b.strides()
|
||||
);
|
||||
scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
|
||||
"Both scale_a and scale_b must be float (fp32) tensors.");
|
||||
|
||||
// Check the singluar scale case for per-tensor scaling
|
||||
if (scale_a.numel() == 1 && scale_b.numel() == 1) {
|
||||
return ScalingType::TensorWise;
|
||||
}
|
||||
|
||||
// For non-TensorWise scaling, enforce 2D input tensors
|
||||
TORCH_CHECK(
|
||||
scale_a.dim() == 2 && scale_b.dim() == 2,
|
||||
"For non-TensorWise scaling, scale tensors must be 2-dimensional, "
|
||||
"but got scale_a.dim()=",
|
||||
scale_a.dim(),
|
||||
" and scale_b.dim()=",
|
||||
scale_b.dim());
|
||||
|
||||
// Check for RowWise scaling
|
||||
if (scale_a.size(0) == dim_m && scale_a.size(1) == 1 &&
|
||||
scale_b.size(0) == 1 && scale_b.size(1) == dim_n) {
|
||||
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || \
|
||||
(defined(USE_ROCM) && (defined(HIPBLASLT_VEC_EXT) || defined(HIPBLASLT_OUTER_VEC)))
|
||||
TORCH_CHECK(
|
||||
scale_a.is_contiguous() && scale_b.is_contiguous(),
|
||||
"Both scale_a and scale_b must be contiguous for RowWise scaling.");
|
||||
return ScalingType::RowWise;
|
||||
#else
|
||||
TORCH_CHECK(false, "Per-row scaling is not supported for this platform!");
|
||||
return ScalingType::Error;
|
||||
#endif
|
||||
}
|
||||
|
||||
// If we reach here, the input doesn't match any valid scaling type
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Invalid scaling configuration. For TensorWise scaling, both scales should be scalar. "
|
||||
"For RowWise scaling, scale_a should be (",
|
||||
dim_m,
|
||||
", 1) and scale_b should be (1, ",
|
||||
dim_n,
|
||||
"). "
|
||||
"Got scale_a.size()=(",
|
||||
scale_a.size(0),
|
||||
", ",
|
||||
scale_a.size(1),
|
||||
") and ",
|
||||
"scale_b.size()=(",
|
||||
scale_b.size(0),
|
||||
", ",
|
||||
scale_b.size(1),
|
||||
")");
|
||||
|
||||
return ScalingType::Error;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -1208,8 +1219,8 @@ std::pair<ScalingType, ScalingType> get_joint_scaling(
|
||||
// - `mat2`: the second operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2`
|
||||
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
|
||||
// - `out_dtype`: the output dtype, can either be a float8 or a higher precision floating point type
|
||||
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose shape/strides/dtype depend on the scaling scheme
|
||||
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose shape/strides/dtype depend on the scaling scheme
|
||||
// - `scale_a`: a scalar or 1-dimensional tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type
|
||||
// - `scale_b`: a scalar or 1-dimensional tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type
|
||||
// - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type
|
||||
// - `use_fast_accum`: if true, enables fast float8 accumulation
|
||||
// - `out`: a reference to the output tensor
|
||||
@ -1232,21 +1243,9 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
|
||||
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
|
||||
|
||||
// Check what type of scaling we are doing based on inputs. This list is sorted
|
||||
// by decreasing priority. We prefer "simpler" schemes as they are supported
|
||||
// more broadly (more GPU archs, more CUDA versions) and because they are more
|
||||
// efficient. This tends to matter only for small matmuls (e.g., 1x1x128).
|
||||
auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling(
|
||||
{
|
||||
std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise),
|
||||
std::make_pair(ScalingType::RowWise, ScalingType::RowWise),
|
||||
std::make_pair(ScalingType::BlockWise128x128, ScalingType::BlockWise1x128),
|
||||
std::make_pair(ScalingType::BlockWise1x128, ScalingType::BlockWise128x128),
|
||||
std::make_pair(ScalingType::BlockWise1x128, ScalingType::BlockWise1x128),
|
||||
std::make_pair(ScalingType::BlockWise1x32, ScalingType::BlockWise1x32),
|
||||
std::make_pair(ScalingType::BlockWise1x16, ScalingType::BlockWise1x16)
|
||||
},
|
||||
mat1, mat2, scale_a, scale_b);
|
||||
// Check what type of scaling we are doing based on inputs
|
||||
ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat1.size(1), mat2.size(1));
|
||||
TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported");
|
||||
|
||||
TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
|
||||
"scale_result must be a float scalar");
|
||||
@ -1317,7 +1316,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
#ifndef USE_ROCM
|
||||
// We are doing row-wise scaling
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise
|
||||
if (scaling_choice == ScalingType::RowWise
|
||||
&& (dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)) {
|
||||
TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
|
||||
at::cuda::detail::f8f8bf16_rowwise(
|
||||
@ -1331,7 +1330,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
return out;
|
||||
}
|
||||
#else
|
||||
if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise) {
|
||||
if (scaling_choice == ScalingType::RowWise) {
|
||||
// For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes.
|
||||
Tensor b = mat2;
|
||||
if (_scaled_mm_is_fnuz()) {
|
||||
@ -1346,7 +1345,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
}
|
||||
#endif
|
||||
|
||||
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result, scaling_choice_a, scaling_choice_b);
|
||||
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result);
|
||||
const auto out_dtype_ = args.result->scalar_type();
|
||||
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
|
||||
|
||||
@ -1423,14 +1422,10 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
params.a_scale_ptr = args.scale_mata_ptr;
|
||||
params.lda = args.lda;
|
||||
params.a_dtype = args.mata->scalar_type();
|
||||
params.a_scale_dtype = args.scale_mata_dtype.value();
|
||||
params.a_scaling_type = args.scaling_mata_type.value();
|
||||
params.b = args.matb->data_ptr();
|
||||
params.b_scale_ptr = args.scale_matb_ptr;
|
||||
params.ldb = args.ldb;
|
||||
params.b_dtype = args.matb->scalar_type();
|
||||
params.b_scale_dtype = args.scale_matb_dtype.value();
|
||||
params.b_scaling_type = args.scaling_matb_type.value();
|
||||
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
|
||||
params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
|
||||
params.c = args.result->data_ptr();
|
||||
@ -1438,6 +1433,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
params.ldc = args.result_ld;
|
||||
params.c_dtype = out_dtype_;
|
||||
params.use_fast_accum = use_fast_accum;
|
||||
params.use_rowwise = scaling_choice == ScalingType::RowWise;
|
||||
if (transa_ && transb_) {
|
||||
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
|
||||
}
|
||||
@ -1471,20 +1467,19 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
args.lda,
|
||||
args.mata->scalar_type(),
|
||||
args.scale_mata_dtype.value(),
|
||||
args.scaling_mata_type.value(),
|
||||
args.matb->data_ptr(),
|
||||
args.scale_matb_ptr,
|
||||
args.ldb,
|
||||
args.matb->scalar_type(),
|
||||
args.scale_matb_dtype.value(),
|
||||
args.scaling_matb_type.value(),
|
||||
bias ? bias->data_ptr(): nullptr,
|
||||
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
|
||||
args.result->data_ptr(),
|
||||
args.scale_result_ptr,
|
||||
args.result_ld,
|
||||
out_dtype_,
|
||||
use_fast_accum);
|
||||
use_fast_accum,
|
||||
scaling_choice == ScalingType::RowWise);
|
||||
}
|
||||
|
||||
return out;
|
||||
|
||||
@ -785,7 +785,7 @@ def amax_to_scale(
|
||||
if float8_dtype == e4m3_type:
|
||||
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
|
||||
elif float8_dtype == e5m2_type:
|
||||
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
|
||||
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
|
||||
else:
|
||||
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
|
||||
|
||||
@ -806,20 +806,6 @@ def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None):
|
||||
|
||||
return amax_to_scale(amax, float8_dtype, x.dtype)
|
||||
|
||||
def tensor_to_scale_block(
|
||||
x: torch.Tensor,
|
||||
float8_dtype: torch.dtype,
|
||||
block_outer: int,
|
||||
block_inner: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
|
||||
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
|
||||
scale = torch.finfo(float8_dtype).max / amax
|
||||
x = x.mul(scale).to(float8_dtype)
|
||||
x = x.flatten(2, 3).flatten(0, 1)
|
||||
scale = scale.flatten(2, 3).flatten(0, 1)
|
||||
return x, scale
|
||||
|
||||
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
||||
# naive implementation: dq -> op -> q
|
||||
x_fp32 = x.to(torch.float) / x_scale
|
||||
@ -828,17 +814,6 @@ def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
||||
|
||||
return out_fp32.to(out_dtype)
|
||||
|
||||
def mm_float8_emulated_block(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
||||
x = x.unflatten(1, (x_scale.shape[1], -1)).unflatten(0, (x_scale.shape[0], -1))
|
||||
y = y.unflatten(1, (y_scale.shape[1], -1)).unflatten(0, (y_scale.shape[0], -1))
|
||||
x_fp32 = x.to(torch.float) / x_scale[:, None, :, None]
|
||||
y_fp32 = y.to(torch.float) / y_scale[:, None, :, None]
|
||||
x_fp32 = x_fp32.flatten(2, 3).flatten(0, 1)
|
||||
y_fp32 = y_fp32.flatten(2, 3).flatten(0, 1)
|
||||
out_fp32 = torch.mm(x_fp32, y_fp32)
|
||||
|
||||
return out_fp32.to(out_dtype)
|
||||
|
||||
def addmm_float8_unwrapped(
|
||||
a_data: torch.Tensor,
|
||||
a_scale: torch.Tensor,
|
||||
@ -1262,7 +1237,11 @@ class TestFP8Matmul(TestCase):
|
||||
y_fp8 = y.to(e4m3_type).t()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, re.escape("Invalid scaling configuration")
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
"For RowWise scaling, scale_a should be (1024, 1) and scale_b "
|
||||
"should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)"
|
||||
),
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
@ -1273,7 +1252,11 @@ class TestFP8Matmul(TestCase):
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, re.escape("Invalid scaling configuration")
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
" For RowWise scaling, scale_a should be (1024, 1) and scale_b "
|
||||
"should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)"
|
||||
),
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
@ -1283,18 +1266,22 @@ class TestFP8Matmul(TestCase):
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, re.escape("Invalid scaling configuration")
|
||||
RuntimeError,
|
||||
re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"),
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=torch.ones((M), device="cuda"),
|
||||
scale_b=torch.ones((N, N, 1), device="cuda"),
|
||||
scale_b=torch.ones((N, N), device="cuda"),
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, re.escape("Invalid scaling configuration")
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
"Both scale_a and scale_b must be contiguous for RowWise scaling."
|
||||
),
|
||||
):
|
||||
torch._scaled_mm(
|
||||
x_fp8,
|
||||
@ -1359,58 +1346,6 @@ class TestFP8Matmul(TestCase):
|
||||
|
||||
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not SM90OrLater, "cuBLAS blockwise scaling requires sm90+")
|
||||
@unittest.skipIf(
|
||||
_get_torch_cuda_version() < (12, 9),
|
||||
"cuBLAS blockwise scaling added in CUDA 12.9",
|
||||
)
|
||||
@parametrize("output_dtype", [torch.bfloat16, torch.float32])
|
||||
@parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)])
|
||||
def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_block):
|
||||
torch.manual_seed(42)
|
||||
|
||||
x = torch.randn(256, 512, device="cuda", dtype=output_dtype).pow(3)
|
||||
y = torch.randn(768, 512, device="cuda", dtype=output_dtype).pow(3)
|
||||
|
||||
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
|
||||
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
|
||||
|
||||
# 1x128 blocks need scales to be outer-dim-major
|
||||
if lhs_block == 1:
|
||||
x_scales = x_scales.t().contiguous().t()
|
||||
if rhs_block == 1:
|
||||
y_scales = y_scales.t().contiguous().t()
|
||||
|
||||
# Calculate actual F8 mm
|
||||
out_scaled_mm = mm_float8(
|
||||
x_fp8, y_fp8.t(), a_scale=x_scales, b_scale=y_scales.t(), output_dtype=output_dtype
|
||||
)
|
||||
|
||||
# Calculate emulated F8 mm
|
||||
out_emulated = mm_float8_emulated_block(
|
||||
x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype
|
||||
)
|
||||
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), out_emulated.flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
if output_dtype in {torch.bfloat16, torch.float16}:
|
||||
atol, rtol = 6e-1, 7e-2
|
||||
else:
|
||||
atol, rtol = 7e-1, 2e-3
|
||||
|
||||
self.assertEqual(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
|
||||
|
||||
# One last check against the full-precision reference, to ensure we
|
||||
# didn't mess up the scaling itself and made the test trivial.
|
||||
cosine_sim = torch.nn.functional.cosine_similarity(
|
||||
out_scaled_mm.flatten().float(), (x @ y.t()).flatten().float(), dim=0
|
||||
)
|
||||
self.assertGreaterEqual(float(cosine_sim), 0.999)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@parametrize("which_dim_zero", [0, 1, 2])
|
||||
@parametrize("use_torch_compile", [False, True])
|
||||
|
||||
Reference in New Issue
Block a user