Revert "Support DeepSeek-style blockwise scaling scaled-mm for fp8 on Hopper+ (#158037)"

This reverts commit 39ac189808c61588f3594dbc2fc1d69bb6194c47.

Reverted https://github.com/pytorch/pytorch/pull/158037 on behalf of https://github.com/jithunnair-amd due to Ignored ROCm failures while ROCm was unstable, but HUD clearly shows this PR introduced failures on trunk ([comment](https://github.com/pytorch/pytorch/pull/158037#issuecomment-3087982975))
This commit is contained in:
PyTorch MergeBot
2025-07-18 07:47:46 +00:00
parent be896d6b41
commit 32aade9d8d
8 changed files with 201 additions and 360 deletions

View File

@ -7,15 +7,8 @@ namespace at {
/**
Computes ceil(a / b)
*/
template <
typename Res = void,
typename T,
typename U,
typename = std::enable_if_t<
std::conjunction_v<std::is_integral<T>, std::is_integral<U>>>>
C10_ALWAYS_INLINE C10_HOST_DEVICE
std::conditional_t<std::is_same_v<Res, void>, std::common_type_t<T, U>, Res>
ceil_div(T a, U b) {
template <typename T, typename = std::enable_if_t<std::is_integral_v<T>>>
C10_ALWAYS_INLINE C10_HOST_DEVICE T ceil_div(T a, T b) {
return (a + b - 1) / b;
}
@ -23,10 +16,8 @@ C10_ALWAYS_INLINE C10_HOST_DEVICE
Computes ceil(a / b) * b; i.e., rounds up `a` to the next highest
multiple of b
*/
template <typename Res = void, typename T, typename U>
C10_ALWAYS_INLINE C10_HOST_DEVICE
std::conditional_t<std::is_same_v<Res, void>, std::common_type_t<T, U>, Res>
round_up(T a, U b) {
template <typename T>
C10_ALWAYS_INLINE C10_HOST_DEVICE T round_up(T a, T b) {
return ceil_div(a, b) * b;
}

View File

@ -1843,69 +1843,6 @@ template bool gemm_and_bias(
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);
int get_scale_mode(ScalingType scaling_type, ScalarType scale_dtype, bool use_fast_accum) {
switch (scaling_type) {
case ScalingType::BlockWise1x32:
TORCH_CHECK(scale_dtype == kFloat8_e8m0fnu);
#if CUDA_VERSION >= 12080
return CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
#else
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales of 1x32 blocks is only supported for CUDA 12.8 and above");
#endif // if CUDA_VERSION >= 12080
case ScalingType::BlockWise1x16:
TORCH_CHECK(scale_dtype == kFloat8_e4m3fn);
#if CUDA_VERSION >= 12080
return CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
#else
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e4m3fn` scales of 1x16 blocks is only supported for CUDA 12.8 and above");
#endif // if CUDA_VERSION >= 12080
case ScalingType::RowWise:
TORCH_CHECK(scale_dtype == kFloat);
#if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC))
return CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F;
#elif defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT)
// Return the default, since in old hipblaslt this is activated via
// the SCALE_POINTER_VEC_EXT attributed.
return 0;
#else
TORCH_CHECK(false, "scaled_gemm with rowwise scaling is only supported for CUDA 12.9 and above");
#endif // if CUDA_VERSION >= 12090
case ScalingType::BlockWise1x128:
TORCH_CHECK(scale_dtype == kFloat);
TORCH_CHECK(!use_fast_accum, "scaled_gemm doesn't support fast accum with 1x128 blockwise scaling")
#if CUDA_VERSION >= 12090
return CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F;
#else
TORCH_CHECK(false, "scaled_gemm with 1x128 blockwise scaling is only supported for CUDA 12.9 and above");
#endif // if CUDA_VERSION >= 12090
case ScalingType::BlockWise128x128:
TORCH_CHECK(scale_dtype == kFloat);
TORCH_CHECK(!use_fast_accum, "scaled_gemm doesn't support fast accum with 128x128 blockwise scaling")
#if CUDA_VERSION >= 12090
return CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F;
#else
TORCH_CHECK(false, "scaled_gemm with 128x128 blockwise scaling is only supported for CUDA 12.9 and above");
#endif // if CUDA_VERSION >= 12090
case ScalingType::TensorWise:
TORCH_CHECK(scale_dtype == kFloat);
#if CUDA_VERSION >= 12080
return CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F;
#else
// The macro isn't defined, thus we inline its value.
return 0;
#endif // if CUDA_VERSION >= 12080
default:
TORCH_CHECK(false);
return -1;
}
}
void scaled_gemm(
char transa,
char transb,
@ -1917,20 +1854,19 @@ void scaled_gemm(
int64_t mat1_ld,
ScalarType mat1_dtype,
ScalarType mat1_scale_dtype,
ScalingType mat1_scaling_type,
const void* mat2_ptr,
const void* mat2_scale_ptr,
int64_t mat2_ld,
ScalarType mat2_dtype,
ScalarType mat2_scale_dtype,
ScalingType mat2_scaling_type,
const void* bias_ptr,
ScalarType bias_dtype,
void* result_ptr,
const void *result_scale_ptr,
int64_t result_ld,
ScalarType result_dtype,
bool use_fast_accum) {
bool use_fast_accum,
bool use_rowwise) {
// Note: see `cublasCommonArgs` for various non-intuitive manupulations
// of input arguments to this function.
#if CUDA_VERSION >= 11080 || defined(USE_ROCM)
@ -1943,15 +1879,19 @@ void scaled_gemm(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb));
cublasLtMatmulDescAttributes_t matmulDescA = CUBLASLT_MATMUL_DESC_A_SCALE_POINTER;
cublasLtMatmulDescAttributes_t matmulDescB = CUBLASLT_MATMUL_DESC_B_SCALE_POINTER;
// hipblaslt supported row-wise before cublas, and did so their own way (via
// the SCALE_POINTERSs), but then migrated to match how cublas does it (via
// the SCALE_MODEs). Here we check for this early custom mode.
#if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
if (mat1_scaling_type == ScalingType::RowWise && mat2_scaling_type == ScalingType::RowWise) {
#if defined(USE_ROCM)
#if defined(HIPBLASLT_OUTER_VEC)
// this case is handled later as hipified CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F
#elif defined(HIPBLASLT_VEC_EXT)
if (use_rowwise) {
matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
}
#endif // if defined(USE_ROCM) && !defined(HIPBLASLT_OUTER_VEC) && defined(HIPBLASLT_VEC_EXT)
#else
// rowwise isn't supported using older hipblaslt
TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with older hipblaslt");
#endif
#endif // defined(USE_ROCM)
computeDesc.setAttribute(matmulDescA, mat1_scale_ptr);
computeDesc.setAttribute(matmulDescB, mat2_scale_ptr);
if (result_scale_ptr != nullptr) {
@ -1991,14 +1931,30 @@ void scaled_gemm(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
}
// The SCALE_MODE attrs only exist in cuBLAS 12.8+ or in recent hipblaslt,
// but we must invoke get_scale_mode anyways to trigger the version checks.
int a_scale_mode = get_scale_mode(mat1_scaling_type, mat1_scale_dtype, use_fast_accum);
int b_scale_mode = get_scale_mode(mat2_scaling_type, mat2_scale_dtype, use_fast_accum);
#if CUDA_VERSION >= 12080 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC))
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, a_scale_mode);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, b_scale_mode);
#endif
if (mat1_scale_dtype == kFloat8_e8m0fnu && mat2_scale_dtype == kFloat8_e8m0fnu) {
#if CUDA_VERSION >= 12080
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0);
#else
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e8m0fnu` scales is only supported for CUDA 12.8 and above");
#endif // if CUDA_VERSION >= 12080
} else if (mat1_scale_dtype == kFloat8_e4m3fn && mat2_scale_dtype == kFloat8_e4m3fn) {
#if CUDA_VERSION >= 12080
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3);
#else
TORCH_CHECK(false, "scaled_gemm with `torch.float8_e4m3fn` scales is only supported for CUDA 12.8 and above");
#endif // if CUDA_VERSION >= 12080
} else if (mat1_scale_dtype == kFloat && mat2_scale_dtype == kFloat && use_rowwise) {
#if CUDA_VERSION >= 12090 || (defined(USE_ROCM) && defined(HIPBLASLT_OUTER_VEC))
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_MODE, CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
#elif defined(USE_ROCM) && defined(HIPBLASLT_VEC_EXT)
// no-op here for older hipblaslt ext enums, to avoid TORCH_CHECK below
#else
TORCH_CHECK(false, "scaled_gemm with `torch.float` outer vector scaling is only supported for CUDA 12.9 and above");
#endif // if CUDA_VERSION >= 12090
}
CuBlasLtMatmulPreference preference;
auto ltworkspace = CublasLtWorkspace();

View File

@ -136,15 +136,6 @@ void int8_gemm(
int32_t* result_ptr,
int64_t result_ld);
enum class ScalingType : std::uint8_t {
TensorWise, // fp32 scales
RowWise, // fp32 scales
BlockWise1x16, // fp8_e4m3fn scales
BlockWise1x32, // fp8_e8m0fnu scales
BlockWise1x128, // fp32 scales
BlockWise128x128, // fp32 scales
};
void scaled_gemm(
char transa,
char transb,
@ -156,20 +147,19 @@ void scaled_gemm(
int64_t mat1_ld,
ScalarType mat1_dtype,
ScalarType mat1_scale_dtype,
ScalingType mat1_scaling_type,
const void* mat2_ptr,
const void* mat2_scale_ptr,
int64_t mat2_ld,
ScalarType mat2_dtype,
ScalarType mat2_scale_dtype,
ScalingType mat2_scaling_type,
const void* bias_ptr,
ScalarType bias_dtype,
void* result_ptr,
const void* result_scale_ptr,
int64_t result_ld,
ScalarType result_dtype,
bool use_fast_accum);
bool use_fast_accum,
bool use_rowwise);
#define CUDABLAS_BGEMM_ARGTYPES(Dtype) CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(Dtype, Dtype)

View File

@ -29,8 +29,6 @@
namespace at::cuda::tunable {
using at::cuda::blas::ScalingType;
enum class BlasOp {
N = 0,
T = 1
@ -600,8 +598,7 @@ struct ScaledGemmParams : OpParams {
//
// In TunableOp, we must distinguish in param signature these two cases: with and without a bias vector.
return fmt::sprintf("%c%c_%ld_%ld_%ld_ld_%ld_%ld_%ld_rw_%d_bias_%s",
transa, transb, m, n, k, lda, ldb, ldc,
a_scaling_type == ScalingType::RowWise && b_scaling_type == ScalingType::RowWise,
transa, transb, m, n, k, lda, ldb, ldc, use_rowwise,
bias_ptr == nullptr ? "None" : at::toString(bias_dtype));
}
@ -676,13 +673,11 @@ struct ScaledGemmParams : OpParams {
int64_t lda{};
ScalarType a_dtype{};
ScalarType a_scale_dtype{};
ScalingType a_scaling_type{};
const void* b{};
const void* b_scale_ptr{};
int64_t ldb{};
ScalarType b_dtype{};
ScalarType b_scale_dtype{};
ScalingType b_scaling_type{};
const void* bias_ptr{};
ScalarType bias_dtype{};
void* c{};
@ -691,6 +686,7 @@ struct ScaledGemmParams : OpParams {
ScalarType c_dtype{};
void* amax_ptr{};
bool use_fast_accum{};
bool use_rowwise{};
private:
bool duplicate_inputs_{false};
};

View File

@ -206,43 +206,23 @@ float GetBetaFromParams(const ScaledGemmParams<T>* params) {
}
template <typename T>
ScalingType GetAScalingTypeFromParams(const GemmParams<T>* params) {
return ScalingType::TensorWise;
bool GetUseRowwiseFromParams(const GemmParams<T>* params) {
return false;
}
template <typename T>
ScalingType GetBScalingTypeFromParams(const GemmParams<T>* params) {
return ScalingType::TensorWise;
bool GetUseRowwiseFromParams(const GemmAndBiasParams<T>* params) {
return false;
}
template <typename T>
ScalingType GetAScalingTypeFromParams(const GemmAndBiasParams<T>* params) {
return ScalingType::TensorWise;
bool GetUseRowwiseFromParams(const GemmStridedBatchedParams<T>* params) {
return false;
}
template <typename T>
ScalingType GetBScalingTypeFromParams(const GemmAndBiasParams<T>* params) {
return ScalingType::TensorWise;
}
template <typename T>
ScalingType GetAScalingTypeFromParams(const GemmStridedBatchedParams<T>* params) {
return ScalingType::TensorWise;
}
template <typename T>
ScalingType GetBScalingTypeFromParams(const GemmStridedBatchedParams<T>* params) {
return ScalingType::TensorWise;
}
template <typename T>
ScalingType GetAScalingTypeFromParams(const ScaledGemmParams<T>* params) {
return params->a_scaling_type;
}
template <typename T>
ScalingType GetBScalingTypeFromParams(const ScaledGemmParams<T>* params) {
return params->b_scaling_type;
bool GetUseRowwiseFromParams(const ScaledGemmParams<T>* params) {
return params->use_rowwise;
}
template <typename T>
@ -509,24 +489,23 @@ class HipblasltGemmOp : public Callable<ParamsT> {
const void* mat2_scale_ptr = GetBScalePointerFromParams<CT>(params);
const void* result_scale_ptr = GetDScalePointerFromParams<CT>(params);
if (mat1_scale_ptr && mat2_scale_ptr) {
hipblasLtMatmulDescAttributes_t a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER;
hipblasLtMatmulDescAttributes_t b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER;
if (GetAScalingTypeFromParams<CT>(params) == ScalingType::RowWise) {
#if defined(HIPBLASLT_OUTER_VEC)
#ifdef HIPBLASLT_VEC_EXT
if (GetUseRowwiseFromParams<CT>(params)) {
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT, mat1_scale_ptr);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT, mat2_scale_ptr);
}
else
#endif
{
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
}
#ifdef HIPBLASLT_OUTER_VEC
if (GetUseRowwiseFromParams<CT>(params)) {
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
#elif defined(HIPBLASLT_VEC_EXT)
a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT;
#endif
}
if (GetBScalingTypeFromParams<CT>(params) == ScalingType::RowWise) {
#if defined(HIPBLASLT_OUTER_VEC)
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F);
#elif defined(HIPBLASLT_VEC_EXT)
b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT;
#endif
}
matmul.setAttribute(a_scale_ptr_desc, mat1_scale_ptr);
matmul.setAttribute(b_scale_ptr_desc, mat2_scale_ptr);
#endif
}
if (result_scale_ptr) {
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);

View File

@ -96,20 +96,19 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
params->lda,
params->a_dtype,
params->a_scale_dtype,
params->a_scaling_type,
params->b,
params->b_scale_ptr,
params->ldb,
params->b_dtype,
params->b_scale_dtype,
params->b_scaling_type,
params->bias_ptr,
params->bias_dtype,
params->c,
params->c_scale_ptr,
params->ldc,
params->c_dtype,
params->use_fast_accum);
params->use_fast_accum,
params->use_rowwise);
return OK;
}
};

View File

@ -19,7 +19,6 @@
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
#include <ATen/native/cuda/GroupMM.h>
#include <ATen/ceil_div.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -100,7 +99,6 @@ c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, b
}
}
using at::cuda::blas::ScalingType;
/**
* @brief Prepares matrices for CUBLAS operation
@ -142,9 +140,7 @@ struct cublasCommonArgs {
Tensor& c,
const std::optional<Tensor>& scale_a = std::nullopt,
const std::optional<Tensor>& scale_b = std::nullopt,
const std::optional<Tensor>& scale_result = std::nullopt,
const std::optional<ScalingType>& scaling_choice_a = std::nullopt,
const std::optional<ScalingType>& scaling_choice_b = std::nullopt) {
const std::optional<Tensor>& scale_result = std::nullopt) {
bool transpose_result = false, transpose_a = false, transpose_b = false;
result = prepare_matrix_for_cublas(c, transpose_result);
mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_a, transpose_result);
@ -156,10 +152,8 @@ struct cublasCommonArgs {
// as B.T @ A.T, check transpose_result to determine if we flip the scales
scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr();
scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type();
scaling_mata_type = transpose_result ? scaling_choice_b : scaling_choice_a;
scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr();
scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type();
scaling_matb_type = transpose_result ? scaling_choice_a : scaling_choice_b;
}
if (scale_result) {
@ -205,9 +199,7 @@ struct cublasCommonArgs {
void* scale_matb_ptr = nullptr;
void* scale_result_ptr = nullptr;
std::optional<c10::ScalarType> scale_mata_dtype;
std::optional<ScalingType> scaling_mata_type;
std::optional<c10::ScalarType> scale_matb_dtype;
std::optional<ScalingType> scaling_matb_type;
std::optional<c10::ScalarType> scale_result_dtype;
};
} // namespace
@ -1083,114 +1075,133 @@ static bool _scaled_mm_is_fnuz() {
namespace{
enum class ScalingType : std::uint8_t {
TensorWise,
RowWise,
BlockWise,
Error
};
/*
* Scaling Type Determination:
* ---------------------------
* Conditions and corresponding Scaling Types:
*
* - If scale tensor is `Float8_e8m0fnu` or `Float8_e4m3fn`:
* - If scale tensors are both `Float8_e8m0fnu` or `Float8_e4m3fn`:
* - Returns BlockWise (with additional size checks).
*
* - Else if scale.numel() == 1:
* - If scale_a.numel() == 1 && scale_b.numel() == 1:
* - Returns TensorWise.
*
* - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) == 1:
* - Else if scale_a.dim() == 2 && scale_a.size(0) == dim_m && scale_b.size(0) == dim_n:
* - Returns RowWise.
*
* - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) == inner_dim / 128:
* - Returns BlockWise 1x128.
*
* - Else if scale.dim() == 2 && scale.size(0) == outer_dim / 128 && scale.size(1) == inner_dim / 128:
* - Returns BlockWise 128x128.
*
* - Otherwise:
* - Returns Error.
*/
using at::cuda::blas::ScalingType;
// Validates the scale tensors to scaled_mm
// And returns the type of scaling/which kernel to use
ScalingType get_scaling_type(
const at::Tensor& scale_a,
const at::Tensor& scale_b,
int64_t dim_m,
int64_t dim_k,
int64_t dim_n) {
// Check for BlockWise scaling (FP8_E8M0 and FP8_E4M3 types)
if ((scale_a.scalar_type() == scale_b.scalar_type()) &&
((scale_a.scalar_type() == at::kFloat8_e8m0fnu) || (scale_a.scalar_type() == at::kFloat8_e4m3fn))) {
const bool is_nvfp4 = scale_a.scalar_type() == at::kFloat8_e4m3fn;
bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
return isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.numel() == 1;
}
// cuBLAS's mxfp8 gemm: block_size is 1 scale per 32 elements
// cuBLAS's nvfp4 gemm: block_size is 1 scale per 16 unpacked elements.
const auto BLOCK_SIZE_K = is_nvfp4 ? 16 : 32;
bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2
&& scale.size(0) == t.size(0) && scale.size(1) == 1
&& scale.is_contiguous());
}
constexpr int64_t BLOCK_SIZE_MN = 128;
// 1x16 blocks for packed nvfp4 data and fp8_e4m3fn scales
bool is_blockwise_1x16_scaling(const at::Tensor& t, const at::Tensor& scale) {
// Multiply t.size(1) by 2 to adjust for fp4x2 packing
// TODO: We might want to enforce some structure on the shapes of the scale
// tensors
return (t.scalar_type() == ScalarType::Float4_e2m1fn_x2 && scale.scalar_type() == at::kFloat8_e4m3fn
&& scale.numel() == round_up(t.size(0), 128) * round_up(ceil_div(t.size(1) * 2, 16), 4)
&& scale.is_contiguous());
}
// adjust for fp4x2 packing if necessary
const auto dim_k_unpacked = is_nvfp4 ? dim_k * 2 : dim_k;
// 1x32 blocks for microscaled fp8 data and fp8_e8m0fnu scales
bool is_blockwise_1x32_scaling(const at::Tensor& t, const at::Tensor& scale) {
// TODO: We might want to enforce some structure on the shapes of the scale
// tensors
return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat8_e8m0fnu
&& scale.numel() == round_up(t.size(0), 128) * round_up(ceil_div(t.size(1), 32), 4)
&& scale.is_contiguous());
}
auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; };
auto num_k_blocks = ceil_div(dim_k_unpacked, BLOCK_SIZE_K);
auto padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4;
bool is_blockwise_1x128_scaling(const at::Tensor& t, const at::Tensor& scale) {
return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2
&& scale.size(0) == t.size(0) && scale.size(1) == ceil_div(t.size(1), 128)
&& scale.stride(0) == 1 && scale.stride(1) == t.size(0));
}
// TODO: We might want to enforce some structure on the shapes of the scale
// tensors
bool is_blockwise_128x128_scaling(const at::Tensor& t, const at::Tensor& scale) {
return (isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.dim() == 2
&& scale.size(0) == ceil_div(t.size(0), 128) && scale.size(1) == ceil_div(t.size(1), 128)
&& scale.stride(0) == round_up(ceil_div(t.size(1), 128), 4) && scale.stride(1) == 1);
}
// Check expected sizes for block-wise scaling
auto expected_a_size =
BLOCK_SIZE_MN * ceil_div(dim_m, BLOCK_SIZE_MN) * padded_num_k_blocks;
auto expected_b_size =
BLOCK_SIZE_MN * ceil_div(dim_n, BLOCK_SIZE_MN) * padded_num_k_blocks;
bool is_desired_scaling(const at::Tensor& t, const at::Tensor& scale, ScalingType desired_scaling) {
switch (desired_scaling) {
case ScalingType::TensorWise:
return is_tensorwise_scaling(t, scale);
case ScalingType::RowWise:
return is_rowwise_scaling(t, scale);
case ScalingType::BlockWise1x16:
return is_blockwise_1x16_scaling(t, scale);
case ScalingType::BlockWise1x32:
return is_blockwise_1x32_scaling(t, scale);
case ScalingType::BlockWise1x128:
return is_blockwise_1x128_scaling(t, scale);
case ScalingType::BlockWise128x128:
return is_blockwise_128x128_scaling(t, scale);
default:
TORCH_CHECK(false);
return false;
}
}
std::pair<ScalingType, ScalingType> get_joint_scaling(
std::initializer_list<std::pair<ScalingType, ScalingType>> options,
const at::Tensor& a, const at::Tensor& b,
const at::Tensor& scale_a, const at::Tensor& scale_b) {
for (auto [lhs, rhs] : options) {
if (is_desired_scaling(a, scale_a, lhs) && is_desired_scaling(b.t(), scale_b.t(), rhs)) {
return {lhs, rhs};
}
TORCH_CHECK(scale_a.numel() == expected_a_size,
"For BlockWise scaling: Expected scale_a size to be ",
expected_a_size, " but got ", scale_a.numel());
TORCH_CHECK(scale_b.numel() == expected_b_size,
"For BlockWise scaling: Expected scale_b size to be ",
expected_b_size, " but got ", scale_b.numel());
TORCH_CHECK(
scale_a.is_contiguous() && scale_b.is_contiguous(),
"For BlockWise scaling: Both scale_a and scale_b must be contiguous");
return ScalingType::BlockWise;
}
// Both Per-Tensor and Row-wise scaling expect fp32 tensors
TORCH_CHECK(
false,
"Invalid scaling configuration.\n"
"- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n"
"- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be (", a.size(0), ", 1) and scale_b should be (1, ", b.size(1), "), and both should be contiguous.\n"
"- For BlockWise 1x128 scaling, a and b should be float8, scales should be float, scale_a should be (", a.size(0), ", ", ceil_div(a.size(1), 128), ") and scale_b should be (", ceil_div(b.size(0), 128), ", ", b.size(1), "), and both should be outer-dim-major.\n"
"- For BlockWise 128x128 scaling, a and b should be float8, scales should be float, scale_a should be (", ceil_div(a.size(0), 128), ", ", ceil_div(a.size(1), 128), ") and scale_b should be (", ceil_div(b.size(0), 128), ", ", ceil_div(b.size(1), 128), "), and both should be near-inner-dim-major (with 16-byte aligned strides).\n"
"- For Blockwise 1x32 scaling, a and b should be float8, scales should be float8_e8m0fnu, scale_a should have ", round_up(a.size(0), 128) * round_up(ceil_div(a.size(1), 32), 4), " elements and scale_b should have ", round_up(b.size(1), 128) * round_up(ceil_div(b.size(0), 32), 4), " elements, and both should be contiguous.\n"
"- For Blockwise 1x16 scaling, a and b should be float4 (packed 2x), scales should be float8_e4m3fn, scale_a should have ", round_up(a.size(0), 128) * round_up(ceil_div(a.size(1) * 2, 16), 4), " elements and scale_b should have ", round_up(b.size(1), 128) * round_up(ceil_div(b.size(0) * 2, 16), 4), " elements, and both should be contiguous.\n"
"Got a.dtype()=", a.scalar_type(), ", scale_a.dtype()=", scale_a.scalar_type(), ", scale_a.size()=", scale_a.sizes(), ", scale_a.stride()=", scale_a.strides(), ", ",
"b.dtype()=", b.scalar_type(), ", scale_b.dtype()=", scale_b.scalar_type(), ", scale_b.size()=", scale_b.sizes(), " and scale_b.stride()=", scale_b.strides()
);
scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
"Both scale_a and scale_b must be float (fp32) tensors.");
// Check the singluar scale case for per-tensor scaling
if (scale_a.numel() == 1 && scale_b.numel() == 1) {
return ScalingType::TensorWise;
}
// For non-TensorWise scaling, enforce 2D input tensors
TORCH_CHECK(
scale_a.dim() == 2 && scale_b.dim() == 2,
"For non-TensorWise scaling, scale tensors must be 2-dimensional, "
"but got scale_a.dim()=",
scale_a.dim(),
" and scale_b.dim()=",
scale_b.dim());
// Check for RowWise scaling
if (scale_a.size(0) == dim_m && scale_a.size(1) == 1 &&
scale_b.size(0) == 1 && scale_b.size(1) == dim_n) {
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || \
(defined(USE_ROCM) && (defined(HIPBLASLT_VEC_EXT) || defined(HIPBLASLT_OUTER_VEC)))
TORCH_CHECK(
scale_a.is_contiguous() && scale_b.is_contiguous(),
"Both scale_a and scale_b must be contiguous for RowWise scaling.");
return ScalingType::RowWise;
#else
TORCH_CHECK(false, "Per-row scaling is not supported for this platform!");
return ScalingType::Error;
#endif
}
// If we reach here, the input doesn't match any valid scaling type
TORCH_CHECK(
false,
"Invalid scaling configuration. For TensorWise scaling, both scales should be scalar. "
"For RowWise scaling, scale_a should be (",
dim_m,
", 1) and scale_b should be (1, ",
dim_n,
"). "
"Got scale_a.size()=(",
scale_a.size(0),
", ",
scale_a.size(1),
") and ",
"scale_b.size()=(",
scale_b.size(0),
", ",
scale_b.size(1),
")");
return ScalingType::Error;
}
} // namespace
@ -1208,8 +1219,8 @@ std::pair<ScalingType, ScalingType> get_joint_scaling(
// - `mat2`: the second operand of the matrix multiply, can be type `torch.float8_e4m3fn` or `torch.float8_e5m2`
// - `bias`: the bias, can be type `torch.float16` or `torch.bfloat16`
// - `out_dtype`: the output dtype, can either be a float8 or a higher precision floating point type
// - `scale_a`: a tensor with the inverse scale of `mat1`, whose shape/strides/dtype depend on the scaling scheme
// - `scale_b`: a tensor with the inverse scale of `mat2`, whose shape/strides/dtype depend on the scaling scheme
// - `scale_a`: a scalar or 1-dimensional tensor with the inverse scale of `mat1`, only needed if `mat1` is a float8 type
// - `scale_b`: a scalar or 1-dimensional tensor with the inverse scale of `mat2`, only needed if `mat2` is a float8 type
// - `scale_result`: a scalar tensor with the scale of the output, only utilized if the output is a float8 type
// - `use_fast_accum`: if true, enables fast float8 accumulation
// - `out`: a reference to the output tensor
@ -1232,21 +1243,9 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
// Check what type of scaling we are doing based on inputs. This list is sorted
// by decreasing priority. We prefer "simpler" schemes as they are supported
// more broadly (more GPU archs, more CUDA versions) and because they are more
// efficient. This tends to matter only for small matmuls (e.g., 1x1x128).
auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling(
{
std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise),
std::make_pair(ScalingType::RowWise, ScalingType::RowWise),
std::make_pair(ScalingType::BlockWise128x128, ScalingType::BlockWise1x128),
std::make_pair(ScalingType::BlockWise1x128, ScalingType::BlockWise128x128),
std::make_pair(ScalingType::BlockWise1x128, ScalingType::BlockWise1x128),
std::make_pair(ScalingType::BlockWise1x32, ScalingType::BlockWise1x32),
std::make_pair(ScalingType::BlockWise1x16, ScalingType::BlockWise1x16)
},
mat1, mat2, scale_a, scale_b);
// Check what type of scaling we are doing based on inputs
ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat1.size(1), mat2.size(1));
TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported");
TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat),
"scale_result must be a float scalar");
@ -1317,7 +1316,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
#ifndef USE_ROCM
// We are doing row-wise scaling
auto dprops = at::cuda::getCurrentDeviceProperties();
if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise
if (scaling_choice == ScalingType::RowWise
&& (dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)) {
TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
at::cuda::detail::f8f8bf16_rowwise(
@ -1331,7 +1330,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
return out;
}
#else
if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise) {
if (scaling_choice == ScalingType::RowWise) {
// For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes.
Tensor b = mat2;
if (_scaled_mm_is_fnuz()) {
@ -1346,7 +1345,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
}
#endif
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result, scaling_choice_a, scaling_choice_b);
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
@ -1423,14 +1422,10 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
params.a_scale_ptr = args.scale_mata_ptr;
params.lda = args.lda;
params.a_dtype = args.mata->scalar_type();
params.a_scale_dtype = args.scale_mata_dtype.value();
params.a_scaling_type = args.scaling_mata_type.value();
params.b = args.matb->data_ptr();
params.b_scale_ptr = args.scale_matb_ptr;
params.ldb = args.ldb;
params.b_dtype = args.matb->scalar_type();
params.b_scale_dtype = args.scale_matb_dtype.value();
params.b_scaling_type = args.scaling_matb_type.value();
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
params.c = args.result->data_ptr();
@ -1438,6 +1433,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
params.ldc = args.result_ld;
params.c_dtype = out_dtype_;
params.use_fast_accum = use_fast_accum;
params.use_rowwise = scaling_choice == ScalingType::RowWise;
if (transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
}
@ -1471,20 +1467,19 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
args.lda,
args.mata->scalar_type(),
args.scale_mata_dtype.value(),
args.scaling_mata_type.value(),
args.matb->data_ptr(),
args.scale_matb_ptr,
args.ldb,
args.matb->scalar_type(),
args.scale_matb_dtype.value(),
args.scaling_matb_type.value(),
bias ? bias->data_ptr(): nullptr,
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
args.result->data_ptr(),
args.scale_result_ptr,
args.result_ld,
out_dtype_,
use_fast_accum);
use_fast_accum,
scaling_choice == ScalingType::RowWise);
}
return out;

