mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add scaled_mm python API, test (#164142)
Summary: * Add `torch.nn.functional.scaled_mm` as an abstraction around the C++ methods * Wraps `torch._scaled_mm_v2` API by default, but user can force use of the older `torch._scaled_mm` interface. * Scaled MM tests now run on the new API Test Plan: `pytest test/test_scaled_matmul_cuda.py` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/164142 Approved by: https://github.com/drisspg ghstack dependencies: #164141
This commit is contained in:
committed by
PyTorch MergeBot
parent
512b6b59f0
commit
6a7f5c0d21
@ -28,4 +28,19 @@ inline std::ostream& operator<<(std::ostream& stream, at::BlasBackend backend) {
|
|||||||
return stream << BlasBackendToString(backend);
|
return stream << BlasBackendToString(backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace blas {
|
||||||
|
|
||||||
|
enum class ScalingType : std::uint8_t {
|
||||||
|
TensorWise, // fp32 scales
|
||||||
|
RowWise, // fp32 scales
|
||||||
|
BlockWise1x16, // fp8_e4m3fn scales
|
||||||
|
BlockWise1x32, // fp8_e8m0fnu scales
|
||||||
|
BlockWise1x128, // fp32 scales
|
||||||
|
BlockWise128x128, // fp32 scales
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class SwizzleType : std::uint8_t { NO_SWIZZLE = 0, SWIZZLE_32_4_4 = 1 };
|
||||||
|
|
||||||
|
} // namespace blas
|
||||||
|
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
|||||||
@ -1861,6 +1861,8 @@ template bool gemm_and_bias(
|
|||||||
int64_t result_ld,
|
int64_t result_ld,
|
||||||
GEMMAndBiasActivationEpilogue activation);
|
GEMMAndBiasActivationEpilogue activation);
|
||||||
|
|
||||||
|
using at::blas::ScalingType;
|
||||||
|
|
||||||
int get_scale_mode(ScalingType scaling_type, ScalarType scale_dtype, bool use_fast_accum) {
|
int get_scale_mode(ScalingType scaling_type, ScalarType scale_dtype, bool use_fast_accum) {
|
||||||
switch (scaling_type) {
|
switch (scaling_type) {
|
||||||
case ScalingType::BlockWise1x32:
|
case ScalingType::BlockWise1x32:
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <ATen/BlasBackend.h>
|
||||||
#include <ATen/OpMathType.h>
|
#include <ATen/OpMathType.h>
|
||||||
|
|
||||||
namespace at::cuda::blas {
|
namespace at::cuda::blas {
|
||||||
@ -136,15 +137,6 @@ void int8_gemm(
|
|||||||
int32_t* result_ptr,
|
int32_t* result_ptr,
|
||||||
int64_t result_ld);
|
int64_t result_ld);
|
||||||
|
|
||||||
enum class ScalingType : std::uint8_t {
|
|
||||||
TensorWise, // fp32 scales
|
|
||||||
RowWise, // fp32 scales
|
|
||||||
BlockWise1x16, // fp8_e4m3fn scales
|
|
||||||
BlockWise1x32, // fp8_e8m0fnu scales
|
|
||||||
BlockWise1x128, // fp32 scales
|
|
||||||
BlockWise128x128, // fp32 scales
|
|
||||||
};
|
|
||||||
|
|
||||||
void scaled_gemm(
|
void scaled_gemm(
|
||||||
char transa,
|
char transa,
|
||||||
char transb,
|
char transb,
|
||||||
@ -156,13 +148,13 @@ void scaled_gemm(
|
|||||||
int64_t mat1_ld,
|
int64_t mat1_ld,
|
||||||
ScalarType mat1_dtype,
|
ScalarType mat1_dtype,
|
||||||
ScalarType mat1_scale_dtype,
|
ScalarType mat1_scale_dtype,
|
||||||
ScalingType mat1_scaling_type,
|
at::blas::ScalingType mat1_scaling_type,
|
||||||
const void* mat2_ptr,
|
const void* mat2_ptr,
|
||||||
const void* mat2_scale_ptr,
|
const void* mat2_scale_ptr,
|
||||||
int64_t mat2_ld,
|
int64_t mat2_ld,
|
||||||
ScalarType mat2_dtype,
|
ScalarType mat2_dtype,
|
||||||
ScalarType mat2_scale_dtype,
|
ScalarType mat2_scale_dtype,
|
||||||
ScalingType mat2_scaling_type,
|
at::blas::ScalingType mat2_scaling_type,
|
||||||
const void* bias_ptr,
|
const void* bias_ptr,
|
||||||
ScalarType bias_dtype,
|
ScalarType bias_dtype,
|
||||||
void* result_ptr,
|
void* result_ptr,
|
||||||
|
|||||||
@ -29,7 +29,7 @@
|
|||||||
|
|
||||||
namespace at::cuda::tunable {
|
namespace at::cuda::tunable {
|
||||||
|
|
||||||
using at::cuda::blas::ScalingType;
|
using at::blas::ScalingType;
|
||||||
|
|
||||||
enum class BlasOp {
|
enum class BlasOp {
|
||||||
N = 0,
|
N = 0,
|
||||||
|
|||||||
@ -106,7 +106,8 @@ c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, b
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
using at::cuda::blas::ScalingType;
|
using at::blas::ScalingType;
|
||||||
|
using at::blas::SwizzleType;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Prepares matrices for CUBLAS operation
|
* @brief Prepares matrices for CUBLAS operation
|
||||||
@ -1093,7 +1094,7 @@ namespace{
|
|||||||
* - Returns Error.
|
* - Returns Error.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
using at::cuda::blas::ScalingType;
|
using at::blas::ScalingType;
|
||||||
|
|
||||||
bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) {
|
||||||
return isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.numel() == 1;
|
return isFloat8Type(t.scalar_type()) && scale.scalar_type() == kFloat && scale.numel() == 1;
|
||||||
@ -1546,7 +1547,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
|||||||
args.scale_result_ptr,
|
args.scale_result_ptr,
|
||||||
args.result_ld,
|
args.result_ld,
|
||||||
out_dtype_,
|
out_dtype_,
|
||||||
false /*use_fast_accum*/);
|
use_fast_accum);
|
||||||
}
|
}
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
@ -1664,11 +1665,6 @@ _scaled_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
|
|||||||
return _scaled_mm_out_cuda(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
|
return _scaled_mm_out_cuda(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
enum class SwizzleType {
|
|
||||||
NO_SWIZZLE = 0,
|
|
||||||
SWIZZLE_32_4_4 = 1
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Track concrete implementations available
|
* Track concrete implementations available
|
||||||
*/
|
*/
|
||||||
@ -1890,9 +1886,9 @@ std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 8>
|
|||||||
ScaledGemmImplementation::BLOCK_128x128_1x128},
|
ScaledGemmImplementation::BLOCK_128x128_1x128},
|
||||||
{ "block_1x128_1x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise1x128, _1, _2, _3, _4, _5, _6),
|
{ "block_1x128_1x128", std::bind(check_deepseek_recipe, ScalingType::BlockWise1x128, ScalingType::BlockWise1x128, _1, _2, _3, _4, _5, _6),
|
||||||
ScaledGemmImplementation::BLOCK_1x128_1x128},
|
ScaledGemmImplementation::BLOCK_1x128_1x128},
|
||||||
{ "nvfp4", check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4},
|
{ "nvfp4_nvfp4", check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4},
|
||||||
{ "nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
|
{ "nvfp4_nvfp4_single_scale", check_nvfp4_recipe_single_scale, ScaledGemmImplementation::NVFP4_NVFP4_SINGLE_SCALE },
|
||||||
{ "mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
|
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
|
||||||
|
|
||||||
Tensor&
|
Tensor&
|
||||||
_cutlass_scaled_gemm(
|
_cutlass_scaled_gemm(
|
||||||
|
|||||||
@ -218,3 +218,13 @@ DataParallel functions (multi-GPU, distributed)
|
|||||||
:nosignatures:
|
:nosignatures:
|
||||||
|
|
||||||
torch.nn.parallel.data_parallel
|
torch.nn.parallel.data_parallel
|
||||||
|
|
||||||
|
Low-Precision functions
|
||||||
|
-----------------------
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
ScalingType
|
||||||
|
SwizzleType
|
||||||
|
scaled_mm
|
||||||
|
|||||||
@ -35,7 +35,6 @@ and supported quantized modules and functions.
|
|||||||
|
|
||||||
quantization-support
|
quantization-support
|
||||||
|
|
||||||
|
|
||||||
.. torch.ao is missing documentation. Since part of it is mentioned here, adding them here for now.
|
.. torch.ao is missing documentation. Since part of it is mentioned here, adding them here for now.
|
||||||
.. They are here for tracking purposes until they are more permanently fixed.
|
.. They are here for tracking purposes until they are more permanently fixed.
|
||||||
.. py:module:: torch.ao
|
.. py:module:: torch.ao
|
||||||
|
|||||||
@ -11,15 +11,17 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
from torch.nn.functional import scaled_mm, ScalingType, SwizzleType
|
||||||
from torch.testing._internal.common_cuda import (
|
from torch.testing._internal.common_cuda import (
|
||||||
SM89OrLater,
|
IS_SM90,
|
||||||
SM90OrLater,
|
|
||||||
_get_torch_cuda_version,
|
_get_torch_cuda_version,
|
||||||
PLATFORM_SUPPORTS_FP8,
|
PLATFORM_SUPPORTS_FP8,
|
||||||
PLATFORM_SUPPORTS_FP8_GROUPED_GEMM,
|
PLATFORM_SUPPORTS_FP8_GROUPED_GEMM,
|
||||||
PLATFORM_SUPPORTS_MX_GEMM,
|
PLATFORM_SUPPORTS_MX_GEMM,
|
||||||
PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM,
|
PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM,
|
||||||
IS_SM90,
|
SM100OrLater,
|
||||||
|
SM89OrLater,
|
||||||
|
SM90OrLater,
|
||||||
with_tf32_off,
|
with_tf32_off,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_device_type import (
|
from torch.testing._internal.common_device_type import (
|
||||||
@ -107,6 +109,112 @@ def tensor_to_scale_block(
|
|||||||
scale = scale.flatten(2, 3).flatten(0, 1)
|
scale = scale.flatten(2, 3).flatten(0, 1)
|
||||||
return x, scale
|
return x, scale
|
||||||
|
|
||||||
|
|
||||||
|
def round_up(x: int, y: int) -> int:
|
||||||
|
return ((x + y - 1) // y) * y
|
||||||
|
|
||||||
|
|
||||||
|
def infer_scale_swizzle(mat, scale):
|
||||||
|
# Tensor-wise
|
||||||
|
if scale.numel() == 1:
|
||||||
|
return ScalingType.TensorWise, SwizzleType.NO_SWIZZLE
|
||||||
|
|
||||||
|
# Row-wise
|
||||||
|
if (scale.shape[0] == mat.shape[0] and scale.shape[1] == 1) or (
|
||||||
|
scale.shape[0] == 1 and scale.shape[1] == mat.shape[1]
|
||||||
|
):
|
||||||
|
return ScalingType.RowWise, SwizzleType.NO_SWIZZLE
|
||||||
|
|
||||||
|
# deepgemm 1x128 / 128x1
|
||||||
|
if len(scale.shape) > 1:
|
||||||
|
if (
|
||||||
|
scale.shape[0] == mat.shape[0]
|
||||||
|
and scale.shape[1] == math.ceil(mat.shape[1] // 128)
|
||||||
|
or scale.shape[1] == mat.shape[1]
|
||||||
|
and scale.shape[0] == math.ceil(mat.shape[0] // 128)
|
||||||
|
):
|
||||||
|
return ScalingType.BlockWise1x128, SwizzleType.NO_SWIZZLE
|
||||||
|
|
||||||
|
# deepgemm 128x128
|
||||||
|
if scale.shape[0] == math.ceil(mat.shape[0] // 128) and scale.shape[
|
||||||
|
1
|
||||||
|
] == math.ceil(mat.shape[1] // 128):
|
||||||
|
return ScalingType.BlockWise128x128, SwizzleType.NO_SWIZZLE
|
||||||
|
|
||||||
|
# NVFP4
|
||||||
|
if (
|
||||||
|
scale.numel()
|
||||||
|
== round_up(mat.shape[0], 128) * round_up(math.ceil(2 * mat.shape[1] // 16), 4)
|
||||||
|
or scale.numel()
|
||||||
|
== round_up(mat.shape[1], 128) * round_up(math.ceil(2 * mat.shape[0] // 16), 4)
|
||||||
|
and mat.dtype == torch.float4_e2m1fn_x2
|
||||||
|
and scale.dtype == torch.float8_e4m3fn
|
||||||
|
):
|
||||||
|
return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4
|
||||||
|
|
||||||
|
# MX
|
||||||
|
if (
|
||||||
|
scale.numel()
|
||||||
|
== round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4)
|
||||||
|
or scale.numel()
|
||||||
|
== round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4)
|
||||||
|
and scale.dtype == torch.float8_e8m0fnu
|
||||||
|
):
|
||||||
|
return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
wrap: bool = True
|
||||||
|
|
||||||
|
def scaled_mm_wrap(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
scale_recipe_a=None,
|
||||||
|
scale_recipe_b=None,
|
||||||
|
swizzle_a=SwizzleType.NO_SWIZZLE,
|
||||||
|
swizzle_b=SwizzleType.NO_SWIZZLE,
|
||||||
|
scale_result=None,
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
use_fast_accum=False,
|
||||||
|
bias=None,
|
||||||
|
wrap_v2=wrap,
|
||||||
|
):
|
||||||
|
if not wrap_v2:
|
||||||
|
return torch._scaled_mm(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
scale_result=scale_result,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
bias=bias,
|
||||||
|
use_fast_accum=use_fast_accum,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# infer scalingtype and swizzle from scales
|
||||||
|
if scale_recipe_a is None:
|
||||||
|
scale_recipe_a, swizzle_a = infer_scale_swizzle(a, scale_a)
|
||||||
|
if scale_recipe_b is None:
|
||||||
|
scale_recipe_b, swizzle_b = infer_scale_swizzle(b, scale_b)
|
||||||
|
|
||||||
|
out = scaled_mm(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a,
|
||||||
|
scale_recipe_a,
|
||||||
|
scale_b,
|
||||||
|
scale_recipe_b,
|
||||||
|
swizzle_a=swizzle_a,
|
||||||
|
swizzle_b=swizzle_b,
|
||||||
|
bias=bias,
|
||||||
|
output_dtype=out_dtype,
|
||||||
|
use_fast_accum=use_fast_accum,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
||||||
# naive implementation: dq -> op -> q
|
# naive implementation: dq -> op -> q
|
||||||
x_fp32 = x.to(torch.float) / x_scale
|
x_fp32 = x.to(torch.float) / x_scale
|
||||||
@ -139,7 +247,7 @@ def addmm_float8_unwrapped(
|
|||||||
b_inverse_scale = b_scale.reciprocal()
|
b_inverse_scale = b_scale.reciprocal()
|
||||||
if output_dtype == torch.float32 and bias is not None:
|
if output_dtype == torch.float32 and bias is not None:
|
||||||
# Bias is not supported by _scaled_mm when output is fp32
|
# Bias is not supported by _scaled_mm when output is fp32
|
||||||
output = torch._scaled_mm(
|
output = scaled_mm_wrap(
|
||||||
a_data,
|
a_data,
|
||||||
b_data,
|
b_data,
|
||||||
scale_a=a_inverse_scale,
|
scale_a=a_inverse_scale,
|
||||||
@ -149,7 +257,7 @@ def addmm_float8_unwrapped(
|
|||||||
)
|
)
|
||||||
output += bias
|
output += bias
|
||||||
return output
|
return output
|
||||||
output = torch._scaled_mm(
|
output = scaled_mm_wrap(
|
||||||
a_data,
|
a_data,
|
||||||
b_data,
|
b_data,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
@ -290,7 +398,7 @@ class TestFP8Matmul(TestCase):
|
|||||||
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
||||||
scale_a = torch.tensor(1.0, device=device)
|
scale_a = torch.tensor(1.0, device=device)
|
||||||
scale_b = torch.tensor(1.0, device=device)
|
scale_b = torch.tensor(1.0, device=device)
|
||||||
out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
|
out_fp8 = scaled_mm_wrap(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
|
||||||
if out_dtype is not None:
|
if out_dtype is not None:
|
||||||
self.assertEqual(out_dtype, out_fp8.dtype)
|
self.assertEqual(out_dtype, out_fp8.dtype)
|
||||||
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
||||||
@ -301,7 +409,7 @@ class TestFP8Matmul(TestCase):
|
|||||||
self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16)
|
self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16)
|
||||||
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
|
# According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
|
||||||
# supported on ROCm but fails on CUDA
|
# supported on ROCm but fails on CUDA
|
||||||
ctx = self.assertRaises(RuntimeError) if torch.version.hip is None and device != "cpu" else contextlib.nullcontext()
|
ctx = self.assertRaises(ValueError) if torch.version.hip is None and device != "cpu" else contextlib.nullcontext()
|
||||||
with ctx:
|
with ctx:
|
||||||
self._test_tautological_mm(device, e5m2_type, e5m2_type)
|
self._test_tautological_mm(device, e5m2_type, e5m2_type)
|
||||||
|
|
||||||
@ -326,9 +434,9 @@ class TestFP8Matmul(TestCase):
|
|||||||
scale_one = torch.tensor(1.0, device=device)
|
scale_one = torch.tensor(1.0, device=device)
|
||||||
scale_a = torch.tensor(1.5, device=device)
|
scale_a = torch.tensor(1.5, device=device)
|
||||||
scale_b = torch.tensor(0.66, device=device)
|
scale_b = torch.tensor(0.66, device=device)
|
||||||
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_one, scale_b=scale_one)
|
out_fp8 = scaled_mm_wrap(x, y, scale_a=scale_one, scale_b=scale_one)
|
||||||
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
|
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
|
||||||
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
out_fp8_s = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b)
|
||||||
self.assertEqual(out_fp8, out_fp8_s)
|
self.assertEqual(out_fp8, out_fp8_s)
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
|
@unittest.skipIf(not PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM, mxfp8_grouped_mm_skip_msg)
|
||||||
@ -506,6 +614,7 @@ class TestFP8Matmul(TestCase):
|
|||||||
x_scale = tensor_to_scale(x, input_dtype).float()
|
x_scale = tensor_to_scale(x, input_dtype).float()
|
||||||
y_scale = tensor_to_scale(y, input_dtype).float()
|
y_scale = tensor_to_scale(y, input_dtype).float()
|
||||||
|
|
||||||
|
|
||||||
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
|
x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
|
||||||
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
|
y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
|
||||||
|
|
||||||
@ -600,11 +709,11 @@ class TestFP8Matmul(TestCase):
|
|||||||
(k, l, m) = (16, 48, 32)
|
(k, l, m) = (16, 48, 32)
|
||||||
x = torch.ones((k, l), device=device).to(e4m3_type)
|
x = torch.ones((k, l), device=device).to(e4m3_type)
|
||||||
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
|
y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
|
||||||
bias = torch.full((m,), 4.0, device=device, dtype=torch.half)
|
bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16)
|
||||||
scale_a = torch.tensor(1.0, device=device)
|
scale_a = torch.tensor(1.0, device=device)
|
||||||
scale_b = torch.tensor(1.0, device=device)
|
scale_b = torch.tensor(1.0, device=device)
|
||||||
out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
|
out_fp8 = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b)
|
||||||
outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias)
|
outb_fp8 = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias)
|
||||||
# this fails on ROCm currently because hipblaslt doesn't have amax op
|
# this fails on ROCm currently because hipblaslt doesn't have amax op
|
||||||
out_fp32 = out_fp8.to(torch.float32)
|
out_fp32 = out_fp8.to(torch.float32)
|
||||||
outb_fp32 = outb_fp8.to(torch.float32)
|
outb_fp32 = outb_fp8.to(torch.float32)
|
||||||
@ -621,8 +730,8 @@ class TestFP8Matmul(TestCase):
|
|||||||
scale_b = torch.tensor(1.0, device=device)
|
scale_b = torch.tensor(1.0, device=device)
|
||||||
input_bias = None
|
input_bias = None
|
||||||
if bias:
|
if bias:
|
||||||
input_bias = torch.rand((16,), device=device).to(torch.half)
|
input_bias = torch.rand((16,), device=device).to(torch.bfloat16)
|
||||||
_ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias)
|
_ = scaled_mm_wrap(x, y, scale_a, scale_b, bias=input_bias)
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||||
@ -630,10 +739,10 @@ class TestFP8Matmul(TestCase):
|
|||||||
(k, l, m) = (16, 48, 32)
|
(k, l, m) = (16, 48, 32)
|
||||||
x = torch.full((k, l), 0.0, device=device).to(e4m3_type)
|
x = torch.full((k, l), 0.0, device=device).to(e4m3_type)
|
||||||
y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t()
|
y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t()
|
||||||
bias = torch.full((m,), -3.0, device=device, dtype=torch.half)
|
bias = torch.full((m,), -3.0, device=device, dtype=torch.bfloat16)
|
||||||
scale_a = torch.tensor(1.0, device=device)
|
scale_a = torch.tensor(1.0, device=device)
|
||||||
scale_b = torch.tensor(1.0, device=device)
|
scale_b = torch.tensor(1.0, device=device)
|
||||||
outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias)
|
outb_fp8 = scaled_mm_wrap(x, y, scale_a, scale_b, bias=bias)
|
||||||
outb_fp32 = outb_fp8.to(torch.float32)
|
outb_fp32 = outb_fp8.to(torch.float32)
|
||||||
self.assertEqual(outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32))
|
self.assertEqual(outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32))
|
||||||
|
|
||||||
@ -647,9 +756,9 @@ class TestFP8Matmul(TestCase):
|
|||||||
scale_b = torch.tensor(1.0, device=device)
|
scale_b = torch.tensor(1.0, device=device)
|
||||||
bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16)
|
bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16)
|
||||||
self.assertRaisesRegex(
|
self.assertRaisesRegex(
|
||||||
RuntimeError,
|
ValueError,
|
||||||
"Bias is not supported when out_dtype is set to Float32",
|
"Bias is not supported when out_dtype is set to Float32",
|
||||||
lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32),
|
lambda: scaled_mm_wrap(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32),
|
||||||
)
|
)
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@ -663,10 +772,11 @@ class TestFP8Matmul(TestCase):
|
|||||||
self.assertRaisesRegex(
|
self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+",
|
r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+",
|
||||||
lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32),
|
lambda: scaled_mm_wrap(x, y, scale_a, scale_b, out_dtype=torch.float32),
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
|
||||||
|
@unittest.skipIf(SM100OrLater, "fast_accum is SM90-only")
|
||||||
def test_float8_scale_fast_accum(self, device) -> None:
|
def test_float8_scale_fast_accum(self, device) -> None:
|
||||||
size = (16, 16)
|
size = (16, 16)
|
||||||
x = torch.full(size, .5, device=device, dtype=e4m3_type)
|
x = torch.full(size, .5, device=device, dtype=e4m3_type)
|
||||||
@ -675,9 +785,9 @@ class TestFP8Matmul(TestCase):
|
|||||||
y = torch.full(size, .5, device=device, dtype=y_type).t()
|
y = torch.full(size, .5, device=device, dtype=y_type).t()
|
||||||
scale_a = torch.tensor(1.5, device=device)
|
scale_a = torch.tensor(1.5, device=device)
|
||||||
scale_b = torch.tensor(0.66, device=device)
|
scale_b = torch.tensor(0.66, device=device)
|
||||||
out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True)
|
out_fp8 = scaled_mm_wrap(x, y, scale_a, scale_b, out_dtype=torch.float8_e4m3fn, use_fast_accum=True)
|
||||||
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
|
self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
|
||||||
out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True)
|
out_fp8_s = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.float8_e4m3fn, use_fast_accum=True)
|
||||||
self.assertEqual(out_fp8, out_fp8_s)
|
self.assertEqual(out_fp8, out_fp8_s)
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@ -696,7 +806,7 @@ class TestFP8Matmul(TestCase):
|
|||||||
x_fp8 = x.to(e4m3_type)
|
x_fp8 = x.to(e4m3_type)
|
||||||
y_fp8 = y.to(e4m3_type).t()
|
y_fp8 = y.to(e4m3_type).t()
|
||||||
|
|
||||||
out_fp8 = torch._scaled_mm(
|
out_fp8 = scaled_mm_wrap(
|
||||||
x_fp8,
|
x_fp8,
|
||||||
y_fp8,
|
y_fp8,
|
||||||
scale_a=x_scales,
|
scale_a=x_scales,
|
||||||
@ -720,50 +830,58 @@ class TestFP8Matmul(TestCase):
|
|||||||
y_fp8 = y.to(e4m3_type).t()
|
y_fp8 = y.to(e4m3_type).t()
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, re.escape("Invalid scaling configuration")
|
ValueError, re.escape("scale_b must have 1 Float element")
|
||||||
):
|
):
|
||||||
torch._scaled_mm(
|
scaled_mm_wrap(
|
||||||
x_fp8,
|
x_fp8,
|
||||||
y_fp8,
|
y_fp8,
|
||||||
scale_a=torch.ones((1, 1), device="cuda"),
|
scale_a=torch.ones((1, 1), device="cuda"),
|
||||||
scale_b=torch.ones((1, 2), device="cuda"),
|
scale_b=torch.ones((1, 2), device="cuda"),
|
||||||
|
scale_recipe_a=ScalingType.TensorWise,
|
||||||
|
scale_recipe_b=ScalingType.TensorWise,
|
||||||
out_dtype=torch.bfloat16,
|
out_dtype=torch.bfloat16,
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, re.escape("Invalid scaling configuration")
|
ValueError, re.escape(f"scale_b must have {N} Float elements, got {N + 1}"),
|
||||||
):
|
):
|
||||||
torch._scaled_mm(
|
scaled_mm_wrap(
|
||||||
x_fp8,
|
x_fp8,
|
||||||
y_fp8,
|
y_fp8,
|
||||||
scale_a=torch.ones((M, 1), device="cuda"),
|
scale_a=torch.ones((M, 1), device="cuda"),
|
||||||
scale_b=torch.ones((1, N + 1), device="cuda"),
|
scale_b=torch.ones((1, N + 1), device="cuda"),
|
||||||
|
scale_recipe_a=ScalingType.RowWise,
|
||||||
|
scale_recipe_b=ScalingType.RowWise,
|
||||||
out_dtype=torch.bfloat16,
|
out_dtype=torch.bfloat16,
|
||||||
)
|
)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, re.escape("Invalid scaling configuration")
|
IndexError, re.escape("Dimension out of range")
|
||||||
):
|
):
|
||||||
torch._scaled_mm(
|
scaled_mm_wrap(
|
||||||
x_fp8,
|
x_fp8,
|
||||||
y_fp8,
|
y_fp8,
|
||||||
scale_a=torch.ones((M), device="cuda"),
|
scale_a=torch.ones((M), device="cuda"),
|
||||||
scale_b=torch.ones((N, 1), device="cuda"),
|
scale_b=torch.ones((N, 1), device="cuda"),
|
||||||
|
scale_recipe_a=ScalingType.RowWise,
|
||||||
|
scale_recipe_b=ScalingType.RowWise,
|
||||||
out_dtype=torch.bfloat16,
|
out_dtype=torch.bfloat16,
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, re.escape("Invalid scaling configuration")
|
ValueError, re.escape("expected scale_b.stride(1) to be 1, but got 2"),
|
||||||
):
|
):
|
||||||
torch._scaled_mm(
|
scaled_mm_wrap(
|
||||||
x_fp8,
|
x_fp8,
|
||||||
y_fp8,
|
y_fp8,
|
||||||
scale_a=torch.ones((M, 1), device="cuda"),
|
scale_a=torch.ones((M, 1), device="cuda"),
|
||||||
scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2],
|
scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2],
|
||||||
|
scale_recipe_a=ScalingType.RowWise,
|
||||||
|
scale_recipe_b=ScalingType.RowWise,
|
||||||
out_dtype=torch.bfloat16,
|
out_dtype=torch.bfloat16,
|
||||||
)
|
)
|
||||||
|
|
||||||
def e5m2():
|
def e5m2():
|
||||||
out = torch._scaled_mm(
|
out = scaled_mm_wrap(
|
||||||
x_fp8,
|
x_fp8,
|
||||||
y_fp8.to(e5m2_type),
|
y_fp8.to(e5m2_type),
|
||||||
scale_a=torch.ones((M, 1), device="cuda"),
|
scale_a=torch.ones((M, 1), device="cuda"),
|
||||||
@ -838,6 +956,7 @@ class TestFP8Matmul(TestCase):
|
|||||||
else:
|
else:
|
||||||
test()
|
test()
|
||||||
|
|
||||||
|
# Note: Removed parameterization over M,N,K from #163829 as it failed tests as-is
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
|
||||||
@unittest.skipIf(not IS_SM90, "cuBLAS blockwise scaling requires sm90+")
|
@unittest.skipIf(not IS_SM90, "cuBLAS blockwise scaling requires sm90+")
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
@ -859,12 +978,20 @@ class TestFP8Matmul(TestCase):
|
|||||||
# 1x128 blocks need scales to be outer-dim-major
|
# 1x128 blocks need scales to be outer-dim-major
|
||||||
if lhs_block == 1:
|
if lhs_block == 1:
|
||||||
x_scales = x_scales.t().contiguous().t()
|
x_scales = x_scales.t().contiguous().t()
|
||||||
|
lhs_recipe = ScalingType.BlockWise1x128
|
||||||
|
else:
|
||||||
|
lhs_recipe = ScalingType.BlockWise128x128
|
||||||
if rhs_block == 1:
|
if rhs_block == 1:
|
||||||
y_scales = y_scales.t().contiguous().t()
|
y_scales = y_scales.t().contiguous().t()
|
||||||
|
rhs_recipe = ScalingType.BlockWise1x128
|
||||||
|
else:
|
||||||
|
rhs_recipe = ScalingType.BlockWise128x128
|
||||||
|
|
||||||
|
|
||||||
# Calculate actual F8 mm
|
# Calculate actual F8 mm
|
||||||
out_scaled_mm = mm_float8(
|
out_scaled_mm = scaled_mm_wrap(
|
||||||
x_fp8, y_fp8.t(), a_scale=x_scales, b_scale=y_scales.t(), output_dtype=output_dtype
|
x_fp8, y_fp8.t(), scale_a=x_scales.reciprocal(), scale_b=y_scales.reciprocal().t(), out_dtype=output_dtype,
|
||||||
|
scale_recipe_a=lhs_recipe, scale_recipe_b=rhs_recipe
|
||||||
)
|
)
|
||||||
|
|
||||||
# Calculate emulated F8 mm
|
# Calculate emulated F8 mm
|
||||||
@ -912,9 +1039,9 @@ class TestFP8Matmul(TestCase):
|
|||||||
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
|
||||||
scale_a = torch.tensor(float('-inf'), device=device)
|
scale_a = torch.tensor(float('-inf'), device=device)
|
||||||
scale_b = torch.tensor(float('-inf'), device=device)
|
scale_b = torch.tensor(float('-inf'), device=device)
|
||||||
f = torch._scaled_mm
|
f = scaled_mm_wrap
|
||||||
if use_torch_compile:
|
if use_torch_compile:
|
||||||
f = torch.compile(torch._scaled_mm)
|
f = torch.compile(scaled_mm_wrap)
|
||||||
out_fp8 = f(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
|
out_fp8 = f(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
|
||||||
self.assertEqual(out_dtype, out_fp8.dtype)
|
self.assertEqual(out_dtype, out_fp8.dtype)
|
||||||
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
self.assertEqual(out_fp32, out_fp8.to(torch.float))
|
||||||
@ -938,16 +1065,16 @@ class TestFP8Matmul(TestCase):
|
|||||||
with tempfile.NamedTemporaryFile() as f:
|
with tempfile.NamedTemporaryFile() as f:
|
||||||
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
|
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
|
||||||
self.assertIsNone(torch._C._get_sm_carveout_experimental())
|
self.assertIsNone(torch._C._get_sm_carveout_experimental())
|
||||||
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
scaled_mm_wrap(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||||
torch._C._set_sm_carveout_experimental(0)
|
torch._C._set_sm_carveout_experimental(0)
|
||||||
self.assertEqual(torch._C._get_sm_carveout_experimental(), 0)
|
self.assertEqual(torch._C._get_sm_carveout_experimental(), 0)
|
||||||
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
scaled_mm_wrap(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||||
torch._C._set_sm_carveout_experimental(66)
|
torch._C._set_sm_carveout_experimental(66)
|
||||||
self.assertEqual(torch._C._get_sm_carveout_experimental(), 66)
|
self.assertEqual(torch._C._get_sm_carveout_experimental(), 66)
|
||||||
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
scaled_mm_wrap(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||||
torch._C._set_sm_carveout_experimental(None)
|
torch._C._set_sm_carveout_experimental(None)
|
||||||
self.assertIsNone(torch._C._get_sm_carveout_experimental())
|
self.assertIsNone(torch._C._get_sm_carveout_experimental())
|
||||||
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
scaled_mm_wrap(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)
|
||||||
|
|
||||||
prof.export_chrome_trace(f.name)
|
prof.export_chrome_trace(f.name)
|
||||||
if torch.version.hip:
|
if torch.version.hip:
|
||||||
@ -1244,7 +1371,7 @@ class TestFP8Matmul(TestCase):
|
|||||||
A_scale = to_blocked(A_scale)
|
A_scale = to_blocked(A_scale)
|
||||||
B_scale = to_blocked(B_scale)
|
B_scale = to_blocked(B_scale)
|
||||||
|
|
||||||
C = torch._scaled_mm(
|
C = scaled_mm_wrap(
|
||||||
A,
|
A,
|
||||||
B.t(),
|
B.t(),
|
||||||
A_scale,
|
A_scale,
|
||||||
@ -1283,49 +1410,65 @@ class TestFP8Matmul(TestCase):
|
|||||||
expected_a_size = BLOCK_SIZE_MN * ceil_div(M, BLOCK_SIZE_MN) * padded_num_k_blocks
|
expected_a_size = BLOCK_SIZE_MN * ceil_div(M, BLOCK_SIZE_MN) * padded_num_k_blocks
|
||||||
expected_b_size = BLOCK_SIZE_MN * ceil_div(N, BLOCK_SIZE_MN) * padded_num_k_blocks
|
expected_b_size = BLOCK_SIZE_MN * ceil_div(N, BLOCK_SIZE_MN) * padded_num_k_blocks
|
||||||
|
|
||||||
|
block = (
|
||||||
|
ScalingType.BlockWise1x16
|
||||||
|
if recipe == "nvfp4"
|
||||||
|
else ScalingType.BlockWise1x32
|
||||||
|
)
|
||||||
|
swizzle = SwizzleType.SWIZZLE_32_4_4
|
||||||
|
|
||||||
# Test wrong scale tensor size for scale_a with correct dtype
|
# Test wrong scale tensor size for scale_a with correct dtype
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
ValueError,
|
||||||
f".*For Block[W,w]ise.*scaling.*scale_a should have {expected_a_size} "
|
f".*For Block[W,w]ise.*scaling.*scale_a should have {expected_a_size} "
|
||||||
f"elements.*"
|
f"elements.*"
|
||||||
,
|
,
|
||||||
):
|
):
|
||||||
incorrect_size_a = torch.ones(expected_a_size - 1, device=device, dtype=scale_dtype)
|
incorrect_size_a = torch.ones(expected_a_size - 1, device=device, dtype=scale_dtype)
|
||||||
correct_size_b = torch.ones(expected_b_size, device=device, dtype=scale_dtype)
|
correct_size_b = torch.ones(expected_b_size, device=device, dtype=scale_dtype)
|
||||||
torch._scaled_mm(
|
|
||||||
|
scaled_mm_wrap(
|
||||||
x_lowp,
|
x_lowp,
|
||||||
y_lowp,
|
y_lowp,
|
||||||
scale_a=incorrect_size_a,
|
scale_a=incorrect_size_a,
|
||||||
|
scale_recipe_a=block,
|
||||||
scale_b=correct_size_b,
|
scale_b=correct_size_b,
|
||||||
|
scale_recipe_b=block,
|
||||||
|
swizzle_a=swizzle,
|
||||||
|
swizzle_b=swizzle,
|
||||||
out_dtype=torch.bfloat16,
|
out_dtype=torch.bfloat16,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test wrong scale tensor size for scale_b with correct dtype
|
# Test wrong scale tensor size for scale_b with correct dtype
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
ValueError,
|
||||||
f"For Block[W,w]ise.*scaling.*scale_b should have {expected_b_size} "
|
f"For Block[W,w]ise.*scaling.*scale_b should have {expected_b_size} "
|
||||||
f"elements.*"
|
f"elements.*"
|
||||||
,
|
,
|
||||||
):
|
):
|
||||||
correct_size_a = torch.ones(expected_a_size, device=device, dtype=scale_dtype)
|
correct_size_a = torch.ones(expected_a_size, device=device, dtype=scale_dtype)
|
||||||
incorrect_size_b = torch.ones(expected_b_size + 1, device=device, dtype=scale_dtype)
|
incorrect_size_b = torch.ones(expected_b_size + 1, device=device, dtype=scale_dtype)
|
||||||
torch._scaled_mm(
|
scaled_mm_wrap(
|
||||||
x_lowp,
|
x_lowp,
|
||||||
y_lowp,
|
y_lowp,
|
||||||
scale_a=correct_size_a,
|
scale_a=correct_size_a,
|
||||||
|
scale_recipe_a=block,
|
||||||
scale_b=incorrect_size_b,
|
scale_b=incorrect_size_b,
|
||||||
|
scale_recipe_b=block,
|
||||||
|
swizzle_a=swizzle,
|
||||||
|
swizzle_b=swizzle,
|
||||||
out_dtype=torch.bfloat16,
|
out_dtype=torch.bfloat16,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test non-contiguous scale tensors with correct dtype
|
# Test non-contiguous scale tensors with correct dtype
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
ValueError,
|
||||||
"For Block[W,w]ise.*scaling.*both should be contiguous"
|
"For Block[W,w]ise.*scaling.*both scales should be contiguous"
|
||||||
,
|
,
|
||||||
):
|
):
|
||||||
non_contiguous_a = torch.ones(expected_a_size * 2, device=device, dtype=scale_dtype)[::2]
|
non_contiguous_a = torch.ones(expected_a_size * 2, device=device, dtype=scale_dtype)[::2]
|
||||||
contiguous_b = torch.ones(expected_b_size, device=device, dtype=scale_dtype)
|
contiguous_b = torch.ones(expected_b_size, device=device, dtype=scale_dtype)
|
||||||
torch._scaled_mm(
|
scaled_mm_wrap(
|
||||||
x_lowp,
|
x_lowp,
|
||||||
y_lowp,
|
y_lowp,
|
||||||
scale_a=non_contiguous_a,
|
scale_a=non_contiguous_a,
|
||||||
@ -1335,8 +1478,8 @@ class TestFP8Matmul(TestCase):
|
|||||||
|
|
||||||
def scaled_grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist, use_fast_accum):
|
def scaled_grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist, use_fast_accum):
|
||||||
for a, b, ascale, bscale, out in zip(alist, blist, ascalelist, bscalelist, outlist):
|
for a, b, ascale, bscale, out in zip(alist, blist, ascalelist, bscalelist, outlist):
|
||||||
out_ref = torch._scaled_mm(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1),
|
out_ref = scaled_mm_wrap(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1),
|
||||||
out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum)
|
out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum)
|
||||||
self.assertEqual(out, out_ref, atol=5e-2, rtol=5e-4)
|
self.assertEqual(out, out_ref, atol=5e-2, rtol=5e-4)
|
||||||
|
|
||||||
# Testing only _scaled_grouped_mm() with multiple shapes, as
|
# Testing only _scaled_grouped_mm() with multiple shapes, as
|
||||||
@ -1485,7 +1628,7 @@ class TestFP8Matmul(TestCase):
|
|||||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e8m0fnu)
|
||||||
C_ref = A_ref @ B_ref.t()
|
C_ref = A_ref @ B_ref.t()
|
||||||
|
|
||||||
compiled_scaled_mm = torch.compile(torch._scaled_mm, backend="inductor")
|
compiled_scaled_mm = torch.compile(scaled_mm_wrap, backend="inductor")
|
||||||
C = compiled_scaled_mm(
|
C = compiled_scaled_mm(
|
||||||
A,
|
A,
|
||||||
B.t(),
|
B.t(),
|
||||||
@ -1514,8 +1657,8 @@ class TestFP8Matmul(TestCase):
|
|||||||
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
|
B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=fp4_scaling_dtype)
|
||||||
C_ref = A_ref @ B_ref.t()
|
C_ref = A_ref @ B_ref.t()
|
||||||
|
|
||||||
compiled_scaled_mm = torch.compile(torch._scaled_mm, backend="inductor")
|
compiled_scaled_mm = torch.compile(scaled_mm_wrap, backend="inductor")
|
||||||
# C = torch._scaled_mm(
|
# C = scaled_mm_wrap(
|
||||||
C = compiled_scaled_mm(
|
C = compiled_scaled_mm(
|
||||||
A,
|
A,
|
||||||
B.t(),
|
B.t(),
|
||||||
|
|||||||
@ -113,6 +113,7 @@
|
|||||||
|
|
||||||
#ifdef USE_CUDA
|
#ifdef USE_CUDA
|
||||||
#include <ATen/ROCmFABackend.h>
|
#include <ATen/ROCmFABackend.h>
|
||||||
|
#include <ATen/cuda/CUDABlas.h>
|
||||||
#include <ATen/cuda/CUDAConfig.h>
|
#include <ATen/cuda/CUDAConfig.h>
|
||||||
#include <ATen/native/transformers/cuda/sdp_utils.h>
|
#include <ATen/native/transformers/cuda/sdp_utils.h>
|
||||||
#include <torch/csrc/inductor/static_cuda_launcher.h>
|
#include <torch/csrc/inductor/static_cuda_launcher.h>
|
||||||
@ -2504,6 +2505,39 @@ Call this whenever a new thread is created in order to propagate values from
|
|||||||
return at::globalContext().blasPreferredBackend();
|
return at::globalContext().blasPreferredBackend();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
py::enum_<at::blas::ScalingType>(
|
||||||
|
py_module, "_ScalingType", "Supported Tensor scaling types")
|
||||||
|
.value(
|
||||||
|
"TensorWise",
|
||||||
|
at::blas::ScalingType::TensorWise,
|
||||||
|
"Single scale per-tensor")
|
||||||
|
.value(
|
||||||
|
"RowWise", at::blas::ScalingType::RowWise, "Scale per-row of tensor")
|
||||||
|
.value(
|
||||||
|
"BlockWise1x16",
|
||||||
|
at::blas::ScalingType::BlockWise1x16,
|
||||||
|
"Scale per 16 contiguous values")
|
||||||
|
.value(
|
||||||
|
"BlockWise1x32",
|
||||||
|
at::blas::ScalingType::BlockWise1x32,
|
||||||
|
"Scale per 32 contiguous values")
|
||||||
|
.value(
|
||||||
|
"BlockWise1x128",
|
||||||
|
at::blas::ScalingType::BlockWise1x128,
|
||||||
|
"Scale per 128 contiguous values")
|
||||||
|
.value(
|
||||||
|
"BlockWise128x128",
|
||||||
|
at::blas::ScalingType::BlockWise128x128,
|
||||||
|
"Scale per 128x128 tile");
|
||||||
|
|
||||||
|
py::enum_<at::blas::SwizzleType>(
|
||||||
|
py_module, "_SwizzleType", "Supported scale swizzle types")
|
||||||
|
.value("NO_SWIZZLE", at::blas::SwizzleType::NO_SWIZZLE, "No swizzling")
|
||||||
|
.value(
|
||||||
|
"SWIZZLE_32_4_4",
|
||||||
|
at::blas::SwizzleType::SWIZZLE_32_4_4,
|
||||||
|
"Blackwell-stype 32x4x4 swizzle");
|
||||||
|
|
||||||
py::enum_<at::ROCmFABackend>(py_module, "_ROCmFABackend")
|
py::enum_<at::ROCmFABackend>(py_module, "_ROCmFABackend")
|
||||||
.value("Default", at::ROCmFABackend::Default)
|
.value("Default", at::ROCmFABackend::Default)
|
||||||
.value("AOTriton", at::ROCmFABackend::AOTriton)
|
.value("AOTriton", at::ROCmFABackend::AOTriton)
|
||||||
|
|||||||
@ -4,11 +4,16 @@ import importlib
|
|||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Optional, TYPE_CHECKING, Union
|
from typing import Any as _Any, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import _VF, sym_int as _sym_int, Tensor
|
from torch import _VF, sym_int as _sym_int, Tensor
|
||||||
from torch._C import _add_docstr, _infer_size
|
from torch._C import (
|
||||||
|
_add_docstr,
|
||||||
|
_infer_size,
|
||||||
|
_ScalingType as ScalingType,
|
||||||
|
_SwizzleType as SwizzleType,
|
||||||
|
)
|
||||||
from torch._jit_internal import (
|
from torch._jit_internal import (
|
||||||
_overload,
|
_overload,
|
||||||
boolean_dispatch,
|
boolean_dispatch,
|
||||||
@ -27,6 +32,10 @@ from torch.overrides import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Set visibility of the bound enums to this module
|
||||||
|
ScalingType.__module__ = "torch.nn.functional"
|
||||||
|
SwizzleType.__module__ = "torch.nn.functional"
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch.types import _dtype as DType
|
from torch.types import _dtype as DType
|
||||||
else:
|
else:
|
||||||
@ -6618,3 +6627,87 @@ def multi_head_attention_forward(
|
|||||||
# squeeze the output if input was unbatched
|
# squeeze the output if input was unbatched
|
||||||
attn_output = attn_output.squeeze(1)
|
attn_output = attn_output.squeeze(1)
|
||||||
return attn_output, None
|
return attn_output, None
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_mm(
|
||||||
|
mat_a: Tensor,
|
||||||
|
mat_b: Tensor,
|
||||||
|
scale_a: Tensor | list[Tensor],
|
||||||
|
scale_recipe_a: ScalingType | list[ScalingType],
|
||||||
|
scale_b: Tensor | list[Tensor],
|
||||||
|
scale_recipe_b: ScalingType | list[ScalingType],
|
||||||
|
swizzle_a: SwizzleType | list[SwizzleType] | None = None,
|
||||||
|
swizzle_b: SwizzleType | list[SwizzleType] | None = None,
|
||||||
|
bias: Optional[Tensor] = None,
|
||||||
|
output_dtype: Optional[torch.dtype] = torch.bfloat16,
|
||||||
|
contraction_dim: list[int] | tuple[int] = (),
|
||||||
|
use_fast_accum: bool = False,
|
||||||
|
) -> Tensor:
|
||||||
|
r"""
|
||||||
|
scaled_mm(mat_a, mat_b, scale_a, scale_recipe_a, scale_b, scale_recipe_b, swizzle_a, swizzle_b, bias, output_dtype,
|
||||||
|
contraction_dim, use_fast_accum)
|
||||||
|
|
||||||
|
Applies a scaled matrix-multiply, mm(mat_a, mat_b) where the scaling of mat_a and mat_b are described by
|
||||||
|
scale_recipe_a and scale_recipe_b respectively.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scale_a: Tensor containing decoding scaling factors for mat_a
|
||||||
|
scale_recipe_a: Enum describing how mat_a has been scaled
|
||||||
|
scale_b: Tensor containing decoding scaling factors for mat_b
|
||||||
|
scale_recipe_b: Enum describing how mat_b has been scaled
|
||||||
|
swizzle_a: Enum describing the swizzling pattern (if any) of scale_a
|
||||||
|
swizzle_b: Enum describing the swizzling pattern (if any) of scale_b
|
||||||
|
bias: optional bias term to be added to the output
|
||||||
|
output_dtype: dtype used for the output tensor
|
||||||
|
contraction_dim: describe which dimensions are :math:`K` in the matmul.
|
||||||
|
use_fast_accum: enable/disable tensor-core fast accumulation (Hopper-GPUs only)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def expand_single_value(v: _Any | list[_Any] | None) -> list[_Any]:
|
||||||
|
if v is None:
|
||||||
|
return []
|
||||||
|
elif not isinstance(v, (list)):
|
||||||
|
return [
|
||||||
|
v,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return v
|
||||||
|
|
||||||
|
scale_a = expand_single_value(scale_a)
|
||||||
|
scale_recipe_a = expand_single_value(scale_recipe_a)
|
||||||
|
scale_b = expand_single_value(scale_b)
|
||||||
|
scale_recipe_b = expand_single_value(scale_recipe_b)
|
||||||
|
swizzle_a = expand_single_value(swizzle_a)
|
||||||
|
swizzle_b = expand_single_value(swizzle_b)
|
||||||
|
|
||||||
|
# native_functions has restrictions on what can be defined
|
||||||
|
# & passed through - std::optional<ArrayRef<Tensor>> for instance
|
||||||
|
# *cannot* be passed, but an empty vector (list) can.
|
||||||
|
# So, we need to convert None arguments for lists in python
|
||||||
|
# explicitly into empty lists.
|
||||||
|
def list_or_empty(l: list[_Any] | None) -> list[_Any]:
|
||||||
|
return [] if not l else l
|
||||||
|
|
||||||
|
def enum_list_as_int_list(l: _Any | list[_Any]) -> list[_Any]:
|
||||||
|
if not isinstance(l, list):
|
||||||
|
l = [
|
||||||
|
l,
|
||||||
|
]
|
||||||
|
return [li.value for li in l]
|
||||||
|
|
||||||
|
out = torch._scaled_mm_v2(
|
||||||
|
mat_a,
|
||||||
|
mat_b,
|
||||||
|
scale_a,
|
||||||
|
enum_list_as_int_list(scale_recipe_a),
|
||||||
|
enum_list_as_int_list(list_or_empty(swizzle_a)),
|
||||||
|
scale_b,
|
||||||
|
enum_list_as_int_list(scale_recipe_b),
|
||||||
|
enum_list_as_int_list(list_or_empty(swizzle_b)),
|
||||||
|
bias,
|
||||||
|
output_dtype,
|
||||||
|
contraction_dim,
|
||||||
|
use_fast_accum,
|
||||||
|
)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|||||||
@ -249,6 +249,7 @@ def get_ignored_functions() -> set[Callable]:
|
|||||||
torch.nn.functional.has_torch_function_unary,
|
torch.nn.functional.has_torch_function_unary,
|
||||||
torch.nn.functional.has_torch_function_variadic,
|
torch.nn.functional.has_torch_function_variadic,
|
||||||
torch.nn.functional.handle_torch_function,
|
torch.nn.functional.handle_torch_function,
|
||||||
|
torch.nn.functional.scaled_mm,
|
||||||
torch.nn.functional.sigmoid,
|
torch.nn.functional.sigmoid,
|
||||||
torch.nn.functional.hardsigmoid,
|
torch.nn.functional.hardsigmoid,
|
||||||
torch.nn.functional.tanh,
|
torch.nn.functional.tanh,
|
||||||
|
|||||||
Reference in New Issue
Block a user