mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add scaled_mm python API, test (#164142)
Summary: * Add `torch.nn.functional.scaled_mm` as an abstraction around the C++ methods * Wraps `torch._scaled_mm_v2` API by default, but user can force use of the older `torch._scaled_mm` interface. * Scaled MM tests now run on the new API Test Plan: `pytest test/test_scaled_matmul_cuda.py` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/164142 Approved by: https://github.com/drisspg ghstack dependencies: #164141
This commit is contained in:
committed by
PyTorch MergeBot
parent
512b6b59f0
commit
6a7f5c0d21
@ -28,4 +28,19 @@ inline std::ostream& operator<<(std::ostream& stream, at::BlasBackend backend) {
|
||||
return stream << BlasBackendToString(backend);
|
||||
}
|
||||
|
||||
namespace blas {
|
||||
|
||||
enum class ScalingType : std::uint8_t {
|
||||
TensorWise, // fp32 scales
|
||||
RowWise, // fp32 scales
|
||||
BlockWise1x16, // fp8_e4m3fn scales
|
||||
BlockWise1x32, // fp8_e8m0fnu scales
|
||||
BlockWise1x128, // fp32 scales
|
||||
BlockWise128x128, // fp32 scales
|
||||
};
|
||||
|
||||
enum class SwizzleType : std::uint8_t { NO_SWIZZLE = 0, SWIZZLE_32_4_4 = 1 };
|
||||
|
||||
} // namespace blas
|
||||
|
||||
} // namespace at
|
||||
|
@ -1861,6 +1861,8 @@ template bool gemm_and_bias(
|
||||
int64_t result_ld,
|
||||
GEMMAndBiasActivationEpilogue activation);
|
||||
|
||||
using at::blas::ScalingType;
|
||||
|
||||
int get_scale_mode(ScalingType scaling_type, ScalarType scale_dtype, bool use_fast_accum) {
|
||||
switch (scaling_type) {
|
||||
case ScalingType::BlockWise1x32:
|
||||
|
@ -14,6 +14,7 @@
|
||||
*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/BlasBackend.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
|
||||
namespace at::cuda::blas {
|
||||
@ -136,15 +137,6 @@ void int8_gemm(
|
||||
int32_t* result_ptr,
|
||||
int64_t result_ld);
|
||||
|
||||
enum class ScalingType : std::uint8_t {
|
||||
TensorWise, // fp32 scales
|
||||
RowWise, // fp32 scales
|
||||
BlockWise1x16, // fp8_e4m3fn scales
|
||||
BlockWise1x32, // fp8_e8m0fnu scales
|
||||
BlockWise1x128, // fp32 scales
|
||||
BlockWise128x128, // fp32 scales
|
||||
};
|
||||
|
||||
void scaled_gemm(
|
||||
char transa,
|
||||
char transb,
|
||||
@ -156,13 +148,13 @@ void scaled_gemm(
|
||||
int64_t mat1_ld,
|
||||
ScalarType mat1_dtype,
|
||||
ScalarType mat1_scale_dtype,
|
||||
ScalingType mat1_scaling_type,
|
||||
at::blas::ScalingType mat1_scaling_type,
|
||||
const void* mat2_ptr,
|
||||
const void* mat2_scale_ptr,
|
||||
int64_t mat2_ld,
|
||||
ScalarType mat2_dtype,
|
||||
ScalarType mat2_scale_dtype,
|
||||
ScalingType mat2_scaling_type,
|
||||
at::blas::ScalingType mat2_scaling_type,
|
||||
const void* bias_ptr,
|
||||
ScalarType bias_dtype,
|
||||
void* result_ptr,
|
||||
|
@ -29,7 +29,7 @@
|
||||
|
||||
namespace at::cuda::tunable {
|
||||
|
||||
using at::cuda::blas::ScalingType;
|
||||
using at::blas::ScalingType;
|
||||
|
||||
enum class BlasOp {
|
||||
N = 0,
|
||||
|
@ -106,7 +106,8 @@ c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, b
|
||||
}
|
||||
}
|
||||
|
||||
using at::cuda::blas::ScalingType;
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
/**
|
||||
* @brief Prepares matrices for CUBLAS operation
|
||||
@ -1093,7 +1094,7 @@ namespace{
|
||||
* - Returns Error.
|
||||
*/
|
||||
|
||||
using at::cuda::blas::ScalingType;
|
||||
using at::blas::ScalingType;
|
||||
|
||||
bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||
return isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.numel() == 1;
|
||||
@ -1546,7 +1547,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
args.scale_result_ptr,
|
||||
args.result_ld,
|
||||
out_dtype_,
|
||||
false /*use_fast_accum*/);
|
||||
use_fast_accum);
|
||||
}
|
||||
|
||||
return out;
|
||||
@ -1664,11 +1665,6 @@ _scaled_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
|
||||
return _scaled_mm_out_cuda(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
|
||||
}
|
||||
|
||||
enum class SwizzleType {
|
||||
NO_SWIZZLE = 0,
|
||||
SWIZZLE_32_4_4 = 1
|
||||
};
|
||||
|
||||
/**
|
||||
* Track concrete implementations available
|
||||
*/
|
||||
@ -1890,9 +1886,9 @@ std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8>
|
||||
ScaledGemmImplementation::BLOCK_128x128_1x128},
|
||||
{ "block_1x128_1x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise1x128, _1, _2, _3, _4, _5, _6),
|
||||
ScaledGemmImplementation::BLOCK_1x128_1x128},
|
||||
{ "nvfp4", check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4},
|
||||
{ "nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
|
||||
{ "mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
|
||||
{ "nvfp4_nvfp4", check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4},
|
||||
{ "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
|
||||
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
|
||||
|
||||
Tensor&
|
||||
_cutlass_scaled_gemm(
|
||||
|
@ -218,3 +218,13 @@ DataParallel functions (multi-GPU, distributed)
|
||||
:nosignatures:
|
||||
|
||||
torch.nn.parallel.data_parallel
|
||||
|
||||
Low-Precision functions
|
||||
-----------------------
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
ScalingType
|
||||
SwizzleType
|
||||
scaled_mm
|
||||
|
@ -35,7 +35,6 @@ and supported quantized modules and functions.
|
||||
|
||||
quantization-support
|
||||
|
||||
|
||||
.. torch.ao is missing documentation. Since part of it is mentioned here, adding them here for now.
|
||||
.. They are here for tracking purposes until they are more permanently fixed.
|
||||
.. py:module:: torch.ao
|
||||
|
@ -11,15 +11,17 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
|
||||
from torch.nn.functional import scaled_mm, ScalingType, SwizzleType
|
||||
from torch.testing._internal.common_cuda import (
|
||||
SM89OrLater,
|
||||
SM90OrLater,
|
||||
IS_SM90,
|
||||
_get_torch_cuda_version,
|
||||
PLATFORM_SUPPORTS_FP8,
|
||||
PLATFORM_SUPPORTS_FP8_GROUPED_GEMM,
|
||||
PLATFORM_SUPPORTS_MX_GEMM,
|
||||
PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM,
|
||||
IS_SM90,
|
||||
SM100OrLater,
|
||||
SM89OrLater,
|
||||
SM90OrLater,
|
||||
with_tf32_off,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
@ -107,6 +109,112 @@ def tensor_to_scale_block(
|
||||
scale = scale.flatten(2, 3).flatten(0, 1)
|
||||
return x, scale
|
||||
|
||||
|
||||
def round_up(x: int, y: int) -> int:
|
||||
return ((x + y - 1) // y) * y
|
||||
|
||||
|
||||
def infer_scale_swizzle(mat, scale):
|
||||
# Tensor-wise
|
||||
if scale.numel() == 1:
|
||||
return ScalingType.TensorWise, SwizzleType.NO_SWIZZLE
|
||||
|
||||
# Row-wise
|
||||
if (scale.shape[0] == mat.shape[0] and scale.shape[1] == 1) or (
|
||||
scale.shape[0] == 1 and scale.shape[1] == mat.shape[1]
|
||||
):
|
||||
return ScalingType.RowWise, SwizzleType.NO_SWIZZLE
|
||||
|
||||
# deepgemm 1x128 / 128x1
|
||||
if len(scale.shape) > 1:
|
||||
if (
|
||||
scale.shape[0] == mat.shape[0]
|
||||
and scale.shape[1] == math.ceil(mat.shape[1] // 128)
|
||||
or scale.shape[1] == mat.shape[1]
|
||||
and scale.shape[0] == math.ceil(mat.shape[0] // 128)
|
||||
):
|
||||
return ScalingType.BlockWise1x128, SwizzleType.NO_SWIZZLE
|
||||
|
||||
# deepgemm 128x128
|
||||
if scale.shape[0] == math.ceil(mat.shape[0] // 128) and scale.shape[
|
||||
1
|
||||
] == math.ceil(mat.shape[1] // 128):
|
||||
return ScalingType.BlockWise128x128, SwizzleType.NO_SWIZZLE
|
||||
|
||||
# NVFP4
|
||||
if (
|
||||
scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(2 * mat.shape[1] // 16), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(2 * mat.shape[0] // 16), 4)
|
||||
and mat.dtype == torch.float4_e2m1fn_x2
|
||||
and scale.dtype == torch.float8_e4m3fn
|
||||
):
|
||||
return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
# MX
|
||||
if (
|
||||
scale.numel()
|
||||
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
|
||||
or scale.numel()
|
||||
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4)
|
||||
and scale.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
wrap: bool = True
|
||||
|
||||
def scaled_mm_wrap(
|
||||
a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
scale_recipe_a=None,
|
||||
scale_recipe_b=None,
|
||||
swizzle_a=SwizzleType.NO_SWIZZLE,
|
||||
swizzle_b=SwizzleType.NO_SWIZZLE,
|
||||
scale_result=None,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=False,
|
||||
bias=None,
|
||||
wrap_v2=wrap,
|
||||
):
|
||||
if not wrap_v2:
|
||||
return torch._scaled_mm(
|
||||
a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
scale_result=scale_result,
|
||||
out_dtype=out_dtype,
|
||||
bias=bias,
|
||||
use_fast_accum=use_fast_accum,
|
||||
)
|
||||
else:
|
||||
# infer scalingtype and swizzle from scales
|
||||
if scale_recipe_a is None:
|
||||
scale_recipe_a, swizzle_a = infer_scale_swizzle(a, scale_a)
|
||||
if scale_recipe_b is None:
|
||||
scale_recipe_b, swizzle_b = infer_scale_swizzle(b, scale_b)
|
||||
|
||||
out = scaled_mm(
|
||||
a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_recipe_a,
|
||||
scale_b,
|
||||
scale_recipe_b,
|
||||
swizzle_a=swizzle_a,
|
||||
swizzle_b=swizzle_b,
|
||||
bias=bias,
|
||||
output_dtype=out_dtype,
|
||||
use_fast_accum=use_fast_accum,
|
||||
)
|
||||
return out
|
||||
|
||||
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
||||
# naive implementation: dq -> op -> q
|
||||
x_fp32 = x.to(torch.float) / x_scale
|
||||
@ -139,7 +247,7 @@ def addmm_float8_unwrapped(
|
||||
b_inverse_scale = b_scale.reciprocal()
|
||||
if output_dtype == torch.float32 and bias is not None:
|
||||
# Bias is not supported by _scaled_mm when output is fp32
|
||||
output = torch._scaled_mm(
|
||||
output = scaled_mm_wrap(
|
||||
a_data,
|
||||
b_data,
|
||||
scale_a=a_inverse_scale,
|
||||
@ -149,7 +257,7 @@ def addmm_float8_unwrapped(
|
||||
)
|
||||
output += bias
|
||||
return output
|
||||
output = torch._scaled_mm(
|
||||
output = scaled_mm_wrap(
|
||||
a_data,
|
||||
b_data,
|
||||
bias=bias,
|
||||
@ -290,7 +398,7 @@ class TestFP8Matmul(TestCase):
|
||||
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
||||
scale_a = torch.tensor(1.0, device=device)
|
||||
scale_b = torch.tensor(1.0, device=device)
|
||||
out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
|
||||
out_fp8 = scaled_mm_wrap(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
|
||||
if out_dtype is not None:
|
||||
self.assertEqual(out_dtype, out_fp8.dtype)
|
||||
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
||||
@ -301,7 +409,7 @@ class TestFP8Matmul(TestCase):
|
||||
self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16)
|
||||
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
|
||||
# supported on ROCm but fails on CUDA
|
||||
ctx = self.assertRaises(RuntimeError) if torch.version.hip is None and device != "cpu" else contextlib.nullcontext()
|
||||
ctx = self.assertRaises(ValueError) if torch.version.hip is None and device != "cpu" else contextlib.nullcontext()
|
||||
with ctx:
|
||||
self._test_tautological_mm(device, e5m2_type, e5m2_type)
|
||||
|
||||
@ -326,9 +434,9 @@ class TestFP8Matmul(TestCase):
|
||||
scale_one = torch.tensor(1.0, device=device)
|
||||
scale_a = torch.tensor(1.5, device=device)
|
||||
scale_b = torch.tensor(0.66, device=device)
|
||||
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_one, scale_b=scale_one)
|
||||
out_fp8 = scaled_mm_wrap(x, y, scale_a=scale_one, scale_b=scale_one)
|
||||
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
|
||||
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
||||
out_fp8_s = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b)
|
||||
self.assertEqual(out_fp8, out_fp8_s)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
|
||||
@ -506,6 +614,7 @@ class TestFP8Matmul(TestCase):
|
||||
x_scale = tensor_to_scale(x, input_dtype).float()
|
||||
y_scale = tensor_to_scale(y, input_dtype).float()
|
||||
|
||||
|
||||
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
|
||||
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
|
||||
|
||||
@ -600,11 +709,11 @@ class TestFP8Matmul(TestCase):
|
||||
(k, l, m) = (16, 48, 32)
|
||||
x = torch.ones((k, l), device=device).to(e4m3_type)
|
||||
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
|
||||
bias = torch.full((m,), 4.0, device=device, dtype=torch.half)
|
||||
bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16)
|
||||
scale_a = torch.tensor(1.0, device=device)
|
||||
scale_b = torch.tensor(1.0, device=device)
|
||||
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
||||
outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias)
|
||||
out_fp8 = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b)
|
||||
outb_fp8 = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias)
|
||||
# this fails on ROCm currently because hipblaslt doesn't have amax op
|
||||
out_fp32 = out_fp8.to(torch.float32)
|
||||
outb_fp32 = outb_fp8.to(torch.float32)
|
||||
@ -621,8 +730,8 @@ class TestFP8Matmul(TestCase):
|
||||
scale_b = torch.tensor(1.0, device=device)
|
||||
input_bias = None
|
||||
if bias:
|
||||
input_bias = torch.rand((16,), device=device).to(torch.half)
|
||||
_ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias)
|
||||
input_bias = torch.rand((16,), device=device).to(torch.bfloat16)
|
||||
_ = scaled_mm_wrap(x, y, scale_a, scale_b, bias=input_bias)
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@ -630,10 +739,10 @@ class TestFP8Matmul(TestCase):
|
||||
(k, l, m) = (16, 48, 32)
|
||||
x = torch.full((k, l), 0.0, device=device).to(e4m3_type)
|
||||
y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t()
|
||||
bias = torch.full((m,), -3.0, device=device, dtype=torch.half)
|
||||
bias = torch.full((m,), -3.0, device=device, dtype=torch.bfloat16)
|
||||
scale_a = torch.tensor(1.0, device=device)
|
||||
scale_b = torch.tensor(1.0, device=device)
|
||||
outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias)
|
||||
outb_fp8 = scaled_mm_wrap(x, y, scale_a, scale_b, bias=bias)
|
||||
outb_fp32 = outb_fp8.to(torch.float32)
|
||||
self.assertEqual(outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32))
|
||||
|
||||
@ -647,9 +756,9 @@ class TestFP8Matmul(TestCase):
|
||||
scale_b = torch.tensor(1.0, device=device)
|
||||
bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16)
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
ValueError,
|
||||
"Bias is not supported when out_dtype is set to Float32",
|
||||
lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32),
|
||||
lambda: scaled_mm_wrap(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32),
|
||||
)
|
||||
|
||||
@onlyCUDA
|
||||
@ -663,10 +772,11 @@ class TestFP8Matmul(TestCase):
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+",
|
||||
lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32),
|
||||
lambda: scaled_mm_wrap(x, y, scale_a, scale_b, out_dtype=torch.float32),
|
||||
)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||
@unittest.skipIf(SM100OrLater, "fast_accum is SM90-only")
|
||||
def test_float8_scale_fast_accum(self, device) -> None:
|
||||
size = (16, 16)
|
||||
x = torch.full(size, .5, device=device, dtype=e4m3_type)
|
||||
@ -675,9 +785,9 @@ class TestFP8Matmul(TestCase):
|
||||
y = torch.full(size, .5, device=device, dtype=y_type).t()
|
||||
scale_a = torch.tensor(1.5, device=device)
|
||||
scale_b = torch.tensor(0.66, device=device)
|
||||
out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True)
|
||||
out_fp8 = scaled_mm_wrap(x, y, scale_a, scale_b, out_dtype=torch.float8_e4m3fn, use_fast_accum=True)
|
||||
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
|
||||
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True)
|
||||
out_fp8_s = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.float8_e4m3fn, use_fast_accum=True)
|
||||
self.assertEqual(out_fp8, out_fp8_s)
|
||||
|
||||
@onlyCUDA
|
||||
@ -696,7 +806,7 @@ class TestFP8Matmul(TestCase):
|
||||
x_fp8 = x.to(e4m3_type)
|
||||
y_fp8 = y.to(e4m3_type).t()
|
||||
|
||||
out_fp8 = torch._scaled_mm(
|
||||
out_fp8 = scaled_mm_wrap(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=x_scales,
|
||||
@ -720,50 +830,58 @@ class TestFP8Matmul(TestCase):
|
||||
y_fp8 = y.to(e4m3_type).t()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, re.escape("Invalid scaling configuration")
|
||||
ValueError, re.escape("scale_b must have 1 Float element")
|
||||
):
|
||||
torch._scaled_mm(
|
||||
scaled_mm_wrap(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=torch.ones((1, 1), device="cuda"),
|
||||
scale_b=torch.ones((1, 2), device="cuda"),
|
||||
scale_recipe_a=ScalingType.TensorWise,
|
||||
scale_recipe_b=ScalingType.TensorWise,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, re.escape("Invalid scaling configuration")
|
||||
ValueError, re.escape(f"scale_b must have {N} Float elements, got {N + 1}"),
|
||||
):
|
||||
torch._scaled_mm(
|
||||
scaled_mm_wrap(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=torch.ones((M, 1), device="cuda"),
|
||||
scale_b=torch.ones((1, N + 1), device="cuda"),
|
||||
scale_recipe_a=ScalingType.RowWise,
|
||||
scale_recipe_b=ScalingType.RowWise,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, re.escape("Invalid scaling configuration")
|
||||
IndexError, re.escape("Dimension out of range")
|
||||
):
|
||||
torch._scaled_mm(
|
||||
scaled_mm_wrap(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=torch.ones((M), device="cuda"),
|
||||
scale_b=torch.ones((N, 1), device="cuda"),
|
||||
scale_recipe_a=ScalingType.RowWise,
|
||||
scale_recipe_b=ScalingType.RowWise,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, re.escape("Invalid scaling configuration")
|
||||
ValueError, re.escape("expected scale_b.stride(1) to be 1, but got 2"),
|
||||
):
|
||||
torch._scaled_mm(
|
||||
scaled_mm_wrap(
|
||||
x_fp8,
|
||||
y_fp8,
|
||||
scale_a=torch.ones((M, 1), device="cuda"),
|
||||
scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2],
|
||||
scale_recipe_a=ScalingType.RowWise,
|
||||
scale_recipe_b=ScalingType.RowWise,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
def e5m2():
|
||||
out = torch._scaled_mm(
|
||||
out = scaled_mm_wrap(
|
||||
x_fp8,
|
||||
y_fp8.to(e5m2_type),
|
||||
scale_a=torch.ones((M, 1), device="cuda"),
|
||||
@ -838,6 +956,7 @@ class TestFP8Matmul(TestCase):
|
||||
else:
|
||||
test()
|
||||
|
||||
# Note: Removed parameterization over M,N,K from #163829 as it failed tests as-is
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||
@unittest.skipIf(not IS_SM90, "cuBLAS blockwise scaling requires sm90+")
|
||||
@unittest.skipIf(
|
||||
@ -859,12 +978,20 @@ class TestFP8Matmul(TestCase):
|
||||
# 1x128 blocks need scales to be outer-dim-major
|
||||
if lhs_block == 1:
|
||||
x_scales = x_scales.t().contiguous().t()
|
||||
lhs_recipe = ScalingType.BlockWise1x128
|
||||
else:
|
||||
lhs_recipe = ScalingType.BlockWise128x128
|
||||
if rhs_block == 1:
|
||||
y_scales = y_scales.t().contiguous().t()
|
||||
rhs_recipe = ScalingType.BlockWise1x128
|
||||
else:
|
||||
rhs_recipe = ScalingType.BlockWise128x128
|
||||
|
||||
|
||||
# Calculate actual F8 mm
|
||||
out_scaled_mm = mm_float8(
|
||||
x_fp8, y_fp8.t(), a_scale=x_scales, b_scale=y_scales.t(), output_dtype=output_dtype
|
||||
out_scaled_mm = scaled_mm_wrap(
|
||||
x_fp8, y_fp8.t(), scale_a=x_scales.reciprocal(), scale_b=y_scales.reciprocal().t(), out_dtype=output_dtype,
|
||||
scale_recipe_a=lhs_recipe, scale_recipe_b=rhs_recipe
|
||||
)
|
||||
|
||||
# Calculate emulated F8 mm
|
||||
@ -912,9 +1039,9 @@ class TestFP8Matmul(TestCase):
|
||||
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
||||
scale_a = torch.tensor(float('-inf'), device=device)
|
||||
scale_b = torch.tensor(float('-inf'), device=device)
|
||||
f = torch._scaled_mm
|
||||
f = scaled_mm_wrap
|
||||
if use_torch_compile:
|
||||
f = torch.compile(torch._scaled_mm)
|
||||
f = torch.compile(scaled_mm_wrap)
|
||||
out_fp8 = f(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
|
||||
self.assertEqual(out_dtype, out_fp8.dtype)
|
||||
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
||||
@ -938,16 +1065,16 @@ class TestFP8Matmul(TestCase):
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
|
||||
self.assertIsNone(torch._C._get_sm_carveout_experimental())
|
||||
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||
scaled_mm_wrap(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||
torch._C._set_sm_carveout_experimental(0)
|
||||
self.assertEqual(torch._C._get_sm_carveout_experimental(), 0)
|
||||
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||
scaled_mm_wrap(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||
torch._C._set_sm_carveout_experimental(66)
|
||||
self.assertEqual(torch._C._get_sm_carveout_experimental(), 66)
|
||||
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||
scaled_mm_wrap(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||
torch._C._set_sm_carveout_experimental(None)
|
||||
self.assertIsNone(torch._C._get_sm_carveout_experimental())
|
||||
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||
scaled_mm_wrap(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||
|
||||
prof.export_chrome_trace(f.name)
|
||||
if torch.version.hip:
|
||||
@ -1244,7 +1371,7 @@ class TestFP8Matmul(TestCase):
|
||||
A_scale = to_blocked(A_scale)
|
||||
B_scale = to_blocked(B_scale)
|
||||
|
||||
C = torch._scaled_mm(
|
||||
C = scaled_mm_wrap(
|
||||
A,
|
||||
B.t(),
|
||||
A_scale,
|
||||
@ -1283,49 +1410,65 @@ class TestFP8Matmul(TestCase):
|
||||
expected_a_size = BLOCK_SIZE_MN * ceil_div(M, BLOCK_SIZE_MN) * padded_num_k_blocks
|
||||
expected_b_size = BLOCK_SIZE_MN * ceil_div(N, BLOCK_SIZE_MN) * padded_num_k_blocks
|
||||
|
||||
block = (
|
||||
ScalingType.BlockWise1x16
|
||||
if recipe == "nvfp4"
|
||||
else ScalingType.BlockWise1x32
|
||||
)
|
||||
swizzle = SwizzleType.SWIZZLE_32_4_4
|
||||
|
||||
# Test wrong scale tensor size for scale_a with correct dtype
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
ValueError,
|
||||
f".*For Block[W,w]ise.*scaling.*scale_a should have {expected_a_size} "
|
||||
f"elements.*"
|
||||
,
|
||||
):
|
||||
incorrect_size_a = torch.ones(expected_a_size - 1, device=device, dtype=scale_dtype)
|
||||
correct_size_b = torch.ones(expected_b_size, device=device, dtype=scale_dtype)
|
||||
torch._scaled_mm(
|
||||
|
||||
scaled_mm_wrap(
|
||||
x_lowp,
|
||||
y_lowp,
|
||||
scale_a=incorrect_size_a,
|
||||
scale_recipe_a=block,
|
||||
scale_b=correct_size_b,
|
||||
scale_recipe_b=block,
|
||||
swizzle_a=swizzle,
|
||||
swizzle_b=swizzle,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Test wrong scale tensor size for scale_b with correct dtype
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
ValueError,
|
||||
f"For Block[W,w]ise.*scaling.*scale_b should have {expected_b_size} "
|
||||
f"elements.*"
|
||||
,
|
||||
):
|
||||
correct_size_a = torch.ones(expected_a_size, device=device, dtype=scale_dtype)
|
||||
incorrect_size_b = torch.ones(expected_b_size + 1, device=device, dtype=scale_dtype)
|
||||
torch._scaled_mm(
|
||||
scaled_mm_wrap(
|
||||
x_lowp,
|
||||
y_lowp,
|
||||
scale_a=correct_size_a,
|
||||
scale_recipe_a=block,
|
||||
scale_b=incorrect_size_b,
|
||||
scale_recipe_b=block,
|
||||
swizzle_a=swizzle,
|
||||
swizzle_b=swizzle,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Test non-contiguous scale tensors with correct dtype
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"For Block[W,w]ise.*scaling.*both should be contiguous"
|
||||
ValueError,
|
||||
"For Block[W,w]ise.*scaling.*both scales should be contiguous"
|
||||
,
|
||||
):
|
||||
non_contiguous_a = torch.ones(expected_a_size * 2, device=device, dtype=scale_dtype)[::2]
|
||||
contiguous_b = torch.ones(expected_b_size, device=device, dtype=scale_dtype)
|
||||
torch._scaled_mm(
|
||||
scaled_mm_wrap(
|
||||
x_lowp,
|
||||
y_lowp,
|
||||
scale_a=non_contiguous_a,
|
||||
@ -1335,8 +1478,8 @@ class TestFP8Matmul(TestCase):
|
||||
|
||||
def scaled_grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist, use_fast_accum):
|
||||
for a, b, ascale, bscale, out in zip(alist, blist, ascalelist, bscalelist, outlist):
|
||||
out_ref = torch._scaled_mm(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1),
|
||||
out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum)
|
||||
out_ref = scaled_mm_wrap(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1),
|
||||
out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum)
|
||||
self.assertEqual(out, out_ref, atol=5e-2, rtol=5e-4)
|
||||
|
||||
# Testing only _scaled_grouped_mm() with multiple shapes, as
|
||||
@ -1485,7 +1628,7 @@ class TestFP8Matmul(TestCase):
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||
C_ref = A_ref @ B_ref.t()
|
||||
|
||||
compiled_scaled_mm = torch.compile(torch._scaled_mm, backend="inductor")
|
||||
compiled_scaled_mm = torch.compile(scaled_mm_wrap, backend="inductor")
|
||||
C = compiled_scaled_mm(
|
||||
A,
|
||||
B.t(),
|
||||
@ -1514,8 +1657,8 @@ class TestFP8Matmul(TestCase):
|
||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
|
||||
C_ref = A_ref @ B_ref.t()
|
||||
|
||||
compiled_scaled_mm = torch.compile(torch._scaled_mm, backend="inductor")
|
||||
# C = torch._scaled_mm(
|
||||
compiled_scaled_mm = torch.compile(scaled_mm_wrap, backend="inductor")
|
||||
# C = scaled_mm_wrap(
|
||||
C = compiled_scaled_mm(
|
||||
A,
|
||||
B.t(),
|
||||
|
@ -113,6 +113,7 @@
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include <ATen/ROCmFABackend.h>
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/CUDAConfig.h>
|
||||
#include <ATen/native/transformers/cuda/sdp_utils.h>
|
||||
#include <torch/csrc/inductor/static_cuda_launcher.h>
|
||||
@ -2504,6 +2505,39 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
return at::globalContext().blasPreferredBackend();
|
||||
});
|
||||
|
||||
py::enum_<at::blas::ScalingType>(
|
||||
py_module, "_ScalingType", "Supported Tensor scaling types")
|
||||
.value(
|
||||
"TensorWise",
|
||||
at::blas::ScalingType::TensorWise,
|
||||
"Single scale per-tensor")
|
||||
.value(
|
||||
"RowWise", at::blas::ScalingType::RowWise, "Scale per-row of tensor")
|
||||
.value(
|
||||
"BlockWise1x16",
|
||||
at::blas::ScalingType::BlockWise1x16,
|
||||
"Scale per 16 contiguous values")
|
||||
.value(
|
||||
"BlockWise1x32",
|
||||
at::blas::ScalingType::BlockWise1x32,
|
||||
"Scale per 32 contiguous values")
|
||||
.value(
|
||||
"BlockWise1x128",
|
||||
at::blas::ScalingType::BlockWise1x128,
|
||||
"Scale per 128 contiguous values")
|
||||
.value(
|
||||
"BlockWise128x128",
|
||||
at::blas::ScalingType::BlockWise128x128,
|
||||
"Scale per 128x128 tile");
|
||||
|
||||
py::enum_<at::blas::SwizzleType>(
|
||||
py_module, "_SwizzleType", "Supported scale swizzle types")
|
||||
.value("NO_SWIZZLE", at::blas::SwizzleType::NO_SWIZZLE, "No swizzling")
|
||||
.value(
|
||||
"SWIZZLE_32_4_4",
|
||||
at::blas::SwizzleType::SWIZZLE_32_4_4,
|
||||
"Blackwell-stype 32x4x4 swizzle");
|
||||
|
||||
py::enum_<at::ROCmFABackend>(py_module, "_ROCmFABackend")
|
||||
.value("Default", at::ROCmFABackend::Default)
|
||||
.value("AOTriton", at::ROCmFABackend::AOTriton)
|
||||
|
@ -4,11 +4,16 @@ import importlib
|
||||
import math
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
from typing import Any as _Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch import _VF, sym_int as _sym_int, Tensor
|
||||
from torch._C import _add_docstr, _infer_size
|
||||
from torch._C import (
|
||||
_add_docstr,
|
||||
_infer_size,
|
||||
_ScalingType as ScalingType,
|
||||
_SwizzleType as SwizzleType,
|
||||
)
|
||||
from torch._jit_internal import (
|
||||
_overload,
|
||||
boolean_dispatch,
|
||||
@ -27,6 +32,10 @@ from torch.overrides import (
|
||||
)
|
||||
|
||||
|
||||
# Set visibility of the bound enums to this module
|
||||
ScalingType.__module__ = "torch.nn.functional"
|
||||
SwizzleType.__module__ = "torch.nn.functional"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.types import _dtype as DType
|
||||
else:
|
||||
@ -6618,3 +6627,87 @@ def multi_head_attention_forward(
|
||||
# squeeze the output if input was unbatched
|
||||
attn_output = attn_output.squeeze(1)
|
||||
return attn_output, None
|
||||
|
||||
|
||||
def scaled_mm(
|
||||
mat_a: Tensor,
|
||||
mat_b: Tensor,
|
||||
scale_a: Tensor | list[Tensor],
|
||||
scale_recipe_a: ScalingType | list[ScalingType],
|
||||
scale_b: Tensor | list[Tensor],
|
||||
scale_recipe_b: ScalingType | list[ScalingType],
|
||||
swizzle_a: SwizzleType | list[SwizzleType] | None = None,
|
||||
swizzle_b: SwizzleType | list[SwizzleType] | None = None,
|
||||
bias: Optional[Tensor] = None,
|
||||
output_dtype: Optional[torch.dtype] = torch.bfloat16,
|
||||
contraction_dim: list[int] | tuple[int] = (),
|
||||
use_fast_accum: bool = False,
|
||||
) -> Tensor:
|
||||
r"""
|
||||
scaled_mm(mat_a, mat_b, scale_a, scale_recipe_a, scale_b, scale_recipe_b, swizzle_a, swizzle_b, bias, output_dtype,
|
||||
contraction_dim, use_fast_accum)
|
||||
|
||||
Applies a scaled matrix-multiply, mm(mat_a, mat_b) where the scaling of mat_a and mat_b are described by
|
||||
scale_recipe_a and scale_recipe_b respectively.
|
||||
|
||||
Args:
|
||||
scale_a: Tensor containing decoding scaling factors for mat_a
|
||||
scale_recipe_a: Enum describing how mat_a has been scaled
|
||||
scale_b: Tensor containing decoding scaling factors for mat_b
|
||||
scale_recipe_b: Enum describing how mat_b has been scaled
|
||||
swizzle_a: Enum describing the swizzling pattern (if any) of scale_a
|
||||
swizzle_b: Enum describing the swizzling pattern (if any) of scale_b
|
||||
bias: optional bias term to be added to the output
|
||||
output_dtype: dtype used for the output tensor
|
||||
contraction_dim: describe which dimensions are :math:`K` in the matmul.
|
||||
use_fast_accum: enable/disable tensor-core fast accumulation (Hopper-GPUs only)
|
||||
"""
|
||||
|
||||
def expand_single_value(v: _Any | list[_Any] | None) -> list[_Any]:
|
||||
if v is None:
|
||||
return []
|
||||
elif not isinstance(v, (list)):
|
||||
return [
|
||||
v,
|
||||
]
|
||||
else:
|
||||
return v
|
||||
|
||||
scale_a = expand_single_value(scale_a)
|
||||
scale_recipe_a = expand_single_value(scale_recipe_a)
|
||||
scale_b = expand_single_value(scale_b)
|
||||
scale_recipe_b = expand_single_value(scale_recipe_b)
|
||||
swizzle_a = expand_single_value(swizzle_a)
|
||||
swizzle_b = expand_single_value(swizzle_b)
|
||||
|
||||
# native_functions has restrictions on what can be defined
|
||||
# & passed through - std::optional<ArrayRef<Tensor>> for instance
|
||||
# *cannot* be passed, but an empty vector (list) can.
|
||||
# So, we need to convert None arguments for lists in python
|
||||
# explicitly into empty lists.
|
||||
def list_or_empty(l: list[_Any] | None) -> list[_Any]:
|
||||
return [] if not l else l
|
||||
|
||||
def enum_list_as_int_list(l: _Any | list[_Any]) -> list[_Any]:
|
||||
if not isinstance(l, list):
|
||||
l = [
|
||||
l,
|
||||
]
|
||||
return [li.value for li in l]
|
||||
|
||||
out = torch._scaled_mm_v2(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
enum_list_as_int_list(scale_recipe_a),
|
||||
enum_list_as_int_list(list_or_empty(swizzle_a)),
|
||||
scale_b,
|
||||
enum_list_as_int_list(scale_recipe_b),
|
||||
enum_list_as_int_list(list_or_empty(swizzle_b)),
|
||||
bias,
|
||||
output_dtype,
|
||||
contraction_dim,
|
||||
use_fast_accum,
|
||||
)
|
||||
|
||||
return out
|
||||
|
@ -249,6 +249,7 @@ def get_ignored_functions() -> set[Callable]:
|
||||
torch.nn.functional.has_torch_function_unary,
|
||||
torch.nn.functional.has_torch_function_variadic,
|
||||
torch.nn.functional.handle_torch_function,
|
||||
torch.nn.functional.scaled_mm,
|
||||
torch.nn.functional.sigmoid,
|
||||
torch.nn.functional.hardsigmoid,
|
||||
torch.nn.functional.tanh,
|
||||
|
Reference in New Issue
Block a user