Add scaled_mm python API, test (#164142)

Summary:

* Add `torch.nn.functional.scaled_mm` as an abstraction around the C++
  methods
* Wraps `torch._scaled_mm_v2` API by default, but user can force use of
  the older `torch._scaled_mm` interface.
* Scaled MM tests now run on the new API

Test Plan:

`pytest test/test_scaled_matmul_cuda.py`

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164142
Approved by: https://github.com/drisspg
ghstack dependencies: #164141
This commit is contained in:
Simon Layton
2025-10-08 16:44:39 -07:00
committed by PyTorch MergeBot
parent 512b6b59f0
commit 6a7f5c0d21
11 changed files with 363 additions and 78 deletions

View File

@ -28,4 +28,19 @@ inline std::ostream& operator<<(std::ostream& stream, at::BlasBackend backend) {
return stream << BlasBackendToString(backend);
}
namespace blas {
enum class ScalingType : std::uint8_t {
TensorWise, // fp32 scales
RowWise, // fp32 scales
BlockWise1x16, // fp8_e4m3fn scales
BlockWise1x32, // fp8_e8m0fnu scales
BlockWise1x128, // fp32 scales
BlockWise128x128, // fp32 scales
};
enum class SwizzleType : std::uint8_t { NO_SWIZZLE = 0, SWIZZLE_32_4_4 = 1 };
} // namespace blas
} // namespace at

View File

