Refactor and unify v1/v2 _scaled_mm codes (#165436)

Summary:

* Refactor out some core routines (scaled_gemm, auto-tuned scaled_gemm)
* Unify v1/v2 dispatch calls where possible
* Simplify call pattern w.r.t. CUDA/ROCM for easier readability.

Test Plan:

```
pytest -svv test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165436
Approved by: https://github.com/drisspg
This commit is contained in:
Simon Layton
2025-10-15 06:37:31 -07:00
committed by PyTorch MergeBot
parent 14af1dc3da
commit 066f818eea
2 changed files with 305 additions and 374 deletions

View File

@ -1230,8 +1230,205 @@ std::pair<ScalingType, ScalingType> get_joint_scaling(
);
}
Tensor&
_tunable_scaled_gemm_rocm(
cublasCommonArgs& args,
const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a, const Tensor& scale_b,
const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
const std::optional<Tensor>& bias,
const bool use_fast_accum,
const at::ScalarType out_dtype,
Tensor& out) {
#ifdef USE_ROCM
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \
if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fn, at::Float8_e5m2, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2, at::Float8_e4m3fn, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2, at::Float8_e5m2, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
}
AT_DISPATCH_V2(out_dtype, "_tunable_scaled_gemm", AT_WRAP([&] {
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
at::cuda::tunable::ScaledGemmParams<scalar_t> params;
params.transa = args.transa;
params.transb = args.transb;
params.m = args.m;
params.n = args.n;
params.k = args.k;
params.a = args.mata->data_ptr();
params.a_scale_ptr = args.scale_mata_ptr;
params.a_scale_dtype = args.scale_mata_dtype.value();
params.lda = args.lda;
params.a_dtype = args.mata->scalar_type();
params.a_scale_dtype = args.scale_mata_dtype.value();
params.a_scaling_type = args.scaling_mata_type.value();
params.b = args.matb->data_ptr();
params.b_scale_ptr = args.scale_matb_ptr;
params.b_scale_dtype = args.scale_matb_dtype.value();
params.ldb = args.ldb;
params.b_dtype = args.matb->scalar_type();
params.b_scale_dtype = args.scale_matb_dtype.value();
params.b_scaling_type = args.scaling_matb_type.value();
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype) ? at::ScalarType::Half : out_dtype;
params.c = args.result->data_ptr();
params.c_scale_ptr = args.scale_result_ptr;
params.ldc = args.result_ld;
params.c_dtype = out_dtype;
params.use_fast_accum = use_fast_accum;
if (transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
}
else if (transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
}
else if (!transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
}
else if (!transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
}
else {
TORCH_CHECK(false, "unreachable");
}
}),
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
#undef TUNABLE_DISPATCH
return out;
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, "_scaled_gemm_rocm only callable on ROCM devices");
#endif
}
Tensor&
_scaled_gemm(
const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a, const Tensor& scale_b,
const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
const std::optional<Tensor>& bias,
const bool use_fast_accum,
Tensor& out) {
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
// ROCM enables the TunableOp path only
// but can fallback to at::cuda::blas::scaled_gemm
#ifdef USE_ROCM
auto tuning_ctx = at::cuda::tunable::getTuningContext();
bool tunable_op_enabled = tuning_ctx->IsTunableOpEnabled();
#else
bool tunable_op_enabled = false;
#endif
if (tunable_op_enabled) {
// Only available on ROCM
return _tunable_scaled_gemm_rocm(
args,
mat1, mat2,
scale_a, scale_b,
scaling_choice_a, scaling_choice_b,
bias,
use_fast_accum,
out_dtype_,
out);
}
else
{
at::cuda::blas::scaled_gemm(
args.transa,
args.transb,
args.m,
args.n,
args.k,
args.mata->data_ptr(),
args.scale_mata_ptr,
args.lda,
args.mata->scalar_type(),
args.scale_mata_dtype.value(),
args.scaling_mata_type.value(),
args.matb->data_ptr(),
args.scale_matb_ptr,
args.ldb,
args.matb->scalar_type(),
args.scale_matb_dtype.value(),
args.scaling_matb_type.value(),
bias ? bias->data_ptr(): nullptr,
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
args.result->data_ptr(),
args.scale_result_ptr,
args.result_ld,
out_dtype_,
use_fast_accum);
return out;
}
}
} // namespace
// NOTE(slayton58): This is defined as part of the _v2 code (way) below - declare the signature here
// to help cleanup v1 call structure.
Tensor&
_scaled_rowwise_rowwise(
const Tensor&, const Tensor&,
const Tensor&, const Tensor&,
const std::optional<Tensor>&,
const c10::ScalarType,
bool,
Tensor&);
// Computes matrix multiply + bias while applying scaling to input and output matrices
// Scales are only applicable when matrices are of Float8 type and assumed to be equal to 1.0 by default.
// If output matrix type is 16 or 32-bit type, scale_result is not applied.
@ -1309,7 +1506,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
TORCH_CHECK(isFloat8Type(mat2.scalar_type()) || mat2.scalar_type() == ScalarType::Float4_e2m1fn_x2, "Expected mat2 to be Float8 or Float4_x2 matrix got ", mat2.scalar_type());
#ifndef USE_ROCM
// Type restrictions imposed by CuBLASLt as of CUDA-12.1
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
TORCH_CHECK_VALUE(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
"Multiplication of two Float8_e5m2 matrices is not supported");
#endif
if (use_fast_accum) {
@ -1375,41 +1572,44 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
// NVIDIA's cuBLAS only started supporting row-wise scaling in version 12.9,
// and only for compute capability 9.0+. In other cases we use CUTLASS.
#ifndef USE_ROCM
// We are doing row-wise scaling
auto dprops = at::cuda::getCurrentDeviceProperties();
if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise
&& ((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
|| (dprops->major >= 10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty())))) {
TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
at::cuda::detail::f8f8bf16_rowwise(
mat1,
mat2,
scale_a,
scale_b,
bias,
use_fast_accum,
out);
return out;
}
#else
if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise) {
#ifndef USE_ROCM
auto dprops = at::cuda::getCurrentDeviceProperties();
if ((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
|| (dprops->major >= 10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty()))) {
TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
return _scaled_rowwise_rowwise(
mat1,
mat2,
scale_a,
scale_b,
bias,
out.scalar_type(),
use_fast_accum,
out);
}
#else
// For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes.
Tensor b = mat2;
if (_scaled_mm_is_fnuz()) {
TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fnuz);
TORCH_CHECK_VALUE(b.dtype() == at::kFloat8_e4m3fnuz,
"Expected b.dtype() == at::kFloat8_e4m3fnuz, got: ", b.dtype());
}
else {
TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fn);
TORCH_CHECK_VALUE(b.dtype() == at::kFloat8_e4m3fn,
"Expected b.dtype() == at::kFloat8_e4m3fn, got: ", b.dtype());
}
// Until more than bf16 is supported.
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16,
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16,
"hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type());
#endif
}
else if (scaling_choice_a == ScalingType::BlockWise1x32 && scaling_choice_b == ScalingType::BlockWise1x32) {
#ifdef USE_ROCM
#if ROCM_VERSION >= 70000
TORCH_CHECK(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");
int packed_factor = 1;
@ -1418,163 +1618,20 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
// effectively packing two elements into one byte.
packed_factor = 2;
}
TORCH_CHECK(mat1.size(0) % 16 == 0 && (mat1.size(1) * packed_factor) % 128 == 0 &&
TORCH_CHECK_VALUE(mat1.size(0) % 16 == 0 && (mat1.size(1) * packed_factor) % 128 == 0 &&
mat2.size(1) % 16 == 0,
"M, N must be multiples of 16 and K must be multiple of 128 for block-wise scaling");
TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16 ||
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
out.scalar_type() == ScalarType::Half,
"Block-wise scaling only supports BFloat16 or Half output types");
#else
TORCH_CHECK(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
#endif
#endif
}
#endif
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, scale_result, scaling_choice_a, scaling_choice_b);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
#ifdef USE_ROCM
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \
if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fn, at::Float8_e5m2, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2, at::Float8_e4m3fn, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2, at::Float8_e5m2, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
}
AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] {
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
at::cuda::tunable::ScaledGemmParams<scalar_t> params;
params.transa = args.transa;
params.transb = args.transb;
params.m = args.m;
params.n = args.n;
params.k = args.k;
params.a = args.mata->data_ptr();
params.a_scale_ptr = args.scale_mata_ptr;
params.a_scale_dtype = args.scale_mata_dtype.value();
params.lda = args.lda;
params.a_dtype = args.mata->scalar_type();
params.a_scale_dtype = args.scale_mata_dtype.value();
params.a_scaling_type = args.scaling_mata_type.value();
params.b = args.matb->data_ptr();
params.b_scale_ptr = args.scale_matb_ptr;
params.b_scale_dtype = args.scale_matb_dtype.value();
params.ldb = args.ldb;
params.b_dtype = args.matb->scalar_type();
params.b_scale_dtype = args.scale_matb_dtype.value();
params.b_scaling_type = args.scaling_matb_type.value();
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
params.c = args.result->data_ptr();
params.c_scale_ptr = args.scale_result_ptr;
params.ldc = args.result_ld;
params.c_dtype = out_dtype_;
params.use_fast_accum = use_fast_accum;
if (transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
}
else if (transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
}
else if (!transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
}
else if (!transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
}
else {
TORCH_CHECK(false, "unreachable");
}
}),
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
#undef TUNABLE_DISPATCH
}
else
#endif
{
at::cuda::blas::scaled_gemm(
args.transa,
args.transb,
args.m,
args.n,
args.k,
args.mata->data_ptr(),
args.scale_mata_ptr,
args.lda,
args.mata->scalar_type(),
args.scale_mata_dtype.value(),
args.scaling_mata_type.value(),
args.matb->data_ptr(),
args.scale_matb_ptr,
args.ldb,
args.matb->scalar_type(),
args.scale_matb_dtype.value(),
args.scaling_matb_type.value(),
bias ? bias->data_ptr(): nullptr,
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
args.result->data_ptr(),
args.scale_result_ptr,
args.result_ld,
out_dtype_,
use_fast_accum);
}
return out;
return _scaled_gemm(mat1, mat2, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
}
namespace {
@ -1914,159 +1971,6 @@ std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8>
{ "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
Tensor&
_cutlass_scaled_gemm(
const Tensor& mat1, const Tensor& mat2,
const Tensor& scale_a, const Tensor& scale_b,
const ScalingType scaling_choice_a, const ScalingType scaling_choice_b,
const std::optional<Tensor>& bias,
const bool use_fast_accum,
Tensor& out) {
cublasCommonArgs args(mat1, mat2, out, scale_a, scale_b, std::nullopt, scaling_choice_a, scaling_choice_b);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
#ifdef USE_ROCM
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \
if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e4m3fn) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fn, at::Float8_e4m3fn, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fn, at::Float8_e5m2, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fn) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2, at::Float8_e4m3fn, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2, at::Float8_e5m2, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
}
AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] {
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
at::cuda::tunable::ScaledGemmParams<scalar_t> params;
params.transa = args.transa;
params.transb = args.transb;
params.m = args.m;
params.n = args.n;
params.k = args.k;
params.a = args.mata->data_ptr();
params.a_scale_ptr = args.scale_mata_ptr;
params.a_scale_dtype = args.scale_mata_dtype.value();
params.lda = args.lda;
params.a_dtype = args.mata->scalar_type();
params.a_scale_dtype = args.scale_mata_dtype.value();
params.a_scaling_type = args.scaling_mata_type.value();
params.b = args.matb->data_ptr();
params.b_scale_ptr = args.scale_matb_ptr;
params.b_scale_dtype = args.scale_matb_dtype.value();
params.ldb = args.ldb;
params.b_dtype = args.matb->scalar_type();
params.b_scale_dtype = args.scale_matb_dtype.value();
params.b_scaling_type = args.scaling_matb_type.value();
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
params.c = args.result->data_ptr();
params.c_scale_ptr = args.scale_result_ptr;
params.ldc = args.result_ld;
params.c_dtype = out_dtype_;
params.use_fast_accum = use_fast_accum;
if (transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
}
else if (transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
}
else if (!transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
}
else if (!transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
}
else {
TORCH_CHECK(false, "unreachable");
}
}),
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
#undef TUNABLE_DISPATCH
}
else
#endif
{
at::cuda::blas::scaled_gemm(
args.transa,
args.transb,
args.m,
args.n,
args.k,
args.mata->data_ptr(),
args.scale_mata_ptr,
args.lda,
args.mata->scalar_type(),
args.scale_mata_dtype.value(),
args.scaling_mata_type.value(),
args.matb->data_ptr(),
args.scale_matb_ptr,
args.ldb,
args.matb->scalar_type(),
args.scale_matb_dtype.value(),
args.scaling_matb_type.value(),
bias ? bias->data_ptr(): nullptr,
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
args.result->data_ptr(),
args.scale_result_ptr,
args.result_ld,
out_dtype_,
use_fast_accum);
}
return out;
}
Tensor&
_scaled_tensorwise_tensorwise(
const Tensor& mat_a, const Tensor& mat_b,
@ -2086,7 +1990,7 @@ _scaled_tensorwise_tensorwise(
auto scaling_choice_a = ScalingType::TensorWise;
auto scaling_choice_b = ScalingType::TensorWise;
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
}
@ -2122,7 +2026,7 @@ _scaled_rowwise_rowwise(
if (((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
|| (dprops->major == 10 && (scale_a.sizes().size() || scale_b.sizes().size())))) {
TORCH_CHECK(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
at::cuda::detail::f8f8bf16_rowwise(
mat_a,
mat_b,
@ -2148,11 +2052,38 @@ _scaled_rowwise_rowwise(
"hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type());
#endif
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
}
// Check the shapes & sizes of scales for deepseek-style (1x128, 128x128) scaling.
// Wraps check_size_stride for easier integration, correctly handles cases where a dimension of the scale == 1,
// and strides become somewhat meaningless
void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const ScalingType scale_type) {
if (scale_type == ScalingType::BlockWise1x128) {
TORCH_CHECK_VALUE(check_size_stride(scale, 0, t.size(0), 1),
"at dim=0 scale should have ", t.size(0), "elements and stride(0) ", 1, "if ", t.size(0), " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
auto expected_size = ceil_div<int64_t>(t.size(1), 128);
TORCH_CHECK_VALUE(check_size_stride(scale, 1, expected_size, t.size(0)),
"at dim=1 scale should have ", expected_size, "elements and stride ", t.size(0), "if ", expected_size, " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
} else if (scale_type == ScalingType::BlockWise128x128) {
TORCH_CHECK_VALUE(check_size_stride(
scale,
0,
ceil_div<int64_t>(t.size(0), 128),
ceil_div<int64_t>(t.size(1), 128)),
"at dim=0 scale should have ", ceil_div<int64_t>(t.size(0), 128), "elements and stride(0) ", ceil_div<int64_t>(t.size(1), 128), "if ", ceil_div<int64_t>(t.size(0), 128), " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
TORCH_CHECK(check_size_stride(
scale, 1, ceil_div<int64_t>(t.size(1), 128), 1),
"at dim=1 scale should have ", ceil_div<int64_t>(t.size(1), 128), "elements and stride(1) ", 1, "if ", ceil_div<int64_t>(t.size(1), 128), " > 1 - Got: ",
"shape=", scale.sizes(), ", stride=", scale.strides());
}
}
Tensor&
_scaled_block1x128_block1x128(
const Tensor& mat_a, const Tensor& mat_b,
@ -2170,15 +2101,14 @@ _scaled_block1x128_block1x128(
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
TORCH_CHECK(scale_a.stride(0) == 1, "expected scale_a.stride(0) to be 1, but got ", scale_a.stride(0));
TORCH_CHECK(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
TORCH_CHECK(scale_b.stride(0) == scale_b.size(1),
"expected scale_b.stride(0) to be ", scale_b.size(1), ", but got ", scale_b.size(1));
auto scaling_choice_a = ScalingType::BlockWise1x128;
auto scaling_choice_b = ScalingType::BlockWise1x128;
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
// Check scale strides (including stride=1 small cases)
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
}
@ -2193,6 +2123,8 @@ _scaled_block128x128_block1x128(
Tensor& out) {
// Restrictions:
// A, B are FP8, scales are fp32, shape K//128
std::cout << "mat_b: " << mat_b.dim() << ", " << mat_b.sizes() << ", " << mat_b.strides() << std::endl;
std::cout << "scale_b: " << scale_b.dim() << ", " << scale_b.sizes() << ", " << scale_b.strides() << std::endl;
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
mat_a.scalar_type(), mat_b.scalar_type());
TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
@ -2200,15 +2132,14 @@ _scaled_block128x128_block1x128(
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
TORCH_CHECK_VALUE(scale_a.stride(1) == 1, "expected scale_a.stride(1) to be 1, but got ", scale_a.stride(1));
TORCH_CHECK_VALUE(scale_b.stride(1) == 1, "expected scale_b.stride(1) to be 1, but got ", scale_b.stride(1));
TORCH_CHECK_VALUE(scale_b.stride(0) == scale_b.size(1),
"expected scale_b.stride(0) to be ", scale_b.size(1), ", but got ", scale_b.stride(0));
auto scaling_choice_a = ScalingType::BlockWise128x128;
auto scaling_choice_b = ScalingType::BlockWise1x128;
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
// Check scale strides (including stride=1 small cases)
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
}
@ -2230,15 +2161,14 @@ _scaled_block1x128_block128x128(
TORCH_CHECK_VALUE(scale_b.sizes()[0] == mat_b.sizes()[0] / 128 && scale_b.sizes()[1] == mat_b.sizes()[1] / 128 && scale_b.scalar_type() == kFloat,
"scale_b must have shape ", mat_b.sizes()[0] / 128, " x ", mat_b.sizes()[1] / 128, " Float elements, got ", scale_b.sizes())
TORCH_CHECK_VALUE(scale_a.stride(0) == 1, "expected scale_a.stride(0) to be 1, but got ", scale_a.stride(0));
TORCH_CHECK_VALUE(scale_b.stride(0) == 1, "expected scale_b.stride(0) to be 1, but got ", scale_b.stride(0));
TORCH_CHECK_VALUE(scale_b.stride(1) == scale_b.size(0),
"expected scale_b.stride(1) to be ", scale_b.size(0), ", but got ", scale_b.stride(1));
auto scaling_choice_a = ScalingType::BlockWise1x128;
auto scaling_choice_b = ScalingType::BlockWise128x128;
_cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
// Check scale strides (including stride=1 small cases)
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
return out;
}
@ -2292,7 +2222,7 @@ _scaled_mxfp8_mxfp8(
#endif
#endif
return _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
}
Tensor&
@ -2329,7 +2259,7 @@ _scaled_nvfp4_nvfp4(
auto scaling_choice_a = ScalingType::BlockWise1x16;
auto scaling_choice_b = ScalingType::BlockWise1x16;
return _cutlass_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
}