Support DeepSeek-style blockwise scaling scaled-mm for fp8 on Hopper+ (#158037)

cuBLAS added support for them in CUDA 12.9. It's rather easy to call into them, the hardest thing is allowing the lhs and rhs operands to have different scaling types, as that changes the whole callstack.

The scaling format is still detected from the sizes of the scale tensors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158037
Approved by: https://github.com/eqy, https://github.com/drisspg
This commit is contained in:
Luca Wehrstedt
2025-07-24 15:25:09 +00:00
committed by PyTorch MergeBot
parent 0b2ef76e85
commit 5ab0eb28f7
8 changed files with 348 additions and 196 deletions

View File

@ -1843,6 +1843,69 @@ template bool gemm_and_bias(
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);
int get_scale_mode(ScalingType scaling_type, ScalarType scale_dtype, bool use_fast_accum) {
switch (scaling_type) {
case ScalingType::BlockWise1x32:
TORCH_CHECK(scale_dtype == kFloat8_e8m0fnu);
#if CUDA_VERSION >= 12080
return CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
#else
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales of 1x32 blocks is only supported for CUDA 12.8 and above");
#endif // if CUDA_VERSION >= 12080
case ScalingType::BlockWise1x16:
TORCH_CHECK(scale_dtype == kFloat8_e4m3fn);
#if CUDA_VERSION >= 12080
return CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
#else
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e4m3fn` scales of 1x16 blocks is only supported for CUDA 12.8 and above");
#endif // if CUDA_VERSION >= 12080
case ScalingType::RowWise:
TORCH_CHECK(scale_dtype == kFloat);
#if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC))
return CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F;
#elif defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT)
// Return the default, since in old hipblaslt this is activated via
// the SCALE_POINTER_VEC_EXT attributed.
return 0;
#else
TORCH_CHECK(false, "scaled_gemm with rowwise scaling is only supported for CUDA 12.9 and above");
#endif // if CUDA_VERSION >= 12090
case ScalingType::BlockWise1x128:
TORCH_CHECK(scale_dtype == kFloat);
TORCH_CHECK(!use_fast_accum, "scaled_gemm doesn't support fast accum with 1x128 blockwise scaling")
#if CUDA_VERSION >= 12090
return CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F;
#else
TORCH_CHECK(false, "scaled_gemm with 1x128 blockwise scaling is only supported for CUDA 12.9 and above");
#endif // if CUDA_VERSION >= 12090
case ScalingType::BlockWise128x128:
TORCH_CHECK(scale_dtype == kFloat);
TORCH_CHECK(!use_fast_accum, "scaled_gemm doesn't support fast accum with 128x128 blockwise scaling")
#if CUDA_VERSION >= 12090
return CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
#else
TORCH_CHECK(false, "scaled_gemm with 128x128 blockwise scaling is only supported for CUDA 12.9 and above");
#endif // if CUDA_VERSION >= 12090
case ScalingType::TensorWise:
TORCH_CHECK(scale_dtype == kFloat);
#if CUDA_VERSION >= 12080
return CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
#else
// The macro isn't defined, thus we inline its value.
return 0;
#endif // if CUDA_VERSION >= 12080
default:
TORCH_CHECK(false);
return -1;
}
}
void scaled_gemm(
char transa,
char transb,
@ -1854,19 +1917,20 @@ void scaled_gemm(
int64_t mat1_ld,
ScalarType mat1_dtype,
ScalarType mat1_scale_dtype,
ScalingType mat1_scaling_type,
const void* mat2_ptr,
const void* mat2_scale_ptr,
int64_t mat2_ld,
ScalarType mat2_dtype,
ScalarType mat2_scale_dtype,
ScalingType mat2_scaling_type,
const void* bias_ptr,
ScalarType bias_dtype,
void* result_ptr,
const void *result_scale_ptr,
int64_t result_ld,
ScalarType result_dtype,
bool use_fast_accum,
bool use_rowwise) {
bool use_fast_accum) {
// Note: see `cublasCommonArgs` for various non-intuitive manupulations
// of input arguments to this function.
#if CUDA_VERSION >= 11080 || defined(USE_ROCM)
@ -1879,19 +1943,15 @@ void scaled_gemm(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
cublasLtMatmulDescAttributes_t matmulDescA = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER;
cublasLtMatmulDescAttributes_t matmulDescB = CUBLASLT_MATMUL_DESC_B_SCALE_POINTER;
#if defined(USE_ROCM)
#if defined(HIPBLASLT_OUTER_VEC)
// this case is handled later as hipified CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F
#elif defined(HIPBLASLT_VEC_EXT)
if (use_rowwise) {
// hipblaslt supported row-wise before cublas, and did so their own way (via
// the SCALE_POINTERSs), but then migrated to match how cublas does it (via
// the SCALE_MODEs). Here we check for this early custom mode.
#if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
if (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise) {
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
}
#else
// rowwise isn't supported using older hipblaslt
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with older hipblaslt");
#endif
#endif // defined(USE_ROCM)
#endif // if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
computeDesc.setAttribute(matmulDescA, mat1_scale_ptr);
computeDesc.setAttribute(matmulDescB, mat2_scale_ptr);
if (result_scale_ptr != nullptr) {
@ -1931,30 +1991,14 @@ void scaled_gemm(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
}
if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) {
#if CUDA_VERSION >= 12080
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
#else
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 and above");
#endif // if CUDA_VERSION >= 12080
} else if (mat1_scale_dtype == kFloat8_e4m3fn && mat2_scale_dtype == kFloat8_e4m3fn) {
#if CUDA_VERSION >= 12080
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3);
#else
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e4m3fn` scales is only supported for CUDA 12.8 and above");
#endif // if CUDA_VERSION >= 12080
} else if (mat1_scale_dtype == kFloat && mat2_scale_dtype == kFloat && use_rowwise) {
#if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC))
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
#elif defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT)
// no-op here for older hipblaslt ext enums, to avoid TORCH_CHECK below
#else
TORCH_CHECK(false, "scaled_gemm with `torch.float` outer vector scaling is only supported for CUDA 12.9 and above");
#endif // if CUDA_VERSION >= 12090
}
// The SCALE_MODE attrs only exist in cuBLAS 12.8+ or in recent hipblaslt,
// but we must invoke get_scale_mode anyways to trigger the version checks.
[[maybe_unused]] int a_scale_mode = get_scale_mode(mat1_scaling_type, mat1_scale_dtype, use_fast_accum);
[[maybe_unused]] int b_scale_mode = get_scale_mode(mat2_scaling_type, mat2_scale_dtype, use_fast_accum);
#if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC))
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, a_scale_mode);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, b_scale_mode);
#endif
CuBlasLtMatmulPreference preference;
auto ltworkspace = CublasLtWorkspace();