View File

@ -785,7 +785,7 @@ def amax_to_scale(
if float8_dtype == e4m3_type:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype == e5m2_type:
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
@ -806,20 +806,6 @@ def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None):
return amax_to_scale(amax, float8_dtype, x.dtype)
def tensor_to_scale_block(
x: torch.Tensor,
float8_dtype: torch.dtype,
block_outer: int,
block_inner: int,
) -> tuple[torch.Tensor, torch.Tensor]:
x = x.unflatten(1, (-1, block_inner)).unflatten(0, (-1, block_outer))
amax = x.abs().amax(dim=[1, 3], keepdim=True).float()
scale = torch.finfo(float8_dtype).max / amax
x = x.mul(scale).to(float8_dtype)
x = x.flatten(2, 3).flatten(0, 1)
scale = scale.flatten(2, 3).flatten(0, 1)
return x, scale
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
# naive implementation: dq -> op -> q
x_fp32 = x.to(torch.float) / x_scale
@ -828,17 +814,6 @@ def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
return out_fp32.to(out_dtype)
def mm_float8_emulated_block(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
x = x.unflatten(1, (x_scale.shape[1], -1)).unflatten(0, (x_scale.shape[0], -1))
y = y.unflatten(1, (y_scale.shape[1], -1)).unflatten(0, (y_scale.shape[0], -1))
x_fp32 = x.to(torch.float) / x_scale[:, None, :, None]
y_fp32 = y.to(torch.float) / y_scale[:, None, :, None]
x_fp32 = x_fp32.flatten(2, 3).flatten(0, 1)
y_fp32 = y_fp32.flatten(2, 3).flatten(0, 1)
out_fp32 = torch.mm(x_fp32, y_fp32)
return out_fp32.to(out_dtype)
def addmm_float8_unwrapped(
a_data: torch.Tensor,
a_scale: torch.Tensor,
@ -1262,7 +1237,11 @@ class TestFP8Matmul(TestCase):
y_fp8 = y.to(e4m3_type).t()
with self.assertRaisesRegex(
RuntimeError, re.escape("Invalid scaling configuration")
RuntimeError,
re.escape(
"For RowWise scaling, scale_a should be (1024, 1) and scale_b "
"should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)"
),
):
torch._scaled_mm(
x_fp8,
@ -1273,7 +1252,11 @@ class TestFP8Matmul(TestCase):
)
with self.assertRaisesRegex(
RuntimeError, re.escape("Invalid scaling configuration")
RuntimeError,
re.escape(
" For RowWise scaling, scale_a should be (1024, 1) and scale_b "
"should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)"
),
):
torch._scaled_mm(
x_fp8,
@ -1283,18 +1266,22 @@ class TestFP8Matmul(TestCase):
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError, re.escape("Invalid scaling configuration")
RuntimeError,
re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"),
):
torch._scaled_mm(
x_fp8,
y_fp8,
scale_a=torch.ones((M), device="cuda"),
scale_b=torch.ones((N, N, 1), device="cuda"),
scale_b=torch.ones((N, N), device="cuda"),
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError, re.escape("Invalid scaling configuration")
RuntimeError,
re.escape(
"Both scale_a and scale_b must be contiguous for RowWise scaling."
),
):
torch._scaled_mm(
x_fp8,
@ -1359,58 +1346,6 @@ class TestFP8Matmul(TestCase):
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not SM90OrLater, "cuBLAS blockwise scaling requires sm90+")
@unittest.skipIf(
_get_torch_cuda_version() < (12, 9),
"cuBLAS blockwise scaling added in CUDA 12.9",
)
@parametrize("output_dtype", [torch.bfloat16, torch.float32])
@parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)])
def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_block):
torch.manual_seed(42)
x = torch.randn(256, 512, device="cuda", dtype=output_dtype).pow(3)
y = torch.randn(768, 512, device="cuda", dtype=output_dtype).pow(3)
x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128)
y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128)
# 1x128 blocks need scales to be outer-dim-major
if lhs_block == 1:
x_scales = x_scales.t().contiguous().t()
if rhs_block == 1:
y_scales = y_scales.t().contiguous().t()
# Calculate actual F8 mm
out_scaled_mm = mm_float8(
x_fp8, y_fp8.t(), a_scale=x_scales, b_scale=y_scales.t(), output_dtype=output_dtype
)
# Calculate emulated F8 mm
out_emulated = mm_float8_emulated_block(
x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype
)
cosine_sim = torch.nn.functional.cosine_similarity(
out_scaled_mm.flatten().float(), out_emulated.flatten().float(), dim=0
)
self.assertGreaterEqual(float(cosine_sim), 0.999)
if output_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 6e-1, 7e-2
else:
atol, rtol = 7e-1, 2e-3
self.assertEqual(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
# One last check against the full-precision reference, to ensure we
# didn't mess up the scaling itself and made the test trivial.
cosine_sim = torch.nn.functional.cosine_similarity(
out_scaled_mm.flatten().float(), (x @ y.t()).flatten().float(), dim=0
)
self.assertGreaterEqual(float(cosine_sim), 0.999)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@parametrize("which_dim_zero", [0, 1, 2])
@parametrize("use_torch_compile", [False, True])