From 066f818eea00e9cfde1c8efbef70190c42453f9b Mon Sep 17 00:00:00 2001 From: Simon Layton Date: Wed, 15 Oct 2025 06:37:31 -0700 Subject: [PATCH] Refactor and unify v1/v2 _scaled_mm codes (#165436) Summary: * Refactor out some core routines (scaled_gemm, auto-tuned scaled_gemm) * Unify v1/v2 dispatch calls where possible * Simplify call pattern w.r.t. CUDA/ROCM for easier readability. Test Plan: ``` pytest -svv test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton Pull Request resolved: https://github.com/pytorch/pytorch/pull/165436 Approved by: https://github.com/drisspg --- aten/src/ATen/native/cuda/Blas.cpp | 624 +++++++++++++---------------- test/test_scaled_matmul_cuda.py | 55 +-- 2 files changed, 305 insertions(+), 374 deletions(-) diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index c95145f0dd1b..67a549165ada 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -1230,8 +1230,205 @@ std::pair get_joint_scaling( ); } +Tensor& +_tunable_scaled_gemm_rocm( + cublasCommonArgs& args, + const Tensor& mat1, const Tensor& mat2, + const Tensor& scale_a, const Tensor& scale_b, + const ScalingType scaling_choice_a, const ScalingType scaling_choice_b, + const std::optional& bias, + const bool use_fast_accum, + const at::ScalarType out_dtype, + Tensor& out) { +#ifdef USE_ROCM +#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \ + if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \ + if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \ + static at::cuda::tunable::ScaledGemmTunableOp< \ + at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \ + BLASOP_A, BLASOP_B> scaledgemm{}; \ + scaledgemm(¶ms); \ + } \ + else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \ + static at::cuda::tunable::ScaledGemmTunableOp< \ + at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \ + BLASOP_A, BLASOP_B> scaledgemm{}; \ + scaledgemm(¶ms); \ + } \ + } \ + else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \ + if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \ + static at::cuda::tunable::ScaledGemmTunableOp< \ + at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \ + BLASOP_A, BLASOP_B> scaledgemm{}; \ + scaledgemm(¶ms); \ + } \ + else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \ + static at::cuda::tunable::ScaledGemmTunableOp< \ + at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \ + BLASOP_A, BLASOP_B> scaledgemm{}; \ + scaledgemm(¶ms); \ + } \ + } \ + else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) { \ + if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \ + static at::cuda::tunable::ScaledGemmTunableOp< \ + at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t, \ + BLASOP_A, BLASOP_B> scaledgemm{}; \ + scaledgemm(¶ms); \ + } \ + else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \ + static at::cuda::tunable::ScaledGemmTunableOp< \ + at::Float8_e4m3fn, at::Float8_e5m2, scalar_t, \ + BLASOP_A, BLASOP_B> scaledgemm{}; \ + scaledgemm(¶ms); \ + } \ + } \ + else if (mat1.scalar_type() == ScalarType::Float8_e5m2) { \ + if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \ + static at::cuda::tunable::ScaledGemmTunableOp< \ + at::Float8_e5m2, at::Float8_e4m3fn, scalar_t, \ + BLASOP_A, BLASOP_B> scaledgemm{}; \ + scaledgemm(¶ms); \ + } \ + else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \ + static at::cuda::tunable::ScaledGemmTunableOp< \ + at::Float8_e5m2, at::Float8_e5m2, scalar_t, \ + BLASOP_A, BLASOP_B> scaledgemm{}; \ + scaledgemm(¶ms); \ + } \ + } + AT_DISPATCH_V2(out_dtype, "_tunable_scaled_gemm", AT_WRAP([&] { + bool transa_ = ((args.transa != 'n') && (args.transa != 'N')); + bool transb_ = ((args.transb != 'n') && (args.transb != 'N')); + at::cuda::tunable::ScaledGemmParams params; + params.transa = args.transa; + params.transb = args.transb; + params.m = args.m; + params.n = args.n; + params.k = args.k; + params.a = args.mata->data_ptr(); + params.a_scale_ptr = args.scale_mata_ptr; + params.a_scale_dtype = args.scale_mata_dtype.value(); + params.lda = args.lda; + params.a_dtype = args.mata->scalar_type(); + params.a_scale_dtype = args.scale_mata_dtype.value(); + params.a_scaling_type = args.scaling_mata_type.value(); + params.b = args.matb->data_ptr(); + params.b_scale_ptr = args.scale_matb_ptr; + params.b_scale_dtype = args.scale_matb_dtype.value(); + params.ldb = args.ldb; + params.b_dtype = args.matb->scalar_type(); + params.b_scale_dtype = args.scale_matb_dtype.value(); + params.b_scaling_type = args.scaling_matb_type.value(); + params.bias_ptr = bias ? bias->data_ptr(): nullptr; + params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype) ? at::ScalarType::Half : out_dtype; + params.c = args.result->data_ptr(); + params.c_scale_ptr = args.scale_result_ptr; + params.ldc = args.result_ld; + params.c_dtype = out_dtype; + params.use_fast_accum = use_fast_accum; + if (transa_ && transb_) { + TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T) + } + else if (transa_ && !transb_) { + TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N) + } + else if (!transa_ && transb_) { + TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T) + } + else if (!transa_ && !transb_) { + TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N) + } + else { + TORCH_CHECK(false, "unreachable"); + } + }), + kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES)); +#undef TUNABLE_DISPATCH + return out; +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, "_scaled_gemm_rocm only callable on ROCM devices"); +#endif +} + +Tensor& +_scaled_gemm( + const Tensor& mat1, const Tensor& mat2, + const Tensor& scale_a, const Tensor& scale_b, + const ScalingType scaling_choice_a, const ScalingType scaling_choice_b, + const std::optional& bias, + const bool use_fast_accum, + Tensor& out) { + cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b); + const auto out_dtype_ = args.result->scalar_type(); + TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); + +// ROCM enables the TunableOp path only +// but can fallback to at::cuda::blas::scaled_gemm +#ifdef USE_ROCM + auto tuning_ctx = at::cuda::tunable::getTuningContext(); + bool tunable_op_enabled = tuning_ctx->IsTunableOpEnabled(); +#else + bool tunable_op_enabled = false; +#endif + if (tunable_op_enabled) { + // Only available on ROCM + return _tunable_scaled_gemm_rocm( + args, + mat1, mat2, + scale_a, scale_b, + scaling_choice_a, scaling_choice_b, + bias, + use_fast_accum, + out_dtype_, + out); + } + else + { + at::cuda::blas::scaled_gemm( + args.transa, + args.transb, + args.m, + args.n, + args.k, + args.mata->data_ptr(), + args.scale_mata_ptr, + args.lda, + args.mata->scalar_type(), + args.scale_mata_dtype.value(), + args.scaling_mata_type.value(), + args.matb->data_ptr(), + args.scale_matb_ptr, + args.ldb, + args.matb->scalar_type(), + args.scale_matb_dtype.value(), + args.scaling_matb_type.value(), + bias ? bias->data_ptr(): nullptr, + bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_, + args.result->data_ptr(), + args.scale_result_ptr, + args.result_ld, + out_dtype_, + use_fast_accum); + return out; + } +} + } // namespace +// NOTE(slayton58): This is defined as part of the _v2 code (way) below - declare the signature here +// to help cleanup v1 call structure. +Tensor& +_scaled_rowwise_rowwise( + const Tensor&, const Tensor&, + const Tensor&, const Tensor&, + const std::optional&, + const c10::ScalarType, + bool, + Tensor&); + + // Computes matrix multiply + bias while applying scaling to input and output matrices // Scales are only applicable when matrices are of Float8 type and assumed to be equal to 1.0 by default. // If output matrix type is 16 or 32-bit type, scale_result is not applied. @@ -1309,7 +1506,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, TORCH_CHECK(isFloat8Type(mat2.scalar_type()) || mat2.scalar_type() == ScalarType::Float4_e2m1fn_x2, "Expected mat2 to be Float8 or Float4_x2 matrix got ", mat2.scalar_type()); #ifndef USE_ROCM // Type restrictions imposed by CuBLASLt as of CUDA-12.1 - TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2, + TORCH_CHECK_VALUE(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2, "Multiplication of two Float8_e5m2 matrices is not supported"); #endif if (use_fast_accum) { @@ -1375,41 +1572,44 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, // NVIDIA's cuBLAS only started supporting row-wise scaling in version 12.9, // and only for compute capability 9.0+. In other cases we use CUTLASS. -#ifndef USE_ROCM // We are doing row-wise scaling - auto dprops = at::cuda::getCurrentDeviceProperties(); - if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise - && ((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900) - // cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales - || (dprops->major >= 10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty())))) { - TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling."); - at::cuda::detail::f8f8bf16_rowwise( - mat1, - mat2, - scale_a, - scale_b, - bias, - use_fast_accum, - out); - return out; - } -#else if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise) { +#ifndef USE_ROCM + auto dprops = at::cuda::getCurrentDeviceProperties(); + if ((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900) + // cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales + || (dprops->major >= 10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty()))) { + TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling."); + return _scaled_rowwise_rowwise( + mat1, + mat2, + scale_a, + scale_b, + bias, + out.scalar_type(), + use_fast_accum, + out); + } +#else // For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes. Tensor b = mat2; if (_scaled_mm_is_fnuz()) { - TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fnuz); + TORCH_CHECK_VALUE(b.dtype() == at::kFloat8_e4m3fnuz, + "Expected b.dtype() == at::kFloat8_e4m3fnuz, got: ", b.dtype()); } else { - TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fn); + TORCH_CHECK_VALUE(b.dtype() == at::kFloat8_e4m3fn, + "Expected b.dtype() == at::kFloat8_e4m3fn, got: ", b.dtype()); } // Until more than bf16 is supported. - TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16, + TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16, "hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type()); +#endif } else if (scaling_choice_a == ScalingType::BlockWise1x32 && scaling_choice_b == ScalingType::BlockWise1x32) { +#ifdef USE_ROCM #if ROCM_VERSION >= 70000 - TORCH_CHECK(at::detail::getCUDAHooks().isGPUArch({"gfx950"}), + TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}), "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950"); int packed_factor = 1; @@ -1418,163 +1618,20 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, // effectively packing two elements into one byte. packed_factor = 2; } - TORCH_CHECK(mat1.size(0) % 16 == 0 && (mat1.size(1) * packed_factor) % 128 == 0 && + TORCH_CHECK_VALUE(mat1.size(0) % 16 == 0 && (mat1.size(1) * packed_factor) % 128 == 0 && mat2.size(1) % 16 == 0, "M, N must be multiples of 16 and K must be multiple of 128 for block-wise scaling"); - TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16 || + TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 || out.scalar_type() == ScalarType::Half, "Block-wise scaling only supports BFloat16 or Half output types"); #else - TORCH_CHECK(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later"); + TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later"); +#endif #endif } -#endif - cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result, scaling_choice_a, scaling_choice_b); - const auto out_dtype_ = args.result->scalar_type(); - TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); - -#ifdef USE_ROCM - auto tuning_ctx = at::cuda::tunable::getTuningContext(); - if (tuning_ctx->IsTunableOpEnabled()) { -#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \ - if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \ - if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \ - static at::cuda::tunable::ScaledGemmTunableOp< \ - at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \ - BLASOP_A, BLASOP_B> scaledgemm{}; \ - scaledgemm(¶ms); \ - } \ - else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \ - static at::cuda::tunable::ScaledGemmTunableOp< \ - at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \ - BLASOP_A, BLASOP_B> scaledgemm{}; \ - scaledgemm(¶ms); \ - } \ - } \ - else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \ - if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \ - static at::cuda::tunable::ScaledGemmTunableOp< \ - at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \ - BLASOP_A, BLASOP_B> scaledgemm{}; \ - scaledgemm(¶ms); \ - } \ - else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \ - static at::cuda::tunable::ScaledGemmTunableOp< \ - at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \ - BLASOP_A, BLASOP_B> scaledgemm{}; \ - scaledgemm(¶ms); \ - } \ - } \ - else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) { \ - if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \ - static at::cuda::tunable::ScaledGemmTunableOp< \ - at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t, \ - BLASOP_A, BLASOP_B> scaledgemm{}; \ - scaledgemm(¶ms); \ - } \ - else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \ - static at::cuda::tunable::ScaledGemmTunableOp< \ - at::Float8_e4m3fn, at::Float8_e5m2, scalar_t, \ - BLASOP_A, BLASOP_B> scaledgemm{}; \ - scaledgemm(¶ms); \ - } \ - } \ - else if (mat1.scalar_type() == ScalarType::Float8_e5m2) { \ - if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \ - static at::cuda::tunable::ScaledGemmTunableOp< \ - at::Float8_e5m2, at::Float8_e4m3fn, scalar_t, \ - BLASOP_A, BLASOP_B> scaledgemm{}; \ - scaledgemm(¶ms); \ - } \ - else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \ - static at::cuda::tunable::ScaledGemmTunableOp< \ - at::Float8_e5m2, at::Float8_e5m2, scalar_t, \ - BLASOP_A, BLASOP_B> scaledgemm{}; \ - scaledgemm(¶ms); \ - } \ - } - AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] { - bool transa_ = ((args.transa != 'n') && (args.transa != 'N')); - bool transb_ = ((args.transb != 'n') && (args.transb != 'N')); - at::cuda::tunable::ScaledGemmParams params; - params.transa = args.transa; - params.transb = args.transb; - params.m = args.m; - params.n = args.n; - params.k = args.k; - params.a = args.mata->data_ptr(); - params.a_scale_ptr = args.scale_mata_ptr; - params.a_scale_dtype = args.scale_mata_dtype.value(); - params.lda = args.lda; - params.a_dtype = args.mata->scalar_type(); - params.a_scale_dtype = args.scale_mata_dtype.value(); - params.a_scaling_type = args.scaling_mata_type.value(); - params.b = args.matb->data_ptr(); - params.b_scale_ptr = args.scale_matb_ptr; - params.b_scale_dtype = args.scale_matb_dtype.value(); - params.ldb = args.ldb; - params.b_dtype = args.matb->scalar_type(); - params.b_scale_dtype = args.scale_matb_dtype.value(); - params.b_scaling_type = args.scaling_matb_type.value(); - params.bias_ptr = bias ? bias->data_ptr(): nullptr; - params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_; - params.c = args.result->data_ptr(); - params.c_scale_ptr = args.scale_result_ptr; - params.ldc = args.result_ld; - params.c_dtype = out_dtype_; - params.use_fast_accum = use_fast_accum; - if (transa_ && transb_) { - TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T) - } - else if (transa_ && !transb_) { - TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N) - } - else if (!transa_ && transb_) { - TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T) - } - else if (!transa_ && !transb_) { - TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N) - } - else { - TORCH_CHECK(false, "unreachable"); - } - }), - kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES)); -#undef TUNABLE_DISPATCH - } - else -#endif - { - at::cuda::blas::scaled_gemm( - args.transa, - args.transb, - args.m, - args.n, - args.k, - args.mata->data_ptr(), - args.scale_mata_ptr, - args.lda, - args.mata->scalar_type(), - args.scale_mata_dtype.value(), - args.scaling_mata_type.value(), - args.matb->data_ptr(), - args.scale_matb_ptr, - args.ldb, - args.matb->scalar_type(), - args.scale_matb_dtype.value(), - args.scaling_matb_type.value(), - bias ? bias->data_ptr(): nullptr, - bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_, - args.result->data_ptr(), - args.scale_result_ptr, - args.result_ld, - out_dtype_, - use_fast_accum); - } - - return out; + return _scaled_gemm(mat1, mat2, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); } namespace { @@ -1914,159 +1971,6 @@ std::array, 8> { "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE }, { "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}}; -Tensor& -_cutlass_scaled_gemm( - const Tensor& mat1, const Tensor& mat2, - const Tensor& scale_a, const Tensor& scale_b, - const ScalingType scaling_choice_a, const ScalingType scaling_choice_b, - const std::optional& bias, - const bool use_fast_accum, - Tensor& out) { - cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b); - const auto out_dtype_ = args.result->scalar_type(); - TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); - -#ifdef USE_ROCM - auto tuning_ctx = at::cuda::tunable::getTuningContext(); - if (tuning_ctx->IsTunableOpEnabled()) { -#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \ - if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \ - if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \ - static at::cuda::tunable::ScaledGemmTunableOp< \ - at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \ - BLASOP_A, BLASOP_B> scaledgemm{}; \ - scaledgemm(¶ms); \ - } \ - else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \ - static at::cuda::tunable::ScaledGemmTunableOp< \ - at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \ - BLASOP_A, BLASOP_B> scaledgemm{}; \ - scaledgemm(¶ms); \ - } \ - } \ - else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \ - if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \ - static at::cuda::tunable::ScaledGemmTunableOp< \ - at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \ - BLASOP_A, BLASOP_B> scaledgemm{}; \ - scaledgemm(¶ms); \ - } \ - else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \ - static at::cuda::tunable::ScaledGemmTunableOp< \ - at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \ - BLASOP_A, BLASOP_B> scaledgemm{}; \ - scaledgemm(¶ms); \ - } \ - } \ - else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) { \ - if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \ - static at::cuda::tunable::ScaledGemmTunableOp< \ - at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t, \ - BLASOP_A, BLASOP_B> scaledgemm{}; \ - scaledgemm(¶ms); \ - } \ - else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \ - static at::cuda::tunable::ScaledGemmTunableOp< \ - at::Float8_e4m3fn, at::Float8_e5m2, scalar_t, \ - BLASOP_A, BLASOP_B> scaledgemm{}; \ - scaledgemm(¶ms); \ - } \ - } \ - else if (mat1.scalar_type() == ScalarType::Float8_e5m2) { \ - if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \ - static at::cuda::tunable::ScaledGemmTunableOp< \ - at::Float8_e5m2, at::Float8_e4m3fn, scalar_t, \ - BLASOP_A, BLASOP_B> scaledgemm{}; \ - scaledgemm(¶ms); \ - } \ - else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \ - static at::cuda::tunable::ScaledGemmTunableOp< \ - at::Float8_e5m2, at::Float8_e5m2, scalar_t, \ - BLASOP_A, BLASOP_B> scaledgemm{}; \ - scaledgemm(¶ms); \ - } \ - } - AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] { - bool transa_ = ((args.transa != 'n') && (args.transa != 'N')); - bool transb_ = ((args.transb != 'n') && (args.transb != 'N')); - at::cuda::tunable::ScaledGemmParams params; - params.transa = args.transa; - params.transb = args.transb; - params.m = args.m; - params.n = args.n; - params.k = args.k; - params.a = args.mata->data_ptr(); - params.a_scale_ptr = args.scale_mata_ptr; - params.a_scale_dtype = args.scale_mata_dtype.value(); - params.lda = args.lda; - params.a_dtype = args.mata->scalar_type(); - params.a_scale_dtype = args.scale_mata_dtype.value(); - params.a_scaling_type = args.scaling_mata_type.value(); - params.b = args.matb->data_ptr(); - params.b_scale_ptr = args.scale_matb_ptr; - params.b_scale_dtype = args.scale_matb_dtype.value(); - params.ldb = args.ldb; - params.b_dtype = args.matb->scalar_type(); - params.b_scale_dtype = args.scale_matb_dtype.value(); - params.b_scaling_type = args.scaling_matb_type.value(); - params.bias_ptr = bias ? bias->data_ptr(): nullptr; - params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_; - params.c = args.result->data_ptr(); - params.c_scale_ptr = args.scale_result_ptr; - params.ldc = args.result_ld; - params.c_dtype = out_dtype_; - params.use_fast_accum = use_fast_accum; - if (transa_ && transb_) { - TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T) - } - else if (transa_ && !transb_) { - TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N) - } - else if (!transa_ && transb_) { - TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T) - } - else if (!transa_ && !transb_) { - TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N) - } - else { - TORCH_CHECK(false, "unreachable"); - } - }), - kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES)); -#undef TUNABLE_DISPATCH - } - else -#endif - { - at::cuda::blas::scaled_gemm( - args.transa, - args.transb, - args.m, - args.n, - args.k, - args.mata->data_ptr(), - args.scale_mata_ptr, - args.lda, - args.mata->scalar_type(), - args.scale_mata_dtype.value(), - args.scaling_mata_type.value(), - args.matb->data_ptr(), - args.scale_matb_ptr, - args.ldb, - args.matb->scalar_type(), - args.scale_matb_dtype.value(), - args.scaling_matb_type.value(), - bias ? bias->data_ptr(): nullptr, - bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_, - args.result->data_ptr(), - args.scale_result_ptr, - args.result_ld, - out_dtype_, - use_fast_accum); - } - return out; -} - Tensor& _scaled_tensorwise_tensorwise( const Tensor& mat_a, const Tensor& mat_b, @@ -2086,7 +1990,7 @@ _scaled_tensorwise_tensorwise( auto scaling_choice_a = ScalingType::TensorWise; auto scaling_choice_b = ScalingType::TensorWise; - _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); + _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); return out; } @@ -2122,7 +2026,7 @@ _scaled_rowwise_rowwise( if (((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900) // cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales || (dprops->major == 10 && (scale_a.sizes().size() || scale_b.sizes().size())))) { - TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling."); + TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling."); at::cuda::detail::f8f8bf16_rowwise( mat_a, mat_b, @@ -2148,11 +2052,38 @@ _scaled_rowwise_rowwise( "hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type()); #endif - _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); + _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); return out; } +// Check the shapes & sizes of scales for deepseek-style (1x128, 128x128) scaling. +// Wraps check_size_stride for easier integration, correctly handles cases where a dimension of the scale == 1, +// and strides become somewhat meaningless +void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const ScalingType scale_type) { + if (scale_type == ScalingType::BlockWise1x128) { + TORCH_CHECK_VALUE(check_size_stride(scale, 0, t.size(0), 1), + "at dim=0 scale should have ", t.size(0), "elements and stride(0) ", 1, "if ", t.size(0), " > 1 - Got: ", + "shape=", scale.sizes(), ", stride=", scale.strides()); + auto expected_size = ceil_div(t.size(1), 128); + TORCH_CHECK_VALUE(check_size_stride(scale, 1, expected_size, t.size(0)), + "at dim=1 scale should have ", expected_size, "elements and stride ", t.size(0), "if ", expected_size, " > 1 - Got: ", + "shape=", scale.sizes(), ", stride=", scale.strides()); + } else if (scale_type == ScalingType::BlockWise128x128) { + TORCH_CHECK_VALUE(check_size_stride( + scale, + 0, + ceil_div(t.size(0), 128), + ceil_div(t.size(1), 128)), + "at dim=0 scale should have ", ceil_div(t.size(0), 128), "elements and stride(0) ", ceil_div(t.size(1), 128), "if ", ceil_div(t.size(0), 128), " > 1 - Got: ", + "shape=", scale.sizes(), ", stride=", scale.strides()); + TORCH_CHECK(check_size_stride( + scale, 1, ceil_div(t.size(1), 128), 1), + "at dim=1 scale should have ", ceil_div(t.size(1), 128), "elements and stride(1) ", 1, "if ", ceil_div(t.size(1), 128), " > 1 - Got: ", + "shape=", scale.sizes(), ", stride=", scale.strides()); + } +} + Tensor& _scaled_block1x128_block1x128( const Tensor& mat_a, const Tensor& mat_b, @@ -2170,15 +2101,14 @@ _scaled_block1x128_block1x128( TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat, "scale_b must have shape ", ceil_div(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes()) - TORCH_CHECK(scale_a.stride(0) == 1, "expected scale_a.stride(0) to be 1, but got ", scale_a.stride(0)); - TORCH_CHECK(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1)); - TORCH_CHECK(scale_b.stride(0) == scale_b.size(1), - "expected scale_b.stride(0) to be ", scale_b.size(1), ", but got ", scale_b.size(1)); - auto scaling_choice_a = ScalingType::BlockWise1x128; auto scaling_choice_b = ScalingType::BlockWise1x128; - _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); + // Check scale strides (including stride=1 small cases) + _check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a); + _check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b); + + _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); return out; } @@ -2193,6 +2123,8 @@ _scaled_block128x128_block1x128( Tensor& out) { // Restrictions: // A, B are FP8, scales are fp32, shape K//128 + std::cout << "mat_b: " << mat_b.dim() << ", " << mat_b.sizes() << ", " << mat_b.strides() << std::endl; + std::cout << "scale_b: " << scale_b.dim() << ", " << scale_b.sizes() << ", " << scale_b.strides() << std::endl; TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()); TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat, @@ -2200,15 +2132,14 @@ _scaled_block128x128_block1x128( TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat, "scale_b must have shape ", ceil_div(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes()) - TORCH_CHECK_VALUE(scale_a.stride(1) == 1, "expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1)); - TORCH_CHECK_VALUE(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1)); - TORCH_CHECK_VALUE(scale_b.stride(0) == scale_b.size(1), - "expected scale_b.stride(0) to be ", scale_b.size(1), ", but got ", scale_b.stride(0)); - auto scaling_choice_a = ScalingType::BlockWise128x128; auto scaling_choice_b = ScalingType::BlockWise1x128; - _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); + // Check scale strides (including stride=1 small cases) + _check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a); + _check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b); + + _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); return out; } @@ -2230,15 +2161,14 @@ _scaled_block1x128_block128x128( TORCH_CHECK_VALUE(scale_b.sizes()[0] == mat_b.sizes()[0] / 128 && scale_b.sizes()[1] == mat_b.sizes()[1] / 128 && scale_b.scalar_type() == kFloat, "scale_b must have shape ", mat_b.sizes()[0] / 128, " x ", mat_b.sizes()[1] / 128, " Float elements, got ", scale_b.sizes()) - TORCH_CHECK_VALUE(scale_a.stride(0) == 1, "expected scale_a.stride(0) to be 1, but got ", scale_a.stride(0)); - TORCH_CHECK_VALUE(scale_b.stride(0) == 1, "expected scale_b.stride(0) to be 1, but got ", scale_b.stride(0)); - TORCH_CHECK_VALUE(scale_b.stride(1) == scale_b.size(0), - "expected scale_b.stride(1) to be ", scale_b.size(0), ", but got ", scale_b.stride(1)); - auto scaling_choice_a = ScalingType::BlockWise1x128; auto scaling_choice_b = ScalingType::BlockWise128x128; - _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); + // Check scale strides (including stride=1 small cases) + _check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a); + _check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b); + + _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); return out; } @@ -2292,7 +2222,7 @@ _scaled_mxfp8_mxfp8( #endif #endif - return _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); + return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); } Tensor& @@ -2329,7 +2259,7 @@ _scaled_nvfp4_nvfp4( auto scaling_choice_a = ScalingType::BlockWise1x16; auto scaling_choice_b = ScalingType::BlockWise1x16; - return _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); + return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); } diff --git a/test/test_scaled_matmul_cuda.py b/test/test_scaled_matmul_cuda.py index e58f3ea8d960..bd7147112e8c 100644 --- a/test/test_scaled_matmul_cuda.py +++ b/test/test_scaled_matmul_cuda.py @@ -311,18 +311,6 @@ def addmm_float8_unwrapped( ) return output -def mm_float8( - a: torch.Tensor, - b: torch.Tensor, - a_scale: torch.Tensor, - b_scale: torch.Tensor, - output_dtype: torch.dtype, # output dtype - output_scale: Optional[torch.Tensor] = None, # output scale, precomputed -) -> torch.Tensor: - return addmm_float8_unwrapped( - a, a_scale, b, b_scale, output_dtype, output_scale - ) - def to_fp8_saturated( x: torch.Tensor, fp8_dtype: torch.dtype @@ -674,12 +662,12 @@ class TestFP8Matmul(TestCase): y_fp8 = to_fp8_saturated(y * y_scale, input_dtype) # Calculate actual F8 mm - out_scaled_mm = mm_float8( + out_scaled_mm = scaled_mm_wrap( x_fp8, y_fp8, - a_scale=x_scale, - b_scale=y_scale, - output_dtype=output_dtype + scale_a=x_scale.reciprocal(), + scale_b=y_scale.reciprocal(), + out_dtype=output_dtype ) # Calculate emulated F8 mm @@ -726,12 +714,12 @@ class TestFP8Matmul(TestCase): y_fp8 = to_fp8_saturated(y * y_scale, input_dtype) # Calculate actual F8 mm - out_scaled_mm = mm_float8( + out_scaled_mm = scaled_mm_wrap( x_fp8, y_fp8, - a_scale=x_scale, - b_scale=y_scale, - output_dtype=output_dtype + scale_a=x_scale.reciprocal(), + scale_b=y_scale.reciprocal(), + out_dtype=output_dtype ) # Calculate emulated F8 mm @@ -993,8 +981,12 @@ class TestFP8Matmul(TestCase): def test(): # Calculate actual F8 mm - out_scaled_mm = mm_float8( - x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype + out_scaled_mm = scaled_mm_wrap( + x_fp8, + y_fp8, + scale_a=x_scales.reciprocal(), + scale_b=y_scales.reciprocal(), + out_dtype=output_dtype ) # Calculate emulated F8 mm @@ -1013,7 +1005,7 @@ class TestFP8Matmul(TestCase): # rowwise on SM 9.0 if torch.cuda.get_device_capability() != (9, 0) and output_dtype == torch.float: with self.assertRaisesRegex( - RuntimeError, + ValueError, "Only bf16 high precision output types are supported for row-wise scaling." ): test() @@ -1105,16 +1097,25 @@ class TestFP8Matmul(TestCase): # 1x128 blocks need scales to be outer-dim-major if lhs_block == 1: x_scales = x_scales.t().contiguous().t() + lhs_recipe = ScalingType.BlockWise1x128 + else: + lhs_recipe = ScalingType.BlockWise128x128 + if rhs_block == 1: y_scales = y_scales.t().contiguous().t() + rhs_recipe = ScalingType.BlockWise1x128 + else: + rhs_recipe = ScalingType.BlockWise128x128 # Verify that actual F8 mm doesn't error - mm_float8( + scaled_mm_wrap( x_fp8, y_fp8.t(), - a_scale=x_scales, - b_scale=y_scales.t(), - output_dtype=output_dtype, + scale_a=x_scales, + scale_recipe_a=lhs_recipe, + scale_b=y_scales.t(), + scale_recipe_b=rhs_recipe, + out_dtype=output_dtype, ) # Verify that emulated F8 mm doesn't error