View File

@ -136,6 +136,15 @@ void int8_gemm(
int32_t* result_ptr,
int64_t result_ld);
enum class ScalingType : std::uint8_t {
TensorWise, // fp32 scales
RowWise, // fp32 scales
BlockWise1x16, // fp8_e4m3fn scales
BlockWise1x32, // fp8_e8m0fnu scales
BlockWise1x128, // fp32 scales
BlockWise128x128, // fp32 scales
};
void scaled_gemm(
char transa,
char transb,
@ -147,19 +156,20 @@ void scaled_gemm(
int64_t mat1_ld,
ScalarType mat1_dtype,
ScalarType mat1_scale_dtype,
ScalingType mat1_scaling_type,
const void* mat2_ptr,
const void* mat2_scale_ptr,
int64_t mat2_ld,
ScalarType mat2_dtype,
ScalarType mat2_scale_dtype,
ScalingType mat2_scaling_type,
const void* bias_ptr,
ScalarType bias_dtype,
void* result_ptr,
const void* result_scale_ptr,
int64_t result_ld,
ScalarType result_dtype,
bool use_fast_accum,
bool use_rowwise);
bool use_fast_accum);
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype)

View File

@ -29,6 +29,8 @@
namespace at::cuda::tunable {
using at::cuda::blas::ScalingType;
enum class BlasOp {
N = 0,
T = 1
@ -598,7 +600,8 @@ struct ScaledGemmParams : OpParams {
//
// In TunableOp, we must distinguish in param signature these two cases: with and without a bias vector.
return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld_rw_%d_bias_%s",
transa, transb, m, n, k, lda, ldb, ldc, use_rowwise,
transa, transb, m, n, k, lda, ldb, ldc,
a_scaling_type == ScalingType::RowWise && b_scaling_type == ScalingType::RowWise,
bias_ptr == nullptr ? "None" : at::toString(bias_dtype));
}
@ -673,11 +676,13 @@ struct ScaledGemmParams : OpParams {
int64_t lda{};
ScalarType a_dtype{};
ScalarType a_scale_dtype{};
ScalingType a_scaling_type{};
const void* b{};
const void* b_scale_ptr{};
int64_t ldb{};
ScalarType b_dtype{};
ScalarType b_scale_dtype{};
ScalingType b_scaling_type{};
const void* bias_ptr{};
ScalarType bias_dtype{};
void* c{};
@ -686,7 +691,6 @@ struct ScaledGemmParams : OpParams {
ScalarType c_dtype{};
void* amax_ptr{};
bool use_fast_accum{};
bool use_rowwise{};
private:
bool duplicate_inputs_{false};
};

View File

@ -206,23 +206,43 @@ float GetBetaFromParams(const ScaledGemmParams<T>* params) {
}
template <typename T>
bool GetUseRowwiseFromParams(const GemmParams<T>* params) {
return false;
ScalingType GetAScalingTypeFromParams(const GemmParams<T>* params) {
return ScalingType::TensorWise;
}
template <typename T>
bool GetUseRowwiseFromParams(const GemmAndBiasParams<T>* params) {
return false;
ScalingType GetBScalingTypeFromParams(const GemmParams<T>* params) {
return ScalingType::TensorWise;
}
template <typename T>
bool GetUseRowwiseFromParams(const GemmStridedBatchedParams<T>* params) {
return false;
ScalingType GetAScalingTypeFromParams(const GemmAndBiasParams<T>* params) {
return ScalingType::TensorWise;
}
template <typename T>
bool GetUseRowwiseFromParams(const ScaledGemmParams<T>* params) {
return params->use_rowwise;
ScalingType GetBScalingTypeFromParams(const GemmAndBiasParams<T>* params) {
return ScalingType::TensorWise;
}
template <typename T>
ScalingType GetAScalingTypeFromParams(const GemmStridedBatchedParams<T>* params) {
return ScalingType::TensorWise;
}
template <typename T>
ScalingType GetBScalingTypeFromParams(const GemmStridedBatchedParams<T>* params) {
return ScalingType::TensorWise;
}
template <typename T>
ScalingType GetAScalingTypeFromParams(const ScaledGemmParams<T>* params) {
return params->a_scaling_type;
}
template <typename T>
ScalingType GetBScalingTypeFromParams(const ScaledGemmParams<T>* params) {
return params->b_scaling_type;
}
template <typename T>
@ -489,23 +509,24 @@ class HipblasltGemmOp : public Callable<ParamsT> {
const void* mat2_scale_ptr = GetBScalePointerFromParams<CT>(params);
const void* result_scale_ptr = GetDScalePointerFromParams<CT>(params);
if (mat1_scale_ptr && mat2_scale_ptr) {
#ifdef HIPBLASLT_VEC_EXT
if (GetUseRowwiseFromParams<CT>(params)) {
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT, mat1_scale_ptr);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT, mat2_scale_ptr);
}
else
#endif
{
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
}
#ifdef HIPBLASLT_OUTER_VEC
if (GetUseRowwiseFromParams<CT>(params)) {
hipblasLtMatmulDescAttributes_t a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER;
hipblasLtMatmulDescAttributes_t b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER;
if (GetAScalingTypeFromParams<CT>(params) == ScalingType::RowWise) {
#if defined(HIPBLASLT_OUTER_VEC)
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
}
#elif defined(HIPBLASLT_VEC_EXT)
a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
#endif
}
if (GetBScalingTypeFromParams<CT>(params) == ScalingType::RowWise) {
#if defined(HIPBLASLT_OUTER_VEC)
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
#elif defined(HIPBLASLT_VEC_EXT)
b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
#endif
}
matmul.setAttribute(a_scale_ptr_desc, mat1_scale_ptr);
matmul.setAttribute(b_scale_ptr_desc, mat2_scale_ptr);
}
if (result_scale_ptr) {
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);

View File

@ -96,19 +96,20 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
params->lda,
params->a_dtype,
params->a_scale_dtype,
params->a_scaling_type,
params->b,
params->b_scale_ptr,
params->ldb,
params->b_dtype,
params->b_scale_dtype,
params->b_scaling_type,
params->bias_ptr,
params->bias_dtype,
params->c,
params->c_scale_ptr,
params->ldc,
params->c_dtype,
params->use_fast_accum,
params->use_rowwise);
params->use_fast_accum);
return OK;
}
};

View File

@ -19,6 +19,7 @@
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
#include <ATen/native/cuda/GroupMM.h>
#include <ATen/ceil_div.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -99,6 +100,7 @@ c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, b
}
}
using at::cuda::blas::ScalingType;
/**
* @brief Prepares matrices for CUBLAS operation
@ -140,7 +142,9 @@ struct cublasCommonArgs {
Tensor& c,
const std::optional<Tensor>& scale_a = std::nullopt,
const std::optional<Tensor>& scale_b = std::nullopt,
const std::optional<Tensor>& scale_result = std::nullopt) {
const std::optional<Tensor>& scale_result = std::nullopt,
const std::optional<ScalingType>& scaling_choice_a = std::nullopt,
const std::optional<ScalingType>& scaling_choice_b = std::nullopt) {
bool transpose_result = false, transpose_a = false, transpose_b = false;
result = prepare_matrix_for_cublas(c, transpose_result);
mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_a, transpose_result);
@ -152,8 +156,10 @@ struct cublasCommonArgs {
// as B.T @ A.T, check transpose_result to determine if we flip the scales
scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr();
scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type();
scaling_mata_type = transpose_result ? scaling_choice_b : scaling_choice_a;
scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr();
scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type();
scaling_matb_type = transpose_result ? scaling_choice_a : scaling_choice_b;
}
if (scale_result) {
@ -199,7 +205,9 @@ struct cublasCommonArgs {
void* scale_matb_ptr = nullptr;
void* scale_result_ptr = nullptr;
std::optional<c10::ScalarType> scale_mata_dtype;
std::optional<ScalingType> scaling_mata_type;
std::optional<c10::ScalarType> scale_matb_dtype;
std::optional<ScalingType> scaling_matb_type;
std::optional<c10::ScalarType> scale_result_dtype;
};
} // namespace
@ -1075,133 +1083,114 @@ static bool _scaled_mm_is_fnuz() {
namespace{
enum class ScalingType : std::uint8_t {
TensorWise,
RowWise,
BlockWise,
Error
};
/*
* Scaling Type Determination:
* ---------------------------
* Conditions and corresponding Scaling Types:
*
* - If scale tensors are both `Float8_e8m0fnu` or `Float8_e4m3fn`:
* - If scale tensor is `Float8_e8m0fnu` or `Float8_e4m3fn`:
* - Returns BlockWise (with additional size checks).
*
* - If scale_a.numel() == 1 && scale_b.numel() == 1:
* - Else if scale.numel() == 1:
* - Returns TensorWise.
*
* - Else if scale_a.dim() == 2 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n:
* - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) == 1:
* - Returns RowWise.
*
* - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) == inner_dim / 128:
* - Returns BlockWise 1x128.
*
* - Else if scale.dim() == 2 && scale.size(0) == outer_dim / 128 && scale.size(1) == inner_dim / 128:
* - Returns BlockWise 128x128.
*
* - Otherwise:
* - Returns Error.
*/
// Validates the scale tensors to scaled_mm
// And returns the type of scaling/which kernel to use
ScalingType get_scaling_type(
const at::Tensor& scale_a,
const at::Tensor& scale_b,
int64_t dim_m,
int64_t dim_k,
int64_t dim_n) {
// Check for BlockWise scaling (FP8_E8M0 and FP8_E4M3 types)
if ((scale_a.scalar_type() == scale_b.scalar_type()) &&
((scale_a.scalar_type() == at::kFloat8_e8m0fnu) || (scale_a.scalar_type() == at::kFloat8_e4m3fn))) {
const bool is_nvfp4 = scale_a.scalar_type() == at::kFloat8_e4m3fn;
using at::cuda::blas::ScalingType;
// cuBLAS's mxfp8 gemm: block_size is 1 scale per 32 elements
// cuBLAS's nvfp4 gemm: block_size is 1 scale per 16 unpacked elements.
const auto BLOCK_SIZE_K = is_nvfp4 ? 16 : 32;
bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
return isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.numel() == 1;
}
constexpr int64_t BLOCK_SIZE_MN = 128;
bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2
&& scale.size(0) == t.size(0) && scale.size(1) == 1
&& scale.is_contiguous());
}
// adjust for fp4x2 packing if necessary
const auto dim_k_unpacked = is_nvfp4 ? dim_k * 2 : dim_k;
// 1x16 blocks for packed nvfp4 data and fp8_e4m3fn scales
bool is_blockwise_1x16_scaling(const at::Tensor& t, const at::Tensor& scale) {
// Multiply t.size(1) by 2 to adjust for fp4x2 packing
// TODO: We might want to enforce some structure on the shapes of the scale
// tensors
return (t.scalar_type() == ScalarType::Float4_e2m1fn_x2 && scale.scalar_type() == at::kFloat8_e4m3fn
&& scale.numel() == round_up<int64_t>(t.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(t.size(1) * 2, 16), 4)
&& scale.is_contiguous());
}
auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; };
auto num_k_blocks = ceil_div(dim_k_unpacked, BLOCK_SIZE_K);
auto padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4;
// 1x32 blocks for microscaled fp8 data and fp8_e8m0fnu scales
bool is_blockwise_1x32_scaling(const at::Tensor& t, const at::Tensor& scale) {
// TODO: We might want to enforce some structure on the shapes of the scale
// tensors
return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat8_e8m0fnu
&& scale.numel() == round_up<int64_t>(t.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(t.size(1), 32), 4)
&& scale.is_contiguous());
}
// TODO: We might want to enforce some structure on the shapes of the scale
// tensors
bool is_blockwise_1x128_scaling(const at::Tensor& t, const at::Tensor& scale) {
return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2
&& scale.size(0) == t.size(0) && scale.size(1) == ceil_div<int64_t>(t.size(1), 128)
&& scale.stride(0) == 1 && scale.stride(1) == t.size(0));
}
// Check expected sizes for block-wise scaling
auto expected_a_size =
BLOCK_SIZE_MN * ceil_div(dim_m, BLOCK_SIZE_MN) * padded_num_k_blocks;
auto expected_b_size =
BLOCK_SIZE_MN * ceil_div(dim_n, BLOCK_SIZE_MN) * padded_num_k_blocks;
bool is_blockwise_128x128_scaling(const at::Tensor& t, const at::Tensor& scale) {
return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2
&& scale.size(0) == ceil_div<int64_t>(t.size(0), 128) && scale.size(1) == ceil_div<int64_t>(t.size(1), 128)
&& scale.stride(0) == round_up<int64_t>(ceil_div<int64_t>(t.size(1), 128), 4) && scale.stride(1) == 1);
}
TORCH_CHECK(scale_a.numel() == expected_a_size,
"For BlockWise scaling: Expected scale_a size to be ",
expected_a_size, " but got ", scale_a.numel());
TORCH_CHECK(scale_b.numel() == expected_b_size,
"For BlockWise scaling: Expected scale_b size to be ",
expected_b_size, " but got ", scale_b.numel());
TORCH_CHECK(
scale_a.is_contiguous() && scale_b.is_contiguous(),
"For BlockWise scaling: Both scale_a and scale_b must be contiguous");
return ScalingType::BlockWise;
bool is_desired_scaling(const at::Tensor& t, const at::Tensor& scale, ScalingType desired_scaling) {
switch (desired_scaling) {
case ScalingType::TensorWise:
return is_tensorwise_scaling(t, scale);
case ScalingType::RowWise:
return is_rowwise_scaling(t, scale);
case ScalingType::BlockWise1x16:
return is_blockwise_1x16_scaling(t, scale);
case ScalingType::BlockWise1x32:
return is_blockwise_1x32_scaling(t, scale);
case ScalingType::BlockWise1x128:
return is_blockwise_1x128_scaling(t, scale);
case ScalingType::BlockWise128x128:
return is_blockwise_128x128_scaling(t, scale);
default:
TORCH_CHECK(false);
return false;
}
// Both Per-Tensor and Row-wise scaling expect fp32 tensors
TORCH_CHECK(
scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
"Both scale_a and scale_b must be float (fp32) tensors.");
}
// Check the singluar scale case for per-tensor scaling
if (scale_a.numel() == 1 && scale_b.numel() == 1) {
return ScalingType::TensorWise;
std::pair<ScalingType, ScalingType> get_joint_scaling(
std::initializer_list<std::pair<ScalingType, ScalingType>> options,
const at::Tensor& a, const at::Tensor& b,
const at::Tensor& scale_a, const at::Tensor& scale_b) {
for (auto [lhs, rhs] : options) {
if (is_desired_scaling(a, scale_a, lhs) && is_desired_scaling(b.t(), scale_b.t(), rhs)) {
return {lhs, rhs};
}
}
// For non-TensorWise scaling, enforce 2D input tensors
TORCH_CHECK(
scale_a.dim() == 2 && scale_b.dim() == 2,
"For non-TensorWise scaling, scale tensors must be 2-dimensional, "
"but got scale_a.dim()=",
scale_a.dim(),
" and scale_b.dim()=",
scale_b.dim());
// Check for RowWise scaling
if (scale_a.size(0) == dim_m && scale_a.size(1) == 1 &&
scale_b.size(0) == 1 && scale_b.size(1) == dim_n) {
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || \
(defined(USE_ROCM) && (defined(HIPBLASLT_VEC_EXT) || defined(HIPBLASLT_OUTER_VEC)))
TORCH_CHECK(
scale_a.is_contiguous() && scale_b.is_contiguous(),
"Both scale_a and scale_b must be contiguous for RowWise scaling.");
return ScalingType::RowWise;
#else
TORCH_CHECK(false, "Per-row scaling is not supported for this platform!");
return ScalingType::Error;
#endif
}
// If we reach here, the input doesn't match any valid scaling type
TORCH_CHECK(
false,
"Invalid scaling configuration. For TensorWise scaling, both scales should be scalar. "
"For RowWise scaling, scale_a should be (",
dim_m,
", 1) and scale_b should be (1, ",
dim_n,
"). "
"Got scale_a.size()=(",
scale_a.size(0),
", ",
scale_a.size(1),
") and ",
"scale_b.size()=(",
scale_b.size(0),
", ",
scale_b.size(1),
")");
return ScalingType::Error;
false,
"Invalid scaling configuration.\n"
"- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n"
"- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (", a.size(0), ", 1) and scale_b should be (1, ", b.size(1), "), and both should be contiguous.\n"
"- For BlockWise 1x128 scaling, a and b should be float8, scales should be float, scale_a should be (", a.size(0), ", ", ceil_div<int64_t>(a.size(1), 128), ") and scale_b should be (", ceil_div<int64_t>(b.size(0), 128), ", ", b.size(1), "), and both should be outer-dim-major.\n"
"- For BlockWise 128x128 scaling, a and b should be float8, scales should be float, scale_a should be (", ceil_div<int64_t>(a.size(0), 128), ", ", ceil_div<int64_t>(a.size(1), 128), ") and scale_b should be (", ceil_div<int64_t>(b.size(0), 128), ", ", ceil_div<int64_t>(b.size(1), 128), "), and both should be near-inner-dim-major (with 16-byte aligned strides).\n"
"- For Blockwise 1x32 scaling, a and b should be float8, scales should be float8_e8m0fnu, scale_a should have ", round_up<int64_t>(a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(a.size(1), 32), 4), " elements and scale_b should have ", round_up<int64_t>(b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(b.size(0), 32), 4), " elements, and both should be contiguous.\n"
"- For Blockwise 1x16 scaling, a and b should be float4 (packed 2x), scales should be float8_e4m3fn, scale_a should have ", round_up<int64_t>(a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(a.size(1) * 2, 16), 4), " elements and scale_b should have ", round_up<int64_t>(b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(b.size(0) * 2, 16), 4), " elements, and both should be contiguous.\n"
"Got a.dtype()=", a.scalar_type(), ", scale_a.dtype()=", scale_a.scalar_type(), ", scale_a.size()=", scale_a.sizes(), ", scale_a.stride()=", scale_a.strides(), ", ",
"b.dtype()=", b.scalar_type(), ", scale_b.dtype()=", scale_b.scalar_type(), ", scale_b.size()=", scale_b.sizes(), " and scale_b.stride()=", scale_b.strides()
);
}
} // namespace
@ -1219,8 +1208,8 @@ ScalingType get_scaling_type(
// - `mat2`: the second operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2`
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
// - `out_dtype`: the output dtype, can either be a float8 or a higher precision floating point type
// - `scale_a`: a scalar or 1-dimensional tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type
// - `scale_b`: a scalar or 1-dimensional tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose shape/strides/dtype depend on the scaling scheme
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose shape/strides/dtype depend on the scaling scheme
// - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type
// - `use_fast_accum`: if true, enables fast float8 accumulation
// - `out`: a reference to the output tensor
@ -1243,9 +1232,21 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
// Check what type of scaling we are doing based on inputs
ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat1.size(1), mat2.size(1));
TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported");
// Check what type of scaling we are doing based on inputs. This list is sorted
// by decreasing priority. We prefer "simpler" schemes as they are supported
// more broadly (more GPU archs, more CUDA versions) and because they are more
// efficient. This tends to matter only for small matmuls (e.g., 1x1x128).
auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling(
{
std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise),
std::make_pair(ScalingType::RowWise, ScalingType::RowWise),
std::make_pair(ScalingType::BlockWise128x128, ScalingType::BlockWise1x128),
std::make_pair(ScalingType::BlockWise1x128, ScalingType::BlockWise128x128),
std::make_pair(ScalingType::BlockWise1x128, ScalingType::BlockWise1x128),
std::make_pair(ScalingType::BlockWise1x32, ScalingType::BlockWise1x32),
std::make_pair(ScalingType::BlockWise1x16, ScalingType::BlockWise1x16)
},
mat1, mat2, scale_a, scale_b);
TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
"scale_result must be a float scalar");
@ -1316,7 +1317,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
#ifndef USE_ROCM
// We are doing row-wise scaling
auto dprops = at::cuda::getCurrentDeviceProperties();
if (scaling_choice == ScalingType::RowWise
if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise
&& (dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)) {
TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
at::cuda::detail::f8f8bf16_rowwise(
@ -1330,7 +1331,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
return out;
}
#else
if (scaling_choice == ScalingType::RowWise) {
if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise) {
// For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes.
Tensor b = mat2;
if (_scaled_mm_is_fnuz()) {
@ -1345,7 +1346,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
}
#endif
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result);
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result, scaling_choice_a, scaling_choice_b);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
@ -1422,10 +1423,14 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
params.a_scale_ptr = args.scale_mata_ptr;
params.lda = args.lda;
params.a_dtype = args.mata->scalar_type();
params.a_scale_dtype = args.scale_mata_dtype.value();
params.a_scaling_type = args.scaling_mata_type.value();
params.b = args.matb->data_ptr();
params.b_scale_ptr = args.scale_matb_ptr;
params.ldb = args.ldb;
params.b_dtype = args.matb->scalar_type();
params.b_scale_dtype = args.scale_matb_dtype.value();
params.b_scaling_type = args.scaling_matb_type.value();
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
params.c = args.result->data_ptr();
@ -1433,7 +1438,6 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
params.ldc = args.result_ld;
params.c_dtype = out_dtype_;
params.use_fast_accum = use_fast_accum;
params.use_rowwise = scaling_choice == ScalingType::RowWise;
if (transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
}
@ -1467,19 +1471,20 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
args.lda,
args.mata->scalar_type(),
args.scale_mata_dtype.value(),
args.scaling_mata_type.value(),
args.matb->data_ptr(),
args.scale_matb_ptr,
args.ldb,
args.matb->scalar_type(),
args.scale_matb_dtype.value(),
args.scaling_matb_type.value(),
bias ? bias->data_ptr(): nullptr,
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
args.result->data_ptr(),
args.scale_result_ptr,
args.result_ld,
out_dtype_,
use_fast_accum,
scaling_choice == ScalingType::RowWise);
use_fast_accum);
}
return out;