@ -1861,6 +1861,8 @@ template bool gemm_and_bias(
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);
using at::blas::ScalingType;
int get_scale_mode(ScalingType scaling_type, ScalarType scale_dtype, bool use_fast_accum) {
switch (scaling_type) {
case ScalingType::BlockWise1x32:

View File

@ -14,6 +14,7 @@
*/
#include <ATen/cuda/CUDAContext.h>
#include <ATen/BlasBackend.h>
#include <ATen/OpMathType.h>
namespace at::cuda::blas {
@ -136,15 +137,6 @@ void int8_gemm(
int32_t* result_ptr,
int64_t result_ld);
enum class ScalingType : std::uint8_t {
TensorWise, // fp32 scales
RowWise, // fp32 scales
BlockWise1x16, // fp8_e4m3fn scales
BlockWise1x32, // fp8_e8m0fnu scales
BlockWise1x128, // fp32 scales
BlockWise128x128, // fp32 scales
};
void scaled_gemm(
char transa,
char transb,
@ -156,13 +148,13 @@ void scaled_gemm(
int64_t mat1_ld,
ScalarType mat1_dtype,
ScalarType mat1_scale_dtype,
ScalingType mat1_scaling_type,
at::blas::ScalingType mat1_scaling_type,
const void* mat2_ptr,
const void* mat2_scale_ptr,
int64_t mat2_ld,
ScalarType mat2_dtype,
ScalarType mat2_scale_dtype,
ScalingType mat2_scaling_type,
at::blas::ScalingType mat2_scaling_type,
const void* bias_ptr,
ScalarType bias_dtype,
void* result_ptr,

View File

@ -29,7 +29,7 @@
namespace at::cuda::tunable {
using at::cuda::blas::ScalingType;
using at::blas::ScalingType;
enum class BlasOp {
N = 0,

View File

@ -106,7 +106,8 @@ c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, b
}
}
using at::cuda::blas::ScalingType;
using at::blas::ScalingType;
using at::blas::SwizzleType;
/**
* @brief Prepares matrices for CUBLAS operation
@ -1093,7 +1094,7 @@ namespace{
* - Returns Error.
*/
using at::cuda::blas::ScalingType;
using at::blas::ScalingType;
bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
return isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.numel() == 1;
@ -1546,7 +1547,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
args.scale_result_ptr,
args.result_ld,
out_dtype_,
false /*use_fast_accum*/);
use_fast_accum);
}
return out;
@ -1664,11 +1665,6 @@ _scaled_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
return _scaled_mm_out_cuda(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
}
enum class SwizzleType {
NO_SWIZZLE = 0,
SWIZZLE_32_4_4 = 1
};
/**
* Track concrete implementations available
*/
@ -1890,9 +1886,9 @@ std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8>
ScaledGemmImplementation::BLOCK_128x128_1x128},
{ "block_1x128_1x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise1x128, _1, _2, _3, _4, _5, _6),
ScaledGemmImplementation::BLOCK_1x128_1x128},
{ "nvfp4", check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4},
{ "nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
{ "mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
{ "nvfp4_nvfp4", check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4},
{ "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
Tensor&
_cutlass_scaled_gemm(

View File

@ -218,3 +218,13 @@ DataParallel functions (multi-GPU, distributed)
:nosignatures:
torch.nn.parallel.data_parallel
Low-Precision functions
-----------------------
.. autosummary::
:toctree: generated
:nosignatures:
ScalingType
SwizzleType
scaled_mm

View File

@ -35,7 +35,6 @@ and supported quantized modules and functions.
quantization-support
.. torch.ao is missing documentation. Since part of it is mentioned here, adding them here for now.
.. They are here for tracking purposes until they are more permanently fixed.
.. py:module:: torch.ao

View File

@ -11,15 +11,17 @@ from typing import Optional
import torch
from torch.nn.functional import scaled_mm, ScalingType, SwizzleType
from torch.testing._internal.common_cuda import (
SM89OrLater,
SM90OrLater,
IS_SM90,
_get_torch_cuda_version,
PLATFORM_SUPPORTS_FP8,
PLATFORM_SUPPORTS_FP8_GROUPED_GEMM,
PLATFORM_SUPPORTS_MX_GEMM,
PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM,
IS_SM90,
SM100OrLater,
SM89OrLater,
SM90OrLater,
with_tf32_off,
)
from torch.testing._internal.common_device_type import (
@ -107,6 +109,112 @@ def tensor_to_scale_block(
scale = scale.flatten(2, 3).flatten(0, 1)
return x, scale
def round_up(x: int, y: int) -> int:
return ((x + y - 1) // y) * y
def infer_scale_swizzle(mat, scale):
# Tensor-wise
if scale.numel() == 1:
return ScalingType.TensorWise, SwizzleType.NO_SWIZZLE
# Row-wise
if (scale.shape[0] == mat.shape[0] and scale.shape[1] == 1) or (
scale.shape[0] == 1 and scale.shape[1] == mat.shape[1]
):
return ScalingType.RowWise, SwizzleType.NO_SWIZZLE
# deepgemm 1x128 / 128x1
if len(scale.shape) > 1:
if (
scale.shape[0] == mat.shape[0]
and scale.shape[1] == math.ceil(mat.shape[1] // 128)
or scale.shape[1] == mat.shape[1]
and scale.shape[0] == math.ceil(mat.shape[0] // 128)
):
return ScalingType.BlockWise1x128, SwizzleType.NO_SWIZZLE
# deepgemm 128x128
if scale.shape[0] == math.ceil(mat.shape[0] // 128) and scale.shape[
1
] == math.ceil(mat.shape[1] // 128):
return ScalingType.BlockWise128x128, SwizzleType.NO_SWIZZLE
# NVFP4
if (
scale.numel()
== round_up(mat.shape[0], 128) * round_up(math.ceil(2 * mat.shape[1] // 16), 4)
or scale.numel()
== round_up(mat.shape[1], 128) * round_up(math.ceil(2 * mat.shape[0] // 16), 4)
and mat.dtype == torch.float4_e2m1fn_x2
and scale.dtype == torch.float8_e4m3fn
):
return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4
# MX
if (
scale.numel()
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
or scale.numel()
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4)
and scale.dtype == torch.float8_e8m0fnu
):
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
return None, None
wrap: bool = True
def scaled_mm_wrap(
a,
b,
scale_a,
scale_b,
scale_recipe_a=None,
scale_recipe_b=None,
swizzle_a=SwizzleType.NO_SWIZZLE,
swizzle_b=SwizzleType.NO_SWIZZLE,
scale_result=None,
out_dtype=torch.bfloat16,
use_fast_accum=False,
bias=None,
wrap_v2=wrap,
):
if not wrap_v2:
return torch._scaled_mm(
a,
b,
scale_a,
scale_b,
scale_result=scale_result,
out_dtype=out_dtype,
bias=bias,
use_fast_accum=use_fast_accum,
)
else:
# infer scalingtype and swizzle from scales
if scale_recipe_a is None:
scale_recipe_a, swizzle_a = infer_scale_swizzle(a, scale_a)
if scale_recipe_b is None:
scale_recipe_b, swizzle_b = infer_scale_swizzle(b, scale_b)
out = scaled_mm(
a,
b,
scale_a,
scale_recipe_a,
scale_b,
scale_recipe_b,
swizzle_a=swizzle_a,
swizzle_b=swizzle_b,
bias=bias,
output_dtype=out_dtype,
use_fast_accum=use_fast_accum,
)
return out
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
# naive implementation: dq -> op -> q
x_fp32 = x.to(torch.float) / x_scale
@ -139,7 +247,7 @@ def addmm_float8_unwrapped(
b_inverse_scale = b_scale.reciprocal()
if output_dtype == torch.float32 and bias is not None:
# Bias is not supported by _scaled_mm when output is fp32
output = torch._scaled_mm(
output = scaled_mm_wrap(
a_data,
b_data,
scale_a=a_inverse_scale,
@ -149,7 +257,7 @@ def addmm_float8_unwrapped(
)
output += bias
return output
output = torch._scaled_mm(
output = scaled_mm_wrap(
a_data,
b_data,
bias=bias,
@ -290,7 +398,7 @@ class TestFP8Matmul(TestCase):
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
out_fp8 = scaled_mm_wrap(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
if out_dtype is not None:
self.assertEqual(out_dtype, out_fp8.dtype)
self.assertEqual(out_fp32, out_fp8.to(torch.float))
@ -301,7 +409,7 @@ class TestFP8Matmul(TestCase):
self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16)
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
# supported on ROCm but fails on CUDA
ctx = self.assertRaises(RuntimeError) if torch.version.hip is None and device != "cpu" else contextlib.nullcontext()
ctx = self.assertRaises(ValueError) if torch.version.hip is None and device != "cpu" else contextlib.nullcontext()
with ctx:
self._test_tautological_mm(device, e5m2_type, e5m2_type)
@ -326,9 +434,9 @@ class TestFP8Matmul(TestCase):
scale_one = torch.tensor(1.0, device=device)
scale_a = torch.tensor(1.5, device=device)
scale_b = torch.tensor(0.66, device=device)
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_one, scale_b=scale_one)
out_fp8 = scaled_mm_wrap(x, y, scale_a=scale_one, scale_b=scale_one)
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
out_fp8_s = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b)
self.assertEqual(out_fp8, out_fp8_s)
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
@ -506,6 +614,7 @@ class TestFP8Matmul(TestCase):
x_scale = tensor_to_scale(x, input_dtype).float()
y_scale = tensor_to_scale(y, input_dtype).float()
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
@ -600,11 +709,11 @@ class TestFP8Matmul(TestCase):
(k, l, m) = (16, 48, 32)
x = torch.ones((k, l), device=device).to(e4m3_type)
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
bias = torch.full((m,), 4.0, device=device, dtype=torch.half)
bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16)
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias)
out_fp8 = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b)
outb_fp8 = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias)
# this fails on ROCm currently because hipblaslt doesn't have amax op
out_fp32 = out_fp8.to(torch.float32)
outb_fp32 = outb_fp8.to(torch.float32)
@ -621,8 +730,8 @@ class TestFP8Matmul(TestCase):
scale_b = torch.tensor(1.0, device=device)
input_bias = None
if bias:
input_bias = torch.rand((16,), device=device).to(torch.half)
_ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias)
input_bias = torch.rand((16,), device=device).to(torch.bfloat16)
_ = scaled_mm_wrap(x, y, scale_a, scale_b, bias=input_bias)
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@ -630,10 +739,10 @@ class TestFP8Matmul(TestCase):
(k, l, m) = (16, 48, 32)
x = torch.full((k, l), 0.0, device=device).to(e4m3_type)
y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t()
bias = torch.full((m,), -3.0, device=device, dtype=torch.half)
bias = torch.full((m,), -3.0, device=device, dtype=torch.bfloat16)
scale_a = torch.tensor(1.0, device=device)
scale_b = torch.tensor(1.0, device=device)
outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias)
outb_fp8 = scaled_mm_wrap(x, y, scale_a, scale_b, bias=bias)
outb_fp32 = outb_fp8.to(torch.float32)
self.assertEqual(outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32))
@ -647,9 +756,9 @@ class TestFP8Matmul(TestCase):
scale_b = torch.tensor(1.0, device=device)
bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16)
self.assertRaisesRegex(
RuntimeError,
ValueError,
"Bias is not supported when out_dtype is set to Float32",
lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32),
lambda: scaled_mm_wrap(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32),
)
@onlyCUDA
@ -663,10 +772,11 @@ class TestFP8Matmul(TestCase):
self.assertRaisesRegex(
RuntimeError,
r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+",
lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32),
lambda: scaled_mm_wrap(x, y, scale_a, scale_b, out_dtype=torch.float32),
)
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@unittest.skipIf(SM100OrLater, "fast_accum is SM90-only")
def test_float8_scale_fast_accum(self, device) -> None:
size = (16, 16)
x = torch.full(size, .5, device=device, dtype=e4m3_type)
@ -675,9 +785,9 @@ class TestFP8Matmul(TestCase):
y = torch.full(size, .5, device=device, dtype=y_type).t()
scale_a = torch.tensor(1.5, device=device)
scale_b = torch.tensor(0.66, device=device)
out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True)
out_fp8 = scaled_mm_wrap(x, y, scale_a, scale_b, out_dtype=torch.float8_e4m3fn, use_fast_accum=True)
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True)
out_fp8_s = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.float8_e4m3fn, use_fast_accum=True)
self.assertEqual(out_fp8, out_fp8_s)
@onlyCUDA
@ -696,7 +806,7 @@ class TestFP8Matmul(TestCase):
x_fp8 = x.to(e4m3_type)
y_fp8 = y.to(e4m3_type).t()
out_fp8 = torch._scaled_mm(
out_fp8 = scaled_mm_wrap(
x_fp8,
y_fp8,
scale_a=x_scales,
@ -720,50 +830,58 @@ class TestFP8Matmul(TestCase):
y_fp8 = y.to(e4m3_type).t()
with self.assertRaisesRegex(
RuntimeError, re.escape("Invalid scaling configuration")
ValueError, re.escape("scale_b must have 1 Float element")
):
torch._scaled_mm(
scaled_mm_wrap(
x_fp8,
y_fp8,
scale_a=torch.ones((1, 1), device="cuda"),
scale_b=torch.ones((1, 2), device="cuda"),
scale_recipe_a=ScalingType.TensorWise,
scale_recipe_b=ScalingType.TensorWise,
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError, re.escape("Invalid scaling configuration")
ValueError, re.escape(f"scale_b must have {N} Float elements, got {N + 1}"),
):
torch._scaled_mm(
scaled_mm_wrap(
x_fp8,
y_fp8,
scale_a=torch.ones((M, 1), device="cuda"),
scale_b=torch.ones((1, N + 1), device="cuda"),
scale_recipe_a=ScalingType.RowWise,
scale_recipe_b=ScalingType.RowWise,
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError, re.escape("Invalid scaling configuration")
IndexError, re.escape("Dimension out of range")
):
torch._scaled_mm(
scaled_mm_wrap(
x_fp8,
y_fp8,
scale_a=torch.ones((M), device="cuda"),
scale_b=torch.ones((N, 1), device="cuda"),
scale_recipe_a=ScalingType.RowWise,
scale_recipe_b=ScalingType.RowWise,
out_dtype=torch.bfloat16,
)
with self.assertRaisesRegex(
RuntimeError, re.escape("Invalid scaling configuration")
ValueError, re.escape("expected scale_b.stride(1) to be 1, but got 2"),
):
torch._scaled_mm(
scaled_mm_wrap(
x_fp8,
y_fp8,
scale_a=torch.ones((M, 1), device="cuda"),
scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2],
scale_recipe_a=ScalingType.RowWise,
scale_recipe_b=ScalingType.RowWise,
out_dtype=torch.bfloat16,
)
def e5m2():
out = torch._scaled_mm(
out = scaled_mm_wrap(
x_fp8,
y_fp8.to(e5m2_type),
scale_a=torch.ones((M, 1), device="cuda"),
@ -838,6 +956,7 @@ class TestFP8Matmul(TestCase):
else:
test()
# Note: Removed parameterization over M,N,K from #163829 as it failed tests as-is
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not IS_SM90, "cuBLAS blockwise scaling requires sm90+")
@unittest.skipIf(
@ -859,12 +978,20 @@ class TestFP8Matmul(TestCase):
# 1x128 blocks need scales to be outer-dim-major
if lhs_block == 1:
x_scales = x_scales.t().contiguous().t()
lhs_recipe = ScalingType.BlockWise1x128
else:
lhs_recipe = ScalingType.BlockWise128x128
if rhs_block == 1:
y_scales = y_scales.t().contiguous().t()
rhs_recipe = ScalingType.BlockWise1x128
else:
rhs_recipe = ScalingType.BlockWise128x128
# Calculate actual F8 mm
out_scaled_mm = mm_float8(
x_fp8, y_fp8.t(), a_scale=x_scales, b_scale=y_scales.t(), output_dtype=output_dtype
out_scaled_mm = scaled_mm_wrap(
x_fp8, y_fp8.t(), scale_a=x_scales.reciprocal(), scale_b=y_scales.reciprocal().t(), out_dtype=output_dtype,
scale_recipe_a=lhs_recipe, scale_recipe_b=rhs_recipe
)
# Calculate emulated F8 mm
@ -912,9 +1039,9 @@ class TestFP8Matmul(TestCase):
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
scale_a = torch.tensor(float('-inf'), device=device)
scale_b = torch.tensor(float('-inf'), device=device)
f = torch._scaled_mm
f = scaled_mm_wrap
if use_torch_compile:
f = torch.compile(torch._scaled_mm)
f = torch.compile(scaled_mm_wrap)
out_fp8 = f(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
self.assertEqual(out_dtype, out_fp8.dtype)
self.assertEqual(out_fp32, out_fp8.to(torch.float))
@ -938,16 +1065,16 @@ class TestFP8Matmul(TestCase):
with tempfile.NamedTemporaryFile() as f:
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
self.assertIsNone(torch._C._get_sm_carveout_experimental())
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
scaled_mm_wrap(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
torch._C._set_sm_carveout_experimental(0)
self.assertEqual(torch._C._get_sm_carveout_experimental(), 0)
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
scaled_mm_wrap(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
torch._C._set_sm_carveout_experimental(66)
self.assertEqual(torch._C._get_sm_carveout_experimental(), 66)
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
scaled_mm_wrap(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
torch._C._set_sm_carveout_experimental(None)
self.assertIsNone(torch._C._get_sm_carveout_experimental())
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
scaled_mm_wrap(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
prof.export_chrome_trace(f.name)
if torch.version.hip:
@ -1244,7 +1371,7 @@ class TestFP8Matmul(TestCase):
A_scale = to_blocked(A_scale)
B_scale = to_blocked(B_scale)
C = torch._scaled_mm(
C = scaled_mm_wrap(
A,
B.t(),
A_scale,
@ -1283,49 +1410,65 @@ class TestFP8Matmul(TestCase):
expected_a_size = BLOCK_SIZE_MN * ceil_div(M, BLOCK_SIZE_MN) * padded_num_k_blocks
expected_b_size = BLOCK_SIZE_MN * ceil_div(N, BLOCK_SIZE_MN) * padded_num_k_blocks
block = (
ScalingType.BlockWise1x16
if recipe == "nvfp4"
else ScalingType.BlockWise1x32
)
swizzle = SwizzleType.SWIZZLE_32_4_4
# Test wrong scale tensor size for scale_a with correct dtype
with self.assertRaisesRegex(
RuntimeError,
ValueError,
f".*For Block[W,w]ise.*scaling.*scale_a should have {expected_a_size} "
f"elements.*"
,
):
incorrect_size_a = torch.ones(expected_a_size - 1, device=device, dtype=scale_dtype)
correct_size_b = torch.ones(expected_b_size, device=device, dtype=scale_dtype)
torch._scaled_mm(
scaled_mm_wrap(
x_lowp,
y_lowp,
scale_a=incorrect_size_a,
scale_recipe_a=block,
scale_b=correct_size_b,
scale_recipe_b=block,
swizzle_a=swizzle,
swizzle_b=swizzle,
out_dtype=torch.bfloat16,
)
# Test wrong scale tensor size for scale_b with correct dtype
with self.assertRaisesRegex(
RuntimeError,
ValueError,
f"For Block[W,w]ise.*scaling.*scale_b should have {expected_b_size} "
f"elements.*"
,
):
correct_size_a = torch.ones(expected_a_size, device=device, dtype=scale_dtype)
incorrect_size_b = torch.ones(expected_b_size + 1, device=device, dtype=scale_dtype)
torch._scaled_mm(
scaled_mm_wrap(
x_lowp,
y_lowp,
scale_a=correct_size_a,
scale_recipe_a=block,
scale_b=incorrect_size_b,
scale_recipe_b=block,
swizzle_a=swizzle,
swizzle_b=swizzle,
out_dtype=torch.bfloat16,
)
# Test non-contiguous scale tensors with correct dtype
with self.assertRaisesRegex(
RuntimeError,
"For Block[W,w]ise.*scaling.*both should be contiguous"
ValueError,
"For Block[W,w]ise.*scaling.*both scales should be contiguous"
,
):
non_contiguous_a = torch.ones(expected_a_size * 2, device=device, dtype=scale_dtype)[::2]
contiguous_b = torch.ones(expected_b_size, device=device, dtype=scale_dtype)
torch._scaled_mm(
scaled_mm_wrap(
x_lowp,
y_lowp,
scale_a=non_contiguous_a,
@ -1335,8 +1478,8 @@ class TestFP8Matmul(TestCase):
def scaled_grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist, use_fast_accum):
for a, b, ascale, bscale, out in zip(alist, blist, ascalelist, bscalelist, outlist):
out_ref = torch._scaled_mm(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1),
out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum)
out_ref = scaled_mm_wrap(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1),
out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum)
self.assertEqual(out, out_ref, atol=5e-2, rtol=5e-4)
# Testing only _scaled_grouped_mm() with multiple shapes, as
@ -1485,7 +1628,7 @@ class TestFP8Matmul(TestCase):
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
C_ref = A_ref @ B_ref.t()
compiled_scaled_mm = torch.compile(torch._scaled_mm, backend="inductor")
compiled_scaled_mm = torch.compile(scaled_mm_wrap, backend="inductor")
C = compiled_scaled_mm(
A,
B.t(),
@ -1514,8 +1657,8 @@ class TestFP8Matmul(TestCase):
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
C_ref = A_ref @ B_ref.t()
compiled_scaled_mm = torch.compile(torch._scaled_mm, backend="inductor")
# C = torch._scaled_mm(
compiled_scaled_mm = torch.compile(scaled_mm_wrap, backend="inductor")
# C = scaled_mm_wrap(
C = compiled_scaled_mm(
A,
B.t(),

View File

@ -113,6 +113,7 @@
#ifdef USE_CUDA
#include <ATen/ROCmFABackend.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/CUDAConfig.h>
#include <ATen/native/transformers/cuda/sdp_utils.h>
#include <torch/csrc/inductor/static_cuda_launcher.h>
@ -2504,6 +2505,39 @@ Call this whenever a new thread is created in order to propagate values from
return at::globalContext().blasPreferredBackend();
});
py::enum_<at::blas::ScalingType>(
py_module, "_ScalingType", "Supported Tensor scaling types")
.value(
"TensorWise",
at::blas::ScalingType::TensorWise,
"Single scale per-tensor")
.value(
"RowWise", at::blas::ScalingType::RowWise, "Scale per-row of tensor")
.value(
"BlockWise1x16",
at::blas::ScalingType::BlockWise1x16,
"Scale per 16 contiguous values")
.value(
"BlockWise1x32",
at::blas::ScalingType::BlockWise1x32,
"Scale per 32 contiguous values")
.value(
"BlockWise1x128",
at::blas::ScalingType::BlockWise1x128,
"Scale per 128 contiguous values")
.value(
"BlockWise128x128",
at::blas::ScalingType::BlockWise128x128,
"Scale per 128x128 tile");
py::enum_<at::blas::SwizzleType>(
py_module, "_SwizzleType", "Supported scale swizzle types")
.value("NO_SWIZZLE", at::blas::SwizzleType::NO_SWIZZLE, "No swizzling")
.value(
"SWIZZLE_32_4_4",
at::blas::SwizzleType::SWIZZLE_32_4_4,
"Blackwell-stype 32x4x4 swizzle");
py::enum_<at::ROCmFABackend>(py_module, "_ROCmFABackend")
.value("Default", at::ROCmFABackend::Default)
.value("AOTriton", at::ROCmFABackend::AOTriton)

View File

@ -4,11 +4,16 @@ import importlib
import math
import warnings
from collections.abc import Callable
from typing import Optional, TYPE_CHECKING, Union
from typing import Any as _Any, Optional, TYPE_CHECKING, Union
import torch
from torch import _VF, sym_int as _sym_int, Tensor
from torch._C import _add_docstr, _infer_size
from torch._C import (
_add_docstr,
_infer_size,
_ScalingType as ScalingType,
_SwizzleType as SwizzleType,
)
from torch._jit_internal import (
_overload,
boolean_dispatch,
@ -27,6 +32,10 @@ from torch.overrides import (
)
# Set visibility of the bound enums to this module
ScalingType.__module__ = "torch.nn.functional"
SwizzleType.__module__ = "torch.nn.functional"
if TYPE_CHECKING:
from torch.types import _dtype as DType
else:
@ -6618,3 +6627,87 @@ def multi_head_attention_forward(
# squeeze the output if input was unbatched
attn_output = attn_output.squeeze(1)
return attn_output, None
def scaled_mm(
mat_a: Tensor,
mat_b: Tensor,
scale_a: Tensor | list[Tensor],
scale_recipe_a: ScalingType | list[ScalingType],
scale_b: Tensor | list[Tensor],
scale_recipe_b: ScalingType | list[ScalingType],
swizzle_a: SwizzleType | list[SwizzleType] | None = None,
swizzle_b: SwizzleType | list[SwizzleType] | None = None,
bias: Optional[Tensor] = None,
output_dtype: Optional[torch.dtype] = torch.bfloat16,
contraction_dim: list[int] | tuple[int] = (),
use_fast_accum: bool = False,
) -> Tensor:
r"""
scaled_mm(mat_a, mat_b, scale_a, scale_recipe_a, scale_b, scale_recipe_b, swizzle_a, swizzle_b, bias, output_dtype,
contraction_dim, use_fast_accum)
Applies a scaled matrix-multiply, mm(mat_a, mat_b) where the scaling of mat_a and mat_b are described by
scale_recipe_a and scale_recipe_b respectively.
Args:
scale_a: Tensor containing decoding scaling factors for mat_a
scale_recipe_a: Enum describing how mat_a has been scaled
scale_b: Tensor containing decoding scaling factors for mat_b
scale_recipe_b: Enum describing how mat_b has been scaled
swizzle_a: Enum describing the swizzling pattern (if any) of scale_a
swizzle_b: Enum describing the swizzling pattern (if any) of scale_b
bias: optional bias term to be added to the output
output_dtype: dtype used for the output tensor
contraction_dim: describe which dimensions are :math:`K` in the matmul.
use_fast_accum: enable/disable tensor-core fast accumulation (Hopper-GPUs only)
"""
def expand_single_value(v: _Any | list[_Any] | None) -> list[_Any]:
if v is None:
return []
elif not isinstance(v, (list)):
return [
v,
]
else:
return v
scale_a = expand_single_value(scale_a)
scale_recipe_a = expand_single_value(scale_recipe_a)
scale_b = expand_single_value(scale_b)
scale_recipe_b = expand_single_value(scale_recipe_b)
swizzle_a = expand_single_value(swizzle_a)
swizzle_b = expand_single_value(swizzle_b)
# native_functions has restrictions on what can be defined
# & passed through - std::optional<ArrayRef<Tensor>> for instance
# *cannot* be passed, but an empty vector (list) can.
# So, we need to convert None arguments for lists in python
# explicitly into empty lists.
def list_or_empty(l: list[_Any] | None) -> list[_Any]:
return [] if not l else l
def enum_list_as_int_list(l: _Any | list[_Any]) -> list[_Any]:
if not isinstance(l, list):
l = [
l,
]
return [li.value for li in l]
out = torch._scaled_mm_v2(
mat_a,
mat_b,
scale_a,
enum_list_as_int_list(scale_recipe_a),
enum_list_as_int_list(list_or_empty(swizzle_a)),
scale_b,
enum_list_as_int_list(scale_recipe_b),
enum_list_as_int_list(list_or_empty(swizzle_b)),
bias,
output_dtype,
contraction_dim,
use_fast_accum,
)
return out

View File

@ -249,6 +249,7 @@ def get_ignored_functions() -> set[Callable]:
torch.nn.functional.has_torch_function_unary,
torch.nn.functional.has_torch_function_variadic,
torch.nn.functional.handle_torch_function,
torch.nn.functional.scaled_mm,
torch.nn.functional.sigmoid,
torch.nn.functional.hardsigmoid,
torch.nn.functional.tanh,