View File

@ -28,6 +28,7 @@ from torch.testing._internal.common_cuda import (
_get_torch_cuda_version,
PLATFORM_SUPPORTS_FP8,
PLATFORM_SUPPORTS_MX_GEMM,
IS_SM90,
)
from torch.testing._internal.common_device_type import (
dtypes,
@ -785,7 +786,7 @@ def amax_to_scale(
if float8_dtype == e4m3_type:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype == e5m2_type:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
@ -806,6 +807,20 @@ def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None):
return amax_to_scale(amax, float8_dtype, x.dtype)
def tensor_to_scale_block(
x: torch.Tensor,
float8_dtype: torch.dtype,
block_outer: int,
block_inner: int,
) -> tuple[torch.Tensor, torch.Tensor]:
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
scale = torch.finfo(float8_dtype).max / amax
x = x.mul(scale).to(float8_dtype)
x = x.flatten(2, 3).flatten(0, 1)
scale = scale.flatten(2, 3).flatten(0, 1)
return x, scale
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
# naive implementation: dq -> op -> q
x_fp32 = x.to(torch.float) / x_scale
@ -814,6 +829,17 @@ def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
return out_fp32.to(out_dtype)
def mm_float8_emulated_block(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
x = x.unflatten(1, (x_scale.shape[1], -1)).unflatten(0, (x_scale.shape[0], -1))
y = y.unflatten(1, (y_scale.shape[1], -1)).unflatten(0, (y_scale.shape[0], -1))
x_fp32 = x.to(torch.float) / x_scale[:, None, :, None]
y_fp32 = y.to(torch.float) / y_scale[:, None, :, None]
x_fp32 = x_fp32.flatten(2, 3).flatten(0, 1)
y_fp32 = y_fp32.flatten(2, 3).flatten(0, 1)
out_fp32 = torch.mm(x_fp32, y_fp32)
return out_fp32.to(out_dtype)
def addmm_float8_unwrapped(
a_data: torch.Tensor,
a_scale: torch.Tensor,
@ -1237,11 +1263,7 @@ class TestFP8Matmul(TestCase):
y_fp8 = y.to(e4m3_type).t()
with self.assertRaisesRegex(
RuntimeError,
re.escape(
"For RowWise scaling, scale_a should be (1024, 1) and scale_b "
"should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)"
),
RuntimeError, re.escape("Invalid scaling configuration")
):
torch._scaled_mm(
x_fp8,
@ -1252,11 +1274,7 @@ class TestFP8Matmul(TestCase):
)
with self.assertRaisesRegex(
RuntimeError,
re.escape(
" For RowWise scaling, scale_a should be (1024, 1) and scale_b "
"should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)"
),
RuntimeError, re.escape("Invalid scaling configuration")
):
torch._scaled_mm(
x_fp8,
@ -1266,22 +1284,18 @@ class TestFP8Matmul(TestCase):
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError,
re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"),
RuntimeError, re.escape("Invalid scaling configuration")
):
torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=torch.ones((M), device="cuda"),
scale_b=torch.ones((N, N), device="cuda"),
scale_b=torch.ones((N, N, 1), device="cuda"),
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError,
re.escape(
"Both scale_a and scale_b must be contiguous for RowWise scaling."
),
RuntimeError, re.escape("Invalid scaling configuration")
):
torch._scaled_mm(
x_fp8,
@ -1346,6 +1360,58 @@ class TestFP8Matmul(TestCase):
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not IS_SM90, "cuBLAS blockwise scaling requires sm90+")
@unittest.skipIf(
_get_torch_cuda_version() < (12, 9),
"cuBLAS blockwise scaling added in CUDA 12.9",
)
@parametrize("output_dtype", [torch.bfloat16, torch.float32])
@parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)])
def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_block):
torch.manual_seed(42)
x = torch.randn(256, 512, device="cuda", dtype=output_dtype).pow(3)
y = torch.randn(768, 512, device="cuda", dtype=output_dtype).pow(3)
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
# 1x128 blocks need scales to be outer-dim-major
if lhs_block == 1:
x_scales = x_scales.t().contiguous().t()
if rhs_block == 1:
y_scales = y_scales.t().contiguous().t()
# Calculate actual F8 mm
out_scaled_mm = mm_float8(
x_fp8, y_fp8.t(), a_scale=x_scales, b_scale=y_scales.t(), output_dtype=output_dtype
)
# Calculate emulated F8 mm
out_emulated = mm_float8_emulated_block(
x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype
)
cosine_sim = torch.nn.functional.cosine_similarity(
out_scaled_mm.flatten().float(), out_emulated.flatten().float(), dim=0
)
self.assertGreaterEqual(float(cosine_sim), 0.999)
if output_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 6e-1, 7e-2
else:
atol, rtol = 7e-1, 2e-3
self.assertEqual(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
# One last check against the full-precision reference, to ensure we
# didn't mess up the scaling itself and made the test trivial.
cosine_sim = torch.nn.functional.cosine_similarity(
out_scaled_mm.flatten().float(), (x @ y.t()).flatten().float(), dim=0
)
self.assertGreaterEqual(float(cosine_sim), 0.999)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("which_dim_zero", [0, 1, 2])
@parametrize("use_torch_compile", [False, True])

View File

@ -39,6 +39,7 @@ IS_THOR = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_ca
and torch.cuda.get_device_capability()[1] > 0)
IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and (torch.cuda.get_device_capability() in [(7, 2), (8, 7)] or IS_THOR))
IS_SM89 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (8, 9))
IS_SM90 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0))
def evaluate_gfx_arch_within(arch_list):
if not torch.cuda.is_available():